Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 133 additions & 2 deletions src/firebase_functions/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import enum as _enum
import json as _json
import sys as _sys
import traceback as _traceback
import typing as _typing

import typing_extensions as _typing_extensions
Expand Down Expand Up @@ -43,6 +44,7 @@ class LogEntry(_typing.TypedDict):

severity: _typing_extensions.Required[LogSeverity]
message: _typing_extensions.NotRequired[str]
stack_trace: _typing_extensions.NotRequired[str]


def _entry_from_args(severity: LogSeverity, *args, **kwargs) -> LogEntry:
Expand Down Expand Up @@ -71,7 +73,116 @@ def _entry_from_args(severity: LogSeverity, *args, **kwargs) -> LogEntry:
return _typing.cast(LogEntry, entry)


def _remove_circular(obj: _typing.Any, refs: set[_typing.Any] | None = None):
def _exception_from_args(
exception: BaseException, refs: set[int] | None = None
) -> dict[str, _typing.Any]:
Comment thread
IzaakGough marked this conversation as resolved.
"""
Creates a JSON-safe representation of an exception.
"""

details: dict[str, _typing.Any] = {
"type": exception.__class__.__name__,
"message": _safe_exception_string(exception),
}
if exception.args:
details["args"] = _json_safe_exception_args(exception.args, refs)
if exception.__traceback__ is not None:
try:
details["stack_trace"] = "".join(
_traceback.format_exception(exception.__class__, exception, exception.__traceback__)
)
except Exception:
details["stack_trace"] = "".join(_traceback.format_tb(exception.__traceback__))
details["stack_trace"] += f"{exception.__class__.__name__}: {details['message']}\n"
return details


def _exception_type_from_args(
exception_type: type[BaseException],
) -> dict[str, _typing.Any]:
"""
Creates a JSON-safe representation of an exception class.

If the class matches the active exception from `sys.exc_info()`, include
the current exception message and stack trace as well.
"""

details: dict[str, _typing.Any] = {
"type": exception_type.__name__,
"message": exception_type.__name__,
}
exc_type, exc_value, exc_traceback = _sys.exc_info()
if exc_type is exception_type and exc_value is not None:
details["message"] = _safe_exception_string(exc_value)
if exc_traceback is not None:
details["stack_trace"] = "".join(
_traceback.format_exception(exc_type, exc_value, exc_traceback)
)
return details


def _safe_exception_string(exception: BaseException) -> str:
"""
Returns a string representation of an exception without propagating repr/str errors.
"""

try:
return str(exception)
except Exception:
return exception.__class__.__name__


def _json_safe_exception_args(args: tuple[_typing.Any, ...], refs: set[int] | None = None):
"""
Returns exception args in a form that can be encoded as JSON.
"""

return _coerce_json_safe(_remove_circular(args, refs))


def _coerce_json_safe(obj: _typing.Any):
"""
Converts values that survive circular-reference removal into JSON-safe values.
"""

if isinstance(obj, str | int | float | bool | type(None)):
return obj
if isinstance(obj, dict):
return {
_coerce_json_safe_dict_key(key): _coerce_json_safe(value) for key, value in obj.items()
}
if isinstance(obj, list):
return [_coerce_json_safe(item) for item in obj]
if isinstance(obj, tuple):
return tuple(_coerce_json_safe(item) for item in obj)
return _safe_repr(obj)


def _coerce_json_safe_dict_key(obj: _typing.Any):
"""
Converts dictionary keys into values accepted by JSON object encoding.
"""

if isinstance(obj, str | int | float | bool | type(None)):
return obj
coerced = _coerce_json_safe(obj)
if isinstance(coerced, str | int | float | bool | type(None)):
return coerced
return _safe_repr(coerced)


def _safe_repr(obj: _typing.Any) -> str:
"""
Returns a repr without propagating repr errors.
"""

try:
return repr(obj)
except Exception:
return obj.__class__.__name__


def _remove_circular(obj: _typing.Any, refs: set[int] | None = None):
"""
Removes circular references from the given object and replaces them with "[CIRCULAR]".
"""
Expand All @@ -89,7 +200,11 @@ def _remove_circular(obj: _typing.Any, refs: set[_typing.Any] | None = None):

# Recursively process the object based on its type
result: _typing.Any
if isinstance(obj, dict):
if isinstance(obj, BaseException):
result = _exception_from_args(obj, refs)
elif isinstance(obj, type) and issubclass(obj, BaseException):
result = _exception_type_from_args(obj)
elif isinstance(obj, dict):
result = {key: _remove_circular(value, refs) for key, value in obj.items()}
elif isinstance(obj, list):
result = [_remove_circular(item, refs) for item in obj]
Expand Down Expand Up @@ -149,3 +264,19 @@ def error(*args, **kwargs) -> None:
Logs an error message.
"""
write(_entry_from_args(LogSeverity.ERROR, *args, **kwargs))


def exception(*args, **kwargs) -> None:
"""
Logs an error message and includes the active stack trace.
"""
raw_error = kwargs.get("error")
entry = _entry_from_args(LogSeverity.ERROR, *args, **kwargs)
exc_type, exc_value, exc_traceback = _sys.exc_info()
if exc_type is not None and exc_value is not None and exc_traceback is not None:
uses_active_error_traceback = raw_error is exc_value or raw_error is exc_type
if not uses_active_error_traceback:
entry["stack_trace"] = "".join(
_traceback.format_exception(exc_type, exc_value, exc_traceback)
)
write(entry)
205 changes: 205 additions & 0 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import json
import sys

import pytest

Expand Down Expand Up @@ -59,6 +60,141 @@ def test_severity_should_be_error(self, capsys: pytest.CaptureFixture[str]):
log_output = json.loads(raw_log_output)
assert log_output["severity"] == "ERROR"

def test_error_should_accept_exception(self, capsys: pytest.CaptureFixture[str]):
try:
raise ValueError("boom")
except ValueError as exception:
logger.error("failed", error=exception)

raw_log_output = capsys.readouterr().err
log_output = json.loads(raw_log_output)

assert log_output["severity"] == "ERROR"
assert log_output["message"] == "failed"
assert log_output["error"]["type"] == "ValueError"
assert log_output["error"]["message"] == "boom"
assert "stack_trace" in log_output["error"]
assert "ValueError: boom" in log_output["error"]["stack_trace"]

def test_error_should_accept_exception_type(self, capsys: pytest.CaptureFixture[str]):
try:
raise TypeError("boom")
except TypeError:
logger.error("failed", error=sys.exc_info()[0])

raw_log_output = capsys.readouterr().err
log_output = json.loads(raw_log_output)

assert log_output["severity"] == "ERROR"
assert log_output["message"] == "failed"
assert log_output["error"]["type"] == "TypeError"
assert log_output["error"]["message"] == "boom"
assert "stack_trace" in log_output["error"]
assert "TypeError: boom" in log_output["error"]["stack_trace"]

def test_error_should_accept_self_referential_exception(
self, capsys: pytest.CaptureFixture[str]
):
class SelfArgError(Exception):
pass

exception = SelfArgError("boom")
exception.args = (exception,)

logger.error("failed", error=exception)

raw_log_output = capsys.readouterr().err
log_output = json.loads(raw_log_output)

assert log_output["severity"] == "ERROR"
assert log_output["message"] == "failed"
assert log_output["error"]["type"] == "SelfArgError"
assert log_output["error"]["args"] == ["[CIRCULAR]"]

def test_error_should_accept_exception_with_cyclic_payload(
self, capsys: pytest.CaptureFixture[str]
):
payload = {}
payload["self"] = payload
exception = ValueError(payload)

logger.error("failed", error=exception)

raw_log_output = capsys.readouterr().err
log_output = json.loads(raw_log_output)

assert log_output["severity"] == "ERROR"
assert log_output["message"] == "failed"
assert log_output["error"]["type"] == "ValueError"
assert log_output["error"]["args"] == [{"self": "[CIRCULAR]"}]

def test_error_should_accept_exception_with_non_json_serializable_args(
self, capsys: pytest.CaptureFixture[str]
):
payload = object()
exception = ValueError(payload)

logger.error("failed", error=exception)

raw_log_output = capsys.readouterr().err
log_output = json.loads(raw_log_output)

assert log_output["severity"] == "ERROR"
assert log_output["message"] == "failed"
assert log_output["error"]["type"] == "ValueError"
assert log_output["error"]["args"] == [repr(payload)]

def test_error_should_accept_exception_with_repr_raising_arg(
self, capsys: pytest.CaptureFixture[str]
):
class BadRepr:
def __repr__(self):
raise RuntimeError("boom")

exception = ValueError(BadRepr())

logger.error("failed", error=exception)

raw_log_output = capsys.readouterr().err
log_output = json.loads(raw_log_output)

assert log_output["severity"] == "ERROR"
assert log_output["message"] == "failed"
assert log_output["error"]["type"] == "ValueError"
assert log_output["error"]["args"] == ["BadRepr"]

def test_error_should_accept_exception_with_non_json_serializable_dict_key(
self, capsys: pytest.CaptureFixture[str]
):
payload = {object(): "value"}
exception = ValueError(payload)

logger.error("failed", error=exception)

raw_log_output = capsys.readouterr().err
log_output = json.loads(raw_log_output)

assert log_output["severity"] == "ERROR"
assert log_output["message"] == "failed"
assert log_output["error"]["type"] == "ValueError"
assert log_output["error"]["args"] == [{repr(next(iter(payload.keys()))): "value"}]

def test_error_should_accept_exception_with_tuple_dict_key(
self, capsys: pytest.CaptureFixture[str]
):
payload = {(1, "two"): "value"}
exception = ValueError(payload)

logger.error("failed", error=exception)

raw_log_output = capsys.readouterr().err
log_output = json.loads(raw_log_output)

assert log_output["severity"] == "ERROR"
assert log_output["message"] == "failed"
assert log_output["error"]["type"] == "ValueError"
assert log_output["error"]["args"] == [{"(1, 'two')": "value"}]

def test_log_should_have_message(self, capsys: pytest.CaptureFixture[str]):
logger.log("bar")
raw_log_output = capsys.readouterr().out
Expand All @@ -78,6 +214,75 @@ def test_message_should_be_space_separated(self, capsys: pytest.CaptureFixture[s
log_output = json.loads(raw_log_output)
assert log_output["message"] == expected_message

def test_exception_should_include_stack_trace(self, capsys: pytest.CaptureFixture[str]):
try:
raise ValueError("boom")
except ValueError:
logger.exception("failed")

raw_log_output = capsys.readouterr().err
log_output = json.loads(raw_log_output)

assert log_output["severity"] == "ERROR"
assert log_output["message"] == "failed"
assert "stack_trace" in log_output
assert "ValueError: boom" in log_output["stack_trace"]

def test_exception_should_not_duplicate_stack_trace_for_exception_error(
self, capsys: pytest.CaptureFixture[str]
):
try:
raise ValueError("boom")
except ValueError as exception:
logger.exception("failed", error=exception)

raw_log_output = capsys.readouterr().err
log_output = json.loads(raw_log_output)

assert log_output["severity"] == "ERROR"
assert log_output["message"] == "failed"
assert "stack_trace" not in log_output
assert log_output["error"]["type"] == "ValueError"
assert log_output["error"]["message"] == "boom"
assert "stack_trace" in log_output["error"]
assert "ValueError: boom" in log_output["error"]["stack_trace"]

def test_exception_should_not_duplicate_stack_trace_for_exception_type_error(
self, capsys: pytest.CaptureFixture[str]
):
try:
raise TypeError("boom")
except TypeError:
logger.exception("failed", error=sys.exc_info()[0])

raw_log_output = capsys.readouterr().err
log_output = json.loads(raw_log_output)

assert log_output["severity"] == "ERROR"
assert log_output["message"] == "failed"
assert "stack_trace" not in log_output
assert log_output["error"]["type"] == "TypeError"
assert log_output["error"]["message"] == "boom"
assert "stack_trace" in log_output["error"]
assert "TypeError: boom" in log_output["error"]["stack_trace"]

def test_exception_should_include_active_stack_trace_for_error_dict(
self, capsys: pytest.CaptureFixture[str]
):
try:
raise ValueError("boom")
except ValueError:
logger.exception("failed", error={"stack_trace": "custom traceback"})

raw_log_output = capsys.readouterr().err
log_output = json.loads(raw_log_output)

assert log_output["severity"] == "ERROR"
assert log_output["message"] == "failed"
assert log_output["error"] == {"stack_trace": "custom traceback"}
assert "stack_trace" in log_output
assert "ValueError: boom" in log_output["stack_trace"]

def test_remove_circular_references(self, capsys: pytest.CaptureFixture[str]):
# Create an object with a circular reference.
circ = {"b": "foo"}
Expand Down
Loading