diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 0580701..7073b88 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -141,7 +141,7 @@ jobs: run: docker pull ${{ inputs.serviceImage }} - name: Run test tool - uses: restatedev/sdk-test-suite@v3.4 + uses: restatedev/e2e/sdk-tests@v1.0 with: restateContainerImage: ${{ inputs.restateCommit != '' && 'localhost/restatedev/restate-commit-download:latest' || (inputs.restateImage != '' && inputs.restateImage || 'ghcr.io/restatedev/restate:main') }} serviceContainerImage: ${{ inputs.serviceImage != '' && inputs.serviceImage || 'restatedev/test-services-python' }} diff --git a/python/restate/exceptions.py b/python/restate/exceptions.py index 3884b10..1ef5695 100644 --- a/python/restate/exceptions.py +++ b/python/restate/exceptions.py @@ -19,10 +19,11 @@ class TerminalError(Exception): """This exception is thrown to indicate that Restate should not retry.""" - def __init__(self, message: str, status_code: int = 500) -> None: + def __init__(self, message: str, status_code: int = 500, metadata: Optional[dict[str, str]] = None) -> None: super().__init__(message) self.message = message self.status_code = status_code + self.metadata = metadata class RetryableError(Exception): diff --git a/python/restate/server_context.py b/python/restate/server_context.py index 721b29b..56e8443 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -407,7 +407,7 @@ async def enter(self): self.vm.sys_write_output_success(bytes(out_buffer)) self.vm.sys_end() except TerminalError as t: - failure = Failure(code=t.status_code, message=t.message) + failure = Failure(code=t.status_code, message=t.message, metadata=t.metadata) restate_context_is_replaying.set(False) self.vm.sys_write_output_failure(failure) self.vm.sys_end() @@ -428,7 +428,7 @@ async def enter(self): cause: BaseException | None = e while cause is not None: if isinstance(cause, TerminalError): - failure = Failure(code=cause.status_code, message=cause.message) + failure = Failure(code=cause.status_code, message=cause.message, metadata=cause.metadata) restate_context_is_replaying.set(False) self.vm.sys_write_output_failure(failure) self.vm.sys_end() @@ -525,7 +525,7 @@ async def must_take_notification(self, handle): if res is None: return None if isinstance(res, Failure): - raise TerminalError(res.message, res.code) + raise TerminalError(res.message, res.code, metadata=res.metadata) return res async def create_poll_or_cancel_coroutine(self, unresolved_future: UnresolvedFuture) -> None: @@ -688,7 +688,7 @@ async def create_run_coroutine( buffer = serde.serialize(action_result) self.vm.propose_run_completion_success(handle, buffer) except TerminalError as t: - failure = Failure(code=t.status_code, message=t.message) + failure = Failure(code=t.status_code, message=t.message, metadata=t.metadata) self.vm.propose_run_completion_failure(handle, failure) except RetryableError as r: failure = Failure(code=r.status_code, message=r.message) diff --git a/python/restate/vm.py b/python/restate/vm.py index 2efaa8b..58d5369 100644 --- a/python/restate/vm.py +++ b/python/restate/vm.py @@ -75,6 +75,7 @@ class Failure: code: int message: str stacktrace: typing.Optional[str] = None + metadata: typing.Optional[typing.Dict[str, str]] = None @dataclass @@ -269,9 +270,8 @@ def take_notification(self, handle: int) -> typing.Union[NotificationType, Excep return result if isinstance(result, PyFailure): # a terminal failure - code = result.code - message = result.message - return Failure(code, message) + metadata = dict(result.metadata) if result.metadata else None + return Failure(result.code, result.message, metadata=metadata) return ValueError(f"Unknown result type: {result}") def sys_input(self) -> Invocation: @@ -314,7 +314,8 @@ def sys_write_output_failure(self, output: Failure): Returns: None """ - res = PyFailure(output.code, output.message) + metadata = list(output.metadata.items()) if output.metadata else None + res = PyFailure(output.code, output.message, metadata=metadata) self.vm.sys_write_output_failure(res) def sys_get_state(self, name) -> int: @@ -446,7 +447,8 @@ def propose_run_completion_failure(self, handle: int, output: Failure) -> None: """ Exit a side effect with a terminal failure. """ - res = PyFailure(output.code, output.message) + metadata = list(output.metadata.items()) if output.metadata else None + res = PyFailure(output.code, output.message, metadata=metadata) self.vm.propose_run_completion_failure(handle, res) # pylint: disable=line-too-long diff --git a/src/lib.rs b/src/lib.rs index e582919..9070928 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,17 +98,20 @@ struct PyFailure { message: String, #[pyo3(get, set)] stacktrace: Option, + #[pyo3(get, set)] + metadata: Option>, } #[pymethods] impl PyFailure { #[new] - #[pyo3(signature = (code, message, stacktrace=None))] - fn new(code: u16, message: String, stacktrace: Option) -> PyFailure { + #[pyo3(signature = (code, message, stacktrace=None, metadata=None))] + fn new(code: u16, message: String, stacktrace: Option, metadata: Option>) -> PyFailure { Self { code, message, stacktrace, + metadata, } } } @@ -181,6 +184,11 @@ impl From for PyFailure { code: value.code, message: value.message, stacktrace: None, + metadata: if value.metadata.is_empty() { + None + } else { + Some(value.metadata) + }, } } } @@ -190,21 +198,15 @@ impl From for TerminalFailure { TerminalFailure { code: value.code, message: value.message, - metadata: vec![], + metadata: value.metadata.unwrap_or_default(), } } } impl From for Error { - fn from( - PyFailure { - code, - message, - stacktrace, - }: PyFailure, - ) -> Self { - let mut e = Self::new(code, message); - if let Some(stacktrace) = stacktrace { + fn from(value: PyFailure) -> Self { + let mut e = Self::new(value.code, value.message); + if let Some(stacktrace) = value.stacktrace { e = e.with_stacktrace(stacktrace); } e diff --git a/test-services/services/failing.py b/test-services/services/failing.py index 3ad7a0a..add5d67 100644 --- a/test-services/services/failing.py +++ b/test-services/services/failing.py @@ -16,21 +16,32 @@ # pylint: disable=W0613 # pylint: disable=W0622 +from typing import Optional, TypedDict + from restate import VirtualObject, ObjectContext from restate.exceptions import TerminalError from restate import RunOptions + +class FailureToPropagate(TypedDict): + errorMessage: str + metadata: Optional[dict[str, str]] + + failing = VirtualObject("Failing") @failing.handler(name="terminallyFailingCall") -async def terminally_failing_call(ctx: ObjectContext, msg: str): - raise TerminalError(message=msg) +async def terminally_failing_call(ctx: ObjectContext, failure_to_propagate: FailureToPropagate): + raise TerminalError( + message=failure_to_propagate["errorMessage"], + metadata=failure_to_propagate.get("metadata"), + ) @failing.handler(name="callTerminallyFailingCall") -async def call_terminally_failing_call(ctx: ObjectContext, msg: str) -> str: - await ctx.object_call(terminally_failing_call, key="random-583e1bf2", arg=msg) +async def call_terminally_failing_call(ctx: ObjectContext, failure_to_propagate: FailureToPropagate) -> str: + await ctx.object_call(terminally_failing_call, key="random-583e1bf2", arg=failure_to_propagate) raise Exception("Should not reach here") @@ -49,9 +60,12 @@ async def failing_call_with_eventual_success(ctx: ObjectContext) -> int: @failing.handler(name="terminallyFailingSideEffect") -async def terminally_failing_side_effect(ctx: ObjectContext, error_message: str): +async def terminally_failing_side_effect(ctx: ObjectContext, failure_to_propagate: FailureToPropagate): + error_message = failure_to_propagate["errorMessage"] + metadata = failure_to_propagate.get("metadata") + def side_effect(): - raise TerminalError(message=error_message) + raise TerminalError(message=error_message, metadata=metadata) await ctx.run_typed("sideEffect", side_effect) raise ValueError("Should not reach here")