From 5a197120f84cb41e2fded5f2bd874298b5871ed3 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Thu, 26 Mar 2026 10:04:00 -0700 Subject: [PATCH 01/16] Storage driver store context metadata --- README.md | 7 +- temporalio/client.py | 1052 ++++++++++------- temporalio/contrib/aws/s3driver/_driver.py | 57 +- temporalio/converter/__init__.py | 4 + temporalio/converter/_extstore.py | 136 ++- temporalio/worker/_activity.py | 323 ++--- temporalio/worker/_workflow.py | 56 +- temporalio/worker/_workflow_instance.py | 84 +- .../worker/workflow_sandbox/_in_sandbox.py | 8 + temporalio/worker/workflow_sandbox/_runner.py | 18 + tests/contrib/aws/s3driver/test_s3driver.py | 177 +-- tests/test_serialization_context.py | 86 -- tests/worker/test_extstore.py | 504 ++++++++ tests/worker/test_workflow.py | 7 + 14 files changed, 1693 insertions(+), 826 deletions(-) diff --git a/README.md b/README.md index ca8a000f6..7cbbfa89c 100644 --- a/README.md +++ b/README.md @@ -533,11 +533,8 @@ def feature_flag_is_on(workflow_id: str | None) -> bool: def feature_flag_selector( context: temporalio.converter.StorageDriverStoreContext, _payload: Payload ) -> temporalio.converter.StorageDriver | None: - workflow_id = None - if isinstance(context.serialization_context, temporalio.converter.WorkflowSerializationContext): - workflow_id = context.serialization_context.workflow_id - elif isinstance(context.serialization_context, temporalio.converter.ActivitySerializationContext): - workflow_id = context.serialization_context.workflow_id + wf = context.current_workflow or context.target_workflow + workflow_id = wf.id if wf else None return my_driver if feature_flag_is_on(workflow_id) else None options = ExternalStorage( diff --git a/temporalio/client.py b/temporalio/client.py index cc2750ec6..8d9e9ac97 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -66,9 +66,15 @@ ActivitySerializationContext, DataConverter, SerializationContext, + StorageDriverActivityInfo, + StorageDriverWorkflowInfo, WithSerializationContext, WorkflowSerializationContext, ) +from temporalio.converter._extstore import ( + StorageDriverStoreMetadata, + store_metadata_context, +) from temporalio.service import ( ConnectConfig, HttpConnectProxyConfig, @@ -6161,67 +6167,80 @@ async def _to_proto( priority: temporalio.api.common.v1.Priority | None = None if self.priority: priority = self.priority._to_proto() - data_converter = client.data_converter.with_context( - WorkflowSerializationContext( + with store_metadata_context( + StorageDriverStoreMetadata( namespace=client.namespace, - workflow_id=self.id, - ) - ) - action = temporalio.api.schedule.v1.ScheduleAction( - start_workflow=temporalio.api.workflow.v1.NewWorkflowExecutionInfo( - workflow_id=self.id, - workflow_type=temporalio.api.common.v1.WorkflowType(name=self.workflow), - task_queue=temporalio.api.taskqueue.v1.TaskQueue(name=self.task_queue), - input=( - temporalio.api.common.v1.Payloads( - payloads=[ - a - if isinstance(a, temporalio.api.common.v1.Payload) - else (await data_converter.encode([a]))[0] - for a in self.args - ] - ) - if self.args - else None - ), - workflow_execution_timeout=execution_timeout, - workflow_run_timeout=run_timeout, - workflow_task_timeout=task_timeout, - retry_policy=retry_policy, - memo=await data_converter._encode_memo(self.memo) - if self.memo - else None, - user_metadata=await _encode_user_metadata( - data_converter, self.static_summary, self.static_details + target_workflow=StorageDriverWorkflowInfo( + id=self.id, type=self.workflow ), - priority=priority, - ), - ) - # Add any untyped attributes that are not also in the typed set - untyped_not_in_typed = { - k: v - for k, v in self.untyped_search_attributes.items() - if k not in self.typed_search_attributes - } - if untyped_not_in_typed: - temporalio.converter.encode_search_attributes( - untyped_not_in_typed, action.start_workflow.search_attributes ) - # TODO (dan): confirm whether this be `is not None` - if self.typed_search_attributes: - temporalio.converter.encode_search_attributes( - self.typed_search_attributes, action.start_workflow.search_attributes + ): + data_converter = client.data_converter.with_context( + WorkflowSerializationContext( + namespace=client.namespace, + workflow_id=self.id, + ) ) - if self.headers: - await _apply_headers( - self.headers, - action.start_workflow.header.fields, - client.config(active_config=True)["header_codec_behavior"] - == HeaderCodecBehavior.CODEC - and not self._from_raw, - client.data_converter, + action = temporalio.api.schedule.v1.ScheduleAction( + start_workflow=temporalio.api.workflow.v1.NewWorkflowExecutionInfo( + workflow_id=self.id, + workflow_type=temporalio.api.common.v1.WorkflowType( + name=self.workflow + ), + task_queue=temporalio.api.taskqueue.v1.TaskQueue( + name=self.task_queue + ), + input=( + temporalio.api.common.v1.Payloads( + payloads=[ + a + if isinstance(a, temporalio.api.common.v1.Payload) + else (await data_converter.encode([a]))[0] + for a in self.args + ] + ) + if self.args + else None + ), + workflow_execution_timeout=execution_timeout, + workflow_run_timeout=run_timeout, + workflow_task_timeout=task_timeout, + retry_policy=retry_policy, + memo=await data_converter._encode_memo(self.memo) + if self.memo + else None, + user_metadata=await _encode_user_metadata( + data_converter, self.static_summary, self.static_details + ), + priority=priority, + ), ) - return action + # Add any untyped attributes that are not also in the typed set + untyped_not_in_typed = { + k: v + for k, v in self.untyped_search_attributes.items() + if k not in self.typed_search_attributes + } + if untyped_not_in_typed: + temporalio.converter.encode_search_attributes( + untyped_not_in_typed, action.start_workflow.search_attributes + ) + # TODO (dan): confirm whether this be `is not None` + if self.typed_search_attributes: + temporalio.converter.encode_search_attributes( + self.typed_search_attributes, + action.start_workflow.search_attributes, + ) + if self.headers: + await _apply_headers( + self.headers, + action.start_workflow.header.fields, + client.config(active_config=True)["header_codec_behavior"] + == HeaderCodecBehavior.CODEC + and not self._from_raw, + client.data_converter, + ) + return action class ScheduleOverlapPolicy(IntEnum): @@ -8077,21 +8096,29 @@ async def _build_signal_with_start_workflow_execution_request( self, input: StartWorkflowInput ) -> temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest: assert input.start_signal - data_converter = self._client.data_converter.with_context( - WorkflowSerializationContext( + with store_metadata_context( + StorageDriverStoreMetadata( namespace=self._client.namespace, - workflow_id=input.id, + target_workflow=StorageDriverWorkflowInfo( + id=input.id, type=input.workflow + ), ) - ) - req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest( - signal_name=input.start_signal - ) - if input.start_signal_args: - req.signal_input.payloads.extend( - await data_converter.encode(input.start_signal_args) + ): + data_converter = self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ) ) - await self._populate_start_workflow_execution_request(req, input) - return req + req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest( + signal_name=input.start_signal + ) + if input.start_signal_args: + req.signal_input.payloads.extend( + await data_converter.encode(input.start_signal_args) + ) + await self._populate_start_workflow_execution_request(req, input) + return req async def _build_update_with_start_start_workflow_execution_request( self, input: UpdateWithStartStartWorkflowInput @@ -8108,57 +8135,65 @@ async def _populate_start_workflow_execution_request( ), input: StartWorkflowInput | UpdateWithStartStartWorkflowInput, ) -> None: - data_converter = self._client.data_converter.with_context( - WorkflowSerializationContext( + with store_metadata_context( + StorageDriverStoreMetadata( namespace=self._client.namespace, - workflow_id=input.id, + target_workflow=StorageDriverWorkflowInfo( + id=input.id, type=input.workflow + ), ) - ) - req.namespace = self._client.namespace - req.workflow_id = input.id - req.workflow_type.name = input.workflow - req.task_queue.name = input.task_queue - if input.args: - req.input.payloads.extend(await data_converter.encode(input.args)) - if input.execution_timeout is not None: - req.workflow_execution_timeout.FromTimedelta(input.execution_timeout) - if input.run_timeout is not None: - req.workflow_run_timeout.FromTimedelta(input.run_timeout) - if input.task_timeout is not None: - req.workflow_task_timeout.FromTimedelta(input.task_timeout) - req.identity = self._client.identity - req.request_id = str(uuid.uuid4()) - req.workflow_id_reuse_policy = cast( - "temporalio.api.enums.v1.WorkflowIdReusePolicy.ValueType", - int(input.id_reuse_policy), - ) - req.workflow_id_conflict_policy = cast( - "temporalio.api.enums.v1.WorkflowIdConflictPolicy.ValueType", - int(input.id_conflict_policy), - ) - - if input.retry_policy is not None: - input.retry_policy.apply_to_proto(req.retry_policy) - req.cron_schedule = input.cron_schedule - if input.memo is not None: - await data_converter._encode_memo_existing(input.memo, req.memo) - if input.search_attributes is not None: - temporalio.converter.encode_search_attributes( - input.search_attributes, req.search_attributes + ): + data_converter = self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ) ) - metadata = await _encode_user_metadata( - data_converter, input.static_summary, input.static_details - ) - if metadata is not None: - req.user_metadata.CopyFrom(metadata) - if input.start_delay is not None: - req.workflow_start_delay.FromTimedelta(input.start_delay) - if input.headers is not None: # type:ignore[reportUnnecessaryComparison] - await self._apply_headers(input.headers, req.header.fields) - if input.priority is not None: # type:ignore[reportUnnecessaryComparison] - req.priority.CopyFrom(input.priority._to_proto()) - if input.versioning_override is not None: - req.versioning_override.CopyFrom(input.versioning_override._to_proto()) + req.namespace = self._client.namespace + req.workflow_id = input.id + req.workflow_type.name = input.workflow + req.task_queue.name = input.task_queue + if input.args: + req.input.payloads.extend(await data_converter.encode(input.args)) + if input.execution_timeout is not None: + req.workflow_execution_timeout.FromTimedelta(input.execution_timeout) + if input.run_timeout is not None: + req.workflow_run_timeout.FromTimedelta(input.run_timeout) + if input.task_timeout is not None: + req.workflow_task_timeout.FromTimedelta(input.task_timeout) + req.identity = self._client.identity + req.request_id = str(uuid.uuid4()) + req.workflow_id_reuse_policy = cast( + "temporalio.api.enums.v1.WorkflowIdReusePolicy.ValueType", + int(input.id_reuse_policy), + ) + req.workflow_id_conflict_policy = cast( + "temporalio.api.enums.v1.WorkflowIdConflictPolicy.ValueType", + int(input.id_conflict_policy), + ) + + if input.retry_policy is not None: + input.retry_policy.apply_to_proto(req.retry_policy) + req.cron_schedule = input.cron_schedule + if input.memo is not None: + await data_converter._encode_memo_existing(input.memo, req.memo) + if input.search_attributes is not None: + temporalio.converter.encode_search_attributes( + input.search_attributes, req.search_attributes + ) + metadata = await _encode_user_metadata( + data_converter, input.static_summary, input.static_details + ) + if metadata is not None: + req.user_metadata.CopyFrom(metadata) + if input.start_delay is not None: + req.workflow_start_delay.FromTimedelta(input.start_delay) + if input.headers is not None: # type:ignore[reportUnnecessaryComparison] + await self._apply_headers(input.headers, req.header.fields) + if input.priority is not None: # type:ignore[reportUnnecessaryComparison] + req.priority.CopyFrom(input.priority._to_proto()) + if input.versioning_override is not None: + req.versioning_override.CopyFrom(input.versioning_override._to_proto()) async def cancel_workflow(self, input: CancelWorkflowInput) -> None: await self._client.workflow_service.request_cancel_workflow_execution( @@ -8228,105 +8263,134 @@ async def count_workflows( ) async def query_workflow(self, input: QueryWorkflowInput) -> Any: - data_converter = self._client.data_converter.with_context( - WorkflowSerializationContext( + with store_metadata_context( + StorageDriverStoreMetadata( namespace=self._client.namespace, - workflow_id=input.id, - ) - ) - req = temporalio.api.workflowservice.v1.QueryWorkflowRequest( - namespace=self._client.namespace, - execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=input.id, - run_id=input.run_id or "", - ), - ) - if input.reject_condition: - req.query_reject_condition = cast( - "temporalio.api.enums.v1.QueryRejectCondition.ValueType", - int(input.reject_condition), + target_workflow=StorageDriverWorkflowInfo( + id=input.id, run_id=input.run_id or None + ), ) - req.query.query_type = input.query - if input.args: - req.query.query_args.payloads.extend( - await data_converter.encode(input.args) + ): + data_converter = self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ) ) - if input.headers is not None: # type:ignore[reportUnnecessaryComparison] - await self._apply_headers(input.headers, req.query.header.fields) - try: - resp = await self._client.workflow_service.query_workflow( - req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout + req = temporalio.api.workflowservice.v1.QueryWorkflowRequest( + namespace=self._client.namespace, + execution=temporalio.api.common.v1.WorkflowExecution( + workflow_id=input.id, + run_id=input.run_id or "", + ), ) - except RPCError as err: - # If the status is INVALID_ARGUMENT, we can assume it's a query - # failed error - if err.status == RPCStatusCode.INVALID_ARGUMENT: - raise WorkflowQueryFailedError(err.message) - else: - raise - if resp.HasField("query_rejected"): - raise WorkflowQueryRejectedError( - WorkflowExecutionStatus(resp.query_rejected.status) - if resp.query_rejected.status - else None + if input.reject_condition: + req.query_reject_condition = cast( + "temporalio.api.enums.v1.QueryRejectCondition.ValueType", + int(input.reject_condition), + ) + req.query.query_type = input.query + if input.args: + req.query.query_args.payloads.extend( + await data_converter.encode(input.args) + ) + if input.headers is not None: # type:ignore[reportUnnecessaryComparison] + await self._apply_headers(input.headers, req.query.header.fields) + try: + resp = await self._client.workflow_service.query_workflow( + req, + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + except RPCError as err: + # If the status is INVALID_ARGUMENT, we can assume it's a query + # failed error + if err.status == RPCStatusCode.INVALID_ARGUMENT: + raise WorkflowQueryFailedError(err.message) + else: + raise + if resp.HasField("query_rejected"): + raise WorkflowQueryRejectedError( + WorkflowExecutionStatus(resp.query_rejected.status) + if resp.query_rejected.status + else None + ) + if not resp.query_result.payloads: + return None + type_hints = [input.ret_type] if input.ret_type else None + results = await data_converter.decode( + resp.query_result.payloads, type_hints ) - if not resp.query_result.payloads: - return None - type_hints = [input.ret_type] if input.ret_type else None - results = await data_converter.decode(resp.query_result.payloads, type_hints) - if not results: - return None - elif len(results) > 1: - warnings.warn(f"Expected single query result, got {len(results)}") - return results[0] + if not results: + return None + elif len(results) > 1: + warnings.warn(f"Expected single query result, got {len(results)}") + return results[0] async def signal_workflow(self, input: SignalWorkflowInput) -> None: - data_converter = self._client.data_converter.with_context( - WorkflowSerializationContext( + with store_metadata_context( + StorageDriverStoreMetadata( namespace=self._client.namespace, - workflow_id=input.id, + target_workflow=StorageDriverWorkflowInfo( + id=input.id, run_id=input.run_id or None + ), + ) + ): + data_converter = self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ) + ) + req = temporalio.api.workflowservice.v1.SignalWorkflowExecutionRequest( + namespace=self._client.namespace, + workflow_execution=temporalio.api.common.v1.WorkflowExecution( + workflow_id=input.id, + run_id=input.run_id or "", + ), + signal_name=input.signal, + identity=self._client.identity, + request_id=str(uuid.uuid4()), + ) + if input.args: + req.input.payloads.extend(await data_converter.encode(input.args)) + if input.headers is not None: # type:ignore[reportUnnecessaryComparison] + await self._apply_headers(input.headers, req.header.fields) + await self._client.workflow_service.signal_workflow_execution( + req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout ) - ) - req = temporalio.api.workflowservice.v1.SignalWorkflowExecutionRequest( - namespace=self._client.namespace, - workflow_execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=input.id, - run_id=input.run_id or "", - ), - signal_name=input.signal, - identity=self._client.identity, - request_id=str(uuid.uuid4()), - ) - if input.args: - req.input.payloads.extend(await data_converter.encode(input.args)) - if input.headers is not None: # type:ignore[reportUnnecessaryComparison] - await self._apply_headers(input.headers, req.header.fields) - await self._client.workflow_service.signal_workflow_execution( - req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout - ) async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: - data_converter = self._client.data_converter.with_context( - WorkflowSerializationContext( + with store_metadata_context( + StorageDriverStoreMetadata( namespace=self._client.namespace, - workflow_id=input.id, + target_workflow=StorageDriverWorkflowInfo( + id=input.id, run_id=input.run_id or None + ), + ) + ): + data_converter = self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ) + ) + req = temporalio.api.workflowservice.v1.TerminateWorkflowExecutionRequest( + namespace=self._client.namespace, + workflow_execution=temporalio.api.common.v1.WorkflowExecution( + workflow_id=input.id, + run_id=input.run_id or "", + ), + reason=input.reason or "", + identity=self._client.identity, + first_execution_run_id=input.first_execution_run_id or "", + ) + if input.args: + req.details.payloads.extend(await data_converter.encode(input.args)) + await self._client.workflow_service.terminate_workflow_execution( + req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout ) - ) - req = temporalio.api.workflowservice.v1.TerminateWorkflowExecutionRequest( - namespace=self._client.namespace, - workflow_execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=input.id, - run_id=input.run_id or "", - ), - reason=input.reason or "", - identity=self._client.identity, - first_execution_run_id=input.first_execution_run_id or "", - ) - if input.args: - req.details.payloads.extend(await data_converter.encode(input.args)) - await self._client.workflow_service.terminate_workflow_execution( - req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout - ) async def start_activity(self, input: StartActivityInput) -> ActivityHandle[Any]: """Start an activity and return a handle to it.""" @@ -8365,70 +8429,82 @@ async def _build_start_activity_execution_request( self, input: StartActivityInput ) -> temporalio.api.workflowservice.v1.StartActivityExecutionRequest: """Build StartActivityExecutionRequest from input.""" - data_converter = self._client.data_converter.with_context( - ActivitySerializationContext( + with store_metadata_context( + StorageDriverStoreMetadata( namespace=self._client.namespace, - activity_id=input.id, - activity_type=input.activity_type, - activity_task_queue=input.task_queue, - is_local=False, - workflow_id=None, - workflow_type=None, + target_activity=StorageDriverActivityInfo( + id=input.id, type=input.activity_type + ), + ) + ): + data_converter = self._client.data_converter.with_context( + ActivitySerializationContext( + namespace=self._client.namespace, + activity_id=input.id, + activity_type=input.activity_type, + activity_task_queue=input.task_queue, + is_local=False, + workflow_id=None, + workflow_type=None, + ) ) - ) - req = temporalio.api.workflowservice.v1.StartActivityExecutionRequest( - namespace=self._client.namespace, - identity=self._client.identity, - activity_id=input.id, - activity_type=temporalio.api.common.v1.ActivityType( - name=input.activity_type - ), - task_queue=temporalio.api.taskqueue.v1.TaskQueue(name=input.task_queue), - id_reuse_policy=cast( - "temporalio.api.enums.v1.ActivityIdReusePolicy.ValueType", - int(input.id_reuse_policy), - ), - id_conflict_policy=cast( - "temporalio.api.enums.v1.ActivityIdConflictPolicy.ValueType", - int(input.id_conflict_policy), - ), - ) + req = temporalio.api.workflowservice.v1.StartActivityExecutionRequest( + namespace=self._client.namespace, + identity=self._client.identity, + activity_id=input.id, + activity_type=temporalio.api.common.v1.ActivityType( + name=input.activity_type + ), + task_queue=temporalio.api.taskqueue.v1.TaskQueue(name=input.task_queue), + id_reuse_policy=cast( + "temporalio.api.enums.v1.ActivityIdReusePolicy.ValueType", + int(input.id_reuse_policy), + ), + id_conflict_policy=cast( + "temporalio.api.enums.v1.ActivityIdConflictPolicy.ValueType", + int(input.id_conflict_policy), + ), + ) - if input.schedule_to_close_timeout is not None: - req.schedule_to_close_timeout.FromTimedelta(input.schedule_to_close_timeout) - if input.start_to_close_timeout is not None: - req.start_to_close_timeout.FromTimedelta(input.start_to_close_timeout) - if input.schedule_to_start_timeout is not None: - req.schedule_to_start_timeout.FromTimedelta(input.schedule_to_start_timeout) - if input.heartbeat_timeout is not None: - req.heartbeat_timeout.FromTimedelta(input.heartbeat_timeout) - if input.retry_policy is not None: - input.retry_policy.apply_to_proto(req.retry_policy) + if input.schedule_to_close_timeout is not None: + req.schedule_to_close_timeout.FromTimedelta( + input.schedule_to_close_timeout + ) + if input.start_to_close_timeout is not None: + req.start_to_close_timeout.FromTimedelta(input.start_to_close_timeout) + if input.schedule_to_start_timeout is not None: + req.schedule_to_start_timeout.FromTimedelta( + input.schedule_to_start_timeout + ) + if input.heartbeat_timeout is not None: + req.heartbeat_timeout.FromTimedelta(input.heartbeat_timeout) + if input.retry_policy is not None: + input.retry_policy.apply_to_proto(req.retry_policy) - # Set input payloads - if input.args: - req.input.payloads.extend(await data_converter.encode(input.args)) + # Set input payloads + if input.args: + req.input.payloads.extend(await data_converter.encode(input.args)) - # Set search attributes - if input.search_attributes is not None: - temporalio.converter.encode_search_attributes( - input.search_attributes, req.search_attributes - ) + # Set search attributes + if input.search_attributes is not None: + temporalio.converter.encode_search_attributes( + input.search_attributes, req.search_attributes + ) - # Set user metadata - metadata = await _encode_user_metadata(data_converter, input.summary, None) - if metadata is not None: - req.user_metadata.CopyFrom(metadata) + # Set user metadata + metadata = await _encode_user_metadata(data_converter, input.summary, None) + if metadata is not None: + req.user_metadata.CopyFrom(metadata) - # Set headers - if input.headers: - await self._apply_headers(input.headers, req.header.fields) + # Set headers + if input.headers: + await self._apply_headers(input.headers, req.header.fields) - # Set priority - req.priority.CopyFrom(input.priority._to_proto()) + # Set priority + req.priority.CopyFrom(input.priority._to_proto()) - return req + return req async def cancel_activity(self, input: CancelActivityInput) -> None: """Cancel an activity.""" @@ -8560,49 +8636,62 @@ async def _build_update_workflow_execution_request( input: StartWorkflowUpdateInput | UpdateWithStartUpdateWorkflowInput, workflow_id: str, ) -> temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest: - data_converter = self._client.data_converter.with_context( - WorkflowSerializationContext( + with store_metadata_context( + StorageDriverStoreMetadata( namespace=self._client.namespace, - workflow_id=workflow_id, + target_workflow=StorageDriverWorkflowInfo( + id=workflow_id, + run_id=(input.run_id or None) + if isinstance(input, StartWorkflowUpdateInput) + else None, + ), ) - ) - run_id, first_execution_run_id = ( - ( - input.run_id, - input.first_execution_run_id, + ): + data_converter = self._client.data_converter.with_context( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=workflow_id, + ) ) - if isinstance(input, StartWorkflowUpdateInput) - else (None, None) - ) - req = temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest( - namespace=self._client.namespace, - workflow_execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=workflow_id, - run_id=run_id or "", - ), - first_execution_run_id=first_execution_run_id or "", - request=temporalio.api.update.v1.Request( - meta=temporalio.api.update.v1.Meta( - update_id=input.update_id or str(uuid.uuid4()), - identity=self._client.identity, + run_id, first_execution_run_id = ( + ( + input.run_id, + input.first_execution_run_id, + ) + if isinstance(input, StartWorkflowUpdateInput) + else (None, None) + ) + req = temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest( + namespace=self._client.namespace, + workflow_execution=temporalio.api.common.v1.WorkflowExecution( + workflow_id=workflow_id, + run_id=run_id or "", ), - input=temporalio.api.update.v1.Input( - name=input.update, + first_execution_run_id=first_execution_run_id or "", + request=temporalio.api.update.v1.Request( + meta=temporalio.api.update.v1.Meta( + update_id=input.update_id or str(uuid.uuid4()), + identity=self._client.identity, + ), + input=temporalio.api.update.v1.Input( + name=input.update, + ), + ), + wait_policy=temporalio.api.update.v1.WaitPolicy( + lifecycle_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType( + input.wait_for_stage + ) ), - ), - wait_policy=temporalio.api.update.v1.WaitPolicy( - lifecycle_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType( - input.wait_for_stage - ) - ), - ) - if input.args: - req.request.input.args.payloads.extend( - await data_converter.encode(input.args) ) - if input.headers is not None: # type:ignore[reportUnnecessaryComparison] - await self._apply_headers(input.headers, req.request.input.header.fields) - return req + if input.args: + req.request.input.args.payloads.extend( + await data_converter.encode(input.args) + ) + if input.headers is not None: # type:ignore[reportUnnecessaryComparison] + await self._apply_headers( + input.headers, req.request.input.header.fields + ) + return req async def start_update_with_start_workflow( self, input: StartWorkflowUpdateWithStartInput @@ -8739,171 +8828,212 @@ async def _start_workflow_update_with_start( ### Async activity calls + def _get_async_activity_store_metadata( + self, id_or_token: AsyncActivityIDReference | bytes + ) -> StorageDriverStoreMetadata: + if isinstance(id_or_token, AsyncActivityIDReference): + return StorageDriverStoreMetadata( + namespace=self._client.namespace, + target_workflow=StorageDriverWorkflowInfo( + id=id_or_token.workflow_id or None, + run_id=id_or_token.run_id or None, + ) + if id_or_token.workflow_id + else None, + target_activity=StorageDriverActivityInfo( + id=id_or_token.activity_id, + ), + ) + else: + return StorageDriverStoreMetadata( + namespace=self._client.namespace, + ) + async def heartbeat_async_activity( self, input: HeartbeatAsyncActivityInput ) -> None: - data_converter = input.data_converter_override or self._client.data_converter - details = ( - None - if not input.details - else await data_converter.encode_wrapper(input.details) - ) - if isinstance(input.id_or_token, AsyncActivityIDReference): - resp_by_id = await self._client.workflow_service.record_activity_task_heartbeat_by_id( - temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatByIdRequest( - workflow_id=input.id_or_token.workflow_id or "", - run_id=input.id_or_token.run_id or "", - activity_id=input.id_or_token.activity_id, - namespace=self._client.namespace, - identity=self._client.identity, - details=details, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, + with store_metadata_context( + self._get_async_activity_store_metadata(input.id_or_token) + ): + data_converter = ( + input.data_converter_override or self._client.data_converter ) - if ( - resp_by_id.cancel_requested - or resp_by_id.activity_paused - or resp_by_id.activity_reset - ): - raise AsyncActivityCancelledError( - details=ActivityCancellationDetails( - cancel_requested=resp_by_id.cancel_requested, - paused=resp_by_id.activity_paused, - reset=resp_by_id.activity_reset, - ) - ) - - else: - resp = await self._client.workflow_service.record_activity_task_heartbeat( - temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatRequest( - task_token=input.id_or_token, - namespace=self._client.namespace, - identity=self._client.identity, - details=details, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, + details = ( + None + if not input.details + else await data_converter.encode_wrapper(input.details) ) - if resp.cancel_requested or resp.activity_paused: - raise AsyncActivityCancelledError( - details=ActivityCancellationDetails( - cancel_requested=resp.cancel_requested, - paused=resp.activity_paused, - reset=resp.activity_reset, + if isinstance(input.id_or_token, AsyncActivityIDReference): + resp_by_id = await self._client.workflow_service.record_activity_task_heartbeat_by_id( + temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatByIdRequest( + workflow_id=input.id_or_token.workflow_id or "", + run_id=input.id_or_token.run_id or "", + activity_id=input.id_or_token.activity_id, + namespace=self._client.namespace, + identity=self._client.identity, + details=details, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + if ( + resp_by_id.cancel_requested + or resp_by_id.activity_paused + or resp_by_id.activity_reset + ): + raise AsyncActivityCancelledError( + details=ActivityCancellationDetails( + cancel_requested=resp_by_id.cancel_requested, + paused=resp_by_id.activity_paused, + reset=resp_by_id.activity_reset, + ) ) + + else: + resp = await self._client.workflow_service.record_activity_task_heartbeat( + temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatRequest( + task_token=input.id_or_token, + namespace=self._client.namespace, + identity=self._client.identity, + details=details, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, ) + if resp.cancel_requested or resp.activity_paused: + raise AsyncActivityCancelledError( + details=ActivityCancellationDetails( + cancel_requested=resp.cancel_requested, + paused=resp.activity_paused, + reset=resp.activity_reset, + ) + ) async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None: - data_converter = input.data_converter_override or self._client.data_converter - result = ( - None - if input.result is temporalio.common._arg_unset - else await data_converter.encode_wrapper([input.result]) - ) - if isinstance(input.id_or_token, AsyncActivityIDReference): - await self._client.workflow_service.respond_activity_task_completed_by_id( - temporalio.api.workflowservice.v1.RespondActivityTaskCompletedByIdRequest( - workflow_id=input.id_or_token.workflow_id or "", - run_id=input.id_or_token.run_id or "", - activity_id=input.id_or_token.activity_id, - namespace=self._client.namespace, - identity=self._client.identity, - result=result, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, + with store_metadata_context( + self._get_async_activity_store_metadata(input.id_or_token) + ): + data_converter = ( + input.data_converter_override or self._client.data_converter ) - else: - await self._client.workflow_service.respond_activity_task_completed( - temporalio.api.workflowservice.v1.RespondActivityTaskCompletedRequest( - task_token=input.id_or_token, - namespace=self._client.namespace, - identity=self._client.identity, - result=result, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, + result = ( + None + if input.result is temporalio.common._arg_unset + else await data_converter.encode_wrapper([input.result]) ) + if isinstance(input.id_or_token, AsyncActivityIDReference): + await self._client.workflow_service.respond_activity_task_completed_by_id( + temporalio.api.workflowservice.v1.RespondActivityTaskCompletedByIdRequest( + workflow_id=input.id_or_token.workflow_id or "", + run_id=input.id_or_token.run_id or "", + activity_id=input.id_or_token.activity_id, + namespace=self._client.namespace, + identity=self._client.identity, + result=result, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + else: + await self._client.workflow_service.respond_activity_task_completed( + temporalio.api.workflowservice.v1.RespondActivityTaskCompletedRequest( + task_token=input.id_or_token, + namespace=self._client.namespace, + identity=self._client.identity, + result=result, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) async def fail_async_activity(self, input: FailAsyncActivityInput) -> None: - data_converter = input.data_converter_override or self._client.data_converter - - failure = temporalio.api.failure.v1.Failure() - await data_converter.encode_failure(input.error, failure) - last_heartbeat_details = ( - await data_converter.encode_wrapper(input.last_heartbeat_details) - if input.last_heartbeat_details - else None - ) - if isinstance(input.id_or_token, AsyncActivityIDReference): - await self._client.workflow_service.respond_activity_task_failed_by_id( - temporalio.api.workflowservice.v1.RespondActivityTaskFailedByIdRequest( - workflow_id=input.id_or_token.workflow_id or "", - run_id=input.id_or_token.run_id or "", - activity_id=input.id_or_token.activity_id, - namespace=self._client.namespace, - identity=self._client.identity, - failure=failure, - last_heartbeat_details=last_heartbeat_details, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, + with store_metadata_context( + self._get_async_activity_store_metadata(input.id_or_token) + ): + data_converter = ( + input.data_converter_override or self._client.data_converter ) - else: - await self._client.workflow_service.respond_activity_task_failed( - temporalio.api.workflowservice.v1.RespondActivityTaskFailedRequest( - task_token=input.id_or_token, - namespace=self._client.namespace, - identity=self._client.identity, - failure=failure, - last_heartbeat_details=last_heartbeat_details, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, + + failure = temporalio.api.failure.v1.Failure() + await data_converter.encode_failure(input.error, failure) + last_heartbeat_details = ( + await data_converter.encode_wrapper(input.last_heartbeat_details) + if input.last_heartbeat_details + else None ) + if isinstance(input.id_or_token, AsyncActivityIDReference): + await self._client.workflow_service.respond_activity_task_failed_by_id( + temporalio.api.workflowservice.v1.RespondActivityTaskFailedByIdRequest( + workflow_id=input.id_or_token.workflow_id or "", + run_id=input.id_or_token.run_id or "", + activity_id=input.id_or_token.activity_id, + namespace=self._client.namespace, + identity=self._client.identity, + failure=failure, + last_heartbeat_details=last_heartbeat_details, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + else: + await self._client.workflow_service.respond_activity_task_failed( + temporalio.api.workflowservice.v1.RespondActivityTaskFailedRequest( + task_token=input.id_or_token, + namespace=self._client.namespace, + identity=self._client.identity, + failure=failure, + last_heartbeat_details=last_heartbeat_details, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) async def report_cancellation_async_activity( self, input: ReportCancellationAsyncActivityInput ) -> None: - data_converter = input.data_converter_override or self._client.data_converter - details = ( - None - if not input.details - else await data_converter.encode_wrapper(input.details) - ) - if isinstance(input.id_or_token, AsyncActivityIDReference): - await self._client.workflow_service.respond_activity_task_canceled_by_id( - temporalio.api.workflowservice.v1.RespondActivityTaskCanceledByIdRequest( - workflow_id=input.id_or_token.workflow_id or "", - run_id=input.id_or_token.run_id or "", - activity_id=input.id_or_token.activity_id, - namespace=self._client.namespace, - identity=self._client.identity, - details=details, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, + with store_metadata_context( + self._get_async_activity_store_metadata(input.id_or_token) + ): + data_converter = ( + input.data_converter_override or self._client.data_converter ) - else: - await self._client.workflow_service.respond_activity_task_canceled( - temporalio.api.workflowservice.v1.RespondActivityTaskCanceledRequest( - task_token=input.id_or_token, - namespace=self._client.namespace, - identity=self._client.identity, - details=details, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, + details = ( + None + if not input.details + else await data_converter.encode_wrapper(input.details) ) + if isinstance(input.id_or_token, AsyncActivityIDReference): + await self._client.workflow_service.respond_activity_task_canceled_by_id( + temporalio.api.workflowservice.v1.RespondActivityTaskCanceledByIdRequest( + workflow_id=input.id_or_token.workflow_id or "", + run_id=input.id_or_token.run_id or "", + activity_id=input.id_or_token.activity_id, + namespace=self._client.namespace, + identity=self._client.identity, + details=details, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + else: + await self._client.workflow_service.respond_activity_task_canceled( + temporalio.api.workflowservice.v1.RespondActivityTaskCanceledRequest( + task_token=input.id_or_token, + namespace=self._client.namespace, + identity=self._client.identity, + details=details, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) ### Schedule calls @@ -8940,27 +9070,35 @@ async def create_schedule(self, input: CreateScheduleInput) -> ScheduleHandle: else None, ) try: - request = temporalio.api.workflowservice.v1.CreateScheduleRequest( - namespace=self._client.namespace, - schedule_id=input.id, - schedule=await input.schedule._to_proto(self._client), - initial_patch=initial_patch, - identity=self._client.identity, - request_id=str(uuid.uuid4()), - memo=await self._client.data_converter._encode_memo(input.memo) - if input.memo - else None, - ) - if input.search_attributes: - temporalio.converter.encode_search_attributes( - input.search_attributes, request.search_attributes + # Set namespace-level store metadata as a baseline for schedule + # encoding. The schedule action's _to_proto will override with + # workflow-specific metadata for its own encoding. + with store_metadata_context( + StorageDriverStoreMetadata( + namespace=self._client.namespace, + ) + ): + request = temporalio.api.workflowservice.v1.CreateScheduleRequest( + namespace=self._client.namespace, + schedule_id=input.id, + schedule=await input.schedule._to_proto(self._client), + initial_patch=initial_patch, + identity=self._client.identity, + request_id=str(uuid.uuid4()), + memo=await self._client.data_converter._encode_memo(input.memo) + if input.memo + else None, + ) + if input.search_attributes: + temporalio.converter.encode_search_attributes( + input.search_attributes, request.search_attributes + ) + await self._client.workflow_service.create_schedule( + request, + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, ) - await self._client.workflow_service.create_schedule( - request, - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, - ) except RPCError as err: already_started = ( err.status == RPCStatusCode.ALREADY_EXISTS diff --git a/temporalio/contrib/aws/s3driver/_driver.py b/temporalio/contrib/aws/s3driver/_driver.py index 481e3a9d4..d6488cf64 100644 --- a/temporalio/contrib/aws/s3driver/_driver.py +++ b/temporalio/contrib/aws/s3driver/_driver.py @@ -15,12 +15,10 @@ from temporalio.api.common.v1 import Payload from temporalio.contrib.aws.s3driver._client import S3StorageDriverClient from temporalio.converter import ( - ActivitySerializationContext, StorageDriver, StorageDriverClaim, StorageDriverRetrieveContext, StorageDriverStoreContext, - WorkflowSerializationContext, ) _T = TypeVar("_T") @@ -113,40 +111,29 @@ async def store( (e.g. proto binary). The returned list is the same length as ``payloads``. """ - workflow_id: str | None = None - activity_id: str | None = None - namespace: str | None = None - if isinstance(context.serialization_context, WorkflowSerializationContext): - workflow_id = context.serialization_context.workflow_id - namespace = context.serialization_context.namespace - if isinstance(context.serialization_context, ActivitySerializationContext): - # Prioritize workflow over activity so that the same payload that - # may be stored across workflow and activity boundaries are deduplicated. - if context.serialization_context.workflow_id: - workflow_id = context.serialization_context.workflow_id - elif context.serialization_context.activity_id: - activity_id = context.serialization_context.activity_id - namespace = context.serialization_context.namespace - - # URL encode values to avoid characters that break the key format - # e.g. spaces, forward-slashes, etc. - if namespace: - namespace = urllib.parse.quote(namespace, safe="") - if workflow_id: - workflow_id = urllib.parse.quote(workflow_id, safe="") - if activity_id: - activity_id = urllib.parse.quote(activity_id, safe="") - - namespace_segments = f"/ns/{namespace}" if namespace else "" + def _quote(val: str | None) -> str | None: + return urllib.parse.quote(val, safe="") if val else None + + namespace = _quote(context.namespace) + namespace_segment = f"/ns/{namespace}" if namespace else "" + + # Build context segments from structured metadata. + # Prefer current workflow context; fall back to target_workflow for + # client-initiated operations where there is no current workflow. context_segments = "" - # Prioritize workflow over activity so that the same payload that - # may be stored across workflow and activity boundaries are deduplicated. - # Workflow and Activity IDs are case sensitive. - if workflow_id: - context_segments += f"/wfi/{workflow_id}" - elif activity_id: - context_segments += f"/aci/{activity_id}" + wf = context.current_workflow or context.target_workflow + act = context.current_activity or context.target_activity + if wf and wf.id: + wf_type = _quote(wf.type) or "null" + wf_id = _quote(wf.id) + wf_run_id = _quote(wf.run_id) or "null" + context_segments = f"/wt/{wf_type}/wi/{wf_id}/ri/{wf_run_id}" + elif act and act.id: + act_type = _quote(act.type) or "null" + act_id = _quote(act.id) + act_run_id = _quote(act.run_id) or "null" + context_segments = f"/at/{act_type}/ai/{act_id}/ri/{act_run_id}" async def _upload(payload: Payload) -> StorageDriverClaim: bucket = self._get_bucket(context, payload) @@ -162,7 +149,7 @@ async def _upload(payload: Payload) -> StorageDriverClaim: digest_segments = f"/d/sha256/{hash_digest}" - key = f"v0{namespace_segments}{context_segments}{digest_segments}" + key = f"v0{namespace_segment}{context_segments}{digest_segments}" try: if not await self._client.object_exists(bucket=bucket, key=key): diff --git a/temporalio/converter/__init__.py b/temporalio/converter/__init__.py index 2777e7e80..3ca6a3507 100644 --- a/temporalio/converter/__init__.py +++ b/temporalio/converter/__init__.py @@ -7,9 +7,11 @@ from temporalio.converter._extstore import ( ExternalStorage, StorageDriver, + StorageDriverActivityInfo, StorageDriverClaim, StorageDriverRetrieveContext, StorageDriverStoreContext, + StorageDriverWorkflowInfo, StorageWarning, ) from temporalio.converter._failure_converter import ( @@ -54,9 +56,11 @@ "ActivitySerializationContext", "ExternalStorage", "StorageDriver", + "StorageDriverActivityInfo", "StorageDriverClaim", "StorageDriverRetrieveContext", "StorageDriverStoreContext", + "StorageDriverWorkflowInfo", "StorageWarning", "AdvancedJSONEncoder", "BinaryNullPayloadConverter", diff --git a/temporalio/converter/_extstore.py b/temporalio/converter/_extstore.py index 078d36d98..588a1038f 100644 --- a/temporalio/converter/_extstore.py +++ b/temporalio/converter/_extstore.py @@ -19,10 +19,6 @@ from temporalio.api.common.v1 import Payload, Payloads from temporalio.converter._payload_converter import JSONPlainPayloadConverter -from temporalio.converter._serialization_context import ( - SerializationContext, - WithSerializationContext, -) _T = TypeVar("_T") @@ -92,6 +88,91 @@ class StorageDriverClaim: """ +@dataclass(frozen=True) +class StorageDriverWorkflowInfo: + """Workflow identity information for external storage operations. + + .. warning:: + This API is experimental. + """ + + id: str | None = None + """The workflow ID.""" + + run_id: str | None = None + """The workflow run ID, if available.""" + + type: str | None = None + """The workflow type name, if available.""" + + +@dataclass(frozen=True) +class StorageDriverActivityInfo: + """Activity identity information for external storage operations. + + .. warning:: + This API is experimental. + """ + + id: str | None = None + """The activity ID.""" + + run_id: str | None = None + """The activity run ID (only for standalone activities).""" + + type: str | None = None + """The activity type name, if available.""" + + +@dataclass(frozen=True) +class StorageDriverStoreMetadata: + """Store-only metadata available during external storage operations. + + .. warning:: + This API is experimental. + """ + + namespace: str | None = None + """The namespace of the current execution context.""" + + current_workflow: StorageDriverWorkflowInfo | None = None + """The workflow execution context from which this payload is being stored, if any.""" + + current_activity: StorageDriverActivityInfo | None = None + """The activity execution context from which this payload is being stored, if any. + Set only when running inside an activity worker.""" + + target_workflow: StorageDriverWorkflowInfo | None = None + """The workflow for which this payload is being stored (e.g. child workflow being + started, external workflow being signaled).""" + + target_activity: StorageDriverActivityInfo | None = None + """The activity for which this payload is being stored.""" + + +_current_store_metadata: contextvars.ContextVar[StorageDriverStoreMetadata | None] = ( + contextvars.ContextVar("_current_store_metadata", default=None) +) + + +@contextlib.contextmanager +def store_metadata_context( + metadata: StorageDriverStoreMetadata | None, +) -> Generator[None, None, None]: + """Context manager that sets store metadata and resets it on exit. + + If metadata is None, yields without setting anything. + """ + if metadata is None: + yield + return + token = _current_store_metadata.set(metadata) + try: + yield + finally: + _current_store_metadata.reset(token) + + @dataclass(frozen=True) class StorageDriverStoreContext: """Context passed to :meth:`StorageDriver.store` and ``driver_selector`` calls. @@ -100,10 +181,22 @@ class StorageDriverStoreContext: This API is experimental. """ - serialization_context: SerializationContext | None = None - """The serialization context active when this store operation was initiated, - or ``None`` if no context has been set. - """ + namespace: str | None = None + """The namespace of the current execution context.""" + + current_workflow: StorageDriverWorkflowInfo | None = None + """The workflow execution context from which this payload is being stored, if any.""" + + current_activity: StorageDriverActivityInfo | None = None + """The activity execution context from which this payload is being stored, if any. + Set only when running inside an activity worker.""" + + target_workflow: StorageDriverWorkflowInfo | None = None + """The workflow for which this payload is being stored (e.g. child workflow being + started, external workflow being signaled).""" + + target_activity: StorageDriverActivityInfo | None = None + """The activity for which this payload is being stored.""" @dataclass(frozen=True) @@ -182,7 +275,7 @@ class _StorageReference: @dataclass(frozen=True) -class ExternalStorage(WithSerializationContext): +class ExternalStorage: """Configuration for external storage behavior. .. warning:: @@ -223,10 +316,6 @@ class ExternalStorage(WithSerializationContext): for retrieval lookups. """ - _context: SerializationContext | None = dataclasses.field( - init=False, default=None, repr=False, compare=False - ) - _claim_converter: ClassVar[JSONPlainPayloadConverter] = JSONPlainPayloadConverter( encoding=_REFERENCE_ENCODING.decode() ) @@ -257,12 +346,6 @@ def __post_init__(self) -> None: driver_map[name] = driver object.__setattr__(self, "_driver_map", driver_map) - def with_context(self, context: SerializationContext) -> Self: - """Return a copy of these options with the serialization context applied.""" - result = dataclasses.replace(self) - object.__setattr__(result, "_context", context) - return result - def _select_driver( self, context: StorageDriverStoreContext, payload: Payload ) -> StorageDriver | None: @@ -292,9 +375,20 @@ def _get_driver_by_name(self, name: str) -> StorageDriver: raise ValueError(f"No driver found with name '{name}'") return driver + @staticmethod + def _build_store_context() -> StorageDriverStoreContext: + meta = _current_store_metadata.get() + return StorageDriverStoreContext( + namespace=meta.namespace if meta else None, + current_workflow=meta.current_workflow if meta else None, + current_activity=meta.current_activity if meta else None, + target_workflow=meta.target_workflow if meta else None, + target_activity=meta.target_activity if meta else None, + ) + async def _store_payload(self, payload: Payload) -> Payload: start_time = time.monotonic() - context = StorageDriverStoreContext(serialization_context=self._context) + context = self._build_store_context() driver = self._select_driver(context, payload) if driver is None: @@ -335,7 +429,7 @@ async def _store_payload_sequence( start_time = time.monotonic() results = list(payloads) - context = StorageDriverStoreContext(serialization_context=self._context) + context = self._build_store_context() to_store: list[tuple[int, Payload, StorageDriver]] = [] for index, payload in enumerate(payloads): diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index c7a1032fe..a19895261 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -34,6 +34,11 @@ import temporalio.converter import temporalio.converter._payload_limits import temporalio.exceptions +from temporalio.converter import StorageDriverActivityInfo, StorageDriverWorkflowInfo +from temporalio.converter._extstore import ( + StorageDriverStoreMetadata, + store_metadata_context, +) from ._interceptor import ( ActivityInboundInterceptor, @@ -252,6 +257,7 @@ async def _heartbeat_async( return data_converter = self._data_converter + store_metadata: StorageDriverStoreMetadata | None = None if activity.info: context = temporalio.converter.ActivitySerializationContext( namespace=activity.info.namespace, @@ -264,14 +270,33 @@ async def _heartbeat_async( ) data_converter = data_converter.with_context(context) + wf_info = ( + StorageDriverWorkflowInfo( + id=activity.info.workflow_id, + run_id=activity.info.workflow_run_id, + type=activity.info.workflow_type, + ) + if activity.info.workflow_id + else None + ) + store_metadata = StorageDriverStoreMetadata( + namespace=activity.info.namespace, + current_workflow=wf_info, + current_activity=StorageDriverActivityInfo( + id=activity.info.activity_id, + type=activity.info.activity_type, + run_id=activity.info.activity_run_id, + ), + ) + # Perform the heartbeat try: heartbeat = temporalio.bridge.proto.ActivityHeartbeat( # type: ignore[reportAttributeAccessIssue] task_token=task_token ) if details: - # Convert to core payloads - heartbeat.details.extend(await data_converter.encode(details)) + with store_metadata_context(store_metadata): + heartbeat.details.extend(await data_converter.encode(details)) logger.debug("Recording heartbeat with details %s", details) self._bridge_worker().record_activity_heartbeat(heartbeat) except Exception as err: @@ -316,153 +341,185 @@ async def _handle_start_activity_task( is_local=start.is_local, ) data_converter = self._data_converter.with_context(context) - try: - result = await self._execute_activity( - start, running_activity, task_token, data_converter + + # Build store metadata for external storage + ns = start.workflow_namespace or self._client.namespace + started_by_workflow = bool(start.workflow_execution.workflow_id) + wf_info = ( + StorageDriverWorkflowInfo( + id=start.workflow_execution.workflow_id or None, + run_id=start.workflow_execution.run_id or None, + type=start.workflow_type or None, ) - [payload] = await data_converter.encode([result]) - completion.result.completed.result.CopyFrom(payload) - except BaseException as err: + if started_by_workflow + else None + ) + act_info = StorageDriverActivityInfo( + id=start.activity_id or None, + type=start.activity_type or None, + run_id=None, + ) + # Store metadata is set for the full activity task lifetime (input + # decode, execution, result/failure encode). Each activity task runs + # in its own coroutine so the value won't leak to other tasks. + with store_metadata_context( + StorageDriverStoreMetadata( + namespace=ns, + current_workflow=wf_info, + current_activity=act_info, + ) + ): try: + result = await self._execute_activity( + start, running_activity, task_token, data_converter + ) + [payload] = await data_converter.encode([result]) + completion.result.completed.result.CopyFrom(payload) + except BaseException as err: try: - if isinstance(err, temporalio.activity._CompleteAsyncError): - temporalio.activity.logger.debug("Completing asynchronously") - completion.result.will_complete_async.SetInParent() - elif ( - isinstance( - err, - ( - asyncio.CancelledError, - temporalio.exceptions.CancelledError, - ), - ) - and running_activity.cancelled_due_to_heartbeat_error - ): - err = running_activity.cancelled_due_to_heartbeat_error - temporalio.activity.logger.warning( - f"Completing as failure during heartbeat with error of type {type(err)}: {err}", - ) - await data_converter.encode_failure( - err, completion.result.failed.failure - ) - elif ( - isinstance( - err, - ( - asyncio.CancelledError, - temporalio.exceptions.CancelledError, - ), - ) - and running_activity.cancellation_details.details - and running_activity.cancellation_details.details.paused - ): - temporalio.activity.logger.warning( - "Completing as failure due to unhandled cancel error produced by activity pause", - ) - await data_converter.encode_failure( - temporalio.exceptions.ApplicationError( - type="ActivityPause", - message="Unhandled activity cancel error produced by activity pause", - ), - completion.result.failed.failure, - ) - elif ( - isinstance( - err, - ( - asyncio.CancelledError, - temporalio.exceptions.CancelledError, - ), - ) - and running_activity.cancellation_details.details - and running_activity.cancellation_details.details.reset - ): - temporalio.activity.logger.warning( - "Completing as failure due to unhandled cancel error produced by activity reset", - ) - await data_converter.encode_failure( - temporalio.exceptions.ApplicationError( - type="ActivityReset", - message="Unhandled activity cancel error produced by activity reset", - ), - completion.result.failed.failure, - ) - elif ( - isinstance( - err, - ( - asyncio.CancelledError, - temporalio.exceptions.CancelledError, - ), - ) - and running_activity.cancelled_by_request - ): - temporalio.activity.logger.debug("Completing as cancelled") - await data_converter.encode_failure( - # TODO(cretz): Should use some other message? - temporalio.exceptions.CancelledError("Cancelled"), - completion.result.cancelled.failure, - ) - elif isinstance( - err, - temporalio.converter._payload_limits._PayloadSizeError, - ): - temporalio.activity.logger.warning( - err.message, - extra={"__temporal_error_identifier": "PayloadSizeError"}, - ) - await data_converter.encode_failure( - err, completion.result.failed.failure - ) - else: - if ( + try: + if isinstance(err, temporalio.activity._CompleteAsyncError): + temporalio.activity.logger.debug( + "Completing asynchronously" + ) + completion.result.will_complete_async.SetInParent() + elif ( isinstance( err, - temporalio.exceptions.ApplicationError, + ( + asyncio.CancelledError, + temporalio.exceptions.CancelledError, + ), ) - and err.category - == temporalio.exceptions.ApplicationErrorCategory.BENIGN + and running_activity.cancelled_due_to_heartbeat_error ): - # Downgrade log level to DEBUG for BENIGN application errors. - temporalio.activity.logger.debug( - "Completing activity as failed", - exc_info=True, - extra={ - "__temporal_error_identifier": "ActivityFailure" - }, + err = running_activity.cancelled_due_to_heartbeat_error + temporalio.activity.logger.warning( + f"Completing as failure during heartbeat with error of type {type(err)}: {err}", ) - else: + await data_converter.encode_failure( + err, completion.result.failed.failure + ) + elif ( + isinstance( + err, + ( + asyncio.CancelledError, + temporalio.exceptions.CancelledError, + ), + ) + and running_activity.cancellation_details.details + and running_activity.cancellation_details.details.paused + ): + temporalio.activity.logger.warning( + "Completing as failure due to unhandled cancel error produced by activity pause", + ) + await data_converter.encode_failure( + temporalio.exceptions.ApplicationError( + type="ActivityPause", + message="Unhandled activity cancel error produced by activity pause", + ), + completion.result.failed.failure, + ) + elif ( + isinstance( + err, + ( + asyncio.CancelledError, + temporalio.exceptions.CancelledError, + ), + ) + and running_activity.cancellation_details.details + and running_activity.cancellation_details.details.reset + ): + temporalio.activity.logger.warning( + "Completing as failure due to unhandled cancel error produced by activity reset", + ) + await data_converter.encode_failure( + temporalio.exceptions.ApplicationError( + type="ActivityReset", + message="Unhandled activity cancel error produced by activity reset", + ), + completion.result.failed.failure, + ) + elif ( + isinstance( + err, + ( + asyncio.CancelledError, + temporalio.exceptions.CancelledError, + ), + ) + and running_activity.cancelled_by_request + ): + temporalio.activity.logger.debug("Completing as cancelled") + await data_converter.encode_failure( + # TODO(cretz): Should use some other message? + temporalio.exceptions.CancelledError("Cancelled"), + completion.result.cancelled.failure, + ) + elif isinstance( + err, + temporalio.converter._payload_limits._PayloadSizeError, + ): temporalio.activity.logger.warning( - "Completing activity as failed", - exc_info=True, + err.message, extra={ - "__temporal_error_identifier": "ActivityFailure" + "__temporal_error_identifier": "PayloadSizeError" }, ) + await data_converter.encode_failure( + err, completion.result.failed.failure + ) + else: + if ( + isinstance( + err, + temporalio.exceptions.ApplicationError, + ) + and err.category + == temporalio.exceptions.ApplicationErrorCategory.BENIGN + ): + # Downgrade log level to DEBUG for BENIGN application errors. + temporalio.activity.logger.debug( + "Completing activity as failed", + exc_info=True, + extra={ + "__temporal_error_identifier": "ActivityFailure" + }, + ) + else: + temporalio.activity.logger.warning( + "Completing activity as failed", + exc_info=True, + extra={ + "__temporal_error_identifier": "ActivityFailure" + }, + ) + await data_converter.encode_failure( + err, completion.result.failed.failure + ) + # For broken executors, we have to fail the entire worker + if isinstance(err, concurrent.futures.BrokenExecutor): + self._fail_worker_exception_queue.put_nowait(err) + # Handle PayloadSizeError from attempting to encode failure information + except ( + temporalio.converter._payload_limits._PayloadSizeError + ) as inner_err: + temporalio.activity.logger.exception(inner_err.message) + completion.result.Clear() await data_converter.encode_failure( - err, completion.result.failed.failure + inner_err, completion.result.failed.failure ) - # For broken executors, we have to fail the entire worker - if isinstance(err, concurrent.futures.BrokenExecutor): - self._fail_worker_exception_queue.put_nowait(err) - # Handle PayloadSizeError from attempting to encode failure information - except ( - temporalio.converter._payload_limits._PayloadSizeError - ) as inner_err: - temporalio.activity.logger.exception(inner_err.message) + except Exception as inner_err: + temporalio.activity.logger.exception( + f"Exception handling failed, original error: {err}" + ) completion.result.Clear() - await data_converter.encode_failure( - inner_err, completion.result.failed.failure + completion.result.failed.failure.message = ( + f"Failed building exception result: {inner_err}" ) - except Exception as inner_err: - temporalio.activity.logger.exception( - f"Exception handling failed, original error: {err}" - ) - completion.result.Clear() - completion.result.failed.failure.message = ( - f"Failed building exception result: {inner_err}" - ) - completion.result.failed.failure.application_failure_info.SetInParent() + completion.result.failed.failure.application_failure_info.SetInParent() # Do final completion try: diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 914b14370..226126a48 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -4,13 +4,14 @@ import asyncio import concurrent.futures +import contextlib import dataclasses import logging import os import sys import threading import time -from collections.abc import Awaitable, Callable, MutableMapping, Sequence +from collections.abc import Awaitable, Callable, Iterator, MutableMapping, Sequence from dataclasses import dataclass from datetime import timedelta, timezone from types import TracebackType @@ -28,6 +29,11 @@ import temporalio.workflow from temporalio.api.enums.v1 import WorkflowTaskFailedCause from temporalio.bridge.worker import PollShutdownError +from temporalio.converter import StorageDriverWorkflowInfo +from temporalio.converter._extstore import ( + StorageDriverStoreMetadata, + store_metadata_context, +) from . import _command_aware_visitor from ._interceptor import ( @@ -296,17 +302,33 @@ async def _handle_activation( workflow_context_dc=data_converter, workflow_context=workflow_context, ) - download_metrics = await temporalio.bridge.worker.decode_activation( - act, - data_converter, - decode_headers=self._encode_headers, - storage_concurrency_limit=self._max_workflow_task_external_storage_concurrency, - ) + # Set default store metadata for decode_activation + with store_metadata_context( + StorageDriverStoreMetadata( + namespace=self._namespace, + current_workflow=StorageDriverWorkflowInfo( + id=workflow_id, + run_id=act.run_id, + type=( + workflow.workflow_type + if workflow + else (init_job.workflow_type if init_job else None) + ), + ), + ) + ): + download_metrics = await temporalio.bridge.worker.decode_activation( + act, + data_converter, + decode_headers=self._encode_headers, + storage_concurrency_limit=self._max_workflow_task_external_storage_concurrency, + ) if not workflow: assert init_job workflow = _RunningWorkflow( self._create_workflow_instance(act, init_job), workflow_id, + workflow_type=init_job.workflow_type, ) self._running_workflows[act.run_id] = workflow @@ -796,9 +818,15 @@ def _gen_tb_helper( class _RunningWorkflow: - def __init__(self, instance: WorkflowInstance, workflow_id: str): + def __init__( + self, + instance: WorkflowInstance, + workflow_id: str, + workflow_type: str | None = None, + ): self.instance = instance self.workflow_id = workflow_id + self.workflow_type = workflow_type self.deadlocked_activation_task: Awaitable | None = None self._deadlock_can_be_interrupted_lock = threading.Lock() self._deadlock_can_be_interrupted = False @@ -888,6 +916,13 @@ def _get_current_dc(self) -> temporalio.converter.DataConverter: return self._ca_workflow_context_dc return self._ca_context_free_dc.with_context(context) + @contextlib.contextmanager + def _store_metadata_context(self) -> Iterator[None]: + command_info = _command_aware_visitor.current_command_info.get() + metadata = self._ca_instance.get_external_store_metadata(command_info) + with store_metadata_context(metadata): + yield + async def _encode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: @@ -896,7 +931,10 @@ async def _encode_payload_sequence( async def _external_store_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - return await self._get_current_dc()._external_store_payload_sequence(payloads) + with self._store_metadata_context(): + return await self._get_current_dc()._external_store_payload_sequence( + payloads + ) async def _external_retrieve_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 1bfa77c3c..bb779b25c 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -58,6 +58,8 @@ import temporalio.converter import temporalio.exceptions import temporalio.workflow +from temporalio.converter import StorageDriverActivityInfo, StorageDriverWorkflowInfo +from temporalio.converter._extstore import StorageDriverStoreMetadata from temporalio.service import __version__ from ..api.failure.v1.message_pb2 import Failure @@ -182,6 +184,21 @@ def get_serialization_context( """ raise NotImplementedError + @abstractmethod + def get_external_store_metadata( + self, + command_info: _command_aware_visitor.CommandInfo | None, + ) -> StorageDriverStoreMetadata | None: + """Return appropriate store metadata for external storage operations. + + Args: + command_info: Optional information identifying the associated command. + + Returns: + The store metadata, or None if no metadata should be set. + """ + raise NotImplementedError + def get_thread_id(self) -> int | None: """Return the thread identifier that this workflow is running on. @@ -1851,7 +1868,6 @@ def workflow_register_random_seed_callback( # These are in alphabetical order and all start with "_outbound_". def _outbound_continue_as_new(self, input: ContinueAsNewInput) -> NoReturn: - # Just throw raise _ContinueAsNewError(self, input) def _outbound_schedule_activity( @@ -2222,6 +2238,72 @@ def get_serialization_context( workflow_id=self._info.workflow_id, ) + def get_external_store_metadata( + self, + command_info: _command_aware_visitor.CommandInfo | None, + ) -> StorageDriverStoreMetadata | None: + ns = self._info.namespace + current_wf = StorageDriverWorkflowInfo( + id=self._info.workflow_id, + run_id=self._info.run_id, + type=self._info.workflow_type, + ) + + if command_info is None: + return StorageDriverStoreMetadata( + namespace=ns, + current_workflow=current_wf, + ) + + COMMAND_TYPE = temporalio.api.enums.v1.command_type_pb2.CommandType + + if ( + command_info.command_type + == COMMAND_TYPE.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK + and command_info.command_seq in self._pending_activities + ): + handle = self._pending_activities[command_info.command_seq] + return StorageDriverStoreMetadata( + namespace=ns, + current_workflow=current_wf, + target_activity=StorageDriverActivityInfo( + id=handle._input.activity_id, + type=handle._input.activity, + ), + ) + + elif ( + command_info.command_type + == COMMAND_TYPE.COMMAND_TYPE_START_CHILD_WORKFLOW_EXECUTION + and command_info.command_seq in self._pending_child_workflows + ): + child = self._pending_child_workflows[command_info.command_seq] + return StorageDriverStoreMetadata( + namespace=ns, + current_workflow=current_wf, + target_workflow=StorageDriverWorkflowInfo( + id=child._input.id, type=child._input.workflow + ), + ) + + elif ( + command_info.command_type + == COMMAND_TYPE.COMMAND_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION + and command_info.command_seq in self._pending_external_signals + ): + _, target_id = self._pending_external_signals[command_info.command_seq] + return StorageDriverStoreMetadata( + namespace=ns, + current_workflow=current_wf, + target_workflow=StorageDriverWorkflowInfo(id=target_id), + ) + + else: + return StorageDriverStoreMetadata( + namespace=ns, + current_workflow=current_wf, + ) + def _instantiate_workflow_object(self) -> Any: if not self._workflow_input: raise RuntimeError("Expected workflow input. This is a Python SDK bug.") diff --git a/temporalio/worker/workflow_sandbox/_in_sandbox.py b/temporalio/worker/workflow_sandbox/_in_sandbox.py index eea8f6940..44a6fb351 100644 --- a/temporalio/worker/workflow_sandbox/_in_sandbox.py +++ b/temporalio/worker/workflow_sandbox/_in_sandbox.py @@ -13,6 +13,7 @@ import temporalio.converter import temporalio.worker._workflow_instance import temporalio.workflow +from temporalio.converter._extstore import StorageDriverStoreMetadata from temporalio.worker import _command_aware_visitor logger = logging.getLogger(__name__) @@ -88,3 +89,10 @@ def get_serialization_context( ) -> temporalio.converter.SerializationContext | None: """Get serialization context.""" return self.instance.get_serialization_context(command_info) + + def get_external_store_metadata( + self, + command_info: _command_aware_visitor.CommandInfo | None, + ) -> StorageDriverStoreMetadata | None: + """Get store metadata for external storage.""" + return self.instance.get_external_store_metadata(command_info) diff --git a/temporalio/worker/workflow_sandbox/_runner.py b/temporalio/worker/workflow_sandbox/_runner.py index 31514e33b..0d943feb1 100644 --- a/temporalio/worker/workflow_sandbox/_runner.py +++ b/temporalio/worker/workflow_sandbox/_runner.py @@ -17,6 +17,7 @@ import temporalio.common import temporalio.converter import temporalio.workflow +from temporalio.converter._extstore import StorageDriverStoreMetadata from temporalio.worker import _command_aware_visitor from ...api.common.v1.message_pb2 import Payloads @@ -205,3 +206,20 @@ def get_serialization_context( return self.globals_and_locals.pop("__temporal_context", None) # type: ignore finally: self.importer.restriction_context.is_runtime = False + + def get_external_store_metadata( + self, + command_info: _command_aware_visitor.CommandInfo | None, + ) -> StorageDriverStoreMetadata | None: + # Forward call to the sandboxed instance + self.importer.restriction_context.is_runtime = True + try: + self._run_code( + "with __temporal_importer.applied():\n" + " __temporal_metadata = __temporal_in_sandbox.get_external_store_metadata(__temporal_command_info)\n", + __temporal_importer=self.importer, + __temporal_command_info=command_info, + ) + return self.globals_and_locals.pop("__temporal_metadata", None) # type: ignore + finally: + self.importer.restriction_context.is_runtime = False diff --git a/tests/contrib/aws/s3driver/test_s3driver.py b/tests/contrib/aws/s3driver/test_s3driver.py index 46184c8b7..d47831603 100644 --- a/tests/contrib/aws/s3driver/test_s3driver.py +++ b/tests/contrib/aws/s3driver/test_s3driver.py @@ -27,12 +27,12 @@ S3StorageDriverClient, ) from temporalio.converter import ( - ActivitySerializationContext, JSONPlainPayloadConverter, + StorageDriverActivityInfo, StorageDriverClaim, StorageDriverRetrieveContext, StorageDriverStoreContext, - WorkflowSerializationContext, + StorageDriverWorkflowInfo, ) from tests.contrib.aws.s3driver.conftest import BUCKET @@ -51,34 +51,47 @@ def make_payload(value: str = "hello") -> Payload: def make_store_context( - serialization_context: WorkflowSerializationContext - | ActivitySerializationContext - | None = None, + namespace: str | None = None, + current_workflow: StorageDriverWorkflowInfo | None = None, + current_activity: StorageDriverActivityInfo | None = None, + target_workflow: StorageDriverWorkflowInfo | None = None, + target_activity: StorageDriverActivityInfo | None = None, ) -> StorageDriverStoreContext: - return StorageDriverStoreContext(serialization_context=serialization_context) + return StorageDriverStoreContext( + namespace=namespace, + current_workflow=current_workflow, + current_activity=current_activity, + target_workflow=target_workflow, + target_activity=target_activity, + ) def make_workflow_context( namespace: str = "my-namespace", workflow_id: str = "my-workflow", -) -> WorkflowSerializationContext: - return WorkflowSerializationContext(namespace=namespace, workflow_id=workflow_id) + workflow_type: str | None = None, + run_id: str | None = None, +) -> StorageDriverStoreContext: + return make_store_context( + namespace=namespace, + current_workflow=StorageDriverWorkflowInfo( + id=workflow_id, type=workflow_type, run_id=run_id + ), + ) def make_activity_context( namespace: str = "my-namespace", activity_id: str | None = "my-activity", workflow_id: str | None = None, - activity_task_queue: str | None = None, -) -> ActivitySerializationContext: - return ActivitySerializationContext( + activity_type: str | None = None, +) -> StorageDriverStoreContext: + return make_store_context( namespace=namespace, - activity_id=activity_id, - activity_type=None, - activity_task_queue=activity_task_queue, - workflow_id=workflow_id, - workflow_type=None, - is_local=False, + current_workflow=( + StorageDriverWorkflowInfo(id=workflow_id) if workflow_id else None + ), + current_activity=StorageDriverActivityInfo(id=activity_id, type=activity_type), ) @@ -211,48 +224,69 @@ async def test_key_context_workflow( ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_workflow_context(namespace="ns1", workflow_id="wf1") + ctx = make_workflow_context(namespace="ns1", workflow_id="wf1") + [claim] = await driver.store(ctx, [payload]) + expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() + assert ( + claim.claim_data["key"] + == f"v0/ns/ns1/wt/null/wi/wf1/ri/null/d/sha256/{expected_hash}" + ) + + async def test_key_context_workflow_with_type_and_run_id( + self, driver_client: S3StorageDriverClient + ) -> None: + driver = S3StorageDriver(client=driver_client, bucket=BUCKET) + payload = make_payload() + ctx = make_workflow_context( + namespace="ns1", + workflow_id="wf1", + workflow_type="MyWorkflow", + run_id="run-abc", ) [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() - assert claim.claim_data["key"] == f"v0/ns/ns1/wfi/wf1/d/sha256/{expected_hash}" + assert ( + claim.claim_data["key"] + == f"v0/ns/ns1/wt/MyWorkflow/wi/wf1/ri/run-abc/d/sha256/{expected_hash}" + ) async def test_key_context_workflow_activity( self, driver_client: S3StorageDriverClient ) -> None: - """workflow_id takes priority over activity_id in ActivitySerializationContext.""" + """workflow takes priority over activity in store context.""" driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_activity_context( - namespace="ns1", workflow_id="wf1", activity_id="act1" - ) + ctx = make_activity_context( + namespace="ns1", workflow_id="wf1", activity_id="act1" ) [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() - assert claim.claim_data["key"] == f"v0/ns/ns1/wfi/wf1/d/sha256/{expected_hash}" + assert ( + claim.claim_data["key"] + == f"v0/ns/ns1/wt/null/wi/wf1/ri/null/d/sha256/{expected_hash}" + ) - async def test_key_context_standalone_activityt( + async def test_key_context_standalone_activity( self, driver_client: S3StorageDriverClient ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_activity_context(namespace="ns1", activity_id="act1", workflow_id=None) + ctx = make_activity_context( + namespace="ns1", activity_id="act1", workflow_id=None ) [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() - assert claim.claim_data["key"] == f"v0/ns/ns1/aci/act1/d/sha256/{expected_hash}" + assert ( + claim.claim_data["key"] + == f"v0/ns/ns1/at/null/ai/act1/ri/null/d/sha256/{expected_hash}" + ) async def test_key_preserves_case( self, driver_client: S3StorageDriverClient ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_workflow_context(namespace="MyNamespace", workflow_id="MyWorkflow") - ) + ctx = make_workflow_context(namespace="MyNamespace", workflow_id="MyWorkflow") [claim] = await driver.store(ctx, [payload]) key = claim.claim_data["key"] assert "MyNamespace" in key @@ -263,14 +297,12 @@ async def test_key_urlencodes_workflow_id_with_slashes( ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_workflow_context(namespace="ns1", workflow_id="order/123/v2") - ) + ctx = make_workflow_context(namespace="ns1", workflow_id="order/123/v2") [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() assert ( claim.claim_data["key"] - == f"v0/ns/ns1/wfi/order%2F123%2Fv2/d/sha256/{expected_hash}" + == f"v0/ns/ns1/wt/null/wi/order%2F123%2Fv2/ri/null/d/sha256/{expected_hash}" ) async def test_key_urlencodes_workflow_id_with_special_chars( @@ -278,14 +310,12 @@ async def test_key_urlencodes_workflow_id_with_special_chars( ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_workflow_context(namespace="ns1", workflow_id="wf#1 &foo=bar") - ) + ctx = make_workflow_context(namespace="ns1", workflow_id="wf#1 &foo=bar") [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() assert ( claim.claim_data["key"] - == f"v0/ns/ns1/wfi/wf%231%20%26foo%3Dbar/d/sha256/{expected_hash}" + == f"v0/ns/ns1/wt/null/wi/wf%231%20%26foo%3Dbar/ri/null/d/sha256/{expected_hash}" ) async def test_key_urlencodes_activity_id( @@ -293,16 +323,14 @@ async def test_key_urlencodes_activity_id( ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_activity_context( - namespace="ns1", activity_id="act/1#2", workflow_id=None - ) + ctx = make_activity_context( + namespace="ns1", activity_id="act/1#2", workflow_id=None ) [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() assert ( claim.claim_data["key"] - == f"v0/ns/ns1/aci/act%2F1%232/d/sha256/{expected_hash}" + == f"v0/ns/ns1/at/null/ai/act%2F1%232/ri/null/d/sha256/{expected_hash}" ) async def test_key_urlencodes_namespace( @@ -310,14 +338,12 @@ async def test_key_urlencodes_namespace( ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_workflow_context(namespace="my/ns#1", workflow_id="wf1") - ) + ctx = make_workflow_context(namespace="my/ns#1", workflow_id="wf1") [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() assert ( claim.claim_data["key"] - == f"v0/ns/my%2Fns%231/wfi/wf1/d/sha256/{expected_hash}" + == f"v0/ns/my%2Fns%231/wt/null/wi/wf1/ri/null/d/sha256/{expected_hash}" ) async def test_key_urlencoded_roundtrip( @@ -326,9 +352,7 @@ async def test_key_urlencoded_roundtrip( """Payloads stored with special-char IDs can be retrieved correctly.""" driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload("special-char-roundtrip") - ctx = make_store_context( - make_workflow_context(namespace="ns/1", workflow_id="wf/2#3") - ) + ctx = make_workflow_context(namespace="ns/1", workflow_id="wf/2#3") [claim] = await driver.store(ctx, [payload]) [retrieved] = await driver.retrieve(StorageDriverRetrieveContext(), [claim]) assert retrieved == payload @@ -528,45 +552,40 @@ def counting_selector(_ctx: StorageDriverStoreContext, _p: Payload) -> str: ) assert call_count == 3 - async def test_selector_routes_by_activity_task_queue( + async def test_selector_routes_by_activity_type( self, aioboto3_client: S3Client, driver_client: S3StorageDriverClient ) -> None: - """bucket callable can route payloads to different buckets by activity task queue.""" - bucket_a = "bucket-queue-a" - bucket_b = "bucket-queue-b" + """bucket callable can route payloads to different buckets by activity type.""" + bucket_a = "bucket-type-a" + bucket_b = "bucket-type-b" await aioboto3_client.create_bucket(Bucket=bucket_a) await aioboto3_client.create_bucket(Bucket=bucket_b) - queue_buckets = {"queue-a": bucket_a, "queue-b": bucket_b} + type_buckets = {"type-a": bucket_a, "type-b": bucket_b} - def queue_selector(ctx: StorageDriverStoreContext, p: Payload) -> str: + def type_selector(ctx: StorageDriverStoreContext, p: Payload) -> str: del p - if isinstance(ctx.serialization_context, ActivitySerializationContext): - queue = ctx.serialization_context.activity_task_queue - if queue and queue in queue_buckets: - return queue_buckets[queue] + act = ctx.current_activity or ctx.target_activity + if act and act.type and act.type in type_buckets: + return type_buckets[act.type] return BUCKET - driver = S3StorageDriver(client=driver_client, bucket=queue_selector) + driver = S3StorageDriver(client=driver_client, bucket=type_selector) - ctx_a = make_store_context( - make_activity_context( - namespace="ns1", - activity_id="act1", - workflow_id="wf1", - activity_task_queue="queue-a", - ) + ctx_a = make_activity_context( + namespace="ns1", + activity_id="act1", + workflow_id="wf1", + activity_type="type-a", ) [claim_a] = await driver.store(ctx_a, [make_payload("payload-a")]) assert claim_a.claim_data["bucket"] == bucket_a - ctx_b = make_store_context( - make_activity_context( - namespace="ns1", - activity_id="act2", - workflow_id="wf1", - activity_task_queue="queue-b", - ) + ctx_b = make_activity_context( + namespace="ns1", + activity_id="act2", + workflow_id="wf1", + activity_type="type-b", ) [claim_b] = await driver.store(ctx_b, [make_payload("payload-b")]) assert claim_b.claim_data["bucket"] == bucket_b @@ -581,7 +600,7 @@ def capturing_selector(ctx: StorageDriverStoreContext, p: Payload) -> str: return BUCKET payload = make_payload() - store_ctx = make_store_context(make_workflow_context()) + store_ctx = make_workflow_context() driver = S3StorageDriver(client=driver_client, bucket=capturing_selector) await driver.store(store_ctx, [payload]) diff --git a/tests/test_serialization_context.py b/tests/test_serialization_context.py index ad3768ec8..d3ce022f5 100644 --- a/tests/test_serialization_context.py +++ b/tests/test_serialization_context.py @@ -40,15 +40,10 @@ DefaultFailureConverter, DefaultPayloadConverter, EncodingPayloadConverter, - ExternalStorage, JSONPlainPayloadConverter, PayloadCodec, PayloadConverter, SerializationContext, - StorageDriver, - StorageDriverClaim, - StorageDriverRetrieveContext, - StorageDriverStoreContext, WithSerializationContext, WorkflowSerializationContext, ) @@ -1919,84 +1914,3 @@ async def test_user_customization_of_default_payload_converter( id=wf_id, task_queue=task_queue, ) - - -# Child workflow external storage context test - - -class ContextTrackingStorageDriver(StorageDriver): - """In-memory driver that records the serialization context on each store/retrieve.""" - - def __init__(self) -> None: - self._storage: dict[str, bytes] = {} - self.store_contexts: list[SerializationContext | None] = [] - - def name(self) -> str: - return "context-tracking" - - async def store( - self, - context: StorageDriverStoreContext, - payloads: Sequence[temporalio.api.common.v1.Payload], - ) -> list[StorageDriverClaim]: - self.store_contexts.append(context.serialization_context) - claims: list[StorageDriverClaim] = [] - for payload in payloads: - key = f"payload-{len(self._storage)}" - self._storage[key] = payload.SerializeToString() - claims.append(StorageDriverClaim(claim_data={"key": key})) - return claims - - async def retrieve( - self, - context: StorageDriverRetrieveContext, - claims: Sequence[StorageDriverClaim], - ) -> list[temporalio.api.common.v1.Payload]: - results: list[temporalio.api.common.v1.Payload] = [] - for claim in claims: - payload = temporalio.api.common.v1.Payload() - payload.ParseFromString(self._storage[claim.claim_data["key"]]) - results.append(payload) - return results - - -async def test_child_workflow_external_storage_with_context(client: Client): - """External storage should receive the child workflow's context, not the parent's.""" - workflow_id = str(uuid.uuid4()) - child_workflow_id = f"{workflow_id}-child" - task_queue = str(uuid.uuid4()) - - driver = ContextTrackingStorageDriver() - config = client.config() - config["data_converter"] = dataclasses.replace( - DataConverter.default, - external_storage=ExternalStorage( - drivers=[driver], - payload_size_threshold=None, - ), - ) - client = Client(**config) - - async with Worker( - client, - task_queue=task_queue, - workflows=[ChildWorkflowCodecTestWorkflow, EchoWorkflow], - workflow_runner=UnsandboxedWorkflowRunner(), - ): - await client.execute_workflow( - ChildWorkflowCodecTestWorkflow.run, - TraceData(), - id=workflow_id, - task_queue=task_queue, - ) - - child_context = WorkflowSerializationContext( - namespace=client.namespace, - workflow_id=child_workflow_id, - ) - # store_contexts[0]: parent input encode → parent context - # store_contexts[1]: child workflow input encode → child context - # store_contexts[2]: child workflow result encode → child context - # store_contexts[3]: parent result encode → parent context - child_context_count = sum(1 for c in driver.store_contexts if c == child_context) - assert child_context_count == 2 diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py index eb4270d08..97571883b 100644 --- a/tests/worker/test_extstore.py +++ b/tests/worker/test_extstore.py @@ -11,6 +11,7 @@ import temporalio import temporalio.bridge.client import temporalio.bridge.worker +import temporalio.client import temporalio.converter import temporalio.worker._workflow from temporalio import activity, workflow @@ -19,6 +20,7 @@ from temporalio.common import RetryPolicy from temporalio.converter import ( ExternalStorage, + StorageDriver, StorageDriverClaim, StorageDriverRetrieveContext, StorageDriverStoreContext, @@ -859,3 +861,505 @@ async def test_tmprl1104_with_extstore_download_and_upload( assert getattr(records[1], "payload_upload_count") == 1 assert getattr(records[1], "payload_upload_size") == expected_output_size assert getattr(records[1], "payload_upload_duration") > timedelta(0) + + +# --------------------------------------------------------------------------- +# Store-metadata context tests +# --------------------------------------------------------------------------- + + +class ContextTrackingStorageDriver(StorageDriver): + """In-memory driver that records the store context on each store/retrieve.""" + + def __init__(self) -> None: + self._storage: dict[str, bytes] = {} + self.store_contexts: list[StorageDriverStoreContext] = [] + + def name(self) -> str: + return "context-tracking" + + async def store( + self, + context: StorageDriverStoreContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + self.store_contexts.append(context) + claims: list[StorageDriverClaim] = [] + for payload in payloads: + key = f"payload-{len(self._storage)}" + self._storage[key] = payload.SerializeToString() + claims.append(StorageDriverClaim(claim_data={"key": key})) + return claims + + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + results: list[Payload] = [] + for claim in claims: + payload = Payload() + payload.ParseFromString(self._storage[claim.claim_data["key"]]) + results.append(payload) + return results + + +@workflow.defn +class SignalWaitWorkflow: + def __init__(self) -> None: + self._signal_data: str | None = None + + @workflow.run + async def run(self, _arg: str) -> str: + await workflow.wait_condition(lambda: self._signal_data is not None) + return self._signal_data # type: ignore + + @workflow.signal + async def my_signal(self, data: str) -> None: + self._signal_data = data + + +@workflow.defn +class EchoWorkflow: + @workflow.run + async def run(self, data: str) -> str: + return data + + +@workflow.defn +class ChildWorkflowStoreMetadataTestWorkflow: + @workflow.run + async def run(self, data: str) -> str: + return await workflow.execute_child_workflow( + EchoWorkflow.run, + data, + id=f"{workflow.info().workflow_id}-child", + ) + + +async def _make_tracking_client( + env: WorkflowEnvironment, +) -> tuple[Client, ContextTrackingStorageDriver]: + driver = ContextTrackingStorageDriver() + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=None, + ), + ), + ) + return client, driver + + +async def test_store_metadata_start_workflow(env: WorkflowEnvironment) -> None: + """start_workflow should set workflow id and type on store metadata.""" + client, driver = await _make_tracking_client(env) + workflow_id = str(uuid.uuid4()) + + async with new_worker(client, EchoWorkflow) as worker: + await client.execute_workflow( + EchoWorkflow.run, + "hello", + id=workflow_id, + task_queue=worker.task_queue, + ) + + assert len(driver.store_contexts) == 2 + + client_ctx = driver.store_contexts[0] + assert client_ctx.namespace == client.namespace + assert client_ctx.current_workflow is None + assert client_ctx.target_workflow is not None + assert client_ctx.target_workflow.id == workflow_id + assert client_ctx.target_workflow.type == "EchoWorkflow" + assert client_ctx.target_workflow.run_id is None + assert client_ctx.target_activity is None + + worker_ctx = driver.store_contexts[1] + assert worker_ctx.namespace == client.namespace + assert worker_ctx.current_workflow is not None + assert worker_ctx.current_workflow.id == workflow_id + assert worker_ctx.current_workflow.type == "EchoWorkflow" + assert worker_ctx.current_workflow.run_id is not None + assert worker_ctx.target_workflow is None + + +async def test_store_metadata_signal_with_start(env: WorkflowEnvironment) -> None: + """signal_with_start should set workflow metadata for signal arg encoding.""" + client, driver = await _make_tracking_client(env) + workflow_id = str(uuid.uuid4()) + + async with new_worker(client, SignalWaitWorkflow) as worker: + handle = await client.start_workflow( + SignalWaitWorkflow.run, + "hello", + id=workflow_id, + task_queue=worker.task_queue, + start_signal="my_signal", + start_signal_args=["signal-data"], + ) + await handle.result() + + assert len(driver.store_contexts) == 3 + + # [0] Workflow input arg + input_ctx = driver.store_contexts[0] + assert input_ctx.current_workflow is None + assert input_ctx.target_workflow is not None + assert input_ctx.target_workflow.id == workflow_id + assert input_ctx.target_workflow.type == "SignalWaitWorkflow" + assert input_ctx.target_workflow.run_id is None + + # [1] Signal arg + signal_ctx = driver.store_contexts[1] + assert signal_ctx.current_workflow is None + assert signal_ctx.target_workflow is not None + assert signal_ctx.target_workflow.id == workflow_id + assert signal_ctx.target_workflow.type == "SignalWaitWorkflow" + assert signal_ctx.target_workflow.run_id is None + + # [2] Workflow result + result_ctx = driver.store_contexts[2] + assert result_ctx.current_workflow is not None + assert result_ctx.current_workflow.run_id is not None + assert result_ctx.target_workflow is None + + +async def test_store_metadata_signal_workflow(env: WorkflowEnvironment) -> None: + """signal_workflow should set workflow id on store metadata.""" + client, driver = await _make_tracking_client(env) + workflow_id = str(uuid.uuid4()) + + async with new_worker(client, SignalWaitWorkflow) as worker: + handle = await client.start_workflow( + SignalWaitWorkflow.run, + "hello", + id=workflow_id, + task_queue=worker.task_queue, + ) + # Signal separately (not signal-with-start) + await handle.signal(SignalWaitWorkflow.my_signal, "signal-data") + await handle.result() + + assert len(driver.store_contexts) == 3 + + signal_ctx = driver.store_contexts[1] + assert signal_ctx.current_workflow is None + assert signal_ctx.target_workflow is not None + assert signal_ctx.target_workflow.id == workflow_id + assert signal_ctx.target_workflow.type is None + + +async def test_store_metadata_schedule_action(env: WorkflowEnvironment) -> None: + """Schedule action _to_proto should set workflow metadata.""" + client, driver = await _make_tracking_client(env) + task_queue = str(uuid.uuid4()) + schedule_id = f"sched-{uuid.uuid4()}" + + try: + await client.create_schedule( + schedule_id, + temporalio.client.Schedule( + action=temporalio.client.ScheduleActionStartWorkflow( + EchoWorkflow.run, + "hello", + id=f"wf-{schedule_id}", + task_queue=task_queue, + ), + spec=temporalio.client.ScheduleSpec(), + ), + ) + + assert len(driver.store_contexts) == 1 + ctx = driver.store_contexts[0] + assert ctx.namespace == client.namespace + assert ctx.current_workflow is None + assert ctx.target_workflow is not None + assert ctx.target_workflow.id == f"wf-{schedule_id}" + assert ctx.target_workflow.type == "EchoWorkflow" + assert ctx.target_activity is None + finally: + try: + handle = client.get_schedule_handle(schedule_id) + await handle.delete() + except Exception: + pass + + +async def test_store_metadata_child_workflow(env: WorkflowEnvironment) -> None: + """External storage should receive the parent as workflow and child as target_workflow.""" + client, driver = await _make_tracking_client(env) + workflow_id = f"workflow-{uuid.uuid4()}" + child_workflow_id = f"{workflow_id}-child" + + async with new_worker( + client, + ChildWorkflowStoreMetadataTestWorkflow, + EchoWorkflow, + ) as worker: + await client.execute_workflow( + ChildWorkflowStoreMetadataTestWorkflow.run, + "hello", + id=workflow_id, + task_queue=worker.task_queue, + ) + + assert len(driver.store_contexts) == 4 + + # [0] Client starts parent workflow + client_ctx = driver.store_contexts[0] + assert client_ctx.current_workflow is None + assert client_ctx.target_workflow is not None + assert client_ctx.target_workflow.id == workflow_id + assert client_ctx.target_workflow.type == "ChildWorkflowStoreMetadataTestWorkflow" + + # [1] Parent schedules child: current_workflow = parent, target_workflow = child + start_child_ctx = driver.store_contexts[1] + assert start_child_ctx.current_workflow is not None + assert start_child_ctx.current_workflow.id == workflow_id + assert ( + start_child_ctx.current_workflow.type + == "ChildWorkflowStoreMetadataTestWorkflow" + ) + assert start_child_ctx.current_workflow.run_id is not None + assert start_child_ctx.target_workflow is not None + assert start_child_ctx.target_workflow.id == child_workflow_id + assert start_child_ctx.target_workflow.type == "EchoWorkflow" + assert start_child_ctx.target_workflow.run_id is None + + # [2] Child returns result: current_workflow = child, no target + child_result_ctx = driver.store_contexts[2] + assert child_result_ctx.current_workflow is not None + assert child_result_ctx.current_workflow.id == child_workflow_id + assert child_result_ctx.current_workflow.type == "EchoWorkflow" + assert child_result_ctx.current_workflow.run_id is not None + assert child_result_ctx.target_workflow is None + + # [3] Parent returns result: current_workflow = parent, no target + parent_result_ctx = driver.store_contexts[3] + assert parent_result_ctx.current_workflow is not None + assert parent_result_ctx.current_workflow.id == workflow_id + assert ( + parent_result_ctx.current_workflow.type + == "ChildWorkflowStoreMetadataTestWorkflow" + ) + assert parent_result_ctx.current_workflow.run_id is not None + assert parent_result_ctx.target_workflow is None + + +# Workflow definitions for gap tests + + +@activity.defn +async def echo_activity(input: str) -> str: + """Simple activity that returns its input.""" + return input + + +@workflow.defn +class ActivityScheduleMetadataWorkflow: + """Workflow that schedules an activity to test activity metadata on the store context.""" + + @workflow.run + async def run(self, data: str) -> str: + return await workflow.execute_activity( + echo_activity, + data, + activity_id="my-activity-id", + schedule_to_close_timeout=timedelta(seconds=10), + ) + + +@workflow.defn +class SignalExternalMetadataWorkflow: + """Workflow that signals another workflow.""" + + @workflow.run + async def run(self, target_workflow_id: str) -> None: + await workflow.get_external_workflow_handle(target_workflow_id).signal( + SignalWaitWorkflow.my_signal, "signal-from-workflow" + ) + + +async def test_store_metadata_activity_scheduling(env: WorkflowEnvironment) -> None: + """When a workflow schedules an activity, context.activity should be populated.""" + client, driver = await _make_tracking_client(env) + workflow_id = f"workflow-{uuid.uuid4()}" + + async with new_worker( + client, + ActivityScheduleMetadataWorkflow, + activities=[echo_activity], + ) as worker: + await client.execute_workflow( + ActivityScheduleMetadataWorkflow.run, + "hello", + id=workflow_id, + task_queue=worker.task_queue, + ) + + assert len(driver.store_contexts) == 4 + + # [0] Client starts workflow + client_ctx = driver.store_contexts[0] + assert client_ctx.current_workflow is None + assert client_ctx.target_workflow is not None + assert client_ctx.target_workflow.id == workflow_id + assert client_ctx.target_workflow.type == "ActivityScheduleMetadataWorkflow" + + # [1] Workflow worker schedules activity: current_workflow + target_activity + schedule_ctx = driver.store_contexts[1] + assert schedule_ctx.namespace == client.namespace + assert schedule_ctx.current_workflow is not None + assert schedule_ctx.current_workflow.id == workflow_id + assert schedule_ctx.current_workflow.type == "ActivityScheduleMetadataWorkflow" + assert schedule_ctx.target_activity is not None + assert schedule_ctx.target_activity.id == "my-activity-id" + assert schedule_ctx.target_activity.type == "echo_activity" + + # [2] Activity worker completes: current_workflow + current_activity + execute_ctx = driver.store_contexts[2] + assert execute_ctx.namespace == client.namespace + assert execute_ctx.current_workflow is not None + assert execute_ctx.current_workflow.id == workflow_id + assert execute_ctx.current_workflow.type == "ActivityScheduleMetadataWorkflow" + assert execute_ctx.current_activity is not None + assert execute_ctx.current_activity.id == "my-activity-id" + assert execute_ctx.current_activity.type == "echo_activity" + + # [3] Workflow returns result + result_ctx = driver.store_contexts[3] + assert result_ctx.current_workflow is not None + assert result_ctx.current_workflow.id == workflow_id + assert result_ctx.target_activity is None + + +async def test_store_metadata_signal_external_workflow( + env: WorkflowEnvironment, +) -> None: + """Signaling an external workflow should set workflow.id to the target.""" + client, driver = await _make_tracking_client(env) + target_workflow_id = f"target-{uuid.uuid4()}" + sender_workflow_id = f"sender-{uuid.uuid4()}" + + async with new_worker( + client, + SignalExternalMetadataWorkflow, + SignalWaitWorkflow, + ) as worker: + # Start the target workflow first + target_handle = await client.start_workflow( + SignalWaitWorkflow.run, + "waiting", + id=target_workflow_id, + task_queue=worker.task_queue, + ) + # Start the sender which will signal the target + await client.execute_workflow( + SignalExternalMetadataWorkflow.run, + target_workflow_id, + id=sender_workflow_id, + task_queue=worker.task_queue, + ) + await target_handle.result() + + assert len(driver.store_contexts) == 5 + + # [0] Client starts target workflow (SignalWaitWorkflow) + target_start_ctx = driver.store_contexts[0] + assert target_start_ctx.current_workflow is None + assert target_start_ctx.target_workflow is not None + assert target_start_ctx.target_workflow.id == target_workflow_id + assert target_start_ctx.target_workflow.type == "SignalWaitWorkflow" + + # [1] Client starts sender workflow (SignalExternalMetadataWorkflow) + sender_start_ctx = driver.store_contexts[1] + assert sender_start_ctx.current_workflow is None + assert sender_start_ctx.target_workflow is not None + assert sender_start_ctx.target_workflow.id == sender_workflow_id + assert sender_start_ctx.target_workflow.type == "SignalExternalMetadataWorkflow" + + # [2] Sender signals target: current_workflow = sender, target_workflow = target + signal_ctx = driver.store_contexts[2] + assert signal_ctx.current_workflow is not None + assert signal_ctx.current_workflow.id == sender_workflow_id + assert signal_ctx.target_workflow is not None + assert signal_ctx.target_workflow.id == target_workflow_id + assert signal_ctx.target_workflow.type is None + assert signal_ctx.target_workflow.run_id is None + + # [3] Sender workflow returns + sender_result_ctx = driver.store_contexts[3] + assert sender_result_ctx.current_workflow is not None + assert sender_result_ctx.current_workflow.id == sender_workflow_id + assert sender_result_ctx.target_workflow is None + + # [4] Target workflow returns signal data + target_result_ctx = driver.store_contexts[4] + assert target_result_ctx.current_workflow is not None + assert target_result_ctx.current_workflow.id == target_workflow_id + assert target_result_ctx.target_workflow is None + + +async def test_store_metadata_activity_worker(env: WorkflowEnvironment) -> None: + """Activity worker should set workflow and activity metadata during encode.""" + client, driver = await _make_tracking_client(env) + workflow_id = f"workflow-{uuid.uuid4()}" + + async with new_worker( + client, + ActivityScheduleMetadataWorkflow, + activities=[echo_activity], + ) as worker: + await client.execute_workflow( + ActivityScheduleMetadataWorkflow.run, + "hello", + id=workflow_id, + task_queue=worker.task_queue, + ) + + assert len(driver.store_contexts) == 4 + + # [2] Activity worker completes: current_workflow has run_id, current_activity set + execute_ctx = driver.store_contexts[2] + assert execute_ctx.current_workflow is not None + assert execute_ctx.current_workflow.id == workflow_id + assert execute_ctx.current_workflow.run_id is not None + assert execute_ctx.current_activity is not None + assert execute_ctx.current_activity.type == "echo_activity" + + +async def test_store_metadata_decode_activation(env: WorkflowEnvironment) -> None: + """All store contexts should have namespace set.""" + client, driver = await _make_tracking_client(env) + workflow_id = f"workflow-{uuid.uuid4()}" + + async with new_worker(client, EchoWorkflow) as worker: + await client.execute_workflow( + EchoWorkflow.run, + "hello", + id=workflow_id, + task_queue=worker.task_queue, + ) + + assert len(driver.store_contexts) == 2 + + # [0] Client starts workflow + client_ctx = driver.store_contexts[0] + assert client_ctx.namespace == client.namespace + assert client_ctx.current_workflow is None + assert client_ctx.target_workflow is not None + assert client_ctx.target_workflow.id == workflow_id + assert client_ctx.target_workflow.type == "EchoWorkflow" + + # [1] Workflow returns result + worker_ctx = driver.store_contexts[1] + assert worker_ctx.namespace == client.namespace + assert worker_ctx.current_workflow is not None + assert worker_ctx.current_workflow.id == workflow_id + assert worker_ctx.current_workflow.type == "EchoWorkflow" + assert worker_ctx.target_workflow is None diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 4ef5c29fa..9da3ea4b9 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -38,6 +38,7 @@ import temporalio.api.sdk.v1 import temporalio.client import temporalio.converter +import temporalio.converter._extstore import temporalio.worker import temporalio.worker._command_aware_visitor import temporalio.workflow @@ -1626,6 +1627,12 @@ def get_serialization_context( ) -> temporalio.converter.SerializationContext | None: return self._unsandboxed.get_serialization_context(command_info) + def get_external_store_metadata( + self, + command_info: temporalio.worker._command_aware_visitor.CommandInfo | None, + ) -> temporalio.converter._extstore.StorageDriverStoreMetadata | None: + return self._unsandboxed.get_external_store_metadata(command_info) + async def test_workflow_with_custom_runner(client: Client): runner = CustomWorkflowRunner() From 8f14f6b0361378591c1bd52c947a51543f3c5c60 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Mon, 30 Mar 2026 12:07:31 -0700 Subject: [PATCH 02/16] Fix test assertions --- .../aws/s3driver/test_s3driver_worker.py | 132 +++++++++++------- 1 file changed, 84 insertions(+), 48 deletions(-) diff --git a/tests/contrib/aws/s3driver/test_s3driver_worker.py b/tests/contrib/aws/s3driver/test_s3driver_worker.py index 87ab73736..3fd1f8fd8 100644 --- a/tests/contrib/aws/s3driver/test_s3driver_worker.py +++ b/tests/contrib/aws/s3driver/test_s3driver_worker.py @@ -107,8 +107,17 @@ async def test_s3_driver_workflow_input_key( execution_timeout=timedelta(seconds=5), ) keys = await _list_keys(aioboto3_client) - assert len(keys) == 1 - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[0] + + # Client stores workflow input with ri=null (run ID not yet assigned); + # worker stores activity input with ri=run_id — same bytes, two S3 objects. + assert len(keys) == 2 + assert all( + f"/ns/default/wt/LargeIOWorkflow/wi/{workflow_id}/ri/" in k for k in keys + ) + # Client-side store: ri=null because run ID is not yet known. + assert sum(1 for k in keys if "/ri/null/" in k) == 1 + # Worker-side store: ri=run_id, assigned by the server. + assert sum(1 for k in keys if "/ri/null/" not in k) == 1 async def test_s3_driver_workflow_output_key( @@ -127,8 +136,11 @@ async def test_s3_driver_workflow_output_key( ) assert result == LARGE keys = await _list_keys(aioboto3_client) + # Activity result and workflow result dedup to same key assert len(keys) == 1 - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[0] + assert f"/ns/default/wt/LargeIOWorkflow/wi/{workflow_id}/ri/" in keys[0] + # Run ID is known for both activity completion and workflow completion + assert "/ri/null/" not in keys[0] async def test_s3_driver_workflow_activity_input_key( @@ -146,11 +158,14 @@ async def test_s3_driver_workflow_activity_input_key( execution_timeout=timedelta(seconds=5), ) keys = await _list_keys(aioboto3_client) - assert len(keys) == 1 - assert f"/ns/default/wfi/{workflow_id}/" in keys[0] - assert ( - "/aci/" not in keys[0] - ), "Activity input should use workflow_id, not activity_id" + # Client start (ri=null) + worker schedules activity (ri=run_id) — same bytes, two objects. + assert len(keys) == 2 + # Both keys are under the workflow wi/ri prefix, not the activity. + assert all( + f"/ns/default/wt/LargeIOWorkflow/wi/{workflow_id}/ri/" in k for k in keys + ) + # Activity input is keyed under the scheduling workflow, not the activity. + assert all("/ai/" not in k for k in keys) async def test_s3_driver_workflow_activity_output_key( @@ -168,8 +183,11 @@ async def test_s3_driver_workflow_activity_output_key( execution_timeout=timedelta(seconds=5), ) keys = await _list_keys(aioboto3_client) + # Activity result and workflow result are both LARGE so they deduplicate to one object. assert len(keys) == 1 - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[0] + assert f"/ns/default/wt/LargeIOWorkflow/wi/{workflow_id}/ri/" in keys[0] + # ri=run_id for both stores (run ID is known by the time the activity completes). + assert "/ri/null/" not in keys[0] async def test_s3_driver_signal_arg_key( @@ -186,8 +204,12 @@ async def test_s3_driver_signal_arg_key( await handle.signal(SignalQueryUpdateWorkflow.finish, LARGE) await handle.result() keys = await _list_keys(aioboto3_client) - assert len(keys) == 1 - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[0] + # Signal arg + workflow result — two distinct keys (different wt and ri). + assert len(keys) == 2 + # Signal arg: client stores with wt=null (type not known) and ri=null. + assert any(f"/wt/null/wi/{workflow_id}/ri/null/" in k for k in keys) + # Workflow result: worker stores with real type and ri=run_id. + assert any(f"/wt/SignalQueryUpdateWorkflow/wi/{workflow_id}/" in k for k in keys) async def test_s3_driver_query_result_key( @@ -206,8 +228,12 @@ async def test_s3_driver_query_result_key( await handle.signal(SignalQueryUpdateWorkflow.finish, "done") await handle.result() keys = await _list_keys(aioboto3_client) - assert len(keys) == 1 - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[0] + # Query arg + (query result deduplicated with workflow result) — two distinct keys. + assert len(keys) == 2 + # Query arg: client stores with wt=null (type not known) and ri=null. + assert any(f"/wt/null/wi/{workflow_id}/ri/null/" in k for k in keys) + # Query result and workflow result are both LARGE and deduplicate to one key with ri=run_id. + assert any(f"/wt/SignalQueryUpdateWorkflow/wi/{workflow_id}/" in k for k in keys) async def test_s3_driver_update_result_key( @@ -226,8 +252,12 @@ async def test_s3_driver_update_result_key( await handle.signal(SignalQueryUpdateWorkflow.finish, "done") await handle.result() keys = await _list_keys(aioboto3_client) - assert len(keys) == 1 - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[0] + # Update arg + (update result deduplicated with workflow result) — two distinct keys. + assert len(keys) == 2 + # Update arg: client stores with wt=null (type not known) and ri=null. + assert any(f"/wt/null/wi/{workflow_id}/ri/null/" in k for k in keys) + # Update result and workflow result are both LARGE and deduplicate to one key with ri=run_id. + assert any(f"/wt/SignalQueryUpdateWorkflow/wi/{workflow_id}/" in k for k in keys) async def test_s3_driver_child_workflow_input_key( @@ -244,9 +274,13 @@ async def test_s3_driver_child_workflow_input_key( execution_timeout=timedelta(seconds=5), ) keys = await _list_keys(aioboto3_client) - assert len(keys) == 1 child_workflow_id = f"{workflow_id}-child" - assert f"/ns/default/wfi/{child_workflow_id}/d/sha256/" in keys[0] + # Child input is the only large payload — stored under the parent's wi/ri. + assert len(keys) == 1 + # Keyed under the parent: it is the current execution context when scheduling the child. + assert f"/ns/default/wt/ParentWithChildWorkflow/wi/{workflow_id}/ri/" in keys[0] + # Not keyed under the child: the payload lives in the parent's history, not the child's. + assert f"/wi/{child_workflow_id}/" not in keys[0] async def test_s3_driver_identified_casing( @@ -264,10 +298,11 @@ async def test_s3_driver_identified_casing( execution_timeout=timedelta(seconds=5), ) keys = await _list_keys(aioboto3_client) - assert len(keys) == 1 - assert "/ns/default/" in keys[0], "Namespace segment should be present" - assert ( - f"/wfi/{workflow_id}/" in keys[0] + # Client start (ri=null) + worker stores (ri=run_id) — two objects. + assert len(keys) == 2 + # Workflow ID is percent-encoded but casing is preserved verbatim. + assert all( + f"/ns/default/wt/LargeIOWorkflow/wi/{workflow_id}/ri/" in k for k in keys ), "Workflow ID should preserve original case in the key" @@ -290,9 +325,14 @@ async def test_s3_driver_content_dedup( execution_timeout=timedelta(seconds=5), ) keys = await _list_keys(aioboto3_client) + # Two distinct content hashes (LARGE from download, LARGE_2 from extract) → two keys. assert len(keys) == 2 - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[0] - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[1] + # Both are under the same workflow wi/ri prefix despite crossing activity boundaries. + assert all( + f"/ns/default/wt/DocumentIngestionWorkflow/wi/{workflow_id}/ri/" in k + for k in keys + ) + # The two keys differ by content hash only. assert keys[0] != keys[1] @@ -301,8 +341,7 @@ async def test_s3_driver_single_workflow_same_key_namespace( ) -> None: """A training job started with a large config, injected with large override parameters mid-run, and polled for large metrics — all produce S3 keys - under the same workflow ID prefix, regardless of which primitive carried - the payload.""" + containing the same workflow ID.""" workflow_id = str(uuid.uuid4()) async with new_worker(tmprl_client, ModelTrainingWorkflow) as worker: handle = await tmprl_client.start_workflow( @@ -320,19 +359,18 @@ async def test_s3_driver_single_workflow_same_key_namespace( await handle.signal(ModelTrainingWorkflow.complete) await handle.result() keys = await _list_keys(aioboto3_client) - # LARGE (input + signal arg) and LARGE_2 (metrics result) deduplicate to - # two distinct keys — both anchored under the same workflow ID prefix. - assert len(keys) == 2 - assert all(f"/ns/default/wfi/{workflow_id}/" in key for key in keys) + # Four distinct keys: client start, signal arg, update result, workflow result. + assert len(keys) == 4 + # All keys are anchored under the same workflow ID regardless of which primitive carried the payload. + assert all(f"/wi/{workflow_id}/" in k for k in keys) async def test_s3_driver_parent_child_independent_key_namespaces( tmprl_client: Client, aioboto3_client: S3Client ) -> None: - """An order fulfillment workflow spawns a child payment processor, passes it - a large order payload, and returns the child's large payment confirmation. - Each workflow accumulates S3 keys under its own workflow ID prefix — - parent and child key namespaces are fully independent.""" + """An order fulfillment workflow spawns a child payment processor and passes + it a large order payload. Child input is keyed under the parent (it lives in + the parent's history); child output is keyed under the child.""" workflow_id = str(uuid.uuid4()) payment_id = f"{workflow_id}-payment" async with new_worker( @@ -346,16 +384,14 @@ async def test_s3_driver_parent_child_independent_key_namespaces( execution_timeout=timedelta(seconds=5), ) keys = await _list_keys(aioboto3_client) - parent_prefix = f"/ns/default/wfi/{workflow_id}/d/" - child_prefix = f"/ns/default/wfi/{payment_id}/d/" - parent_keys = [k for k in keys if parent_prefix in k] - child_keys = [k for k in keys if child_prefix in k] - # The parent stores its input (LARGE) and the child's result propagated - # back (LARGE_2) under the parent's prefix → 2 keys. - # The child stores its input (LARGE) and its result (LARGE_2) under the - # child's prefix → 2 keys. - assert len(parent_keys) == 2 - assert len(child_keys) == 2 + parent_keys = [k for k in keys if f"/wi/{workflow_id}/" in k] + child_keys = [k for k in keys if f"/wi/{payment_id}/" in k] + # Parent accumulates 3 keys: client start (ri=null), child input (ri=run_id, + # keyed under parent because parent is scheduling context), and child result + # propagated back to parent (ri=run_id). + assert len(parent_keys) == 3 + # Child accumulates 1 key: its own result (LARGE_2, ri=child_run_id). + assert len(child_keys) == 1 async def test_s3_store_failure_surfaces_in_workflow_history( @@ -400,7 +436,6 @@ async def test_s3_store_failure_surfaces_in_workflow_history( large_payload = JSONPlainPayloadConverter().to_payload(LARGE) assert large_payload is not None expected_hash = hashlib.sha256(large_payload.SerializeToString()).hexdigest() - expected_key = f"v0/ns/default/wfi/{workflow_id}/d/sha256/{expected_hash}" assert isinstance(exc_info.value, WorkflowFailureError) activity_error = exc_info.value.__cause__ @@ -408,7 +443,8 @@ async def test_s3_store_failure_surfaces_in_workflow_history( app_error = activity_error.__cause__ assert isinstance(app_error, ApplicationError) assert app_error.type == "RuntimeError" - assert ( - app_error.message - == f"S3StorageDriver store failed [bucket={bad_bucket}, key={expected_key}]" - ) + # Key includes run_id which is only known at runtime; use substring checks. + msg = app_error.message + assert f"S3StorageDriver store failed [bucket={bad_bucket}, key=" in msg + assert f"/wt/LargeOutputNoRetryWorkflow/wi/{workflow_id}/ri/" in msg + assert f"/d/sha256/{expected_hash}]" in msg From 50ffea280a2368957297f8f8c1195f2f36d9f721 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Mon, 30 Mar 2026 13:00:40 -0700 Subject: [PATCH 03/16] Fix test for non-deterministic ordering --- tests/worker/test_extstore.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py index 97571883b..6810f7296 100644 --- a/tests/worker/test_extstore.py +++ b/tests/worker/test_extstore.py @@ -1292,16 +1292,22 @@ async def test_store_metadata_signal_external_workflow( assert signal_ctx.target_workflow.type is None assert signal_ctx.target_workflow.run_id is None - # [3] Sender workflow returns - sender_result_ctx = driver.store_contexts[3] - assert sender_result_ctx.current_workflow is not None - assert sender_result_ctx.current_workflow.id == sender_workflow_id + # [3] and [4] are the sender and target workflow completions in some order. + # The sender's WFT 2 (after signal resolution) and the target's WFT (after + # receiving the signal) are both scheduled by the server at nearly the same + # time, so the order of their completions is non-deterministic. + completion_ctxs = { + ctx.current_workflow.id: ctx + for ctx in driver.store_contexts[3:5] + if ctx.current_workflow is not None + } + assert sender_workflow_id in completion_ctxs + assert target_workflow_id in completion_ctxs + + sender_result_ctx = completion_ctxs[sender_workflow_id] assert sender_result_ctx.target_workflow is None - # [4] Target workflow returns signal data - target_result_ctx = driver.store_contexts[4] - assert target_result_ctx.current_workflow is not None - assert target_result_ctx.current_workflow.id == target_workflow_id + target_result_ctx = completion_ctxs[target_workflow_id] assert target_result_ctx.target_workflow is None From f7ed14960e3c5b82a14350baf7f986bbd9368ac4 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Mon, 30 Mar 2026 16:32:54 -0700 Subject: [PATCH 04/16] Consolidate to single target field --- temporalio/client.py | 41 ++-- temporalio/contrib/aws/s3driver/_driver.py | 25 +- temporalio/converter/_extstore.py | 39 +-- temporalio/worker/_activity.py | 44 ++-- temporalio/worker/_workflow.py | 2 +- temporalio/worker/_workflow_instance.py | 27 +-- tests/contrib/aws/s3driver/test_s3driver.py | 50 ++-- .../aws/s3driver/test_s3driver_worker.py | 60 ++++- tests/worker/test_extstore.py | 225 ++++++++---------- 9 files changed, 234 insertions(+), 279 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 8d9e9ac97..3a2d594b9 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -6170,9 +6170,7 @@ async def _to_proto( with store_metadata_context( StorageDriverStoreMetadata( namespace=client.namespace, - target_workflow=StorageDriverWorkflowInfo( - id=self.id, type=self.workflow - ), + target=StorageDriverWorkflowInfo(id=self.id, type=self.workflow), ) ): data_converter = client.data_converter.with_context( @@ -8099,9 +8097,7 @@ async def _build_signal_with_start_workflow_execution_request( with store_metadata_context( StorageDriverStoreMetadata( namespace=self._client.namespace, - target_workflow=StorageDriverWorkflowInfo( - id=input.id, type=input.workflow - ), + target=StorageDriverWorkflowInfo(id=input.id, type=input.workflow), ) ): data_converter = self._client.data_converter.with_context( @@ -8138,9 +8134,7 @@ async def _populate_start_workflow_execution_request( with store_metadata_context( StorageDriverStoreMetadata( namespace=self._client.namespace, - target_workflow=StorageDriverWorkflowInfo( - id=input.id, type=input.workflow - ), + target=StorageDriverWorkflowInfo(id=input.id, type=input.workflow), ) ): data_converter = self._client.data_converter.with_context( @@ -8266,7 +8260,7 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any: with store_metadata_context( StorageDriverStoreMetadata( namespace=self._client.namespace, - target_workflow=StorageDriverWorkflowInfo( + target=StorageDriverWorkflowInfo( id=input.id, run_id=input.run_id or None ), ) @@ -8332,7 +8326,7 @@ async def signal_workflow(self, input: SignalWorkflowInput) -> None: with store_metadata_context( StorageDriverStoreMetadata( namespace=self._client.namespace, - target_workflow=StorageDriverWorkflowInfo( + target=StorageDriverWorkflowInfo( id=input.id, run_id=input.run_id or None ), ) @@ -8365,7 +8359,7 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: with store_metadata_context( StorageDriverStoreMetadata( namespace=self._client.namespace, - target_workflow=StorageDriverWorkflowInfo( + target=StorageDriverWorkflowInfo( id=input.id, run_id=input.run_id or None ), ) @@ -8432,9 +8426,7 @@ async def _build_start_activity_execution_request( with store_metadata_context( StorageDriverStoreMetadata( namespace=self._client.namespace, - target_activity=StorageDriverActivityInfo( - id=input.id, type=input.activity_type - ), + target=StorageDriverActivityInfo(id=input.id, type=input.activity_type), ) ): data_converter = self._client.data_converter.with_context( @@ -8639,7 +8631,7 @@ async def _build_update_workflow_execution_request( with store_metadata_context( StorageDriverStoreMetadata( namespace=self._client.namespace, - target_workflow=StorageDriverWorkflowInfo( + target=StorageDriverWorkflowInfo( id=workflow_id, run_id=(input.run_id or None) if isinstance(input, StartWorkflowUpdateInput) @@ -8832,16 +8824,19 @@ def _get_async_activity_store_metadata( self, id_or_token: AsyncActivityIDReference | bytes ) -> StorageDriverStoreMetadata: if isinstance(id_or_token, AsyncActivityIDReference): + if id_or_token.workflow_id: + return StorageDriverStoreMetadata( + namespace=self._client.namespace, + target=StorageDriverWorkflowInfo( + id=id_or_token.workflow_id or None, + run_id=id_or_token.run_id or None, + ), + ) return StorageDriverStoreMetadata( namespace=self._client.namespace, - target_workflow=StorageDriverWorkflowInfo( - id=id_or_token.workflow_id or None, - run_id=id_or_token.run_id or None, - ) - if id_or_token.workflow_id - else None, - target_activity=StorageDriverActivityInfo( + target=StorageDriverActivityInfo( id=id_or_token.activity_id, + run_id=id_or_token.run_id or None, ), ) else: diff --git a/temporalio/contrib/aws/s3driver/_driver.py b/temporalio/contrib/aws/s3driver/_driver.py index d6488cf64..a28e91503 100644 --- a/temporalio/contrib/aws/s3driver/_driver.py +++ b/temporalio/contrib/aws/s3driver/_driver.py @@ -16,9 +16,11 @@ from temporalio.contrib.aws.s3driver._client import S3StorageDriverClient from temporalio.converter import ( StorageDriver, + StorageDriverActivityInfo, StorageDriverClaim, StorageDriverRetrieveContext, StorageDriverStoreContext, + StorageDriverWorkflowInfo, ) _T = TypeVar("_T") @@ -118,21 +120,18 @@ def _quote(val: str | None) -> str | None: namespace = _quote(context.namespace) namespace_segment = f"/ns/{namespace}" if namespace else "" - # Build context segments from structured metadata. - # Prefer current workflow context; fall back to target_workflow for - # client-initiated operations where there is no current workflow. + # Build context segments from the target identity. context_segments = "" - wf = context.current_workflow or context.target_workflow - act = context.current_activity or context.target_activity - if wf and wf.id: - wf_type = _quote(wf.type) or "null" - wf_id = _quote(wf.id) - wf_run_id = _quote(wf.run_id) or "null" + target = context.target + if isinstance(target, StorageDriverWorkflowInfo): + wf_type = _quote(target.type) or "null" + wf_id = _quote(target.id) or "null" + wf_run_id = _quote(target.run_id) or "null" context_segments = f"/wt/{wf_type}/wi/{wf_id}/ri/{wf_run_id}" - elif act and act.id: - act_type = _quote(act.type) or "null" - act_id = _quote(act.id) - act_run_id = _quote(act.run_id) or "null" + elif isinstance(target, StorageDriverActivityInfo): + act_type = _quote(target.type) or "null" + act_id = _quote(target.id) or "null" + act_run_id = _quote(target.run_id) or "null" context_segments = f"/at/{act_type}/ai/{act_id}/ri/{act_run_id}" async def _upload(payload: Payload) -> StorageDriverClaim: diff --git a/temporalio/converter/_extstore.py b/temporalio/converter/_extstore.py index 588a1038f..df669d139 100644 --- a/temporalio/converter/_extstore.py +++ b/temporalio/converter/_extstore.py @@ -135,19 +135,8 @@ class StorageDriverStoreMetadata: namespace: str | None = None """The namespace of the current execution context.""" - current_workflow: StorageDriverWorkflowInfo | None = None - """The workflow execution context from which this payload is being stored, if any.""" - - current_activity: StorageDriverActivityInfo | None = None - """The activity execution context from which this payload is being stored, if any. - Set only when running inside an activity worker.""" - - target_workflow: StorageDriverWorkflowInfo | None = None - """The workflow for which this payload is being stored (e.g. child workflow being - started, external workflow being signaled).""" - - target_activity: StorageDriverActivityInfo | None = None - """The activity for which this payload is being stored.""" + target: StorageDriverActivityInfo | StorageDriverWorkflowInfo | None = None + """The workflow or activity for which this payload is being stored.""" _current_store_metadata: contextvars.ContextVar[StorageDriverStoreMetadata | None] = ( @@ -184,19 +173,14 @@ class StorageDriverStoreContext: namespace: str | None = None """The namespace of the current execution context.""" - current_workflow: StorageDriverWorkflowInfo | None = None - """The workflow execution context from which this payload is being stored, if any.""" - - current_activity: StorageDriverActivityInfo | None = None - """The activity execution context from which this payload is being stored, if any. - Set only when running inside an activity worker.""" - - target_workflow: StorageDriverWorkflowInfo | None = None - """The workflow for which this payload is being stored (e.g. child workflow being - started, external workflow being signaled).""" + target: StorageDriverActivityInfo | StorageDriverWorkflowInfo | None = None + """The workflow or activity for which this payload is being stored. - target_activity: StorageDriverActivityInfo | None = None - """The activity for which this payload is being stored.""" + For payloads being stored on behalf of an explicit target (e.g. a child + workflow being started, an activity being scheduled, an external workflow + being signaled), this is that target's identity. When no explicit target + exists the current execution context (workflow or activity) is used as the + target instead.""" @dataclass(frozen=True) @@ -380,10 +364,7 @@ def _build_store_context() -> StorageDriverStoreContext: meta = _current_store_metadata.get() return StorageDriverStoreContext( namespace=meta.namespace if meta else None, - current_workflow=meta.current_workflow if meta else None, - current_activity=meta.current_activity if meta else None, - target_workflow=meta.target_workflow if meta else None, - target_activity=meta.target_activity if meta else None, + target=meta.target if meta else None, ) async def _store_payload(self, payload: Payload) -> Payload: diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index a19895261..857a14656 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -270,19 +270,9 @@ async def _heartbeat_async( ) data_converter = data_converter.with_context(context) - wf_info = ( - StorageDriverWorkflowInfo( - id=activity.info.workflow_id, - run_id=activity.info.workflow_run_id, - type=activity.info.workflow_type, - ) - if activity.info.workflow_id - else None - ) store_metadata = StorageDriverStoreMetadata( namespace=activity.info.namespace, - current_workflow=wf_info, - current_activity=StorageDriverActivityInfo( + target=StorageDriverActivityInfo( id=activity.info.activity_id, type=activity.info.activity_type, run_id=activity.info.activity_run_id, @@ -344,30 +334,24 @@ async def _handle_start_activity_task( # Build store metadata for external storage ns = start.workflow_namespace or self._client.namespace + # Store metadata is set for the full activity task lifetime (input + # decode, execution, result/failure encode). Each activity task runs + # in its own coroutine so the value won't leak to other tasks. started_by_workflow = bool(start.workflow_execution.workflow_id) - wf_info = ( - StorageDriverWorkflowInfo( + store_target: StorageDriverWorkflowInfo | StorageDriverActivityInfo + if started_by_workflow: + store_target = StorageDriverWorkflowInfo( id=start.workflow_execution.workflow_id or None, - run_id=start.workflow_execution.run_id or None, type=start.workflow_type or None, + run_id=start.workflow_execution.run_id or None, ) - if started_by_workflow - else None - ) - act_info = StorageDriverActivityInfo( - id=start.activity_id or None, - type=start.activity_type or None, - run_id=None, - ) - # Store metadata is set for the full activity task lifetime (input - # decode, execution, result/failure encode). Each activity task runs - # in its own coroutine so the value won't leak to other tasks. - with store_metadata_context( - StorageDriverStoreMetadata( - namespace=ns, - current_workflow=wf_info, - current_activity=act_info, + else: + store_target = StorageDriverActivityInfo( + id=start.activity_id or None, + type=start.activity_type or None, ) + with store_metadata_context( + StorageDriverStoreMetadata(namespace=ns, target=store_target) ): try: result = await self._execute_activity( diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 226126a48..c9a1a8193 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -306,7 +306,7 @@ async def _handle_activation( with store_metadata_context( StorageDriverStoreMetadata( namespace=self._namespace, - current_workflow=StorageDriverWorkflowInfo( + target=StorageDriverWorkflowInfo( id=workflow_id, run_id=act.run_id, type=( diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index bb779b25c..c532e114e 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -58,7 +58,7 @@ import temporalio.converter import temporalio.exceptions import temporalio.workflow -from temporalio.converter import StorageDriverActivityInfo, StorageDriverWorkflowInfo +from temporalio.converter import StorageDriverWorkflowInfo from temporalio.converter._extstore import StorageDriverStoreMetadata from temporalio.service import __version__ @@ -2252,27 +2252,12 @@ def get_external_store_metadata( if command_info is None: return StorageDriverStoreMetadata( namespace=ns, - current_workflow=current_wf, + target=current_wf, ) COMMAND_TYPE = temporalio.api.enums.v1.command_type_pb2.CommandType if ( - command_info.command_type - == COMMAND_TYPE.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK - and command_info.command_seq in self._pending_activities - ): - handle = self._pending_activities[command_info.command_seq] - return StorageDriverStoreMetadata( - namespace=ns, - current_workflow=current_wf, - target_activity=StorageDriverActivityInfo( - id=handle._input.activity_id, - type=handle._input.activity, - ), - ) - - elif ( command_info.command_type == COMMAND_TYPE.COMMAND_TYPE_START_CHILD_WORKFLOW_EXECUTION and command_info.command_seq in self._pending_child_workflows @@ -2280,8 +2265,7 @@ def get_external_store_metadata( child = self._pending_child_workflows[command_info.command_seq] return StorageDriverStoreMetadata( namespace=ns, - current_workflow=current_wf, - target_workflow=StorageDriverWorkflowInfo( + target=StorageDriverWorkflowInfo( id=child._input.id, type=child._input.workflow ), ) @@ -2294,14 +2278,13 @@ def get_external_store_metadata( _, target_id = self._pending_external_signals[command_info.command_seq] return StorageDriverStoreMetadata( namespace=ns, - current_workflow=current_wf, - target_workflow=StorageDriverWorkflowInfo(id=target_id), + target=StorageDriverWorkflowInfo(id=target_id), ) else: return StorageDriverStoreMetadata( namespace=ns, - current_workflow=current_wf, + target=current_wf, ) def _instantiate_workflow_object(self) -> Any: diff --git a/tests/contrib/aws/s3driver/test_s3driver.py b/tests/contrib/aws/s3driver/test_s3driver.py index d47831603..57e2072f7 100644 --- a/tests/contrib/aws/s3driver/test_s3driver.py +++ b/tests/contrib/aws/s3driver/test_s3driver.py @@ -52,17 +52,11 @@ def make_payload(value: str = "hello") -> Payload: def make_store_context( namespace: str | None = None, - current_workflow: StorageDriverWorkflowInfo | None = None, - current_activity: StorageDriverActivityInfo | None = None, - target_workflow: StorageDriverWorkflowInfo | None = None, - target_activity: StorageDriverActivityInfo | None = None, + target: StorageDriverActivityInfo | StorageDriverWorkflowInfo | None = None, ) -> StorageDriverStoreContext: return StorageDriverStoreContext( namespace=namespace, - current_workflow=current_workflow, - current_activity=current_activity, - target_workflow=target_workflow, - target_activity=target_activity, + target=target, ) @@ -74,7 +68,7 @@ def make_workflow_context( ) -> StorageDriverStoreContext: return make_store_context( namespace=namespace, - current_workflow=StorageDriverWorkflowInfo( + target=StorageDriverWorkflowInfo( id=workflow_id, type=workflow_type, run_id=run_id ), ) @@ -83,15 +77,14 @@ def make_workflow_context( def make_activity_context( namespace: str = "my-namespace", activity_id: str | None = "my-activity", - workflow_id: str | None = None, activity_type: str | None = None, + run_id: str | None = None, ) -> StorageDriverStoreContext: return make_store_context( namespace=namespace, - current_workflow=( - StorageDriverWorkflowInfo(id=workflow_id) if workflow_id else None + target=StorageDriverActivityInfo( + id=activity_id, type=activity_type, run_id=run_id ), - current_activity=StorageDriverActivityInfo(id=activity_id, type=activity_type), ) @@ -250,35 +243,36 @@ async def test_key_context_workflow_with_type_and_run_id( == f"v0/ns/ns1/wt/MyWorkflow/wi/wf1/ri/run-abc/d/sha256/{expected_hash}" ) - async def test_key_context_workflow_activity( + async def test_key_context_activity( self, driver_client: S3StorageDriverClient ) -> None: - """workflow takes priority over activity in store context.""" + """activity target uses activity key segment.""" driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_activity_context( - namespace="ns1", workflow_id="wf1", activity_id="act1" - ) + ctx = make_activity_context(namespace="ns1", activity_id="act1") [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() assert ( claim.claim_data["key"] - == f"v0/ns/ns1/wt/null/wi/wf1/ri/null/d/sha256/{expected_hash}" + == f"v0/ns/ns1/at/null/ai/act1/ri/null/d/sha256/{expected_hash}" ) - async def test_key_context_standalone_activity( + async def test_key_context_activity_with_type_and_run_id( self, driver_client: S3StorageDriverClient ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() ctx = make_activity_context( - namespace="ns1", activity_id="act1", workflow_id=None + namespace="ns1", + activity_id="act1", + activity_type="MyActivity", + run_id="run-abc", ) [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() assert ( claim.claim_data["key"] - == f"v0/ns/ns1/at/null/ai/act1/ri/null/d/sha256/{expected_hash}" + == f"v0/ns/ns1/at/MyActivity/ai/act1/ri/run-abc/d/sha256/{expected_hash}" ) async def test_key_preserves_case( @@ -323,9 +317,7 @@ async def test_key_urlencodes_activity_id( ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_activity_context( - namespace="ns1", activity_id="act/1#2", workflow_id=None - ) + ctx = make_activity_context(namespace="ns1", activity_id="act/1#2") [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() assert ( @@ -565,7 +557,11 @@ async def test_selector_routes_by_activity_type( def type_selector(ctx: StorageDriverStoreContext, p: Payload) -> str: del p - act = ctx.current_activity or ctx.target_activity + act = ( + ctx.target + if isinstance(ctx.target, StorageDriverActivityInfo) + else None + ) if act and act.type and act.type in type_buckets: return type_buckets[act.type] return BUCKET @@ -575,7 +571,6 @@ def type_selector(ctx: StorageDriverStoreContext, p: Payload) -> str: ctx_a = make_activity_context( namespace="ns1", activity_id="act1", - workflow_id="wf1", activity_type="type-a", ) [claim_a] = await driver.store(ctx_a, [make_payload("payload-a")]) @@ -584,7 +579,6 @@ def type_selector(ctx: StorageDriverStoreContext, p: Payload) -> str: ctx_b = make_activity_context( namespace="ns1", activity_id="act2", - workflow_id="wf1", activity_type="type-b", ) [claim_b] = await driver.store(ctx_b, [make_payload("payload-b")]) diff --git a/tests/contrib/aws/s3driver/test_s3driver_worker.py b/tests/contrib/aws/s3driver/test_s3driver_worker.py index 3fd1f8fd8..af69e76f0 100644 --- a/tests/contrib/aws/s3driver/test_s3driver_worker.py +++ b/tests/contrib/aws/s3driver/test_s3driver_worker.py @@ -190,6 +190,46 @@ async def test_s3_driver_workflow_activity_output_key( assert "/ri/null/" not in keys[0] +async def test_s3_driver_standalone_activity_input_key( + tmprl_client: Client, aioboto3_client: S3Client +) -> None: + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + async with new_worker(tmprl_client, activities=[large_io_activity], task_queue=task_queue): + await tmprl_client.execute_activity( + large_io_activity, + LARGE, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + keys = await _list_keys(aioboto3_client) + # Input and output are the same LARGE bytes, so they deduplicate to one key. + assert len(keys) == 1 + # Keyed under the activity, not a workflow. + assert f"/ns/default/at/large_io_activity/ai/{activity_id}/ri/null/" in keys[0] + assert "/wt/" not in keys[0] + + +async def test_s3_driver_standalone_activity_output_key( + tmprl_client: Client, aioboto3_client: S3Client +) -> None: + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + async with new_worker(tmprl_client, activities=[large_output_activity], task_queue=task_queue): + await tmprl_client.execute_activity( + large_output_activity, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + keys = await _list_keys(aioboto3_client) + # Only the output is large; keyed under the activity. + assert len(keys) == 1 + assert f"/ns/default/at/large_output_activity/ai/{activity_id}/ri/null/" in keys[0] + assert "/wt/" not in keys[0] + + async def test_s3_driver_signal_arg_key( tmprl_client: Client, aioboto3_client: S3Client ) -> None: @@ -275,12 +315,10 @@ async def test_s3_driver_child_workflow_input_key( ) keys = await _list_keys(aioboto3_client) child_workflow_id = f"{workflow_id}-child" - # Child input is the only large payload — stored under the parent's wi/ri. + # Child input is the only large payload — stored under the child's wi/ri. assert len(keys) == 1 - # Keyed under the parent: it is the current execution context when scheduling the child. - assert f"/ns/default/wt/ParentWithChildWorkflow/wi/{workflow_id}/ri/" in keys[0] - # Not keyed under the child: the payload lives in the parent's history, not the child's. - assert f"/wi/{child_workflow_id}/" not in keys[0] + # Keyed under the child: child input is stored in the child's context. + assert f"/ns/default/wt/ChildWorkflow/wi/{child_workflow_id}/ri/" in keys[0] async def test_s3_driver_identified_casing( @@ -386,12 +424,12 @@ async def test_s3_driver_parent_child_independent_key_namespaces( keys = await _list_keys(aioboto3_client) parent_keys = [k for k in keys if f"/wi/{workflow_id}/" in k] child_keys = [k for k in keys if f"/wi/{payment_id}/" in k] - # Parent accumulates 3 keys: client start (ri=null), child input (ri=run_id, - # keyed under parent because parent is scheduling context), and child result - # propagated back to parent (ri=run_id). - assert len(parent_keys) == 3 - # Child accumulates 1 key: its own result (LARGE_2, ri=child_run_id). - assert len(child_keys) == 1 + # Parent accumulates 2 keys: client start (ri=null, LARGE input) and child + # result propagated back to parent as workflow result (ri=run_id, LARGE_2). + assert len(parent_keys) == 2 + # Child accumulates 2 keys: its input from parent (ri=null, LARGE) and its + # own result (ri=child_run_id, LARGE_2). + assert len(child_keys) == 2 async def test_s3_store_failure_surfaces_in_workflow_history( diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py index 6810f7296..c92610bb4 100644 --- a/tests/worker/test_extstore.py +++ b/tests/worker/test_extstore.py @@ -21,9 +21,11 @@ from temporalio.converter import ( ExternalStorage, StorageDriver, + StorageDriverActivityInfo, StorageDriverClaim, StorageDriverRetrieveContext, StorageDriverStoreContext, + StorageDriverWorkflowInfo, StorageWarning, ) from temporalio.exceptions import ActivityError, ApplicationError @@ -972,20 +974,17 @@ async def test_store_metadata_start_workflow(env: WorkflowEnvironment) -> None: client_ctx = driver.store_contexts[0] assert client_ctx.namespace == client.namespace - assert client_ctx.current_workflow is None - assert client_ctx.target_workflow is not None - assert client_ctx.target_workflow.id == workflow_id - assert client_ctx.target_workflow.type == "EchoWorkflow" - assert client_ctx.target_workflow.run_id is None - assert client_ctx.target_activity is None + assert isinstance(client_ctx.target, StorageDriverWorkflowInfo) + assert client_ctx.target.id == workflow_id + assert client_ctx.target.type == "EchoWorkflow" + assert client_ctx.target.run_id is None worker_ctx = driver.store_contexts[1] assert worker_ctx.namespace == client.namespace - assert worker_ctx.current_workflow is not None - assert worker_ctx.current_workflow.id == workflow_id - assert worker_ctx.current_workflow.type == "EchoWorkflow" - assert worker_ctx.current_workflow.run_id is not None - assert worker_ctx.target_workflow is None + assert isinstance(worker_ctx.target, StorageDriverWorkflowInfo) + assert worker_ctx.target.id == workflow_id + assert worker_ctx.target.type == "EchoWorkflow" + assert worker_ctx.target.run_id is not None async def test_store_metadata_signal_with_start(env: WorkflowEnvironment) -> None: @@ -1008,25 +1007,24 @@ async def test_store_metadata_signal_with_start(env: WorkflowEnvironment) -> Non # [0] Workflow input arg input_ctx = driver.store_contexts[0] - assert input_ctx.current_workflow is None - assert input_ctx.target_workflow is not None - assert input_ctx.target_workflow.id == workflow_id - assert input_ctx.target_workflow.type == "SignalWaitWorkflow" - assert input_ctx.target_workflow.run_id is None + assert isinstance(input_ctx.target, StorageDriverWorkflowInfo) + assert input_ctx.target.id == workflow_id + assert input_ctx.target.type == "SignalWaitWorkflow" + assert input_ctx.target.run_id is None # [1] Signal arg signal_ctx = driver.store_contexts[1] - assert signal_ctx.current_workflow is None - assert signal_ctx.target_workflow is not None - assert signal_ctx.target_workflow.id == workflow_id - assert signal_ctx.target_workflow.type == "SignalWaitWorkflow" - assert signal_ctx.target_workflow.run_id is None + assert isinstance(signal_ctx.target, StorageDriverWorkflowInfo) + assert signal_ctx.target.id == workflow_id + assert signal_ctx.target.type == "SignalWaitWorkflow" + assert signal_ctx.target.run_id is None # [2] Workflow result result_ctx = driver.store_contexts[2] - assert result_ctx.current_workflow is not None - assert result_ctx.current_workflow.run_id is not None - assert result_ctx.target_workflow is None + assert isinstance(result_ctx.target, StorageDriverWorkflowInfo) + assert result_ctx.target.id == workflow_id + assert result_ctx.target.type == "SignalWaitWorkflow" + assert result_ctx.target.run_id is not None async def test_store_metadata_signal_workflow(env: WorkflowEnvironment) -> None: @@ -1048,10 +1046,10 @@ async def test_store_metadata_signal_workflow(env: WorkflowEnvironment) -> None: assert len(driver.store_contexts) == 3 signal_ctx = driver.store_contexts[1] - assert signal_ctx.current_workflow is None - assert signal_ctx.target_workflow is not None - assert signal_ctx.target_workflow.id == workflow_id - assert signal_ctx.target_workflow.type is None + assert isinstance(signal_ctx.target, StorageDriverWorkflowInfo) + assert signal_ctx.target.id == workflow_id + assert signal_ctx.target.type is None + assert signal_ctx.target.run_id is None async def test_store_metadata_schedule_action(env: WorkflowEnvironment) -> None: @@ -1077,11 +1075,10 @@ async def test_store_metadata_schedule_action(env: WorkflowEnvironment) -> None: assert len(driver.store_contexts) == 1 ctx = driver.store_contexts[0] assert ctx.namespace == client.namespace - assert ctx.current_workflow is None - assert ctx.target_workflow is not None - assert ctx.target_workflow.id == f"wf-{schedule_id}" - assert ctx.target_workflow.type == "EchoWorkflow" - assert ctx.target_activity is None + assert isinstance(ctx.target, StorageDriverWorkflowInfo) + assert ctx.target.id == f"wf-{schedule_id}" + assert ctx.target.type == "EchoWorkflow" + assert ctx.target.run_id is None finally: try: handle = client.get_schedule_handle(schedule_id) @@ -1091,7 +1088,7 @@ async def test_store_metadata_schedule_action(env: WorkflowEnvironment) -> None: async def test_store_metadata_child_workflow(env: WorkflowEnvironment) -> None: - """External storage should receive the parent as workflow and child as target_workflow.""" + """External storage should receive the child workflow as the target when scheduling.""" client, driver = await _make_tracking_client(env) workflow_id = f"workflow-{uuid.uuid4()}" child_workflow_id = f"{workflow_id}-child" @@ -1112,43 +1109,31 @@ async def test_store_metadata_child_workflow(env: WorkflowEnvironment) -> None: # [0] Client starts parent workflow client_ctx = driver.store_contexts[0] - assert client_ctx.current_workflow is None - assert client_ctx.target_workflow is not None - assert client_ctx.target_workflow.id == workflow_id - assert client_ctx.target_workflow.type == "ChildWorkflowStoreMetadataTestWorkflow" + assert isinstance(client_ctx.target, StorageDriverWorkflowInfo) + assert client_ctx.target.id == workflow_id + assert client_ctx.target.type == "ChildWorkflowStoreMetadataTestWorkflow" + assert client_ctx.target.run_id is None - # [1] Parent schedules child: current_workflow = parent, target_workflow = child + # [1] Parent schedules child: target = child workflow start_child_ctx = driver.store_contexts[1] - assert start_child_ctx.current_workflow is not None - assert start_child_ctx.current_workflow.id == workflow_id - assert ( - start_child_ctx.current_workflow.type - == "ChildWorkflowStoreMetadataTestWorkflow" - ) - assert start_child_ctx.current_workflow.run_id is not None - assert start_child_ctx.target_workflow is not None - assert start_child_ctx.target_workflow.id == child_workflow_id - assert start_child_ctx.target_workflow.type == "EchoWorkflow" - assert start_child_ctx.target_workflow.run_id is None + assert isinstance(start_child_ctx.target, StorageDriverWorkflowInfo) + assert start_child_ctx.target.id == child_workflow_id + assert start_child_ctx.target.type == "EchoWorkflow" + assert start_child_ctx.target.run_id is None - # [2] Child returns result: current_workflow = child, no target + # [2] Child returns result: target = child (current execution) child_result_ctx = driver.store_contexts[2] - assert child_result_ctx.current_workflow is not None - assert child_result_ctx.current_workflow.id == child_workflow_id - assert child_result_ctx.current_workflow.type == "EchoWorkflow" - assert child_result_ctx.current_workflow.run_id is not None - assert child_result_ctx.target_workflow is None + assert isinstance(child_result_ctx.target, StorageDriverWorkflowInfo) + assert child_result_ctx.target.id == child_workflow_id + assert child_result_ctx.target.type == "EchoWorkflow" + assert child_result_ctx.target.run_id is not None - # [3] Parent returns result: current_workflow = parent, no target + # [3] Parent returns result: target = parent (current execution) parent_result_ctx = driver.store_contexts[3] - assert parent_result_ctx.current_workflow is not None - assert parent_result_ctx.current_workflow.id == workflow_id - assert ( - parent_result_ctx.current_workflow.type - == "ChildWorkflowStoreMetadataTestWorkflow" - ) - assert parent_result_ctx.current_workflow.run_id is not None - assert parent_result_ctx.target_workflow is None + assert isinstance(parent_result_ctx.target, StorageDriverWorkflowInfo) + assert parent_result_ctx.target.id == workflow_id + assert parent_result_ctx.target.type == "ChildWorkflowStoreMetadataTestWorkflow" + assert parent_result_ctx.target.run_id is not None # Workflow definitions for gap tests @@ -1206,36 +1191,33 @@ async def test_store_metadata_activity_scheduling(env: WorkflowEnvironment) -> N # [0] Client starts workflow client_ctx = driver.store_contexts[0] - assert client_ctx.current_workflow is None - assert client_ctx.target_workflow is not None - assert client_ctx.target_workflow.id == workflow_id - assert client_ctx.target_workflow.type == "ActivityScheduleMetadataWorkflow" + assert isinstance(client_ctx.target, StorageDriverWorkflowInfo) + assert client_ctx.target.id == workflow_id + assert client_ctx.target.type == "ActivityScheduleMetadataWorkflow" + assert client_ctx.target.run_id is None - # [1] Workflow worker schedules activity: current_workflow + target_activity + # [1] Workflow worker schedules activity: target = activity schedule_ctx = driver.store_contexts[1] assert schedule_ctx.namespace == client.namespace - assert schedule_ctx.current_workflow is not None - assert schedule_ctx.current_workflow.id == workflow_id - assert schedule_ctx.current_workflow.type == "ActivityScheduleMetadataWorkflow" - assert schedule_ctx.target_activity is not None - assert schedule_ctx.target_activity.id == "my-activity-id" - assert schedule_ctx.target_activity.type == "echo_activity" - - # [2] Activity worker completes: current_workflow + current_activity + assert isinstance(schedule_ctx.target, StorageDriverActivityInfo) + assert schedule_ctx.target.id == "my-activity-id" + assert schedule_ctx.target.type == "echo_activity" + assert schedule_ctx.target.run_id is None + + # [2] Activity worker completes: target = activity execute_ctx = driver.store_contexts[2] assert execute_ctx.namespace == client.namespace - assert execute_ctx.current_workflow is not None - assert execute_ctx.current_workflow.id == workflow_id - assert execute_ctx.current_workflow.type == "ActivityScheduleMetadataWorkflow" - assert execute_ctx.current_activity is not None - assert execute_ctx.current_activity.id == "my-activity-id" - assert execute_ctx.current_activity.type == "echo_activity" - - # [3] Workflow returns result + assert isinstance(execute_ctx.target, StorageDriverActivityInfo) + assert execute_ctx.target.id == "my-activity-id" + assert execute_ctx.target.type == "echo_activity" + assert execute_ctx.target.run_id is None + + # [3] Workflow returns result: target = workflow (current execution) result_ctx = driver.store_contexts[3] - assert result_ctx.current_workflow is not None - assert result_ctx.current_workflow.id == workflow_id - assert result_ctx.target_activity is None + assert isinstance(result_ctx.target, StorageDriverWorkflowInfo) + assert result_ctx.target.id == workflow_id + assert result_ctx.target.type == "ActivityScheduleMetadataWorkflow" + assert result_ctx.target.run_id is not None async def test_store_metadata_signal_external_workflow( @@ -1271,44 +1253,44 @@ async def test_store_metadata_signal_external_workflow( # [0] Client starts target workflow (SignalWaitWorkflow) target_start_ctx = driver.store_contexts[0] - assert target_start_ctx.current_workflow is None - assert target_start_ctx.target_workflow is not None - assert target_start_ctx.target_workflow.id == target_workflow_id - assert target_start_ctx.target_workflow.type == "SignalWaitWorkflow" + assert isinstance(target_start_ctx.target, StorageDriverWorkflowInfo) + assert target_start_ctx.target.id == target_workflow_id + assert target_start_ctx.target.type == "SignalWaitWorkflow" + assert target_start_ctx.target.run_id is None # [1] Client starts sender workflow (SignalExternalMetadataWorkflow) sender_start_ctx = driver.store_contexts[1] - assert sender_start_ctx.current_workflow is None - assert sender_start_ctx.target_workflow is not None - assert sender_start_ctx.target_workflow.id == sender_workflow_id - assert sender_start_ctx.target_workflow.type == "SignalExternalMetadataWorkflow" + assert isinstance(sender_start_ctx.target, StorageDriverWorkflowInfo) + assert sender_start_ctx.target.id == sender_workflow_id + assert sender_start_ctx.target.type == "SignalExternalMetadataWorkflow" + assert sender_start_ctx.target.run_id is None - # [2] Sender signals target: current_workflow = sender, target_workflow = target + # [2] Sender signals target: target = the workflow being signaled signal_ctx = driver.store_contexts[2] - assert signal_ctx.current_workflow is not None - assert signal_ctx.current_workflow.id == sender_workflow_id - assert signal_ctx.target_workflow is not None - assert signal_ctx.target_workflow.id == target_workflow_id - assert signal_ctx.target_workflow.type is None - assert signal_ctx.target_workflow.run_id is None + assert isinstance(signal_ctx.target, StorageDriverWorkflowInfo) + assert signal_ctx.target.id == target_workflow_id + assert signal_ctx.target.type is None + assert signal_ctx.target.run_id is None # [3] and [4] are the sender and target workflow completions in some order. # The sender's WFT 2 (after signal resolution) and the target's WFT (after # receiving the signal) are both scheduled by the server at nearly the same # time, so the order of their completions is non-deterministic. completion_ctxs = { - ctx.current_workflow.id: ctx + ctx.target.id: ctx for ctx in driver.store_contexts[3:5] - if ctx.current_workflow is not None + if isinstance(ctx.target, StorageDriverWorkflowInfo) and ctx.target.id } assert sender_workflow_id in completion_ctxs assert target_workflow_id in completion_ctxs sender_result_ctx = completion_ctxs[sender_workflow_id] - assert sender_result_ctx.target_workflow is None + assert isinstance(sender_result_ctx.target, StorageDriverWorkflowInfo) + assert sender_result_ctx.target.run_id is not None target_result_ctx = completion_ctxs[target_workflow_id] - assert target_result_ctx.target_workflow is None + assert isinstance(target_result_ctx.target, StorageDriverWorkflowInfo) + assert target_result_ctx.target.run_id is not None async def test_store_metadata_activity_worker(env: WorkflowEnvironment) -> None: @@ -1330,13 +1312,12 @@ async def test_store_metadata_activity_worker(env: WorkflowEnvironment) -> None: assert len(driver.store_contexts) == 4 - # [2] Activity worker completes: current_workflow has run_id, current_activity set + # [2] Activity worker completes: target = activity execute_ctx = driver.store_contexts[2] - assert execute_ctx.current_workflow is not None - assert execute_ctx.current_workflow.id == workflow_id - assert execute_ctx.current_workflow.run_id is not None - assert execute_ctx.current_activity is not None - assert execute_ctx.current_activity.type == "echo_activity" + assert isinstance(execute_ctx.target, StorageDriverActivityInfo) + assert execute_ctx.target.id == "my-activity-id" + assert execute_ctx.target.type == "echo_activity" + assert execute_ctx.target.run_id is None async def test_store_metadata_decode_activation(env: WorkflowEnvironment) -> None: @@ -1357,15 +1338,15 @@ async def test_store_metadata_decode_activation(env: WorkflowEnvironment) -> Non # [0] Client starts workflow client_ctx = driver.store_contexts[0] assert client_ctx.namespace == client.namespace - assert client_ctx.current_workflow is None - assert client_ctx.target_workflow is not None - assert client_ctx.target_workflow.id == workflow_id - assert client_ctx.target_workflow.type == "EchoWorkflow" + assert isinstance(client_ctx.target, StorageDriverWorkflowInfo) + assert client_ctx.target.id == workflow_id + assert client_ctx.target.type == "EchoWorkflow" + assert client_ctx.target.run_id is None # [1] Workflow returns result worker_ctx = driver.store_contexts[1] assert worker_ctx.namespace == client.namespace - assert worker_ctx.current_workflow is not None - assert worker_ctx.current_workflow.id == workflow_id - assert worker_ctx.current_workflow.type == "EchoWorkflow" - assert worker_ctx.target_workflow is None + assert isinstance(worker_ctx.target, StorageDriverWorkflowInfo) + assert worker_ctx.target.id == workflow_id + assert worker_ctx.target.type == "EchoWorkflow" + assert worker_ctx.target.run_id is not None From 700b73c51e614338cc2f42a88149a60767568f1c Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Mon, 30 Mar 2026 21:43:10 -0700 Subject: [PATCH 05/16] Format --- tests/contrib/aws/s3driver/test_s3driver_worker.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/contrib/aws/s3driver/test_s3driver_worker.py b/tests/contrib/aws/s3driver/test_s3driver_worker.py index af69e76f0..8bd40dc2d 100644 --- a/tests/contrib/aws/s3driver/test_s3driver_worker.py +++ b/tests/contrib/aws/s3driver/test_s3driver_worker.py @@ -195,7 +195,9 @@ async def test_s3_driver_standalone_activity_input_key( ) -> None: activity_id = str(uuid.uuid4()) task_queue = str(uuid.uuid4()) - async with new_worker(tmprl_client, activities=[large_io_activity], task_queue=task_queue): + async with new_worker( + tmprl_client, activities=[large_io_activity], task_queue=task_queue + ): await tmprl_client.execute_activity( large_io_activity, LARGE, @@ -216,7 +218,9 @@ async def test_s3_driver_standalone_activity_output_key( ) -> None: activity_id = str(uuid.uuid4()) task_queue = str(uuid.uuid4()) - async with new_worker(tmprl_client, activities=[large_output_activity], task_queue=task_queue): + async with new_worker( + tmprl_client, activities=[large_output_activity], task_queue=task_queue + ): await tmprl_client.execute_activity( large_output_activity, id=activity_id, From 31d6509046a2b50a6bd5b9b76907d2d61b4bdb9d Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 31 Mar 2026 09:21:48 -0700 Subject: [PATCH 06/16] Fix workflow activity target and remote redundant tests --- tests/worker/test_extstore.py | 106 ++++++++++++++++------------------ 1 file changed, 50 insertions(+), 56 deletions(-) diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py index c92610bb4..7752fa2fa 100644 --- a/tests/worker/test_extstore.py +++ b/tests/worker/test_extstore.py @@ -972,6 +972,7 @@ async def test_store_metadata_start_workflow(env: WorkflowEnvironment) -> None: assert len(driver.store_contexts) == 2 + # [0] Workflow input arg client_ctx = driver.store_contexts[0] assert client_ctx.namespace == client.namespace assert isinstance(client_ctx.target, StorageDriverWorkflowInfo) @@ -979,6 +980,7 @@ async def test_store_metadata_start_workflow(env: WorkflowEnvironment) -> None: assert client_ctx.target.type == "EchoWorkflow" assert client_ctx.target.run_id is None + # [1] Workflow result worker_ctx = driver.store_contexts[1] assert worker_ctx.namespace == client.namespace assert isinstance(worker_ctx.target, StorageDriverWorkflowInfo) @@ -1045,12 +1047,27 @@ async def test_store_metadata_signal_workflow(env: WorkflowEnvironment) -> None: assert len(driver.store_contexts) == 3 + # [0] Client starts workflow + start_ctx = driver.store_contexts[0] + assert isinstance(start_ctx.target, StorageDriverWorkflowInfo) + assert start_ctx.target.id == workflow_id + assert start_ctx.target.type == "SignalWaitWorkflow" + assert start_ctx.target.run_id is None + + # [1] Client sends signal: type and run_id are unknown at signal time signal_ctx = driver.store_contexts[1] assert isinstance(signal_ctx.target, StorageDriverWorkflowInfo) assert signal_ctx.target.id == workflow_id assert signal_ctx.target.type is None assert signal_ctx.target.run_id is None + # [2] Workflow worker returns result + result_ctx = driver.store_contexts[2] + assert isinstance(result_ctx.target, StorageDriverWorkflowInfo) + assert result_ctx.target.id == workflow_id + assert result_ctx.target.type == "SignalWaitWorkflow" + assert result_ctx.target.run_id is not None + async def test_store_metadata_schedule_action(env: WorkflowEnvironment) -> None: """Schedule action _to_proto should set workflow metadata.""" @@ -1073,6 +1090,8 @@ async def test_store_metadata_schedule_action(env: WorkflowEnvironment) -> None: ) assert len(driver.store_contexts) == 1 + + # [0] Client encodes workflow args when creating the schedule action ctx = driver.store_contexts[0] assert ctx.namespace == client.namespace assert isinstance(ctx.target, StorageDriverWorkflowInfo) @@ -1196,23 +1215,23 @@ async def test_store_metadata_activity_scheduling(env: WorkflowEnvironment) -> N assert client_ctx.target.type == "ActivityScheduleMetadataWorkflow" assert client_ctx.target.run_id is None - # [1] Workflow worker schedules activity: target = activity + # [1] Workflow worker schedules activity schedule_ctx = driver.store_contexts[1] assert schedule_ctx.namespace == client.namespace - assert isinstance(schedule_ctx.target, StorageDriverActivityInfo) - assert schedule_ctx.target.id == "my-activity-id" - assert schedule_ctx.target.type == "echo_activity" - assert schedule_ctx.target.run_id is None + assert isinstance(schedule_ctx.target, StorageDriverWorkflowInfo) + assert schedule_ctx.target.id == workflow_id + assert schedule_ctx.target.type == "ActivityScheduleMetadataWorkflow" + assert schedule_ctx.target.run_id is not None - # [2] Activity worker completes: target = activity + # [2] Activity worker completes execute_ctx = driver.store_contexts[2] assert execute_ctx.namespace == client.namespace - assert isinstance(execute_ctx.target, StorageDriverActivityInfo) - assert execute_ctx.target.id == "my-activity-id" - assert execute_ctx.target.type == "echo_activity" - assert execute_ctx.target.run_id is None + assert isinstance(execute_ctx.target, StorageDriverWorkflowInfo) + assert execute_ctx.target.id == workflow_id + assert execute_ctx.target.type == "ActivityScheduleMetadataWorkflow" + assert execute_ctx.target.run_id is not None - # [3] Workflow returns result: target = workflow (current execution) + # [3] Workflow returns result result_ctx = driver.store_contexts[3] assert isinstance(result_ctx.target, StorageDriverWorkflowInfo) assert result_ctx.target.id == workflow_id @@ -1293,60 +1312,35 @@ async def test_store_metadata_signal_external_workflow( assert target_result_ctx.target.run_id is not None -async def test_store_metadata_activity_worker(env: WorkflowEnvironment) -> None: - """Activity worker should set workflow and activity metadata during encode.""" +async def test_store_metadata_standalone_activity(env: WorkflowEnvironment) -> None: + """Standalone activity worker should use StorageDriverActivityInfo as target.""" client, driver = await _make_tracking_client(env) - workflow_id = f"workflow-{uuid.uuid4()}" - - async with new_worker( - client, - ActivityScheduleMetadataWorkflow, - activities=[echo_activity], - ) as worker: - await client.execute_workflow( - ActivityScheduleMetadataWorkflow.run, - "hello", - id=workflow_id, - task_queue=worker.task_queue, - ) - - assert len(driver.store_contexts) == 4 + activity_id = f"activity-{uuid.uuid4()}" - # [2] Activity worker completes: target = activity - execute_ctx = driver.store_contexts[2] - assert isinstance(execute_ctx.target, StorageDriverActivityInfo) - assert execute_ctx.target.id == "my-activity-id" - assert execute_ctx.target.type == "echo_activity" - assert execute_ctx.target.run_id is None - - -async def test_store_metadata_decode_activation(env: WorkflowEnvironment) -> None: - """All store contexts should have namespace set.""" - client, driver = await _make_tracking_client(env) - workflow_id = f"workflow-{uuid.uuid4()}" - - async with new_worker(client, EchoWorkflow) as worker: - await client.execute_workflow( - EchoWorkflow.run, + async with new_worker(client, activities=[echo_activity]) as worker: + await client.execute_activity( + echo_activity, "hello", - id=workflow_id, + id=activity_id, task_queue=worker.task_queue, + schedule_to_close_timeout=timedelta(seconds=30), ) assert len(driver.store_contexts) == 2 - # [0] Client starts workflow client_ctx = driver.store_contexts[0] + # [0] Client schedules standalone activity assert client_ctx.namespace == client.namespace - assert isinstance(client_ctx.target, StorageDriverWorkflowInfo) - assert client_ctx.target.id == workflow_id - assert client_ctx.target.type == "EchoWorkflow" + assert isinstance(client_ctx.target, StorageDriverActivityInfo) + assert client_ctx.target.id == activity_id + assert client_ctx.target.type == "echo_activity" assert client_ctx.target.run_id is None - # [1] Workflow returns result - worker_ctx = driver.store_contexts[1] - assert worker_ctx.namespace == client.namespace - assert isinstance(worker_ctx.target, StorageDriverWorkflowInfo) - assert worker_ctx.target.id == workflow_id - assert worker_ctx.target.type == "EchoWorkflow" - assert worker_ctx.target.run_id is not None + # [1] Activity worker completes: target = activity (no parent workflow) + execute_ctx = driver.store_contexts[1] + assert execute_ctx.namespace == client.namespace + assert isinstance(execute_ctx.target, StorageDriverActivityInfo) + assert execute_ctx.target.id == activity_id + assert execute_ctx.target.type == "echo_activity" + # TODO: Fix after information is provided by sdk-core + assert client_ctx.target.run_id is not None From 9c392c86141321c44c1ff7e93d18354e943a66c4 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 31 Mar 2026 09:58:08 -0700 Subject: [PATCH 07/16] Fix assertion --- tests/worker/test_extstore.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py index 7752fa2fa..823f4699c 100644 --- a/tests/worker/test_extstore.py +++ b/tests/worker/test_extstore.py @@ -1342,5 +1342,4 @@ async def test_store_metadata_standalone_activity(env: WorkflowEnvironment) -> N assert isinstance(execute_ctx.target, StorageDriverActivityInfo) assert execute_ctx.target.id == activity_id assert execute_ctx.target.type == "echo_activity" - # TODO: Fix after information is provided by sdk-core - assert client_ctx.target.run_id is not None + assert execute_ctx.target.run_id is None From d80338bd1fa34c22beec230df14a94a757e44065 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 31 Mar 2026 10:09:11 -0700 Subject: [PATCH 08/16] Child workflows store payload in parent context --- temporalio/worker/_command_aware_visitor.py | 9 +++++++++ temporalio/worker/_workflow_instance.py | 13 +++++++++++++ tests/worker/test_extstore.py | 7 ++++--- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/temporalio/worker/_command_aware_visitor.py b/temporalio/worker/_command_aware_visitor.py index 327f8c68c..f77bea042 100644 --- a/temporalio/worker/_command_aware_visitor.py +++ b/temporalio/worker/_command_aware_visitor.py @@ -17,6 +17,7 @@ ResolveSignalExternalWorkflow, ) from temporalio.bridge.proto.workflow_commands.workflow_commands_pb2 import ( + CompleteWorkflowExecution, ScheduleActivity, ScheduleLocalActivity, ScheduleNexusOperation, @@ -67,6 +68,14 @@ def __init__( ) # Workflow commands with payloads + async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution( + self, fs: VisitorFunctions, o: CompleteWorkflowExecution + ) -> None: + with current_command(CommandType.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION, 0): + await super()._visit_coresdk_workflow_commands_CompleteWorkflowExecution( + fs, o + ) + async def _visit_coresdk_workflow_commands_ScheduleActivity( self, fs: VisitorFunctions, o: ScheduleActivity ) -> None: diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index c532e114e..e289240de 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -2281,6 +2281,19 @@ def get_external_store_metadata( target=StorageDriverWorkflowInfo(id=target_id), ) + elif ( + command_info.command_type + == COMMAND_TYPE.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION + and self._info.parent is not None + ): + return StorageDriverStoreMetadata( + namespace=ns, + target=StorageDriverWorkflowInfo( + id=self._info.parent.workflow_id, + run_id=self._info.parent.run_id, + ), + ) + else: return StorageDriverStoreMetadata( namespace=ns, diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py index 823f4699c..f40a91aef 100644 --- a/tests/worker/test_extstore.py +++ b/tests/worker/test_extstore.py @@ -1140,11 +1140,12 @@ async def test_store_metadata_child_workflow(env: WorkflowEnvironment) -> None: assert start_child_ctx.target.type == "EchoWorkflow" assert start_child_ctx.target.run_id is None - # [2] Child returns result: target = child (current execution) + # [2] Child returns result: target = parent workflow (child results are + # stored in the parent's key space so they remain accessible during replay) child_result_ctx = driver.store_contexts[2] assert isinstance(child_result_ctx.target, StorageDriverWorkflowInfo) - assert child_result_ctx.target.id == child_workflow_id - assert child_result_ctx.target.type == "EchoWorkflow" + assert child_result_ctx.target.id == workflow_id + assert child_result_ctx.target.type is None # ParentInfo does not carry workflow type assert child_result_ctx.target.run_id is not None # [3] Parent returns result: target = parent (current execution) From 3d841758838e40131daefd6935e1aa06b97ca767 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 31 Mar 2026 10:21:46 -0700 Subject: [PATCH 09/16] Format --- tests/worker/test_extstore.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py index f40a91aef..df4e14414 100644 --- a/tests/worker/test_extstore.py +++ b/tests/worker/test_extstore.py @@ -1145,7 +1145,8 @@ async def test_store_metadata_child_workflow(env: WorkflowEnvironment) -> None: child_result_ctx = driver.store_contexts[2] assert isinstance(child_result_ctx.target, StorageDriverWorkflowInfo) assert child_result_ctx.target.id == workflow_id - assert child_result_ctx.target.type is None # ParentInfo does not carry workflow type + # ParentInfo does not carry workflow type + assert child_result_ctx.target.type is None assert child_result_ctx.target.run_id is not None # [3] Parent returns result: target = parent (current execution) From ff30607b94fe748a18e8d28e55a7d50f1f7091a2 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 31 Mar 2026 11:04:57 -0700 Subject: [PATCH 10/16] Fix S3 test --- .../contrib/aws/s3driver/test_s3driver_worker.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/contrib/aws/s3driver/test_s3driver_worker.py b/tests/contrib/aws/s3driver/test_s3driver_worker.py index 8bd40dc2d..b15193f77 100644 --- a/tests/contrib/aws/s3driver/test_s3driver_worker.py +++ b/tests/contrib/aws/s3driver/test_s3driver_worker.py @@ -412,7 +412,8 @@ async def test_s3_driver_parent_child_independent_key_namespaces( ) -> None: """An order fulfillment workflow spawns a child payment processor and passes it a large order payload. Child input is keyed under the parent (it lives in - the parent's history); child output is keyed under the child.""" + the parent's history); child output is keyed under the parent (for lifecycle + resilience — the child result lives in the parent's completion history).""" workflow_id = str(uuid.uuid4()) payment_id = f"{workflow_id}-payment" async with new_worker( @@ -428,12 +429,13 @@ async def test_s3_driver_parent_child_independent_key_namespaces( keys = await _list_keys(aioboto3_client) parent_keys = [k for k in keys if f"/wi/{workflow_id}/" in k] child_keys = [k for k in keys if f"/wi/{payment_id}/" in k] - # Parent accumulates 2 keys: client start (ri=null, LARGE input) and child - # result propagated back to parent as workflow result (ri=run_id, LARGE_2). - assert len(parent_keys) == 2 - # Child accumulates 2 keys: its input from parent (ri=null, LARGE) and its - # own result (ri=child_run_id, LARGE_2). - assert len(child_keys) == 2 + # Parent accumulates 3 keys: + # 1. Client start stored in parent's key space (ri=null) + # 2. Child result stored in parent's key space + # 3. Parent's own workflow result + assert len(parent_keys) == 3 + # Child accumulates 1 key: its input from the parent + assert len(child_keys) == 1 async def test_s3_store_failure_surfaces_in_workflow_history( From 77ae64d744c462688a14899531e3c5f68d32c580 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 31 Mar 2026 11:44:06 -0700 Subject: [PATCH 11/16] Skip standalone activities and schedules in time-skipping environment --- tests/contrib/aws/s3driver/test_s3driver_worker.py | 12 ++++++++++-- tests/worker/test_extstore.py | 6 ++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/contrib/aws/s3driver/test_s3driver_worker.py b/tests/contrib/aws/s3driver/test_s3driver_worker.py index b15193f77..e25be5fbf 100644 --- a/tests/contrib/aws/s3driver/test_s3driver_worker.py +++ b/tests/contrib/aws/s3driver/test_s3driver_worker.py @@ -191,8 +191,12 @@ async def test_s3_driver_workflow_activity_output_key( async def test_s3_driver_standalone_activity_input_key( - tmprl_client: Client, aioboto3_client: S3Client + env: WorkflowEnvironment, tmprl_client: Client, aioboto3_client: S3Client ) -> None: + if env.supports_time_skipping: + pytest.skip( + "Java test server: https://github.com/temporalio/sdk-java/issues/2741" + ) activity_id = str(uuid.uuid4()) task_queue = str(uuid.uuid4()) async with new_worker( @@ -214,8 +218,12 @@ async def test_s3_driver_standalone_activity_input_key( async def test_s3_driver_standalone_activity_output_key( - tmprl_client: Client, aioboto3_client: S3Client + env: WorkflowEnvironment, tmprl_client: Client, aioboto3_client: S3Client ) -> None: + if env.supports_time_skipping: + pytest.skip( + "Java test server: https://github.com/temporalio/sdk-java/issues/2741" + ) activity_id = str(uuid.uuid4()) task_queue = str(uuid.uuid4()) async with new_worker( diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py index df4e14414..62594e74c 100644 --- a/tests/worker/test_extstore.py +++ b/tests/worker/test_extstore.py @@ -1071,6 +1071,8 @@ async def test_store_metadata_signal_workflow(env: WorkflowEnvironment) -> None: async def test_store_metadata_schedule_action(env: WorkflowEnvironment) -> None: """Schedule action _to_proto should set workflow metadata.""" + if env.supports_time_skipping: + pytest.skip("Java test server doesn't support schedules") client, driver = await _make_tracking_client(env) task_queue = str(uuid.uuid4()) schedule_id = f"sched-{uuid.uuid4()}" @@ -1316,6 +1318,10 @@ async def test_store_metadata_signal_external_workflow( async def test_store_metadata_standalone_activity(env: WorkflowEnvironment) -> None: """Standalone activity worker should use StorageDriverActivityInfo as target.""" + if env.supports_time_skipping: + pytest.skip( + "Java test server: https://github.com/temporalio/sdk-java/issues/2741" + ) client, driver = await _make_tracking_client(env) activity_id = f"activity-{uuid.uuid4()}" From b57c6ecf5211923b9c01c4b530e2dd8947488133 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Wed, 1 Apr 2026 08:33:40 -0700 Subject: [PATCH 12/16] Move namespace to info classes --- temporalio/client.py | 95 ++++++++++----------- temporalio/contrib/aws/s3driver/_driver.py | 5 +- temporalio/converter/_extstore.py | 25 +++--- temporalio/worker/_activity.py | 8 +- temporalio/worker/_workflow.py | 2 +- temporalio/worker/_workflow_instance.py | 19 ++--- tests/contrib/aws/s3driver/test_s3driver.py | 8 +- tests/worker/test_extstore.py | 14 +-- 8 files changed, 83 insertions(+), 93 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 3a2d594b9..851adacbb 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -6169,8 +6169,9 @@ async def _to_proto( priority = self.priority._to_proto() with store_metadata_context( StorageDriverStoreMetadata( - namespace=client.namespace, - target=StorageDriverWorkflowInfo(id=self.id, type=self.workflow), + target=StorageDriverWorkflowInfo( + id=self.id, type=self.workflow, namespace=client.namespace + ), ) ): data_converter = client.data_converter.with_context( @@ -8096,8 +8097,9 @@ async def _build_signal_with_start_workflow_execution_request( assert input.start_signal with store_metadata_context( StorageDriverStoreMetadata( - namespace=self._client.namespace, - target=StorageDriverWorkflowInfo(id=input.id, type=input.workflow), + target=StorageDriverWorkflowInfo( + id=input.id, type=input.workflow, namespace=self._client.namespace + ), ) ): data_converter = self._client.data_converter.with_context( @@ -8133,8 +8135,9 @@ async def _populate_start_workflow_execution_request( ) -> None: with store_metadata_context( StorageDriverStoreMetadata( - namespace=self._client.namespace, - target=StorageDriverWorkflowInfo(id=input.id, type=input.workflow), + target=StorageDriverWorkflowInfo( + id=input.id, type=input.workflow, namespace=self._client.namespace + ), ) ): data_converter = self._client.data_converter.with_context( @@ -8259,9 +8262,10 @@ async def count_workflows( async def query_workflow(self, input: QueryWorkflowInput) -> Any: with store_metadata_context( StorageDriverStoreMetadata( - namespace=self._client.namespace, target=StorageDriverWorkflowInfo( - id=input.id, run_id=input.run_id or None + id=input.id, + run_id=input.run_id or None, + namespace=self._client.namespace, ), ) ): @@ -8325,9 +8329,10 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any: async def signal_workflow(self, input: SignalWorkflowInput) -> None: with store_metadata_context( StorageDriverStoreMetadata( - namespace=self._client.namespace, target=StorageDriverWorkflowInfo( - id=input.id, run_id=input.run_id or None + id=input.id, + run_id=input.run_id or None, + namespace=self._client.namespace, ), ) ): @@ -8358,9 +8363,10 @@ async def signal_workflow(self, input: SignalWorkflowInput) -> None: async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: with store_metadata_context( StorageDriverStoreMetadata( - namespace=self._client.namespace, target=StorageDriverWorkflowInfo( - id=input.id, run_id=input.run_id or None + id=input.id, + run_id=input.run_id or None, + namespace=self._client.namespace, ), ) ): @@ -8425,8 +8431,11 @@ async def _build_start_activity_execution_request( """Build StartActivityExecutionRequest from input.""" with store_metadata_context( StorageDriverStoreMetadata( - namespace=self._client.namespace, - target=StorageDriverActivityInfo(id=input.id, type=input.activity_type), + target=StorageDriverActivityInfo( + id=input.id, + type=input.activity_type, + namespace=self._client.namespace, + ), ) ): data_converter = self._client.data_converter.with_context( @@ -8630,12 +8639,12 @@ async def _build_update_workflow_execution_request( ) -> temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest: with store_metadata_context( StorageDriverStoreMetadata( - namespace=self._client.namespace, target=StorageDriverWorkflowInfo( id=workflow_id, run_id=(input.run_id or None) if isinstance(input, StartWorkflowUpdateInput) else None, + namespace=self._client.namespace, ), ) ): @@ -8826,23 +8835,21 @@ def _get_async_activity_store_metadata( if isinstance(id_or_token, AsyncActivityIDReference): if id_or_token.workflow_id: return StorageDriverStoreMetadata( - namespace=self._client.namespace, target=StorageDriverWorkflowInfo( id=id_or_token.workflow_id or None, run_id=id_or_token.run_id or None, + namespace=self._client.namespace, ), ) return StorageDriverStoreMetadata( - namespace=self._client.namespace, target=StorageDriverActivityInfo( id=id_or_token.activity_id, run_id=id_or_token.run_id or None, + namespace=self._client.namespace, ), ) else: - return StorageDriverStoreMetadata( - namespace=self._client.namespace, - ) + return StorageDriverStoreMetadata() async def heartbeat_async_activity( self, input: HeartbeatAsyncActivityInput @@ -9065,35 +9072,27 @@ async def create_schedule(self, input: CreateScheduleInput) -> ScheduleHandle: else None, ) try: - # Set namespace-level store metadata as a baseline for schedule - # encoding. The schedule action's _to_proto will override with - # workflow-specific metadata for its own encoding. - with store_metadata_context( - StorageDriverStoreMetadata( - namespace=self._client.namespace, - ) - ): - request = temporalio.api.workflowservice.v1.CreateScheduleRequest( - namespace=self._client.namespace, - schedule_id=input.id, - schedule=await input.schedule._to_proto(self._client), - initial_patch=initial_patch, - identity=self._client.identity, - request_id=str(uuid.uuid4()), - memo=await self._client.data_converter._encode_memo(input.memo) - if input.memo - else None, - ) - if input.search_attributes: - temporalio.converter.encode_search_attributes( - input.search_attributes, request.search_attributes - ) - await self._client.workflow_service.create_schedule( - request, - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, + request = temporalio.api.workflowservice.v1.CreateScheduleRequest( + namespace=self._client.namespace, + schedule_id=input.id, + schedule=await input.schedule._to_proto(self._client), + initial_patch=initial_patch, + identity=self._client.identity, + request_id=str(uuid.uuid4()), + memo=await self._client.data_converter._encode_memo(input.memo) + if input.memo + else None, + ) + if input.search_attributes: + temporalio.converter.encode_search_attributes( + input.search_attributes, request.search_attributes ) + await self._client.workflow_service.create_schedule( + request, + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) except RPCError as err: already_started = ( err.status == RPCStatusCode.ALREADY_EXISTS diff --git a/temporalio/contrib/aws/s3driver/_driver.py b/temporalio/contrib/aws/s3driver/_driver.py index a28e91503..1f9d129c9 100644 --- a/temporalio/contrib/aws/s3driver/_driver.py +++ b/temporalio/contrib/aws/s3driver/_driver.py @@ -117,12 +117,11 @@ async def store( def _quote(val: str | None) -> str | None: return urllib.parse.quote(val, safe="") if val else None - namespace = _quote(context.namespace) - namespace_segment = f"/ns/{namespace}" if namespace else "" - # Build context segments from the target identity. context_segments = "" target = context.target + namespace = _quote(target.namespace) if target is not None else None + namespace_segment = f"/ns/{namespace}" if namespace else "" if isinstance(target, StorageDriverWorkflowInfo): wf_type = _quote(target.type) or "null" wf_id = _quote(target.id) or "null" diff --git a/temporalio/converter/_extstore.py b/temporalio/converter/_extstore.py index df669d139..f2ac0abb3 100644 --- a/temporalio/converter/_extstore.py +++ b/temporalio/converter/_extstore.py @@ -88,7 +88,7 @@ class StorageDriverClaim: """ -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class StorageDriverWorkflowInfo: """Workflow identity information for external storage operations. @@ -96,6 +96,9 @@ class StorageDriverWorkflowInfo: This API is experimental. """ + namespace: str + """The namespace of the workflow execution.""" + id: str | None = None """The workflow ID.""" @@ -106,7 +109,7 @@ class StorageDriverWorkflowInfo: """The workflow type name, if available.""" -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class StorageDriverActivityInfo: """Activity identity information for external storage operations. @@ -114,6 +117,9 @@ class StorageDriverActivityInfo: This API is experimental. """ + namespace: str + """The namespace of the activity execution.""" + id: str | None = None """The activity ID.""" @@ -124,7 +130,7 @@ class StorageDriverActivityInfo: """The activity type name, if available.""" -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class StorageDriverStoreMetadata: """Store-only metadata available during external storage operations. @@ -132,9 +138,6 @@ class StorageDriverStoreMetadata: This API is experimental. """ - namespace: str | None = None - """The namespace of the current execution context.""" - target: StorageDriverActivityInfo | StorageDriverWorkflowInfo | None = None """The workflow or activity for which this payload is being stored.""" @@ -170,9 +173,6 @@ class StorageDriverStoreContext: This API is experimental. """ - namespace: str | None = None - """The namespace of the current execution context.""" - target: StorageDriverActivityInfo | StorageDriverWorkflowInfo | None = None """The workflow or activity for which this payload is being stored. @@ -180,7 +180,11 @@ class StorageDriverStoreContext: workflow being started, an activity being scheduled, an external workflow being signaled), this is that target's identity. When no explicit target exists the current execution context (workflow or activity) is used as the - target instead.""" + target instead. + + The :attr:`StorageDriverWorkflowInfo.namespace` or + :attr:`StorageDriverActivityInfo.namespace` field on the target carries the + namespace for the execution, when available.""" @dataclass(frozen=True) @@ -363,7 +367,6 @@ def _get_driver_by_name(self, name: str) -> StorageDriver: def _build_store_context() -> StorageDriverStoreContext: meta = _current_store_metadata.get() return StorageDriverStoreContext( - namespace=meta.namespace if meta else None, target=meta.target if meta else None, ) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 857a14656..9cdf89823 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -271,11 +271,11 @@ async def _heartbeat_async( data_converter = data_converter.with_context(context) store_metadata = StorageDriverStoreMetadata( - namespace=activity.info.namespace, target=StorageDriverActivityInfo( id=activity.info.activity_id, type=activity.info.activity_type, run_id=activity.info.activity_run_id, + namespace=activity.info.namespace, ), ) @@ -344,15 +344,15 @@ async def _handle_start_activity_task( id=start.workflow_execution.workflow_id or None, type=start.workflow_type or None, run_id=start.workflow_execution.run_id or None, + namespace=ns, ) else: store_target = StorageDriverActivityInfo( id=start.activity_id or None, type=start.activity_type or None, + namespace=ns, ) - with store_metadata_context( - StorageDriverStoreMetadata(namespace=ns, target=store_target) - ): + with store_metadata_context(StorageDriverStoreMetadata(target=store_target)): try: result = await self._execute_activity( start, running_activity, task_token, data_converter diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index c9a1a8193..d5145971b 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -305,7 +305,6 @@ async def _handle_activation( # Set default store metadata for decode_activation with store_metadata_context( StorageDriverStoreMetadata( - namespace=self._namespace, target=StorageDriverWorkflowInfo( id=workflow_id, run_id=act.run_id, @@ -314,6 +313,7 @@ async def _handle_activation( if workflow else (init_job.workflow_type if init_job else None) ), + namespace=self._namespace, ), ) ): diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index e289240de..f6a7ee849 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -2247,13 +2247,11 @@ def get_external_store_metadata( id=self._info.workflow_id, run_id=self._info.run_id, type=self._info.workflow_type, + namespace=ns, ) if command_info is None: - return StorageDriverStoreMetadata( - namespace=ns, - target=current_wf, - ) + return StorageDriverStoreMetadata(target=current_wf) COMMAND_TYPE = temporalio.api.enums.v1.command_type_pb2.CommandType @@ -2264,9 +2262,8 @@ def get_external_store_metadata( ): child = self._pending_child_workflows[command_info.command_seq] return StorageDriverStoreMetadata( - namespace=ns, target=StorageDriverWorkflowInfo( - id=child._input.id, type=child._input.workflow + id=child._input.id, type=child._input.workflow, namespace=ns ), ) @@ -2277,8 +2274,7 @@ def get_external_store_metadata( ): _, target_id = self._pending_external_signals[command_info.command_seq] return StorageDriverStoreMetadata( - namespace=ns, - target=StorageDriverWorkflowInfo(id=target_id), + target=StorageDriverWorkflowInfo(id=target_id, namespace=ns), ) elif ( @@ -2287,18 +2283,15 @@ def get_external_store_metadata( and self._info.parent is not None ): return StorageDriverStoreMetadata( - namespace=ns, target=StorageDriverWorkflowInfo( id=self._info.parent.workflow_id, run_id=self._info.parent.run_id, + namespace=ns, ), ) else: - return StorageDriverStoreMetadata( - namespace=ns, - target=current_wf, - ) + return StorageDriverStoreMetadata(target=current_wf) def _instantiate_workflow_object(self) -> Any: if not self._workflow_input: diff --git a/tests/contrib/aws/s3driver/test_s3driver.py b/tests/contrib/aws/s3driver/test_s3driver.py index 57e2072f7..c389fe07c 100644 --- a/tests/contrib/aws/s3driver/test_s3driver.py +++ b/tests/contrib/aws/s3driver/test_s3driver.py @@ -51,11 +51,9 @@ def make_payload(value: str = "hello") -> Payload: def make_store_context( - namespace: str | None = None, target: StorageDriverActivityInfo | StorageDriverWorkflowInfo | None = None, ) -> StorageDriverStoreContext: return StorageDriverStoreContext( - namespace=namespace, target=target, ) @@ -67,9 +65,8 @@ def make_workflow_context( run_id: str | None = None, ) -> StorageDriverStoreContext: return make_store_context( - namespace=namespace, target=StorageDriverWorkflowInfo( - id=workflow_id, type=workflow_type, run_id=run_id + id=workflow_id, type=workflow_type, run_id=run_id, namespace=namespace ), ) @@ -81,9 +78,8 @@ def make_activity_context( run_id: str | None = None, ) -> StorageDriverStoreContext: return make_store_context( - namespace=namespace, target=StorageDriverActivityInfo( - id=activity_id, type=activity_type, run_id=run_id + id=activity_id, type=activity_type, run_id=run_id, namespace=namespace ), ) diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py index 62594e74c..1c2a97725 100644 --- a/tests/worker/test_extstore.py +++ b/tests/worker/test_extstore.py @@ -974,16 +974,16 @@ async def test_store_metadata_start_workflow(env: WorkflowEnvironment) -> None: # [0] Workflow input arg client_ctx = driver.store_contexts[0] - assert client_ctx.namespace == client.namespace assert isinstance(client_ctx.target, StorageDriverWorkflowInfo) + assert client_ctx.target.namespace == client.namespace assert client_ctx.target.id == workflow_id assert client_ctx.target.type == "EchoWorkflow" assert client_ctx.target.run_id is None # [1] Workflow result worker_ctx = driver.store_contexts[1] - assert worker_ctx.namespace == client.namespace assert isinstance(worker_ctx.target, StorageDriverWorkflowInfo) + assert worker_ctx.target.namespace == client.namespace assert worker_ctx.target.id == workflow_id assert worker_ctx.target.type == "EchoWorkflow" assert worker_ctx.target.run_id is not None @@ -1095,8 +1095,8 @@ async def test_store_metadata_schedule_action(env: WorkflowEnvironment) -> None: # [0] Client encodes workflow args when creating the schedule action ctx = driver.store_contexts[0] - assert ctx.namespace == client.namespace assert isinstance(ctx.target, StorageDriverWorkflowInfo) + assert ctx.target.namespace == client.namespace assert ctx.target.id == f"wf-{schedule_id}" assert ctx.target.type == "EchoWorkflow" assert ctx.target.run_id is None @@ -1221,16 +1221,16 @@ async def test_store_metadata_activity_scheduling(env: WorkflowEnvironment) -> N # [1] Workflow worker schedules activity schedule_ctx = driver.store_contexts[1] - assert schedule_ctx.namespace == client.namespace assert isinstance(schedule_ctx.target, StorageDriverWorkflowInfo) + assert schedule_ctx.target.namespace == client.namespace assert schedule_ctx.target.id == workflow_id assert schedule_ctx.target.type == "ActivityScheduleMetadataWorkflow" assert schedule_ctx.target.run_id is not None # [2] Activity worker completes execute_ctx = driver.store_contexts[2] - assert execute_ctx.namespace == client.namespace assert isinstance(execute_ctx.target, StorageDriverWorkflowInfo) + assert execute_ctx.target.namespace == client.namespace assert execute_ctx.target.id == workflow_id assert execute_ctx.target.type == "ActivityScheduleMetadataWorkflow" assert execute_ctx.target.run_id is not None @@ -1338,16 +1338,16 @@ async def test_store_metadata_standalone_activity(env: WorkflowEnvironment) -> N client_ctx = driver.store_contexts[0] # [0] Client schedules standalone activity - assert client_ctx.namespace == client.namespace assert isinstance(client_ctx.target, StorageDriverActivityInfo) + assert client_ctx.target.namespace == client.namespace assert client_ctx.target.id == activity_id assert client_ctx.target.type == "echo_activity" assert client_ctx.target.run_id is None # [1] Activity worker completes: target = activity (no parent workflow) execute_ctx = driver.store_contexts[1] - assert execute_ctx.namespace == client.namespace assert isinstance(execute_ctx.target, StorageDriverActivityInfo) + assert execute_ctx.target.namespace == client.namespace assert execute_ctx.target.id == activity_id assert execute_ctx.target.type == "echo_activity" assert execute_ctx.target.run_id is None From a6c5ca2769f2434fe8438ceb597adde937b633ac Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Wed, 1 Apr 2026 10:33:53 -0700 Subject: [PATCH 13/16] Replace context var with context transform methods --- temporalio/client.py | 1027 ++++++++--------- temporalio/converter/_data_converter.py | 20 + temporalio/converter/_extstore.py | 64 +- temporalio/worker/_activity.py | 305 ++--- temporalio/worker/_workflow.py | 59 +- temporalio/worker/_workflow_instance.py | 25 +- .../worker/workflow_sandbox/_in_sandbox.py | 8 +- temporalio/worker/workflow_sandbox/_runner.py | 12 +- tests/worker/test_workflow.py | 6 +- 9 files changed, 737 insertions(+), 789 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 851adacbb..37a615e22 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -67,14 +67,11 @@ DataConverter, SerializationContext, StorageDriverActivityInfo, + StorageDriverStoreContext, StorageDriverWorkflowInfo, WithSerializationContext, WorkflowSerializationContext, ) -from temporalio.converter._extstore import ( - StorageDriverStoreMetadata, - store_metadata_context, -) from temporalio.service import ( ConnectConfig, HttpConnectProxyConfig, @@ -6167,79 +6164,77 @@ async def _to_proto( priority: temporalio.api.common.v1.Priority | None = None if self.priority: priority = self.priority._to_proto() - with store_metadata_context( - StorageDriverStoreMetadata( + data_converter = client.data_converter._with_contexts( + WorkflowSerializationContext( + namespace=client.namespace, + workflow_id=self.id, + ), + StorageDriverStoreContext( target=StorageDriverWorkflowInfo( id=self.id, type=self.workflow, namespace=client.namespace ), + ), + ) + action = temporalio.api.schedule.v1.ScheduleAction( + start_workflow=temporalio.api.workflow.v1.NewWorkflowExecutionInfo( + workflow_id=self.id, + workflow_type=temporalio.api.common.v1.WorkflowType( + name=self.workflow + ), + task_queue=temporalio.api.taskqueue.v1.TaskQueue( + name=self.task_queue + ), + input=( + temporalio.api.common.v1.Payloads( + payloads=[ + a + if isinstance(a, temporalio.api.common.v1.Payload) + else (await data_converter.encode([a]))[0] + for a in self.args + ] + ) + if self.args + else None + ), + workflow_execution_timeout=execution_timeout, + workflow_run_timeout=run_timeout, + workflow_task_timeout=task_timeout, + retry_policy=retry_policy, + memo=await data_converter._encode_memo(self.memo) + if self.memo + else None, + user_metadata=await _encode_user_metadata( + data_converter, self.static_summary, self.static_details + ), + priority=priority, + ), + ) + # Add any untyped attributes that are not also in the typed set + untyped_not_in_typed = { + k: v + for k, v in self.untyped_search_attributes.items() + if k not in self.typed_search_attributes + } + if untyped_not_in_typed: + temporalio.converter.encode_search_attributes( + untyped_not_in_typed, action.start_workflow.search_attributes ) - ): - data_converter = client.data_converter.with_context( - WorkflowSerializationContext( - namespace=client.namespace, - workflow_id=self.id, - ) + # TODO (dan): confirm whether this be `is not None` + if self.typed_search_attributes: + temporalio.converter.encode_search_attributes( + self.typed_search_attributes, + action.start_workflow.search_attributes, ) - action = temporalio.api.schedule.v1.ScheduleAction( - start_workflow=temporalio.api.workflow.v1.NewWorkflowExecutionInfo( - workflow_id=self.id, - workflow_type=temporalio.api.common.v1.WorkflowType( - name=self.workflow - ), - task_queue=temporalio.api.taskqueue.v1.TaskQueue( - name=self.task_queue - ), - input=( - temporalio.api.common.v1.Payloads( - payloads=[ - a - if isinstance(a, temporalio.api.common.v1.Payload) - else (await data_converter.encode([a]))[0] - for a in self.args - ] - ) - if self.args - else None - ), - workflow_execution_timeout=execution_timeout, - workflow_run_timeout=run_timeout, - workflow_task_timeout=task_timeout, - retry_policy=retry_policy, - memo=await data_converter._encode_memo(self.memo) - if self.memo - else None, - user_metadata=await _encode_user_metadata( - data_converter, self.static_summary, self.static_details - ), - priority=priority, - ), + if self.headers: + await _apply_headers( + self.headers, + action.start_workflow.header.fields, + client.config(active_config=True)["header_codec_behavior"] + == HeaderCodecBehavior.CODEC + and not self._from_raw, + client.data_converter, ) - # Add any untyped attributes that are not also in the typed set - untyped_not_in_typed = { - k: v - for k, v in self.untyped_search_attributes.items() - if k not in self.typed_search_attributes - } - if untyped_not_in_typed: - temporalio.converter.encode_search_attributes( - untyped_not_in_typed, action.start_workflow.search_attributes - ) - # TODO (dan): confirm whether this be `is not None` - if self.typed_search_attributes: - temporalio.converter.encode_search_attributes( - self.typed_search_attributes, - action.start_workflow.search_attributes, - ) - if self.headers: - await _apply_headers( - self.headers, - action.start_workflow.header.fields, - client.config(active_config=True)["header_codec_behavior"] - == HeaderCodecBehavior.CODEC - and not self._from_raw, - client.data_converter, - ) - return action + return action class ScheduleOverlapPolicy(IntEnum): @@ -8095,28 +8090,26 @@ async def _build_signal_with_start_workflow_execution_request( self, input: StartWorkflowInput ) -> temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest: assert input.start_signal - with store_metadata_context( - StorageDriverStoreMetadata( + data_converter = self._client.data_converter._with_contexts( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ), + StorageDriverStoreContext( target=StorageDriverWorkflowInfo( id=input.id, type=input.workflow, namespace=self._client.namespace ), + ), + ) + req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest( + signal_name=input.start_signal + ) + if input.start_signal_args: + req.signal_input.payloads.extend( + await data_converter.encode(input.start_signal_args) ) - ): - data_converter = self._client.data_converter.with_context( - WorkflowSerializationContext( - namespace=self._client.namespace, - workflow_id=input.id, - ) - ) - req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest( - signal_name=input.start_signal - ) - if input.start_signal_args: - req.signal_input.payloads.extend( - await data_converter.encode(input.start_signal_args) - ) - await self._populate_start_workflow_execution_request(req, input) - return req + await self._populate_start_workflow_execution_request(req, input) + return req async def _build_update_with_start_start_workflow_execution_request( self, input: UpdateWithStartStartWorkflowInput @@ -8133,64 +8126,62 @@ async def _populate_start_workflow_execution_request( ), input: StartWorkflowInput | UpdateWithStartStartWorkflowInput, ) -> None: - with store_metadata_context( - StorageDriverStoreMetadata( + data_converter = self._client.data_converter._with_contexts( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ), + StorageDriverStoreContext( target=StorageDriverWorkflowInfo( id=input.id, type=input.workflow, namespace=self._client.namespace ), + ), + ) + req.namespace = self._client.namespace + req.workflow_id = input.id + req.workflow_type.name = input.workflow + req.task_queue.name = input.task_queue + if input.args: + req.input.payloads.extend(await data_converter.encode(input.args)) + if input.execution_timeout is not None: + req.workflow_execution_timeout.FromTimedelta(input.execution_timeout) + if input.run_timeout is not None: + req.workflow_run_timeout.FromTimedelta(input.run_timeout) + if input.task_timeout is not None: + req.workflow_task_timeout.FromTimedelta(input.task_timeout) + req.identity = self._client.identity + req.request_id = str(uuid.uuid4()) + req.workflow_id_reuse_policy = cast( + "temporalio.api.enums.v1.WorkflowIdReusePolicy.ValueType", + int(input.id_reuse_policy), + ) + req.workflow_id_conflict_policy = cast( + "temporalio.api.enums.v1.WorkflowIdConflictPolicy.ValueType", + int(input.id_conflict_policy), + ) + + if input.retry_policy is not None: + input.retry_policy.apply_to_proto(req.retry_policy) + req.cron_schedule = input.cron_schedule + if input.memo is not None: + await data_converter._encode_memo_existing(input.memo, req.memo) + if input.search_attributes is not None: + temporalio.converter.encode_search_attributes( + input.search_attributes, req.search_attributes ) - ): - data_converter = self._client.data_converter.with_context( - WorkflowSerializationContext( - namespace=self._client.namespace, - workflow_id=input.id, - ) - ) - req.namespace = self._client.namespace - req.workflow_id = input.id - req.workflow_type.name = input.workflow - req.task_queue.name = input.task_queue - if input.args: - req.input.payloads.extend(await data_converter.encode(input.args)) - if input.execution_timeout is not None: - req.workflow_execution_timeout.FromTimedelta(input.execution_timeout) - if input.run_timeout is not None: - req.workflow_run_timeout.FromTimedelta(input.run_timeout) - if input.task_timeout is not None: - req.workflow_task_timeout.FromTimedelta(input.task_timeout) - req.identity = self._client.identity - req.request_id = str(uuid.uuid4()) - req.workflow_id_reuse_policy = cast( - "temporalio.api.enums.v1.WorkflowIdReusePolicy.ValueType", - int(input.id_reuse_policy), - ) - req.workflow_id_conflict_policy = cast( - "temporalio.api.enums.v1.WorkflowIdConflictPolicy.ValueType", - int(input.id_conflict_policy), - ) - - if input.retry_policy is not None: - input.retry_policy.apply_to_proto(req.retry_policy) - req.cron_schedule = input.cron_schedule - if input.memo is not None: - await data_converter._encode_memo_existing(input.memo, req.memo) - if input.search_attributes is not None: - temporalio.converter.encode_search_attributes( - input.search_attributes, req.search_attributes - ) - metadata = await _encode_user_metadata( - data_converter, input.static_summary, input.static_details - ) - if metadata is not None: - req.user_metadata.CopyFrom(metadata) - if input.start_delay is not None: - req.workflow_start_delay.FromTimedelta(input.start_delay) - if input.headers is not None: # type:ignore[reportUnnecessaryComparison] - await self._apply_headers(input.headers, req.header.fields) - if input.priority is not None: # type:ignore[reportUnnecessaryComparison] - req.priority.CopyFrom(input.priority._to_proto()) - if input.versioning_override is not None: - req.versioning_override.CopyFrom(input.versioning_override._to_proto()) + metadata = await _encode_user_metadata( + data_converter, input.static_summary, input.static_details + ) + if metadata is not None: + req.user_metadata.CopyFrom(metadata) + if input.start_delay is not None: + req.workflow_start_delay.FromTimedelta(input.start_delay) + if input.headers is not None: # type:ignore[reportUnnecessaryComparison] + await self._apply_headers(input.headers, req.header.fields) + if input.priority is not None: # type:ignore[reportUnnecessaryComparison] + req.priority.CopyFrom(input.priority._to_proto()) + if input.versioning_override is not None: + req.versioning_override.CopyFrom(input.versioning_override._to_proto()) async def cancel_workflow(self, input: CancelWorkflowInput) -> None: await self._client.workflow_service.request_cancel_workflow_execution( @@ -8260,137 +8251,131 @@ async def count_workflows( ) async def query_workflow(self, input: QueryWorkflowInput) -> Any: - with store_metadata_context( - StorageDriverStoreMetadata( + data_converter = self._client.data_converter._with_contexts( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ), + StorageDriverStoreContext( target=StorageDriverWorkflowInfo( id=input.id, run_id=input.run_id or None, namespace=self._client.namespace, ), + ), + ) + req = temporalio.api.workflowservice.v1.QueryWorkflowRequest( + namespace=self._client.namespace, + execution=temporalio.api.common.v1.WorkflowExecution( + workflow_id=input.id, + run_id=input.run_id or "", + ), + ) + if input.reject_condition: + req.query_reject_condition = cast( + "temporalio.api.enums.v1.QueryRejectCondition.ValueType", + int(input.reject_condition), ) - ): - data_converter = self._client.data_converter.with_context( - WorkflowSerializationContext( - namespace=self._client.namespace, - workflow_id=input.id, - ) + req.query.query_type = input.query + if input.args: + req.query.query_args.payloads.extend( + await data_converter.encode(input.args) ) - req = temporalio.api.workflowservice.v1.QueryWorkflowRequest( - namespace=self._client.namespace, - execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=input.id, - run_id=input.run_id or "", - ), + if input.headers is not None: # type:ignore[reportUnnecessaryComparison] + await self._apply_headers(input.headers, req.query.header.fields) + try: + resp = await self._client.workflow_service.query_workflow( + req, + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, ) - if input.reject_condition: - req.query_reject_condition = cast( - "temporalio.api.enums.v1.QueryRejectCondition.ValueType", - int(input.reject_condition), - ) - req.query.query_type = input.query - if input.args: - req.query.query_args.payloads.extend( - await data_converter.encode(input.args) - ) - if input.headers is not None: # type:ignore[reportUnnecessaryComparison] - await self._apply_headers(input.headers, req.query.header.fields) - try: - resp = await self._client.workflow_service.query_workflow( - req, - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, - ) - except RPCError as err: - # If the status is INVALID_ARGUMENT, we can assume it's a query - # failed error - if err.status == RPCStatusCode.INVALID_ARGUMENT: - raise WorkflowQueryFailedError(err.message) - else: - raise - if resp.HasField("query_rejected"): - raise WorkflowQueryRejectedError( - WorkflowExecutionStatus(resp.query_rejected.status) - if resp.query_rejected.status - else None - ) - if not resp.query_result.payloads: - return None - type_hints = [input.ret_type] if input.ret_type else None - results = await data_converter.decode( - resp.query_result.payloads, type_hints + except RPCError as err: + # If the status is INVALID_ARGUMENT, we can assume it's a query + # failed error + if err.status == RPCStatusCode.INVALID_ARGUMENT: + raise WorkflowQueryFailedError(err.message) + else: + raise + if resp.HasField("query_rejected"): + raise WorkflowQueryRejectedError( + WorkflowExecutionStatus(resp.query_rejected.status) + if resp.query_rejected.status + else None ) - if not results: - return None - elif len(results) > 1: - warnings.warn(f"Expected single query result, got {len(results)}") - return results[0] + if not resp.query_result.payloads: + return None + type_hints = [input.ret_type] if input.ret_type else None + results = await data_converter.decode( + resp.query_result.payloads, type_hints + ) + if not results: + return None + elif len(results) > 1: + warnings.warn(f"Expected single query result, got {len(results)}") + return results[0] async def signal_workflow(self, input: SignalWorkflowInput) -> None: - with store_metadata_context( - StorageDriverStoreMetadata( + data_converter = self._client.data_converter._with_contexts( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ), + StorageDriverStoreContext( target=StorageDriverWorkflowInfo( id=input.id, run_id=input.run_id or None, namespace=self._client.namespace, ), - ) - ): - data_converter = self._client.data_converter.with_context( - WorkflowSerializationContext( - namespace=self._client.namespace, - workflow_id=input.id, - ) - ) - req = temporalio.api.workflowservice.v1.SignalWorkflowExecutionRequest( - namespace=self._client.namespace, - workflow_execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=input.id, - run_id=input.run_id or "", - ), - signal_name=input.signal, - identity=self._client.identity, - request_id=str(uuid.uuid4()), - ) - if input.args: - req.input.payloads.extend(await data_converter.encode(input.args)) - if input.headers is not None: # type:ignore[reportUnnecessaryComparison] - await self._apply_headers(input.headers, req.header.fields) - await self._client.workflow_service.signal_workflow_execution( - req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout - ) + ), + ) + req = temporalio.api.workflowservice.v1.SignalWorkflowExecutionRequest( + namespace=self._client.namespace, + workflow_execution=temporalio.api.common.v1.WorkflowExecution( + workflow_id=input.id, + run_id=input.run_id or "", + ), + signal_name=input.signal, + identity=self._client.identity, + request_id=str(uuid.uuid4()), + ) + if input.args: + req.input.payloads.extend(await data_converter.encode(input.args)) + if input.headers is not None: # type:ignore[reportUnnecessaryComparison] + await self._apply_headers(input.headers, req.header.fields) + await self._client.workflow_service.signal_workflow_execution( + req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout + ) async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: - with store_metadata_context( - StorageDriverStoreMetadata( + data_converter = self._client.data_converter._with_contexts( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=input.id, + ), + StorageDriverStoreContext( target=StorageDriverWorkflowInfo( id=input.id, run_id=input.run_id or None, namespace=self._client.namespace, ), - ) - ): - data_converter = self._client.data_converter.with_context( - WorkflowSerializationContext( - namespace=self._client.namespace, - workflow_id=input.id, - ) - ) - req = temporalio.api.workflowservice.v1.TerminateWorkflowExecutionRequest( - namespace=self._client.namespace, - workflow_execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=input.id, - run_id=input.run_id or "", - ), - reason=input.reason or "", - identity=self._client.identity, - first_execution_run_id=input.first_execution_run_id or "", - ) - if input.args: - req.details.payloads.extend(await data_converter.encode(input.args)) - await self._client.workflow_service.terminate_workflow_execution( - req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout - ) + ), + ) + req = temporalio.api.workflowservice.v1.TerminateWorkflowExecutionRequest( + namespace=self._client.namespace, + workflow_execution=temporalio.api.common.v1.WorkflowExecution( + workflow_id=input.id, + run_id=input.run_id or "", + ), + reason=input.reason or "", + identity=self._client.identity, + first_execution_run_id=input.first_execution_run_id or "", + ) + if input.args: + req.details.payloads.extend(await data_converter.encode(input.args)) + await self._client.workflow_service.terminate_workflow_execution( + req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout + ) async def start_activity(self, input: StartActivityInput) -> ActivityHandle[Any]: """Start an activity and return a handle to it.""" @@ -8429,83 +8414,81 @@ async def _build_start_activity_execution_request( self, input: StartActivityInput ) -> temporalio.api.workflowservice.v1.StartActivityExecutionRequest: """Build StartActivityExecutionRequest from input.""" - with store_metadata_context( - StorageDriverStoreMetadata( + data_converter = self._client.data_converter._with_contexts( + ActivitySerializationContext( + namespace=self._client.namespace, + activity_id=input.id, + activity_type=input.activity_type, + activity_task_queue=input.task_queue, + is_local=False, + workflow_id=None, + workflow_type=None, + ), + StorageDriverStoreContext( target=StorageDriverActivityInfo( id=input.id, type=input.activity_type, namespace=self._client.namespace, ), - ) - ): - data_converter = self._client.data_converter.with_context( - ActivitySerializationContext( - namespace=self._client.namespace, - activity_id=input.id, - activity_type=input.activity_type, - activity_task_queue=input.task_queue, - is_local=False, - workflow_id=None, - workflow_type=None, - ) - ) + ), + ) - req = temporalio.api.workflowservice.v1.StartActivityExecutionRequest( - namespace=self._client.namespace, - identity=self._client.identity, - activity_id=input.id, - activity_type=temporalio.api.common.v1.ActivityType( - name=input.activity_type - ), - task_queue=temporalio.api.taskqueue.v1.TaskQueue(name=input.task_queue), - id_reuse_policy=cast( - "temporalio.api.enums.v1.ActivityIdReusePolicy.ValueType", - int(input.id_reuse_policy), - ), - id_conflict_policy=cast( - "temporalio.api.enums.v1.ActivityIdConflictPolicy.ValueType", - int(input.id_conflict_policy), - ), - ) + req = temporalio.api.workflowservice.v1.StartActivityExecutionRequest( + namespace=self._client.namespace, + identity=self._client.identity, + activity_id=input.id, + activity_type=temporalio.api.common.v1.ActivityType( + name=input.activity_type + ), + task_queue=temporalio.api.taskqueue.v1.TaskQueue(name=input.task_queue), + id_reuse_policy=cast( + "temporalio.api.enums.v1.ActivityIdReusePolicy.ValueType", + int(input.id_reuse_policy), + ), + id_conflict_policy=cast( + "temporalio.api.enums.v1.ActivityIdConflictPolicy.ValueType", + int(input.id_conflict_policy), + ), + ) - if input.schedule_to_close_timeout is not None: - req.schedule_to_close_timeout.FromTimedelta( - input.schedule_to_close_timeout - ) - if input.start_to_close_timeout is not None: - req.start_to_close_timeout.FromTimedelta(input.start_to_close_timeout) - if input.schedule_to_start_timeout is not None: - req.schedule_to_start_timeout.FromTimedelta( - input.schedule_to_start_timeout - ) - if input.heartbeat_timeout is not None: - req.heartbeat_timeout.FromTimedelta(input.heartbeat_timeout) - if input.retry_policy is not None: - input.retry_policy.apply_to_proto(req.retry_policy) + if input.schedule_to_close_timeout is not None: + req.schedule_to_close_timeout.FromTimedelta( + input.schedule_to_close_timeout + ) + if input.start_to_close_timeout is not None: + req.start_to_close_timeout.FromTimedelta(input.start_to_close_timeout) + if input.schedule_to_start_timeout is not None: + req.schedule_to_start_timeout.FromTimedelta( + input.schedule_to_start_timeout + ) + if input.heartbeat_timeout is not None: + req.heartbeat_timeout.FromTimedelta(input.heartbeat_timeout) + if input.retry_policy is not None: + input.retry_policy.apply_to_proto(req.retry_policy) - # Set input payloads - if input.args: - req.input.payloads.extend(await data_converter.encode(input.args)) + # Set input payloads + if input.args: + req.input.payloads.extend(await data_converter.encode(input.args)) - # Set search attributes - if input.search_attributes is not None: - temporalio.converter.encode_search_attributes( - input.search_attributes, req.search_attributes - ) + # Set search attributes + if input.search_attributes is not None: + temporalio.converter.encode_search_attributes( + input.search_attributes, req.search_attributes + ) - # Set user metadata - metadata = await _encode_user_metadata(data_converter, input.summary, None) - if metadata is not None: - req.user_metadata.CopyFrom(metadata) + # Set user metadata + metadata = await _encode_user_metadata(data_converter, input.summary, None) + if metadata is not None: + req.user_metadata.CopyFrom(metadata) - # Set headers - if input.headers: - await self._apply_headers(input.headers, req.header.fields) + # Set headers + if input.headers: + await self._apply_headers(input.headers, req.header.fields) - # Set priority - req.priority.CopyFrom(input.priority._to_proto()) + # Set priority + req.priority.CopyFrom(input.priority._to_proto()) - return req + return req async def cancel_activity(self, input: CancelActivityInput) -> None: """Cancel an activity.""" @@ -8637,8 +8620,12 @@ async def _build_update_workflow_execution_request( input: StartWorkflowUpdateInput | UpdateWithStartUpdateWorkflowInput, workflow_id: str, ) -> temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest: - with store_metadata_context( - StorageDriverStoreMetadata( + data_converter = self._client.data_converter._with_contexts( + WorkflowSerializationContext( + namespace=self._client.namespace, + workflow_id=workflow_id, + ), + StorageDriverStoreContext( target=StorageDriverWorkflowInfo( id=workflow_id, run_id=(input.run_id or None) @@ -8646,53 +8633,47 @@ async def _build_update_workflow_execution_request( else None, namespace=self._client.namespace, ), + ), + ) + run_id, first_execution_run_id = ( + ( + input.run_id, + input.first_execution_run_id, ) - ): - data_converter = self._client.data_converter.with_context( - WorkflowSerializationContext( - namespace=self._client.namespace, - workflow_id=workflow_id, - ) - ) - run_id, first_execution_run_id = ( - ( - input.run_id, - input.first_execution_run_id, - ) - if isinstance(input, StartWorkflowUpdateInput) - else (None, None) - ) - req = temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest( - namespace=self._client.namespace, - workflow_execution=temporalio.api.common.v1.WorkflowExecution( - workflow_id=workflow_id, - run_id=run_id or "", - ), - first_execution_run_id=first_execution_run_id or "", - request=temporalio.api.update.v1.Request( - meta=temporalio.api.update.v1.Meta( - update_id=input.update_id or str(uuid.uuid4()), - identity=self._client.identity, - ), - input=temporalio.api.update.v1.Input( - name=input.update, - ), + if isinstance(input, StartWorkflowUpdateInput) + else (None, None) + ) + req = temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest( + namespace=self._client.namespace, + workflow_execution=temporalio.api.common.v1.WorkflowExecution( + workflow_id=workflow_id, + run_id=run_id or "", + ), + first_execution_run_id=first_execution_run_id or "", + request=temporalio.api.update.v1.Request( + meta=temporalio.api.update.v1.Meta( + update_id=input.update_id or str(uuid.uuid4()), + identity=self._client.identity, ), - wait_policy=temporalio.api.update.v1.WaitPolicy( - lifecycle_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType( - input.wait_for_stage - ) + input=temporalio.api.update.v1.Input( + name=input.update, ), - ) - if input.args: - req.request.input.args.payloads.extend( - await data_converter.encode(input.args) - ) - if input.headers is not None: # type:ignore[reportUnnecessaryComparison] - await self._apply_headers( - input.headers, req.request.input.header.fields + ), + wait_policy=temporalio.api.update.v1.WaitPolicy( + lifecycle_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.ValueType( + input.wait_for_stage ) - return req + ), + ) + if input.args: + req.request.input.args.payloads.extend( + await data_converter.encode(input.args) + ) + if input.headers is not None: # type:ignore[reportUnnecessaryComparison] + await self._apply_headers( + input.headers, req.request.input.header.fields + ) + return req async def start_update_with_start_workflow( self, input: StartWorkflowUpdateWithStartInput @@ -8829,19 +8810,19 @@ async def _start_workflow_update_with_start( ### Async activity calls - def _get_async_activity_store_metadata( + def _get_async_activity_store_context( self, id_or_token: AsyncActivityIDReference | bytes - ) -> StorageDriverStoreMetadata: + ) -> StorageDriverStoreContext: if isinstance(id_or_token, AsyncActivityIDReference): if id_or_token.workflow_id: - return StorageDriverStoreMetadata( + return StorageDriverStoreContext( target=StorageDriverWorkflowInfo( id=id_or_token.workflow_id or None, run_id=id_or_token.run_id or None, namespace=self._client.namespace, ), ) - return StorageDriverStoreMetadata( + return StorageDriverStoreContext( target=StorageDriverActivityInfo( id=id_or_token.activity_id, run_id=id_or_token.run_id or None, @@ -8849,193 +8830,181 @@ def _get_async_activity_store_metadata( ), ) else: - return StorageDriverStoreMetadata() + return StorageDriverStoreContext(target=None) async def heartbeat_async_activity( self, input: HeartbeatAsyncActivityInput ) -> None: - with store_metadata_context( - self._get_async_activity_store_metadata(input.id_or_token) - ): - data_converter = ( - input.data_converter_override or self._client.data_converter - ) - details = ( - None - if not input.details - else await data_converter.encode_wrapper(input.details) + data_converter = ( + input.data_converter_override or self._client.data_converter + )._with_store_context(self._get_async_activity_store_context(input.id_or_token)) + details = ( + None + if not input.details + else await data_converter.encode_wrapper(input.details) + ) + if isinstance(input.id_or_token, AsyncActivityIDReference): + resp_by_id = await self._client.workflow_service.record_activity_task_heartbeat_by_id( + temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatByIdRequest( + workflow_id=input.id_or_token.workflow_id or "", + run_id=input.id_or_token.run_id or "", + activity_id=input.id_or_token.activity_id, + namespace=self._client.namespace, + identity=self._client.identity, + details=details, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, ) - if isinstance(input.id_or_token, AsyncActivityIDReference): - resp_by_id = await self._client.workflow_service.record_activity_task_heartbeat_by_id( - temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatByIdRequest( - workflow_id=input.id_or_token.workflow_id or "", - run_id=input.id_or_token.run_id or "", - activity_id=input.id_or_token.activity_id, - namespace=self._client.namespace, - identity=self._client.identity, - details=details, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, - ) - if ( - resp_by_id.cancel_requested - or resp_by_id.activity_paused - or resp_by_id.activity_reset - ): - raise AsyncActivityCancelledError( - details=ActivityCancellationDetails( - cancel_requested=resp_by_id.cancel_requested, - paused=resp_by_id.activity_paused, - reset=resp_by_id.activity_reset, - ) + if ( + resp_by_id.cancel_requested + or resp_by_id.activity_paused + or resp_by_id.activity_reset + ): + raise AsyncActivityCancelledError( + details=ActivityCancellationDetails( + cancel_requested=resp_by_id.cancel_requested, + paused=resp_by_id.activity_paused, + reset=resp_by_id.activity_reset, ) - - else: - resp = await self._client.workflow_service.record_activity_task_heartbeat( - temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatRequest( - task_token=input.id_or_token, - namespace=self._client.namespace, - identity=self._client.identity, - details=details, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, ) - if resp.cancel_requested or resp.activity_paused: - raise AsyncActivityCancelledError( - details=ActivityCancellationDetails( - cancel_requested=resp.cancel_requested, - paused=resp.activity_paused, - reset=resp.activity_reset, - ) + + else: + resp = await self._client.workflow_service.record_activity_task_heartbeat( + temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatRequest( + task_token=input.id_or_token, + namespace=self._client.namespace, + identity=self._client.identity, + details=details, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + if resp.cancel_requested or resp.activity_paused: + raise AsyncActivityCancelledError( + details=ActivityCancellationDetails( + cancel_requested=resp.cancel_requested, + paused=resp.activity_paused, + reset=resp.activity_reset, ) + ) async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None: - with store_metadata_context( - self._get_async_activity_store_metadata(input.id_or_token) - ): - data_converter = ( - input.data_converter_override or self._client.data_converter + data_converter = ( + input.data_converter_override or self._client.data_converter + )._with_store_context(self._get_async_activity_store_context(input.id_or_token)) + result = ( + None + if input.result is temporalio.common._arg_unset + else await data_converter.encode_wrapper([input.result]) + ) + if isinstance(input.id_or_token, AsyncActivityIDReference): + await self._client.workflow_service.respond_activity_task_completed_by_id( + temporalio.api.workflowservice.v1.RespondActivityTaskCompletedByIdRequest( + workflow_id=input.id_or_token.workflow_id or "", + run_id=input.id_or_token.run_id or "", + activity_id=input.id_or_token.activity_id, + namespace=self._client.namespace, + identity=self._client.identity, + result=result, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, ) - result = ( - None - if input.result is temporalio.common._arg_unset - else await data_converter.encode_wrapper([input.result]) + else: + await self._client.workflow_service.respond_activity_task_completed( + temporalio.api.workflowservice.v1.RespondActivityTaskCompletedRequest( + task_token=input.id_or_token, + namespace=self._client.namespace, + identity=self._client.identity, + result=result, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, ) - if isinstance(input.id_or_token, AsyncActivityIDReference): - await self._client.workflow_service.respond_activity_task_completed_by_id( - temporalio.api.workflowservice.v1.RespondActivityTaskCompletedByIdRequest( - workflow_id=input.id_or_token.workflow_id or "", - run_id=input.id_or_token.run_id or "", - activity_id=input.id_or_token.activity_id, - namespace=self._client.namespace, - identity=self._client.identity, - result=result, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, - ) - else: - await self._client.workflow_service.respond_activity_task_completed( - temporalio.api.workflowservice.v1.RespondActivityTaskCompletedRequest( - task_token=input.id_or_token, - namespace=self._client.namespace, - identity=self._client.identity, - result=result, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, - ) async def fail_async_activity(self, input: FailAsyncActivityInput) -> None: - with store_metadata_context( - self._get_async_activity_store_metadata(input.id_or_token) - ): - data_converter = ( - input.data_converter_override or self._client.data_converter + data_converter = ( + input.data_converter_override or self._client.data_converter + )._with_store_context(self._get_async_activity_store_context(input.id_or_token)) + + failure = temporalio.api.failure.v1.Failure() + await data_converter.encode_failure(input.error, failure) + last_heartbeat_details = ( + await data_converter.encode_wrapper(input.last_heartbeat_details) + if input.last_heartbeat_details + else None + ) + if isinstance(input.id_or_token, AsyncActivityIDReference): + await self._client.workflow_service.respond_activity_task_failed_by_id( + temporalio.api.workflowservice.v1.RespondActivityTaskFailedByIdRequest( + workflow_id=input.id_or_token.workflow_id or "", + run_id=input.id_or_token.run_id or "", + activity_id=input.id_or_token.activity_id, + namespace=self._client.namespace, + identity=self._client.identity, + failure=failure, + last_heartbeat_details=last_heartbeat_details, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, ) - - failure = temporalio.api.failure.v1.Failure() - await data_converter.encode_failure(input.error, failure) - last_heartbeat_details = ( - await data_converter.encode_wrapper(input.last_heartbeat_details) - if input.last_heartbeat_details - else None + else: + await self._client.workflow_service.respond_activity_task_failed( + temporalio.api.workflowservice.v1.RespondActivityTaskFailedRequest( + task_token=input.id_or_token, + namespace=self._client.namespace, + identity=self._client.identity, + failure=failure, + last_heartbeat_details=last_heartbeat_details, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, ) - if isinstance(input.id_or_token, AsyncActivityIDReference): - await self._client.workflow_service.respond_activity_task_failed_by_id( - temporalio.api.workflowservice.v1.RespondActivityTaskFailedByIdRequest( - workflow_id=input.id_or_token.workflow_id or "", - run_id=input.id_or_token.run_id or "", - activity_id=input.id_or_token.activity_id, - namespace=self._client.namespace, - identity=self._client.identity, - failure=failure, - last_heartbeat_details=last_heartbeat_details, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, - ) - else: - await self._client.workflow_service.respond_activity_task_failed( - temporalio.api.workflowservice.v1.RespondActivityTaskFailedRequest( - task_token=input.id_or_token, - namespace=self._client.namespace, - identity=self._client.identity, - failure=failure, - last_heartbeat_details=last_heartbeat_details, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, - ) async def report_cancellation_async_activity( self, input: ReportCancellationAsyncActivityInput ) -> None: - with store_metadata_context( - self._get_async_activity_store_metadata(input.id_or_token) - ): - data_converter = ( - input.data_converter_override or self._client.data_converter + data_converter = ( + input.data_converter_override or self._client.data_converter + )._with_store_context(self._get_async_activity_store_context(input.id_or_token)) + details = ( + None + if not input.details + else await data_converter.encode_wrapper(input.details) + ) + if isinstance(input.id_or_token, AsyncActivityIDReference): + await self._client.workflow_service.respond_activity_task_canceled_by_id( + temporalio.api.workflowservice.v1.RespondActivityTaskCanceledByIdRequest( + workflow_id=input.id_or_token.workflow_id or "", + run_id=input.id_or_token.run_id or "", + activity_id=input.id_or_token.activity_id, + namespace=self._client.namespace, + identity=self._client.identity, + details=details, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, ) - details = ( - None - if not input.details - else await data_converter.encode_wrapper(input.details) + else: + await self._client.workflow_service.respond_activity_task_canceled( + temporalio.api.workflowservice.v1.RespondActivityTaskCanceledRequest( + task_token=input.id_or_token, + namespace=self._client.namespace, + identity=self._client.identity, + details=details, + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, ) - if isinstance(input.id_or_token, AsyncActivityIDReference): - await self._client.workflow_service.respond_activity_task_canceled_by_id( - temporalio.api.workflowservice.v1.RespondActivityTaskCanceledByIdRequest( - workflow_id=input.id_or_token.workflow_id or "", - run_id=input.id_or_token.run_id or "", - activity_id=input.id_or_token.activity_id, - namespace=self._client.namespace, - identity=self._client.identity, - details=details, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, - ) - else: - await self._client.workflow_service.respond_activity_task_canceled( - temporalio.api.workflowservice.v1.RespondActivityTaskCanceledRequest( - task_token=input.id_or_token, - namespace=self._client.namespace, - identity=self._client.identity, - details=details, - ), - retry=True, - metadata=input.rpc_metadata, - timeout=input.rpc_timeout, - ) ### Schedule calls diff --git a/temporalio/converter/_data_converter.py b/temporalio/converter/_data_converter.py index 99de876ea..0323466e7 100644 --- a/temporalio/converter/_data_converter.py +++ b/temporalio/converter/_data_converter.py @@ -17,6 +17,7 @@ from temporalio.converter._extstore import ( _REFERENCE_ENCODING, ExternalStorage, + StorageDriverStoreContext, StorageWarning, ) from temporalio.converter._failure_converter import ( @@ -199,6 +200,25 @@ def with_context(self, context: SerializationContext) -> Self: object.__setattr__(cloned, "external_storage", external_storage) return cloned + def _with_store_context( + self, store_ctx: StorageDriverStoreContext + ) -> DataConverter: + """Return an instance with ``store_ctx`` bound into :attr:`external_storage`.""" + if self.external_storage is None: + return self + return dataclasses.replace( + self, + external_storage=self.external_storage._with_store_context(store_ctx), + ) + + def _with_contexts( + self, + serialization_ctx: SerializationContext, + store_ctx: StorageDriverStoreContext, + ) -> DataConverter: + """Return an instance with both serialization and store contexts applied.""" + return self.with_context(serialization_ctx)._with_store_context(store_ctx) + def _with_payload_error_limits( self, limits: _ServerPayloadErrorLimits | None ) -> DataConverter: diff --git a/temporalio/converter/_extstore.py b/temporalio/converter/_extstore.py index f2ac0abb3..068d3a94e 100644 --- a/temporalio/converter/_extstore.py +++ b/temporalio/converter/_extstore.py @@ -130,41 +130,6 @@ class StorageDriverActivityInfo: """The activity type name, if available.""" -@dataclass(frozen=True, kw_only=True) -class StorageDriverStoreMetadata: - """Store-only metadata available during external storage operations. - - .. warning:: - This API is experimental. - """ - - target: StorageDriverActivityInfo | StorageDriverWorkflowInfo | None = None - """The workflow or activity for which this payload is being stored.""" - - -_current_store_metadata: contextvars.ContextVar[StorageDriverStoreMetadata | None] = ( - contextvars.ContextVar("_current_store_metadata", default=None) -) - - -@contextlib.contextmanager -def store_metadata_context( - metadata: StorageDriverStoreMetadata | None, -) -> Generator[None, None, None]: - """Context manager that sets store metadata and resets it on exit. - - If metadata is None, yields without setting anything. - """ - if metadata is None: - yield - return - token = _current_store_metadata.set(metadata) - try: - yield - finally: - _current_store_metadata.reset(token) - - @dataclass(frozen=True) class StorageDriverStoreContext: """Context passed to :meth:`StorageDriver.store` and ``driver_selector`` calls. @@ -304,6 +269,14 @@ class ExternalStorage: for retrieval lookups. """ + _store_context: StorageDriverStoreContext = dataclasses.field( + default=StorageDriverStoreContext(target=None), + init=False, + repr=False, + compare=False, + ) + """Store context bound to this instance via :meth:`_with_store_context`.""" + _claim_converter: ClassVar[JSONPlainPayloadConverter] = JSONPlainPayloadConverter( encoding=_REFERENCE_ENCODING.decode() ) @@ -363,22 +336,20 @@ def _get_driver_by_name(self, name: str) -> StorageDriver: raise ValueError(f"No driver found with name '{name}'") return driver - @staticmethod - def _build_store_context() -> StorageDriverStoreContext: - meta = _current_store_metadata.get() - return StorageDriverStoreContext( - target=meta.target if meta else None, - ) + def _with_store_context(self, ctx: StorageDriverStoreContext) -> ExternalStorage: + """Return a copy of this instance with ``ctx`` bound as the store context.""" + result = dataclasses.replace(self) + object.__setattr__(result, "_store_context", ctx) + return result async def _store_payload(self, payload: Payload) -> Payload: start_time = time.monotonic() - context = self._build_store_context() - driver = self._select_driver(context, payload) + driver = self._select_driver(self._store_context, payload) if driver is None: return payload - claims = await driver.store(context, [payload]) + claims = await driver.store(self._store_context, [payload]) self._validate_claim_length(claims, expected=1, driver=driver) @@ -413,11 +384,10 @@ async def _store_payload_sequence( start_time = time.monotonic() results = list(payloads) - context = self._build_store_context() to_store: list[tuple[int, Payload, StorageDriver]] = [] for index, payload in enumerate(payloads): - driver = self._select_driver(context, payload) + driver = self._select_driver(self._store_context, payload) if driver is None: continue to_store.append((index, payload, driver)) @@ -433,7 +403,7 @@ async def _store_payload_sequence( all_claims = await _gather_cancel_on_error( [ - driver.store(context, [p for _, p in indexed_payloads]) + driver.store(self._store_context, [p for _, p in indexed_payloads]) for driver, indexed_payloads in driver_group_list ] ) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 9cdf89823..3b449dd7b 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -34,10 +34,10 @@ import temporalio.converter import temporalio.converter._payload_limits import temporalio.exceptions -from temporalio.converter import StorageDriverActivityInfo, StorageDriverWorkflowInfo -from temporalio.converter._extstore import ( - StorageDriverStoreMetadata, - store_metadata_context, +from temporalio.converter import ( + StorageDriverActivityInfo, + StorageDriverStoreContext, + StorageDriverWorkflowInfo, ) from ._interceptor import ( @@ -257,7 +257,6 @@ async def _heartbeat_async( return data_converter = self._data_converter - store_metadata: StorageDriverStoreMetadata | None = None if activity.info: context = temporalio.converter.ActivitySerializationContext( namespace=activity.info.namespace, @@ -268,14 +267,15 @@ async def _heartbeat_async( activity_task_queue=self._task_queue, is_local=activity.info.is_local, ) - data_converter = data_converter.with_context(context) - - store_metadata = StorageDriverStoreMetadata( - target=StorageDriverActivityInfo( - id=activity.info.activity_id, - type=activity.info.activity_type, - run_id=activity.info.activity_run_id, - namespace=activity.info.namespace, + data_converter = data_converter._with_contexts( + context, + StorageDriverStoreContext( + target=StorageDriverActivityInfo( + id=activity.info.activity_id, + type=activity.info.activity_type, + run_id=activity.info.activity_run_id, + namespace=activity.info.namespace, + ), ), ) @@ -285,8 +285,7 @@ async def _heartbeat_async( task_token=task_token ) if details: - with store_metadata_context(store_metadata): - heartbeat.details.extend(await data_converter.encode(details)) + heartbeat.details.extend(await data_converter.encode(details)) logger.debug("Recording heartbeat with details %s", details) self._bridge_worker().record_activity_heartbeat(heartbeat) except Exception as err: @@ -352,158 +351,160 @@ async def _handle_start_activity_task( type=start.activity_type or None, namespace=ns, ) - with store_metadata_context(StorageDriverStoreMetadata(target=store_target)): + data_converter = self._data_converter._with_contexts( + context, StorageDriverStoreContext(target=store_target) + ) + try: + result = await self._execute_activity( + start, running_activity, task_token, data_converter + ) + [payload] = await data_converter.encode([result]) + completion.result.completed.result.CopyFrom(payload) + except BaseException as err: try: - result = await self._execute_activity( - start, running_activity, task_token, data_converter - ) - [payload] = await data_converter.encode([result]) - completion.result.completed.result.CopyFrom(payload) - except BaseException as err: try: - try: - if isinstance(err, temporalio.activity._CompleteAsyncError): - temporalio.activity.logger.debug( - "Completing asynchronously" - ) - completion.result.will_complete_async.SetInParent() - elif ( - isinstance( - err, - ( - asyncio.CancelledError, - temporalio.exceptions.CancelledError, - ), - ) - and running_activity.cancelled_due_to_heartbeat_error - ): - err = running_activity.cancelled_due_to_heartbeat_error - temporalio.activity.logger.warning( - f"Completing as failure during heartbeat with error of type {type(err)}: {err}", - ) - await data_converter.encode_failure( - err, completion.result.failed.failure - ) - elif ( - isinstance( - err, - ( - asyncio.CancelledError, - temporalio.exceptions.CancelledError, - ), - ) - and running_activity.cancellation_details.details - and running_activity.cancellation_details.details.paused - ): - temporalio.activity.logger.warning( - "Completing as failure due to unhandled cancel error produced by activity pause", - ) - await data_converter.encode_failure( - temporalio.exceptions.ApplicationError( - type="ActivityPause", - message="Unhandled activity cancel error produced by activity pause", - ), - completion.result.failed.failure, - ) - elif ( - isinstance( - err, - ( - asyncio.CancelledError, - temporalio.exceptions.CancelledError, - ), - ) - and running_activity.cancellation_details.details - and running_activity.cancellation_details.details.reset - ): - temporalio.activity.logger.warning( - "Completing as failure due to unhandled cancel error produced by activity reset", - ) - await data_converter.encode_failure( - temporalio.exceptions.ApplicationError( - type="ActivityReset", - message="Unhandled activity cancel error produced by activity reset", - ), - completion.result.failed.failure, - ) - elif ( + if isinstance(err, temporalio.activity._CompleteAsyncError): + temporalio.activity.logger.debug( + "Completing asynchronously" + ) + completion.result.will_complete_async.SetInParent() + elif ( + isinstance( + err, + ( + asyncio.CancelledError, + temporalio.exceptions.CancelledError, + ), + ) + and running_activity.cancelled_due_to_heartbeat_error + ): + err = running_activity.cancelled_due_to_heartbeat_error + temporalio.activity.logger.warning( + f"Completing as failure during heartbeat with error of type {type(err)}: {err}", + ) + await data_converter.encode_failure( + err, completion.result.failed.failure + ) + elif ( + isinstance( + err, + ( + asyncio.CancelledError, + temporalio.exceptions.CancelledError, + ), + ) + and running_activity.cancellation_details.details + and running_activity.cancellation_details.details.paused + ): + temporalio.activity.logger.warning( + "Completing as failure due to unhandled cancel error produced by activity pause", + ) + await data_converter.encode_failure( + temporalio.exceptions.ApplicationError( + type="ActivityPause", + message="Unhandled activity cancel error produced by activity pause", + ), + completion.result.failed.failure, + ) + elif ( + isinstance( + err, + ( + asyncio.CancelledError, + temporalio.exceptions.CancelledError, + ), + ) + and running_activity.cancellation_details.details + and running_activity.cancellation_details.details.reset + ): + temporalio.activity.logger.warning( + "Completing as failure due to unhandled cancel error produced by activity reset", + ) + await data_converter.encode_failure( + temporalio.exceptions.ApplicationError( + type="ActivityReset", + message="Unhandled activity cancel error produced by activity reset", + ), + completion.result.failed.failure, + ) + elif ( + isinstance( + err, + ( + asyncio.CancelledError, + temporalio.exceptions.CancelledError, + ), + ) + and running_activity.cancelled_by_request + ): + temporalio.activity.logger.debug("Completing as cancelled") + await data_converter.encode_failure( + # TODO(cretz): Should use some other message? + temporalio.exceptions.CancelledError("Cancelled"), + completion.result.cancelled.failure, + ) + elif isinstance( + err, + temporalio.converter._payload_limits._PayloadSizeError, + ): + temporalio.activity.logger.warning( + err.message, + extra={ + "__temporal_error_identifier": "PayloadSizeError" + }, + ) + await data_converter.encode_failure( + err, completion.result.failed.failure + ) + else: + if ( isinstance( err, - ( - asyncio.CancelledError, - temporalio.exceptions.CancelledError, - ), - ) - and running_activity.cancelled_by_request - ): - temporalio.activity.logger.debug("Completing as cancelled") - await data_converter.encode_failure( - # TODO(cretz): Should use some other message? - temporalio.exceptions.CancelledError("Cancelled"), - completion.result.cancelled.failure, + temporalio.exceptions.ApplicationError, ) - elif isinstance( - err, - temporalio.converter._payload_limits._PayloadSizeError, + and err.category + == temporalio.exceptions.ApplicationErrorCategory.BENIGN ): - temporalio.activity.logger.warning( - err.message, + # Downgrade log level to DEBUG for BENIGN application errors. + temporalio.activity.logger.debug( + "Completing activity as failed", + exc_info=True, extra={ - "__temporal_error_identifier": "PayloadSizeError" + "__temporal_error_identifier": "ActivityFailure" }, ) - await data_converter.encode_failure( - err, completion.result.failed.failure - ) else: - if ( - isinstance( - err, - temporalio.exceptions.ApplicationError, - ) - and err.category - == temporalio.exceptions.ApplicationErrorCategory.BENIGN - ): - # Downgrade log level to DEBUG for BENIGN application errors. - temporalio.activity.logger.debug( - "Completing activity as failed", - exc_info=True, - extra={ - "__temporal_error_identifier": "ActivityFailure" - }, - ) - else: - temporalio.activity.logger.warning( - "Completing activity as failed", - exc_info=True, - extra={ - "__temporal_error_identifier": "ActivityFailure" - }, - ) - await data_converter.encode_failure( - err, completion.result.failed.failure + temporalio.activity.logger.warning( + "Completing activity as failed", + exc_info=True, + extra={ + "__temporal_error_identifier": "ActivityFailure" + }, ) - # For broken executors, we have to fail the entire worker - if isinstance(err, concurrent.futures.BrokenExecutor): - self._fail_worker_exception_queue.put_nowait(err) - # Handle PayloadSizeError from attempting to encode failure information - except ( - temporalio.converter._payload_limits._PayloadSizeError - ) as inner_err: - temporalio.activity.logger.exception(inner_err.message) - completion.result.Clear() await data_converter.encode_failure( - inner_err, completion.result.failed.failure + err, completion.result.failed.failure ) - except Exception as inner_err: - temporalio.activity.logger.exception( - f"Exception handling failed, original error: {err}" - ) + # For broken executors, we have to fail the entire worker + if isinstance(err, concurrent.futures.BrokenExecutor): + self._fail_worker_exception_queue.put_nowait(err) + # Handle PayloadSizeError from attempting to encode failure information + except ( + temporalio.converter._payload_limits._PayloadSizeError + ) as inner_err: + temporalio.activity.logger.exception(inner_err.message) completion.result.Clear() - completion.result.failed.failure.message = ( - f"Failed building exception result: {inner_err}" + await data_converter.encode_failure( + inner_err, completion.result.failed.failure ) - completion.result.failed.failure.application_failure_info.SetInParent() + except Exception as inner_err: + temporalio.activity.logger.exception( + f"Exception handling failed, original error: {err}" + ) + completion.result.Clear() + completion.result.failed.failure.message = ( + f"Failed building exception result: {inner_err}" + ) + completion.result.failed.failure.application_failure_info.SetInParent() # Do final completion try: diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index d5145971b..b8de9461c 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -4,14 +4,13 @@ import asyncio import concurrent.futures -import contextlib import dataclasses import logging import os import sys import threading import time -from collections.abc import Awaitable, Callable, Iterator, MutableMapping, Sequence +from collections.abc import Awaitable, Callable, MutableMapping, Sequence from dataclasses import dataclass from datetime import timedelta, timezone from types import TracebackType @@ -29,11 +28,7 @@ import temporalio.workflow from temporalio.api.enums.v1 import WorkflowTaskFailedCause from temporalio.bridge.worker import PollShutdownError -from temporalio.converter import StorageDriverWorkflowInfo -from temporalio.converter._extstore import ( - StorageDriverStoreMetadata, - store_metadata_context, -) +from temporalio.converter import StorageDriverStoreContext, StorageDriverWorkflowInfo from . import _command_aware_visitor from ._interceptor import ( @@ -294,17 +289,9 @@ async def _handle_activation( namespace=self._namespace, workflow_id=workflow_id, ) - data_converter = self._data_converter.with_context(workflow_context) - if workflow: - data_converter = _CommandAwareDataConverter.create( - instance=workflow.instance, - context_free_dc=self._data_converter, - workflow_context_dc=data_converter, - workflow_context=workflow_context, - ) - # Set default store metadata for decode_activation - with store_metadata_context( - StorageDriverStoreMetadata( + data_converter = self._data_converter._with_contexts( + workflow_context, + StorageDriverStoreContext( target=StorageDriverWorkflowInfo( id=workflow_id, run_id=act.run_id, @@ -315,14 +302,21 @@ async def _handle_activation( ), namespace=self._namespace, ), + ), + ) + if workflow: + data_converter = _CommandAwareDataConverter.create( + instance=workflow.instance, + context_free_dc=self._data_converter, + workflow_context_dc=data_converter, + workflow_context=workflow_context, ) - ): - download_metrics = await temporalio.bridge.worker.decode_activation( - act, - data_converter, - decode_headers=self._encode_headers, - storage_concurrency_limit=self._max_workflow_task_external_storage_concurrency, - ) + download_metrics = await temporalio.bridge.worker.decode_activation( + act, + data_converter, + decode_headers=self._encode_headers, + storage_concurrency_limit=self._max_workflow_task_external_storage_concurrency, + ) if not workflow: assert init_job workflow = _RunningWorkflow( @@ -916,13 +910,6 @@ def _get_current_dc(self) -> temporalio.converter.DataConverter: return self._ca_workflow_context_dc return self._ca_context_free_dc.with_context(context) - @contextlib.contextmanager - def _store_metadata_context(self) -> Iterator[None]: - command_info = _command_aware_visitor.current_command_info.get() - metadata = self._ca_instance.get_external_store_metadata(command_info) - with store_metadata_context(metadata): - yield - async def _encode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: @@ -931,10 +918,10 @@ async def _encode_payload_sequence( async def _external_store_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - with self._store_metadata_context(): - return await self._get_current_dc()._external_store_payload_sequence( - payloads - ) + command_info = _command_aware_visitor.current_command_info.get() + store_ctx = self._ca_instance.get_external_store_context(command_info) + dc = self._get_current_dc()._with_store_context(store_ctx) + return await dc._external_store_payload_sequence(payloads) async def _external_retrieve_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index f6a7ee849..1521f24a0 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -58,8 +58,7 @@ import temporalio.converter import temporalio.exceptions import temporalio.workflow -from temporalio.converter import StorageDriverWorkflowInfo -from temporalio.converter._extstore import StorageDriverStoreMetadata +from temporalio.converter import StorageDriverStoreContext, StorageDriverWorkflowInfo from temporalio.service import __version__ from ..api.failure.v1.message_pb2 import Failure @@ -185,17 +184,17 @@ def get_serialization_context( raise NotImplementedError @abstractmethod - def get_external_store_metadata( + def get_external_store_context( self, command_info: _command_aware_visitor.CommandInfo | None, - ) -> StorageDriverStoreMetadata | None: - """Return appropriate store metadata for external storage operations. + ) -> StorageDriverStoreContext: + """Return appropriate store context for external storage operations. Args: command_info: Optional information identifying the associated command. Returns: - The store metadata, or None if no metadata should be set. + The store context associated with the command. """ raise NotImplementedError @@ -2238,10 +2237,10 @@ def get_serialization_context( workflow_id=self._info.workflow_id, ) - def get_external_store_metadata( + def get_external_store_context( self, command_info: _command_aware_visitor.CommandInfo | None, - ) -> StorageDriverStoreMetadata | None: + ) -> StorageDriverStoreContext: ns = self._info.namespace current_wf = StorageDriverWorkflowInfo( id=self._info.workflow_id, @@ -2251,7 +2250,7 @@ def get_external_store_metadata( ) if command_info is None: - return StorageDriverStoreMetadata(target=current_wf) + return StorageDriverStoreContext(target=current_wf) COMMAND_TYPE = temporalio.api.enums.v1.command_type_pb2.CommandType @@ -2261,7 +2260,7 @@ def get_external_store_metadata( and command_info.command_seq in self._pending_child_workflows ): child = self._pending_child_workflows[command_info.command_seq] - return StorageDriverStoreMetadata( + return StorageDriverStoreContext( target=StorageDriverWorkflowInfo( id=child._input.id, type=child._input.workflow, namespace=ns ), @@ -2273,7 +2272,7 @@ def get_external_store_metadata( and command_info.command_seq in self._pending_external_signals ): _, target_id = self._pending_external_signals[command_info.command_seq] - return StorageDriverStoreMetadata( + return StorageDriverStoreContext( target=StorageDriverWorkflowInfo(id=target_id, namespace=ns), ) @@ -2282,7 +2281,7 @@ def get_external_store_metadata( == COMMAND_TYPE.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION and self._info.parent is not None ): - return StorageDriverStoreMetadata( + return StorageDriverStoreContext( target=StorageDriverWorkflowInfo( id=self._info.parent.workflow_id, run_id=self._info.parent.run_id, @@ -2291,7 +2290,7 @@ def get_external_store_metadata( ) else: - return StorageDriverStoreMetadata(target=current_wf) + return StorageDriverStoreContext(target=current_wf) def _instantiate_workflow_object(self) -> Any: if not self._workflow_input: diff --git a/temporalio/worker/workflow_sandbox/_in_sandbox.py b/temporalio/worker/workflow_sandbox/_in_sandbox.py index 44a6fb351..d9415cea5 100644 --- a/temporalio/worker/workflow_sandbox/_in_sandbox.py +++ b/temporalio/worker/workflow_sandbox/_in_sandbox.py @@ -13,7 +13,7 @@ import temporalio.converter import temporalio.worker._workflow_instance import temporalio.workflow -from temporalio.converter._extstore import StorageDriverStoreMetadata +from temporalio.converter._extstore import StorageDriverStoreContext from temporalio.worker import _command_aware_visitor logger = logging.getLogger(__name__) @@ -90,9 +90,9 @@ def get_serialization_context( """Get serialization context.""" return self.instance.get_serialization_context(command_info) - def get_external_store_metadata( + def get_external_store_context( self, command_info: _command_aware_visitor.CommandInfo | None, - ) -> StorageDriverStoreMetadata | None: + ) -> StorageDriverStoreContext: """Get store metadata for external storage.""" - return self.instance.get_external_store_metadata(command_info) + return self.instance.get_external_store_context(command_info) diff --git a/temporalio/worker/workflow_sandbox/_runner.py b/temporalio/worker/workflow_sandbox/_runner.py index 0d943feb1..7605f3054 100644 --- a/temporalio/worker/workflow_sandbox/_runner.py +++ b/temporalio/worker/workflow_sandbox/_runner.py @@ -17,7 +17,7 @@ import temporalio.common import temporalio.converter import temporalio.workflow -from temporalio.converter._extstore import StorageDriverStoreMetadata +from temporalio.converter._extstore import StorageDriverStoreContext from temporalio.worker import _command_aware_visitor from ...api.common.v1.message_pb2 import Payloads @@ -207,19 +207,21 @@ def get_serialization_context( finally: self.importer.restriction_context.is_runtime = False - def get_external_store_metadata( + def get_external_store_context( self, command_info: _command_aware_visitor.CommandInfo | None, - ) -> StorageDriverStoreMetadata | None: + ) -> StorageDriverStoreContext: # Forward call to the sandboxed instance self.importer.restriction_context.is_runtime = True try: self._run_code( "with __temporal_importer.applied():\n" - " __temporal_metadata = __temporal_in_sandbox.get_external_store_metadata(__temporal_command_info)\n", + " __temporal_context = __temporal_in_sandbox.get_external_store_context(__temporal_command_info)\n", __temporal_importer=self.importer, __temporal_command_info=command_info, ) - return self.globals_and_locals.pop("__temporal_metadata", None) # type: ignore + return self.globals_and_locals.pop( + "__temporal_context", StorageDriverStoreContext(target=None) + ) # type: ignore finally: self.importer.restriction_context.is_runtime = False diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 9da3ea4b9..f123e5c61 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -1627,11 +1627,11 @@ def get_serialization_context( ) -> temporalio.converter.SerializationContext | None: return self._unsandboxed.get_serialization_context(command_info) - def get_external_store_metadata( + def get_external_store_context( self, command_info: temporalio.worker._command_aware_visitor.CommandInfo | None, - ) -> temporalio.converter._extstore.StorageDriverStoreMetadata | None: - return self._unsandboxed.get_external_store_metadata(command_info) + ) -> temporalio.converter._extstore.StorageDriverStoreContext: + return self._unsandboxed.get_external_store_context(command_info) async def test_workflow_with_custom_runner(client: Client): From e2881d0ae7f82c60232c245cb6989fab1035b252 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Wed, 1 Apr 2026 10:48:36 -0700 Subject: [PATCH 14/16] Document the store target for workflow commands --- temporalio/client.py | 24 ++++++------------------ temporalio/worker/_activity.py | 8 ++------ temporalio/worker/_workflow_instance.py | 22 +++++++++++++++++----- 3 files changed, 25 insertions(+), 29 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 37a615e22..9e7bc6045 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -6178,12 +6178,8 @@ async def _to_proto( action = temporalio.api.schedule.v1.ScheduleAction( start_workflow=temporalio.api.workflow.v1.NewWorkflowExecutionInfo( workflow_id=self.id, - workflow_type=temporalio.api.common.v1.WorkflowType( - name=self.workflow - ), - task_queue=temporalio.api.taskqueue.v1.TaskQueue( - name=self.task_queue - ), + workflow_type=temporalio.api.common.v1.WorkflowType(name=self.workflow), + task_queue=temporalio.api.taskqueue.v1.TaskQueue(name=self.task_queue), input=( temporalio.api.common.v1.Payloads( payloads=[ @@ -8306,9 +8302,7 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any: if not resp.query_result.payloads: return None type_hints = [input.ret_type] if input.ret_type else None - results = await data_converter.decode( - resp.query_result.payloads, type_hints - ) + results = await data_converter.decode(resp.query_result.payloads, type_hints) if not results: return None elif len(results) > 1: @@ -8452,15 +8446,11 @@ async def _build_start_activity_execution_request( ) if input.schedule_to_close_timeout is not None: - req.schedule_to_close_timeout.FromTimedelta( - input.schedule_to_close_timeout - ) + req.schedule_to_close_timeout.FromTimedelta(input.schedule_to_close_timeout) if input.start_to_close_timeout is not None: req.start_to_close_timeout.FromTimedelta(input.start_to_close_timeout) if input.schedule_to_start_timeout is not None: - req.schedule_to_start_timeout.FromTimedelta( - input.schedule_to_start_timeout - ) + req.schedule_to_start_timeout.FromTimedelta(input.schedule_to_start_timeout) if input.heartbeat_timeout is not None: req.heartbeat_timeout.FromTimedelta(input.heartbeat_timeout) if input.retry_policy is not None: @@ -8670,9 +8660,7 @@ async def _build_update_workflow_execution_request( await data_converter.encode(input.args) ) if input.headers is not None: # type:ignore[reportUnnecessaryComparison] - await self._apply_headers( - input.headers, req.request.input.header.fields - ) + await self._apply_headers(input.headers, req.request.input.header.fields) return req async def start_update_with_start_workflow( diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 3b449dd7b..6979641a6 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -364,9 +364,7 @@ async def _handle_start_activity_task( try: try: if isinstance(err, temporalio.activity._CompleteAsyncError): - temporalio.activity.logger.debug( - "Completing asynchronously" - ) + temporalio.activity.logger.debug("Completing asynchronously") completion.result.will_complete_async.SetInParent() elif ( isinstance( @@ -449,9 +447,7 @@ async def _handle_start_activity_task( ): temporalio.activity.logger.warning( err.message, - extra={ - "__temporal_error_identifier": "PayloadSizeError" - }, + extra={"__temporal_error_identifier": "PayloadSizeError"}, ) await data_converter.encode_failure( err, completion.result.failed.failure diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 1521f24a0..c809b3bd1 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -2241,12 +2241,20 @@ def get_external_store_context( self, command_info: _command_aware_visitor.CommandInfo | None, ) -> StorageDriverStoreContext: - ns = self._info.namespace + # The current workflow is the default target for external store + # operations. For commands that target other workflows, those workflows + # are the target for that command's external store operation. For + # workflow activities, the target is the current workflow since the + # activity is bound to the lifetime of the current workflow, the + # activity run information is the same as the current workflow, and + # successfully completed activities are not involved in replay. + # Otherwise, the storage space for a given workflow would be disparate + # if stored under activity information. current_wf = StorageDriverWorkflowInfo( id=self._info.workflow_id, run_id=self._info.run_id, type=self._info.workflow_type, - namespace=ns, + namespace=self._info.namespace, ) if command_info is None: @@ -2262,7 +2270,9 @@ def get_external_store_context( child = self._pending_child_workflows[command_info.command_seq] return StorageDriverStoreContext( target=StorageDriverWorkflowInfo( - id=child._input.id, type=child._input.workflow, namespace=ns + id=child._input.id, + type=child._input.workflow, + namespace=self._info.namespace, ), ) @@ -2273,7 +2283,9 @@ def get_external_store_context( ): _, target_id = self._pending_external_signals[command_info.command_seq] return StorageDriverStoreContext( - target=StorageDriverWorkflowInfo(id=target_id, namespace=ns), + target=StorageDriverWorkflowInfo( + id=target_id, namespace=self._info.namespace + ), ) elif ( @@ -2285,7 +2297,7 @@ def get_external_store_context( target=StorageDriverWorkflowInfo( id=self._info.parent.workflow_id, run_id=self._info.parent.run_id, - namespace=ns, + namespace=self._info.parent.namespace, ), ) From 98677f62b53ba6db96d5e873ba2dbf381d5a5b25 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Wed, 1 Apr 2026 10:53:52 -0700 Subject: [PATCH 15/16] Update readme --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7cbbfa89c..965537208 100644 --- a/README.md +++ b/README.md @@ -533,8 +533,11 @@ def feature_flag_is_on(workflow_id: str | None) -> bool: def feature_flag_selector( context: temporalio.converter.StorageDriverStoreContext, _payload: Payload ) -> temporalio.converter.StorageDriver | None: - wf = context.current_workflow or context.target_workflow - workflow_id = wf.id if wf else None + workflow_id = ( + context.target.id + if isinstance(context.target, temporalio.converter.StorageDriverWorkflowInfo) + else None + ) return my_driver if feature_flag_is_on(workflow_id) else None options = ExternalStorage( From 019da3ee537933acef54e7e6fd1e594c09581b35 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Wed, 1 Apr 2026 12:24:46 -0700 Subject: [PATCH 16/16] Comment updates --- temporalio/converter/_extstore.py | 6 +----- temporalio/worker/_activity.py | 4 ++-- temporalio/worker/workflow_sandbox/_in_sandbox.py | 2 +- tests/worker/test_extstore.py | 4 ++-- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/temporalio/converter/_extstore.py b/temporalio/converter/_extstore.py index 068d3a94e..44541336c 100644 --- a/temporalio/converter/_extstore.py +++ b/temporalio/converter/_extstore.py @@ -145,11 +145,7 @@ class StorageDriverStoreContext: workflow being started, an activity being scheduled, an external workflow being signaled), this is that target's identity. When no explicit target exists the current execution context (workflow or activity) is used as the - target instead. - - The :attr:`StorageDriverWorkflowInfo.namespace` or - :attr:`StorageDriverActivityInfo.namespace` field on the target carries the - namespace for the execution, when available.""" + target instead.""" @dataclass(frozen=True) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 6979641a6..28cc1458a 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -331,9 +331,9 @@ async def _handle_start_activity_task( ) data_converter = self._data_converter.with_context(context) - # Build store metadata for external storage + # Build store context for external storage ns = start.workflow_namespace or self._client.namespace - # Store metadata is set for the full activity task lifetime (input + # Store context is set for the full activity task lifetime (input # decode, execution, result/failure encode). Each activity task runs # in its own coroutine so the value won't leak to other tasks. started_by_workflow = bool(start.workflow_execution.workflow_id) diff --git a/temporalio/worker/workflow_sandbox/_in_sandbox.py b/temporalio/worker/workflow_sandbox/_in_sandbox.py index d9415cea5..d18374899 100644 --- a/temporalio/worker/workflow_sandbox/_in_sandbox.py +++ b/temporalio/worker/workflow_sandbox/_in_sandbox.py @@ -94,5 +94,5 @@ def get_external_store_context( self, command_info: _command_aware_visitor.CommandInfo | None, ) -> StorageDriverStoreContext: - """Get store metadata for external storage.""" + """Get store context for external storage.""" return self.instance.get_external_store_context(command_info) diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py index 1c2a97725..e714cbf40 100644 --- a/tests/worker/test_extstore.py +++ b/tests/worker/test_extstore.py @@ -958,7 +958,7 @@ async def _make_tracking_client( async def test_store_metadata_start_workflow(env: WorkflowEnvironment) -> None: - """start_workflow should set workflow id and type on store metadata.""" + """start_workflow should set workflow id and type on store context.""" client, driver = await _make_tracking_client(env) workflow_id = str(uuid.uuid4()) @@ -1030,7 +1030,7 @@ async def test_store_metadata_signal_with_start(env: WorkflowEnvironment) -> Non async def test_store_metadata_signal_workflow(env: WorkflowEnvironment) -> None: - """signal_workflow should set workflow id on store metadata.""" + """signal_workflow should set workflow id on store context.""" client, driver = await _make_tracking_client(env) workflow_id = str(uuid.uuid4())