diff --git a/README.md b/README.md index 6fcd6fecb..f1e995ea6 100644 --- a/README.md +++ b/README.md @@ -533,11 +533,11 @@ def feature_flag_is_on(workflow_id: str | None) -> bool: def feature_flag_selector( context: temporalio.converter.StorageDriverStoreContext, _payload: Payload ) -> temporalio.converter.StorageDriver | None: - workflow_id = None - if isinstance(context.serialization_context, temporalio.converter.WorkflowSerializationContext): - workflow_id = context.serialization_context.workflow_id - elif isinstance(context.serialization_context, temporalio.converter.ActivitySerializationContext): - workflow_id = context.serialization_context.workflow_id + workflow_id = ( + context.target.id + if isinstance(context.target, temporalio.converter.StorageDriverWorkflowInfo) + else None + ) return my_driver if feature_flag_is_on(workflow_id) else None options = ExternalStorage( diff --git a/temporalio/client.py b/temporalio/client.py index cc2750ec6..9e7bc6045 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -66,6 +66,9 @@ ActivitySerializationContext, DataConverter, SerializationContext, + StorageDriverActivityInfo, + StorageDriverStoreContext, + StorageDriverWorkflowInfo, WithSerializationContext, WorkflowSerializationContext, ) @@ -6161,11 +6164,16 @@ async def _to_proto( priority: temporalio.api.common.v1.Priority | None = None if self.priority: priority = self.priority._to_proto() - data_converter = client.data_converter.with_context( + data_converter = client.data_converter._with_contexts( WorkflowSerializationContext( namespace=client.namespace, workflow_id=self.id, - ) + ), + StorageDriverStoreContext( + target=StorageDriverWorkflowInfo( + id=self.id, type=self.workflow, namespace=client.namespace + ), + ), ) action = temporalio.api.schedule.v1.ScheduleAction( start_workflow=temporalio.api.workflow.v1.NewWorkflowExecutionInfo( @@ -6210,7 +6218,8 @@ async def _to_proto( # TODO (dan): confirm whether this be `is not None` if self.typed_search_attributes: temporalio.converter.encode_search_attributes( - self.typed_search_attributes, action.start_workflow.search_attributes + self.typed_search_attributes, + action.start_workflow.search_attributes, ) if self.headers: await _apply_headers( @@ -8077,11 +8086,16 @@ async def _build_signal_with_start_workflow_execution_request( self, input: StartWorkflowInput ) -> temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest: assert input.start_signal - data_converter = self._client.data_converter.with_context( + data_converter = self._client.data_converter._with_contexts( WorkflowSerializationContext( namespace=self._client.namespace, workflow_id=input.id, - ) + ), + StorageDriverStoreContext( + target=StorageDriverWorkflowInfo( + id=input.id, type=input.workflow, namespace=self._client.namespace + ), + ), ) req = temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest( signal_name=input.start_signal @@ -8108,11 +8122,16 @@ async def _populate_start_workflow_execution_request( ), input: StartWorkflowInput | UpdateWithStartStartWorkflowInput, ) -> None: - data_converter = self._client.data_converter.with_context( + data_converter = self._client.data_converter._with_contexts( WorkflowSerializationContext( namespace=self._client.namespace, workflow_id=input.id, - ) + ), + StorageDriverStoreContext( + target=StorageDriverWorkflowInfo( + id=input.id, type=input.workflow, namespace=self._client.namespace + ), + ), ) req.namespace = self._client.namespace req.workflow_id = input.id @@ -8228,11 +8247,18 @@ async def count_workflows( ) async def query_workflow(self, input: QueryWorkflowInput) -> Any: - data_converter = self._client.data_converter.with_context( + data_converter = self._client.data_converter._with_contexts( WorkflowSerializationContext( namespace=self._client.namespace, workflow_id=input.id, - ) + ), + StorageDriverStoreContext( + target=StorageDriverWorkflowInfo( + id=input.id, + run_id=input.run_id or None, + namespace=self._client.namespace, + ), + ), ) req = temporalio.api.workflowservice.v1.QueryWorkflowRequest( namespace=self._client.namespace, @@ -8255,7 +8281,10 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any: await self._apply_headers(input.headers, req.query.header.fields) try: resp = await self._client.workflow_service.query_workflow( - req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout + req, + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, ) except RPCError as err: # If the status is INVALID_ARGUMENT, we can assume it's a query @@ -8281,11 +8310,18 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any: return results[0] async def signal_workflow(self, input: SignalWorkflowInput) -> None: - data_converter = self._client.data_converter.with_context( + data_converter = self._client.data_converter._with_contexts( WorkflowSerializationContext( namespace=self._client.namespace, workflow_id=input.id, - ) + ), + StorageDriverStoreContext( + target=StorageDriverWorkflowInfo( + id=input.id, + run_id=input.run_id or None, + namespace=self._client.namespace, + ), + ), ) req = temporalio.api.workflowservice.v1.SignalWorkflowExecutionRequest( namespace=self._client.namespace, @@ -8306,11 +8342,18 @@ async def signal_workflow(self, input: SignalWorkflowInput) -> None: ) async def terminate_workflow(self, input: TerminateWorkflowInput) -> None: - data_converter = self._client.data_converter.with_context( + data_converter = self._client.data_converter._with_contexts( WorkflowSerializationContext( namespace=self._client.namespace, workflow_id=input.id, - ) + ), + StorageDriverStoreContext( + target=StorageDriverWorkflowInfo( + id=input.id, + run_id=input.run_id or None, + namespace=self._client.namespace, + ), + ), ) req = temporalio.api.workflowservice.v1.TerminateWorkflowExecutionRequest( namespace=self._client.namespace, @@ -8365,7 +8408,7 @@ async def _build_start_activity_execution_request( self, input: StartActivityInput ) -> temporalio.api.workflowservice.v1.StartActivityExecutionRequest: """Build StartActivityExecutionRequest from input.""" - data_converter = self._client.data_converter.with_context( + data_converter = self._client.data_converter._with_contexts( ActivitySerializationContext( namespace=self._client.namespace, activity_id=input.id, @@ -8374,7 +8417,14 @@ async def _build_start_activity_execution_request( is_local=False, workflow_id=None, workflow_type=None, - ) + ), + StorageDriverStoreContext( + target=StorageDriverActivityInfo( + id=input.id, + type=input.activity_type, + namespace=self._client.namespace, + ), + ), ) req = temporalio.api.workflowservice.v1.StartActivityExecutionRequest( @@ -8560,11 +8610,20 @@ async def _build_update_workflow_execution_request( input: StartWorkflowUpdateInput | UpdateWithStartUpdateWorkflowInput, workflow_id: str, ) -> temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest: - data_converter = self._client.data_converter.with_context( + data_converter = self._client.data_converter._with_contexts( WorkflowSerializationContext( namespace=self._client.namespace, workflow_id=workflow_id, - ) + ), + StorageDriverStoreContext( + target=StorageDriverWorkflowInfo( + id=workflow_id, + run_id=(input.run_id or None) + if isinstance(input, StartWorkflowUpdateInput) + else None, + namespace=self._client.namespace, + ), + ), ) run_id, first_execution_run_id = ( ( @@ -8739,10 +8798,34 @@ async def _start_workflow_update_with_start( ### Async activity calls + def _get_async_activity_store_context( + self, id_or_token: AsyncActivityIDReference | bytes + ) -> StorageDriverStoreContext: + if isinstance(id_or_token, AsyncActivityIDReference): + if id_or_token.workflow_id: + return StorageDriverStoreContext( + target=StorageDriverWorkflowInfo( + id=id_or_token.workflow_id or None, + run_id=id_or_token.run_id or None, + namespace=self._client.namespace, + ), + ) + return StorageDriverStoreContext( + target=StorageDriverActivityInfo( + id=id_or_token.activity_id, + run_id=id_or_token.run_id or None, + namespace=self._client.namespace, + ), + ) + else: + return StorageDriverStoreContext(target=None) + async def heartbeat_async_activity( self, input: HeartbeatAsyncActivityInput ) -> None: - data_converter = input.data_converter_override or self._client.data_converter + data_converter = ( + input.data_converter_override or self._client.data_converter + )._with_store_context(self._get_async_activity_store_context(input.id_or_token)) details = ( None if not input.details @@ -8797,7 +8880,9 @@ async def heartbeat_async_activity( ) async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None: - data_converter = input.data_converter_override or self._client.data_converter + data_converter = ( + input.data_converter_override or self._client.data_converter + )._with_store_context(self._get_async_activity_store_context(input.id_or_token)) result = ( None if input.result is temporalio.common._arg_unset @@ -8831,7 +8916,9 @@ async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> No ) async def fail_async_activity(self, input: FailAsyncActivityInput) -> None: - data_converter = input.data_converter_override or self._client.data_converter + data_converter = ( + input.data_converter_override or self._client.data_converter + )._with_store_context(self._get_async_activity_store_context(input.id_or_token)) failure = temporalio.api.failure.v1.Failure() await data_converter.encode_failure(input.error, failure) @@ -8872,7 +8959,9 @@ async def fail_async_activity(self, input: FailAsyncActivityInput) -> None: async def report_cancellation_async_activity( self, input: ReportCancellationAsyncActivityInput ) -> None: - data_converter = input.data_converter_override or self._client.data_converter + data_converter = ( + input.data_converter_override or self._client.data_converter + )._with_store_context(self._get_async_activity_store_context(input.id_or_token)) details = ( None if not input.details diff --git a/temporalio/contrib/aws/s3driver/_driver.py b/temporalio/contrib/aws/s3driver/_driver.py index 481e3a9d4..1f9d129c9 100644 --- a/temporalio/contrib/aws/s3driver/_driver.py +++ b/temporalio/contrib/aws/s3driver/_driver.py @@ -15,12 +15,12 @@ from temporalio.api.common.v1 import Payload from temporalio.contrib.aws.s3driver._client import S3StorageDriverClient from temporalio.converter import ( - ActivitySerializationContext, StorageDriver, + StorageDriverActivityInfo, StorageDriverClaim, StorageDriverRetrieveContext, StorageDriverStoreContext, - WorkflowSerializationContext, + StorageDriverWorkflowInfo, ) _T = TypeVar("_T") @@ -113,40 +113,25 @@ async def store( (e.g. proto binary). The returned list is the same length as ``payloads``. """ - workflow_id: str | None = None - activity_id: str | None = None - namespace: str | None = None - if isinstance(context.serialization_context, WorkflowSerializationContext): - workflow_id = context.serialization_context.workflow_id - namespace = context.serialization_context.namespace - if isinstance(context.serialization_context, ActivitySerializationContext): - # Prioritize workflow over activity so that the same payload that - # may be stored across workflow and activity boundaries are deduplicated. - if context.serialization_context.workflow_id: - workflow_id = context.serialization_context.workflow_id - elif context.serialization_context.activity_id: - activity_id = context.serialization_context.activity_id - namespace = context.serialization_context.namespace - - # URL encode values to avoid characters that break the key format - # e.g. spaces, forward-slashes, etc. - if namespace: - namespace = urllib.parse.quote(namespace, safe="") - if workflow_id: - workflow_id = urllib.parse.quote(workflow_id, safe="") - if activity_id: - activity_id = urllib.parse.quote(activity_id, safe="") - - namespace_segments = f"/ns/{namespace}" if namespace else "" + def _quote(val: str | None) -> str | None: + return urllib.parse.quote(val, safe="") if val else None + + # Build context segments from the target identity. context_segments = "" - # Prioritize workflow over activity so that the same payload that - # may be stored across workflow and activity boundaries are deduplicated. - # Workflow and Activity IDs are case sensitive. - if workflow_id: - context_segments += f"/wfi/{workflow_id}" - elif activity_id: - context_segments += f"/aci/{activity_id}" + target = context.target + namespace = _quote(target.namespace) if target is not None else None + namespace_segment = f"/ns/{namespace}" if namespace else "" + if isinstance(target, StorageDriverWorkflowInfo): + wf_type = _quote(target.type) or "null" + wf_id = _quote(target.id) or "null" + wf_run_id = _quote(target.run_id) or "null" + context_segments = f"/wt/{wf_type}/wi/{wf_id}/ri/{wf_run_id}" + elif isinstance(target, StorageDriverActivityInfo): + act_type = _quote(target.type) or "null" + act_id = _quote(target.id) or "null" + act_run_id = _quote(target.run_id) or "null" + context_segments = f"/at/{act_type}/ai/{act_id}/ri/{act_run_id}" async def _upload(payload: Payload) -> StorageDriverClaim: bucket = self._get_bucket(context, payload) @@ -162,7 +147,7 @@ async def _upload(payload: Payload) -> StorageDriverClaim: digest_segments = f"/d/sha256/{hash_digest}" - key = f"v0{namespace_segments}{context_segments}{digest_segments}" + key = f"v0{namespace_segment}{context_segments}{digest_segments}" try: if not await self._client.object_exists(bucket=bucket, key=key): diff --git a/temporalio/converter/__init__.py b/temporalio/converter/__init__.py index 2777e7e80..3ca6a3507 100644 --- a/temporalio/converter/__init__.py +++ b/temporalio/converter/__init__.py @@ -7,9 +7,11 @@ from temporalio.converter._extstore import ( ExternalStorage, StorageDriver, + StorageDriverActivityInfo, StorageDriverClaim, StorageDriverRetrieveContext, StorageDriverStoreContext, + StorageDriverWorkflowInfo, StorageWarning, ) from temporalio.converter._failure_converter import ( @@ -54,9 +56,11 @@ "ActivitySerializationContext", "ExternalStorage", "StorageDriver", + "StorageDriverActivityInfo", "StorageDriverClaim", "StorageDriverRetrieveContext", "StorageDriverStoreContext", + "StorageDriverWorkflowInfo", "StorageWarning", "AdvancedJSONEncoder", "BinaryNullPayloadConverter", diff --git a/temporalio/converter/_data_converter.py b/temporalio/converter/_data_converter.py index 99de876ea..0323466e7 100644 --- a/temporalio/converter/_data_converter.py +++ b/temporalio/converter/_data_converter.py @@ -17,6 +17,7 @@ from temporalio.converter._extstore import ( _REFERENCE_ENCODING, ExternalStorage, + StorageDriverStoreContext, StorageWarning, ) from temporalio.converter._failure_converter import ( @@ -199,6 +200,25 @@ def with_context(self, context: SerializationContext) -> Self: object.__setattr__(cloned, "external_storage", external_storage) return cloned + def _with_store_context( + self, store_ctx: StorageDriverStoreContext + ) -> DataConverter: + """Return an instance with ``store_ctx`` bound into :attr:`external_storage`.""" + if self.external_storage is None: + return self + return dataclasses.replace( + self, + external_storage=self.external_storage._with_store_context(store_ctx), + ) + + def _with_contexts( + self, + serialization_ctx: SerializationContext, + store_ctx: StorageDriverStoreContext, + ) -> DataConverter: + """Return an instance with both serialization and store contexts applied.""" + return self.with_context(serialization_ctx)._with_store_context(store_ctx) + def _with_payload_error_limits( self, limits: _ServerPayloadErrorLimits | None ) -> DataConverter: diff --git a/temporalio/converter/_extstore.py b/temporalio/converter/_extstore.py index 28fad00a4..e787652a5 100644 --- a/temporalio/converter/_extstore.py +++ b/temporalio/converter/_extstore.py @@ -19,10 +19,6 @@ from temporalio.api.common.v1 import Payload, Payloads from temporalio.converter._payload_converter import JSONPlainPayloadConverter -from temporalio.converter._serialization_context import ( - SerializationContext, - WithSerializationContext, -) _T = TypeVar("_T") @@ -92,6 +88,48 @@ class StorageDriverClaim: """ +@dataclass(frozen=True, kw_only=True) +class StorageDriverWorkflowInfo: + """Workflow identity information for external storage operations. + + .. warning:: + This API is experimental. + """ + + namespace: str + """The namespace of the workflow execution.""" + + id: str | None = None + """The workflow ID.""" + + run_id: str | None = None + """The workflow run ID, if available.""" + + type: str | None = None + """The workflow type name, if available.""" + + +@dataclass(frozen=True, kw_only=True) +class StorageDriverActivityInfo: + """Activity identity information for external storage operations. + + .. warning:: + This API is experimental. + """ + + namespace: str + """The namespace of the activity execution.""" + + id: str | None = None + """The activity ID.""" + + run_id: str | None = None + """The activity run ID (only for standalone activities).""" + + type: str | None = None + """The activity type name, if available.""" + + @dataclass(frozen=True) class StorageDriverStoreContext: """Context passed to :meth:`StorageDriver.store` and ``driver_selector`` calls. @@ -100,10 +138,14 @@ class StorageDriverStoreContext: This API is experimental. """ - serialization_context: SerializationContext | None = None - """The serialization context active when this store operation was initiated, - or ``None`` if no context has been set. - """ + target: StorageDriverActivityInfo | StorageDriverWorkflowInfo | None = None + """The workflow or activity for which this payload is being stored. + + For payloads being stored on behalf of an explicit target (e.g. a child + workflow being started, an activity being scheduled, an external workflow + being signaled), this is that target's identity. When no explicit target + exists the current execution context (workflow or activity) is used as the + target instead.""" @dataclass(frozen=True) @@ -182,7 +224,7 @@ class _StorageReference: @dataclass(frozen=True) -class ExternalStorage(WithSerializationContext): +class ExternalStorage: """Configuration for external storage behavior. .. warning:: @@ -222,9 +264,13 @@ class ExternalStorage(WithSerializationContext): for retrieval lookups. """ - _context: SerializationContext | None = dataclasses.field( - init=False, default=None, repr=False, compare=False + _store_context: StorageDriverStoreContext = dataclasses.field( + default=StorageDriverStoreContext(target=None), + init=False, + repr=False, + compare=False, ) + """Store context bound to this instance via :meth:`_with_store_context`.""" _claim_converter: ClassVar[JSONPlainPayloadConverter] = JSONPlainPayloadConverter( encoding=_REFERENCE_ENCODING.decode() @@ -261,12 +307,6 @@ def __post_init__(self) -> None: driver_map[name] = driver object.__setattr__(self, "_driver_map", driver_map) - def with_context(self, context: SerializationContext) -> Self: - """Return a copy of these options with the serialization context applied.""" - result = dataclasses.replace(self) - object.__setattr__(result, "_context", context) - return result - def _select_driver( self, context: StorageDriverStoreContext, payload: Payload ) -> StorageDriver | None: @@ -293,15 +333,20 @@ def _get_driver_by_name(self, name: str) -> StorageDriver: raise ValueError(f"No driver found with name '{name}'") return driver + def _with_store_context(self, ctx: StorageDriverStoreContext) -> ExternalStorage: + """Return a copy of this instance with ``ctx`` bound as the store context.""" + result = dataclasses.replace(self) + object.__setattr__(result, "_store_context", ctx) + return result + async def _store_payload(self, payload: Payload) -> Payload: start_time = time.monotonic() - context = StorageDriverStoreContext(serialization_context=self._context) - driver = self._select_driver(context, payload) + driver = self._select_driver(self._store_context, payload) if driver is None: return payload - claims = await driver.store(context, [payload]) + claims = await driver.store(self._store_context, [payload]) self._validate_claim_length(claims, expected=1, driver=driver) @@ -336,11 +381,10 @@ async def _store_payload_sequence( start_time = time.monotonic() results = list(payloads) - context = StorageDriverStoreContext(serialization_context=self._context) to_store: list[tuple[int, Payload, StorageDriver]] = [] for index, payload in enumerate(payloads): - driver = self._select_driver(context, payload) + driver = self._select_driver(self._store_context, payload) if driver is None: continue to_store.append((index, payload, driver)) @@ -356,7 +400,7 @@ async def _store_payload_sequence( all_claims = await _gather_cancel_on_error( [ - driver.store(context, [p for _, p in indexed_payloads]) + driver.store(self._store_context, [p for _, p in indexed_payloads]) for driver, indexed_payloads in driver_group_list ] ) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index c7a1032fe..28cc1458a 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -34,6 +34,11 @@ import temporalio.converter import temporalio.converter._payload_limits import temporalio.exceptions +from temporalio.converter import ( + StorageDriverActivityInfo, + StorageDriverStoreContext, + StorageDriverWorkflowInfo, +) from ._interceptor import ( ActivityInboundInterceptor, @@ -262,7 +267,17 @@ async def _heartbeat_async( activity_task_queue=self._task_queue, is_local=activity.info.is_local, ) - data_converter = data_converter.with_context(context) + data_converter = data_converter._with_contexts( + context, + StorageDriverStoreContext( + target=StorageDriverActivityInfo( + id=activity.info.activity_id, + type=activity.info.activity_type, + run_id=activity.info.activity_run_id, + namespace=activity.info.namespace, + ), + ), + ) # Perform the heartbeat try: @@ -270,7 +285,6 @@ async def _heartbeat_async( task_token=task_token ) if details: - # Convert to core payloads heartbeat.details.extend(await data_converter.encode(details)) logger.debug("Recording heartbeat with details %s", details) self._bridge_worker().record_activity_heartbeat(heartbeat) @@ -316,6 +330,30 @@ async def _handle_start_activity_task( is_local=start.is_local, ) data_converter = self._data_converter.with_context(context) + + # Build store context for external storage + ns = start.workflow_namespace or self._client.namespace + # Store context is set for the full activity task lifetime (input + # decode, execution, result/failure encode). Each activity task runs + # in its own coroutine so the value won't leak to other tasks. + started_by_workflow = bool(start.workflow_execution.workflow_id) + store_target: StorageDriverWorkflowInfo | StorageDriverActivityInfo + if started_by_workflow: + store_target = StorageDriverWorkflowInfo( + id=start.workflow_execution.workflow_id or None, + type=start.workflow_type or None, + run_id=start.workflow_execution.run_id or None, + namespace=ns, + ) + else: + store_target = StorageDriverActivityInfo( + id=start.activity_id or None, + type=start.activity_type or None, + namespace=ns, + ) + data_converter = self._data_converter._with_contexts( + context, StorageDriverStoreContext(target=store_target) + ) try: result = await self._execute_activity( start, running_activity, task_token, data_converter diff --git a/temporalio/worker/_command_aware_visitor.py b/temporalio/worker/_command_aware_visitor.py index 327f8c68c..f77bea042 100644 --- a/temporalio/worker/_command_aware_visitor.py +++ b/temporalio/worker/_command_aware_visitor.py @@ -17,6 +17,7 @@ ResolveSignalExternalWorkflow, ) from temporalio.bridge.proto.workflow_commands.workflow_commands_pb2 import ( + CompleteWorkflowExecution, ScheduleActivity, ScheduleLocalActivity, ScheduleNexusOperation, @@ -67,6 +68,14 @@ def __init__( ) # Workflow commands with payloads + async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution( + self, fs: VisitorFunctions, o: CompleteWorkflowExecution + ) -> None: + with current_command(CommandType.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION, 0): + await super()._visit_coresdk_workflow_commands_CompleteWorkflowExecution( + fs, o + ) + async def _visit_coresdk_workflow_commands_ScheduleActivity( self, fs: VisitorFunctions, o: ScheduleActivity ) -> None: diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index fb104b414..b699e421d 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -28,6 +28,7 @@ import temporalio.workflow from temporalio.api.enums.v1 import WorkflowTaskFailedCause from temporalio.bridge.worker import PollShutdownError +from temporalio.converter import StorageDriverStoreContext, StorageDriverWorkflowInfo from . import _command_aware_visitor from ._interceptor import ( @@ -294,7 +295,21 @@ async def _handle_activation( namespace=self._namespace, workflow_id=workflow_id, ) - data_converter = self._data_converter.with_context(workflow_context) + data_converter = self._data_converter._with_contexts( + workflow_context, + StorageDriverStoreContext( + target=StorageDriverWorkflowInfo( + id=workflow_id, + run_id=act.run_id, + type=( + workflow.workflow_type + if workflow + else (init_job.workflow_type if init_job else None) + ), + namespace=self._namespace, + ), + ), + ) if workflow: data_converter = _CommandAwareDataConverter.create( instance=workflow.instance, @@ -313,6 +328,7 @@ async def _handle_activation( workflow = _RunningWorkflow( self._create_workflow_instance(act, init_job), workflow_id, + workflow_type=init_job.workflow_type, ) self._running_workflows[act.run_id] = workflow @@ -802,9 +818,15 @@ def _gen_tb_helper( class _RunningWorkflow: - def __init__(self, instance: WorkflowInstance, workflow_id: str): + def __init__( + self, + instance: WorkflowInstance, + workflow_id: str, + workflow_type: str | None = None, + ): self.instance = instance self.workflow_id = workflow_id + self.workflow_type = workflow_type self.deadlocked_activation_task: Awaitable | None = None self._deadlock_can_be_interrupted_lock = threading.Lock() self._deadlock_can_be_interrupted = False @@ -902,7 +924,10 @@ async def _encode_payload_sequence( async def _external_store_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - return await self._get_current_dc()._external_store_payload_sequence(payloads) + command_info = _command_aware_visitor.current_command_info.get() + store_ctx = self._ca_instance.get_external_store_context(command_info) + dc = self._get_current_dc()._with_store_context(store_ctx) + return await dc._external_store_payload_sequence(payloads) async def _external_retrieve_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 1bfa77c3c..c809b3bd1 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -58,6 +58,7 @@ import temporalio.converter import temporalio.exceptions import temporalio.workflow +from temporalio.converter import StorageDriverStoreContext, StorageDriverWorkflowInfo from temporalio.service import __version__ from ..api.failure.v1.message_pb2 import Failure @@ -182,6 +183,21 @@ def get_serialization_context( """ raise NotImplementedError + @abstractmethod + def get_external_store_context( + self, + command_info: _command_aware_visitor.CommandInfo | None, + ) -> StorageDriverStoreContext: + """Return appropriate store context for external storage operations. + + Args: + command_info: Optional information identifying the associated command. + + Returns: + The store context associated with the command. + """ + raise NotImplementedError + def get_thread_id(self) -> int | None: """Return the thread identifier that this workflow is running on. @@ -1851,7 +1867,6 @@ def workflow_register_random_seed_callback( # These are in alphabetical order and all start with "_outbound_". def _outbound_continue_as_new(self, input: ContinueAsNewInput) -> NoReturn: - # Just throw raise _ContinueAsNewError(self, input) def _outbound_schedule_activity( @@ -2222,6 +2237,73 @@ def get_serialization_context( workflow_id=self._info.workflow_id, ) + def get_external_store_context( + self, + command_info: _command_aware_visitor.CommandInfo | None, + ) -> StorageDriverStoreContext: + # The current workflow is the default target for external store + # operations. For commands that target other workflows, those workflows + # are the target for that command's external store operation. For + # workflow activities, the target is the current workflow since the + # activity is bound to the lifetime of the current workflow, the + # activity run information is the same as the current workflow, and + # successfully completed activities are not involved in replay. + # Otherwise, the storage space for a given workflow would be disparate + # if stored under activity information. + current_wf = StorageDriverWorkflowInfo( + id=self._info.workflow_id, + run_id=self._info.run_id, + type=self._info.workflow_type, + namespace=self._info.namespace, + ) + + if command_info is None: + return StorageDriverStoreContext(target=current_wf) + + COMMAND_TYPE = temporalio.api.enums.v1.command_type_pb2.CommandType + + if ( + command_info.command_type + == COMMAND_TYPE.COMMAND_TYPE_START_CHILD_WORKFLOW_EXECUTION + and command_info.command_seq in self._pending_child_workflows + ): + child = self._pending_child_workflows[command_info.command_seq] + return StorageDriverStoreContext( + target=StorageDriverWorkflowInfo( + id=child._input.id, + type=child._input.workflow, + namespace=self._info.namespace, + ), + ) + + elif ( + command_info.command_type + == COMMAND_TYPE.COMMAND_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION + and command_info.command_seq in self._pending_external_signals + ): + _, target_id = self._pending_external_signals[command_info.command_seq] + return StorageDriverStoreContext( + target=StorageDriverWorkflowInfo( + id=target_id, namespace=self._info.namespace + ), + ) + + elif ( + command_info.command_type + == COMMAND_TYPE.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION + and self._info.parent is not None + ): + return StorageDriverStoreContext( + target=StorageDriverWorkflowInfo( + id=self._info.parent.workflow_id, + run_id=self._info.parent.run_id, + namespace=self._info.parent.namespace, + ), + ) + + else: + return StorageDriverStoreContext(target=current_wf) + def _instantiate_workflow_object(self) -> Any: if not self._workflow_input: raise RuntimeError("Expected workflow input. This is a Python SDK bug.") diff --git a/temporalio/worker/workflow_sandbox/_in_sandbox.py b/temporalio/worker/workflow_sandbox/_in_sandbox.py index eea8f6940..d18374899 100644 --- a/temporalio/worker/workflow_sandbox/_in_sandbox.py +++ b/temporalio/worker/workflow_sandbox/_in_sandbox.py @@ -13,6 +13,7 @@ import temporalio.converter import temporalio.worker._workflow_instance import temporalio.workflow +from temporalio.converter._extstore import StorageDriverStoreContext from temporalio.worker import _command_aware_visitor logger = logging.getLogger(__name__) @@ -88,3 +89,10 @@ def get_serialization_context( ) -> temporalio.converter.SerializationContext | None: """Get serialization context.""" return self.instance.get_serialization_context(command_info) + + def get_external_store_context( + self, + command_info: _command_aware_visitor.CommandInfo | None, + ) -> StorageDriverStoreContext: + """Get store context for external storage.""" + return self.instance.get_external_store_context(command_info) diff --git a/temporalio/worker/workflow_sandbox/_runner.py b/temporalio/worker/workflow_sandbox/_runner.py index 31514e33b..7605f3054 100644 --- a/temporalio/worker/workflow_sandbox/_runner.py +++ b/temporalio/worker/workflow_sandbox/_runner.py @@ -17,6 +17,7 @@ import temporalio.common import temporalio.converter import temporalio.workflow +from temporalio.converter._extstore import StorageDriverStoreContext from temporalio.worker import _command_aware_visitor from ...api.common.v1.message_pb2 import Payloads @@ -205,3 +206,22 @@ def get_serialization_context( return self.globals_and_locals.pop("__temporal_context", None) # type: ignore finally: self.importer.restriction_context.is_runtime = False + + def get_external_store_context( + self, + command_info: _command_aware_visitor.CommandInfo | None, + ) -> StorageDriverStoreContext: + # Forward call to the sandboxed instance + self.importer.restriction_context.is_runtime = True + try: + self._run_code( + "with __temporal_importer.applied():\n" + " __temporal_context = __temporal_in_sandbox.get_external_store_context(__temporal_command_info)\n", + __temporal_importer=self.importer, + __temporal_command_info=command_info, + ) + return self.globals_and_locals.pop( + "__temporal_context", StorageDriverStoreContext(target=None) + ) # type: ignore + finally: + self.importer.restriction_context.is_runtime = False diff --git a/tests/contrib/aws/s3driver/test_s3driver.py b/tests/contrib/aws/s3driver/test_s3driver.py index 46184c8b7..c389fe07c 100644 --- a/tests/contrib/aws/s3driver/test_s3driver.py +++ b/tests/contrib/aws/s3driver/test_s3driver.py @@ -27,12 +27,12 @@ S3StorageDriverClient, ) from temporalio.converter import ( - ActivitySerializationContext, JSONPlainPayloadConverter, + StorageDriverActivityInfo, StorageDriverClaim, StorageDriverRetrieveContext, StorageDriverStoreContext, - WorkflowSerializationContext, + StorageDriverWorkflowInfo, ) from tests.contrib.aws.s3driver.conftest import BUCKET @@ -51,34 +51,36 @@ def make_payload(value: str = "hello") -> Payload: def make_store_context( - serialization_context: WorkflowSerializationContext - | ActivitySerializationContext - | None = None, + target: StorageDriverActivityInfo | StorageDriverWorkflowInfo | None = None, ) -> StorageDriverStoreContext: - return StorageDriverStoreContext(serialization_context=serialization_context) + return StorageDriverStoreContext( + target=target, + ) def make_workflow_context( namespace: str = "my-namespace", workflow_id: str = "my-workflow", -) -> WorkflowSerializationContext: - return WorkflowSerializationContext(namespace=namespace, workflow_id=workflow_id) + workflow_type: str | None = None, + run_id: str | None = None, +) -> StorageDriverStoreContext: + return make_store_context( + target=StorageDriverWorkflowInfo( + id=workflow_id, type=workflow_type, run_id=run_id, namespace=namespace + ), + ) def make_activity_context( namespace: str = "my-namespace", activity_id: str | None = "my-activity", - workflow_id: str | None = None, - activity_task_queue: str | None = None, -) -> ActivitySerializationContext: - return ActivitySerializationContext( - namespace=namespace, - activity_id=activity_id, - activity_type=None, - activity_task_queue=activity_task_queue, - workflow_id=workflow_id, - workflow_type=None, - is_local=False, + activity_type: str | None = None, + run_id: str | None = None, +) -> StorageDriverStoreContext: + return make_store_context( + target=StorageDriverActivityInfo( + id=activity_id, type=activity_type, run_id=run_id, namespace=namespace + ), ) @@ -211,48 +213,70 @@ async def test_key_context_workflow( ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_workflow_context(namespace="ns1", workflow_id="wf1") - ) + ctx = make_workflow_context(namespace="ns1", workflow_id="wf1") [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() - assert claim.claim_data["key"] == f"v0/ns/ns1/wfi/wf1/d/sha256/{expected_hash}" + assert ( + claim.claim_data["key"] + == f"v0/ns/ns1/wt/null/wi/wf1/ri/null/d/sha256/{expected_hash}" + ) - async def test_key_context_workflow_activity( + async def test_key_context_workflow_with_type_and_run_id( self, driver_client: S3StorageDriverClient ) -> None: - """workflow_id takes priority over activity_id in ActivitySerializationContext.""" driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_activity_context( - namespace="ns1", workflow_id="wf1", activity_id="act1" - ) + ctx = make_workflow_context( + namespace="ns1", + workflow_id="wf1", + workflow_type="MyWorkflow", + run_id="run-abc", ) [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() - assert claim.claim_data["key"] == f"v0/ns/ns1/wfi/wf1/d/sha256/{expected_hash}" + assert ( + claim.claim_data["key"] + == f"v0/ns/ns1/wt/MyWorkflow/wi/wf1/ri/run-abc/d/sha256/{expected_hash}" + ) - async def test_key_context_standalone_activityt( + async def test_key_context_activity( self, driver_client: S3StorageDriverClient ) -> None: + """activity target uses activity key segment.""" driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_activity_context(namespace="ns1", activity_id="act1", workflow_id=None) + ctx = make_activity_context(namespace="ns1", activity_id="act1") + [claim] = await driver.store(ctx, [payload]) + expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() + assert ( + claim.claim_data["key"] + == f"v0/ns/ns1/at/null/ai/act1/ri/null/d/sha256/{expected_hash}" + ) + + async def test_key_context_activity_with_type_and_run_id( + self, driver_client: S3StorageDriverClient + ) -> None: + driver = S3StorageDriver(client=driver_client, bucket=BUCKET) + payload = make_payload() + ctx = make_activity_context( + namespace="ns1", + activity_id="act1", + activity_type="MyActivity", + run_id="run-abc", ) [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() - assert claim.claim_data["key"] == f"v0/ns/ns1/aci/act1/d/sha256/{expected_hash}" + assert ( + claim.claim_data["key"] + == f"v0/ns/ns1/at/MyActivity/ai/act1/ri/run-abc/d/sha256/{expected_hash}" + ) async def test_key_preserves_case( self, driver_client: S3StorageDriverClient ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_workflow_context(namespace="MyNamespace", workflow_id="MyWorkflow") - ) + ctx = make_workflow_context(namespace="MyNamespace", workflow_id="MyWorkflow") [claim] = await driver.store(ctx, [payload]) key = claim.claim_data["key"] assert "MyNamespace" in key @@ -263,14 +287,12 @@ async def test_key_urlencodes_workflow_id_with_slashes( ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_workflow_context(namespace="ns1", workflow_id="order/123/v2") - ) + ctx = make_workflow_context(namespace="ns1", workflow_id="order/123/v2") [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() assert ( claim.claim_data["key"] - == f"v0/ns/ns1/wfi/order%2F123%2Fv2/d/sha256/{expected_hash}" + == f"v0/ns/ns1/wt/null/wi/order%2F123%2Fv2/ri/null/d/sha256/{expected_hash}" ) async def test_key_urlencodes_workflow_id_with_special_chars( @@ -278,14 +300,12 @@ async def test_key_urlencodes_workflow_id_with_special_chars( ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_workflow_context(namespace="ns1", workflow_id="wf#1 &foo=bar") - ) + ctx = make_workflow_context(namespace="ns1", workflow_id="wf#1 &foo=bar") [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() assert ( claim.claim_data["key"] - == f"v0/ns/ns1/wfi/wf%231%20%26foo%3Dbar/d/sha256/{expected_hash}" + == f"v0/ns/ns1/wt/null/wi/wf%231%20%26foo%3Dbar/ri/null/d/sha256/{expected_hash}" ) async def test_key_urlencodes_activity_id( @@ -293,16 +313,12 @@ async def test_key_urlencodes_activity_id( ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_activity_context( - namespace="ns1", activity_id="act/1#2", workflow_id=None - ) - ) + ctx = make_activity_context(namespace="ns1", activity_id="act/1#2") [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() assert ( claim.claim_data["key"] - == f"v0/ns/ns1/aci/act%2F1%232/d/sha256/{expected_hash}" + == f"v0/ns/ns1/at/null/ai/act%2F1%232/ri/null/d/sha256/{expected_hash}" ) async def test_key_urlencodes_namespace( @@ -310,14 +326,12 @@ async def test_key_urlencodes_namespace( ) -> None: driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload() - ctx = make_store_context( - make_workflow_context(namespace="my/ns#1", workflow_id="wf1") - ) + ctx = make_workflow_context(namespace="my/ns#1", workflow_id="wf1") [claim] = await driver.store(ctx, [payload]) expected_hash = hashlib.sha256(payload.SerializeToString()).hexdigest() assert ( claim.claim_data["key"] - == f"v0/ns/my%2Fns%231/wfi/wf1/d/sha256/{expected_hash}" + == f"v0/ns/my%2Fns%231/wt/null/wi/wf1/ri/null/d/sha256/{expected_hash}" ) async def test_key_urlencoded_roundtrip( @@ -326,9 +340,7 @@ async def test_key_urlencoded_roundtrip( """Payloads stored with special-char IDs can be retrieved correctly.""" driver = S3StorageDriver(client=driver_client, bucket=BUCKET) payload = make_payload("special-char-roundtrip") - ctx = make_store_context( - make_workflow_context(namespace="ns/1", workflow_id="wf/2#3") - ) + ctx = make_workflow_context(namespace="ns/1", workflow_id="wf/2#3") [claim] = await driver.store(ctx, [payload]) [retrieved] = await driver.retrieve(StorageDriverRetrieveContext(), [claim]) assert retrieved == payload @@ -528,45 +540,42 @@ def counting_selector(_ctx: StorageDriverStoreContext, _p: Payload) -> str: ) assert call_count == 3 - async def test_selector_routes_by_activity_task_queue( + async def test_selector_routes_by_activity_type( self, aioboto3_client: S3Client, driver_client: S3StorageDriverClient ) -> None: - """bucket callable can route payloads to different buckets by activity task queue.""" - bucket_a = "bucket-queue-a" - bucket_b = "bucket-queue-b" + """bucket callable can route payloads to different buckets by activity type.""" + bucket_a = "bucket-type-a" + bucket_b = "bucket-type-b" await aioboto3_client.create_bucket(Bucket=bucket_a) await aioboto3_client.create_bucket(Bucket=bucket_b) - queue_buckets = {"queue-a": bucket_a, "queue-b": bucket_b} + type_buckets = {"type-a": bucket_a, "type-b": bucket_b} - def queue_selector(ctx: StorageDriverStoreContext, p: Payload) -> str: + def type_selector(ctx: StorageDriverStoreContext, p: Payload) -> str: del p - if isinstance(ctx.serialization_context, ActivitySerializationContext): - queue = ctx.serialization_context.activity_task_queue - if queue and queue in queue_buckets: - return queue_buckets[queue] + act = ( + ctx.target + if isinstance(ctx.target, StorageDriverActivityInfo) + else None + ) + if act and act.type and act.type in type_buckets: + return type_buckets[act.type] return BUCKET - driver = S3StorageDriver(client=driver_client, bucket=queue_selector) + driver = S3StorageDriver(client=driver_client, bucket=type_selector) - ctx_a = make_store_context( - make_activity_context( - namespace="ns1", - activity_id="act1", - workflow_id="wf1", - activity_task_queue="queue-a", - ) + ctx_a = make_activity_context( + namespace="ns1", + activity_id="act1", + activity_type="type-a", ) [claim_a] = await driver.store(ctx_a, [make_payload("payload-a")]) assert claim_a.claim_data["bucket"] == bucket_a - ctx_b = make_store_context( - make_activity_context( - namespace="ns1", - activity_id="act2", - workflow_id="wf1", - activity_task_queue="queue-b", - ) + ctx_b = make_activity_context( + namespace="ns1", + activity_id="act2", + activity_type="type-b", ) [claim_b] = await driver.store(ctx_b, [make_payload("payload-b")]) assert claim_b.claim_data["bucket"] == bucket_b @@ -581,7 +590,7 @@ def capturing_selector(ctx: StorageDriverStoreContext, p: Payload) -> str: return BUCKET payload = make_payload() - store_ctx = make_store_context(make_workflow_context()) + store_ctx = make_workflow_context() driver = S3StorageDriver(client=driver_client, bucket=capturing_selector) await driver.store(store_ctx, [payload]) diff --git a/tests/contrib/aws/s3driver/test_s3driver_worker.py b/tests/contrib/aws/s3driver/test_s3driver_worker.py index 87ab73736..e25be5fbf 100644 --- a/tests/contrib/aws/s3driver/test_s3driver_worker.py +++ b/tests/contrib/aws/s3driver/test_s3driver_worker.py @@ -107,8 +107,17 @@ async def test_s3_driver_workflow_input_key( execution_timeout=timedelta(seconds=5), ) keys = await _list_keys(aioboto3_client) - assert len(keys) == 1 - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[0] + + # Client stores workflow input with ri=null (run ID not yet assigned); + # worker stores activity input with ri=run_id — same bytes, two S3 objects. + assert len(keys) == 2 + assert all( + f"/ns/default/wt/LargeIOWorkflow/wi/{workflow_id}/ri/" in k for k in keys + ) + # Client-side store: ri=null because run ID is not yet known. + assert sum(1 for k in keys if "/ri/null/" in k) == 1 + # Worker-side store: ri=run_id, assigned by the server. + assert sum(1 for k in keys if "/ri/null/" not in k) == 1 async def test_s3_driver_workflow_output_key( @@ -127,8 +136,11 @@ async def test_s3_driver_workflow_output_key( ) assert result == LARGE keys = await _list_keys(aioboto3_client) + # Activity result and workflow result dedup to same key assert len(keys) == 1 - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[0] + assert f"/ns/default/wt/LargeIOWorkflow/wi/{workflow_id}/ri/" in keys[0] + # Run ID is known for both activity completion and workflow completion + assert "/ri/null/" not in keys[0] async def test_s3_driver_workflow_activity_input_key( @@ -146,11 +158,14 @@ async def test_s3_driver_workflow_activity_input_key( execution_timeout=timedelta(seconds=5), ) keys = await _list_keys(aioboto3_client) - assert len(keys) == 1 - assert f"/ns/default/wfi/{workflow_id}/" in keys[0] - assert ( - "/aci/" not in keys[0] - ), "Activity input should use workflow_id, not activity_id" + # Client start (ri=null) + worker schedules activity (ri=run_id) — same bytes, two objects. + assert len(keys) == 2 + # Both keys are under the workflow wi/ri prefix, not the activity. + assert all( + f"/ns/default/wt/LargeIOWorkflow/wi/{workflow_id}/ri/" in k for k in keys + ) + # Activity input is keyed under the scheduling workflow, not the activity. + assert all("/ai/" not in k for k in keys) async def test_s3_driver_workflow_activity_output_key( @@ -168,8 +183,63 @@ async def test_s3_driver_workflow_activity_output_key( execution_timeout=timedelta(seconds=5), ) keys = await _list_keys(aioboto3_client) + # Activity result and workflow result are both LARGE so they deduplicate to one object. assert len(keys) == 1 - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[0] + assert f"/ns/default/wt/LargeIOWorkflow/wi/{workflow_id}/ri/" in keys[0] + # ri=run_id for both stores (run ID is known by the time the activity completes). + assert "/ri/null/" not in keys[0] + + +async def test_s3_driver_standalone_activity_input_key( + env: WorkflowEnvironment, tmprl_client: Client, aioboto3_client: S3Client +) -> None: + if env.supports_time_skipping: + pytest.skip( + "Java test server: https://github.com/temporalio/sdk-java/issues/2741" + ) + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + async with new_worker( + tmprl_client, activities=[large_io_activity], task_queue=task_queue + ): + await tmprl_client.execute_activity( + large_io_activity, + LARGE, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + keys = await _list_keys(aioboto3_client) + # Input and output are the same LARGE bytes, so they deduplicate to one key. + assert len(keys) == 1 + # Keyed under the activity, not a workflow. + assert f"/ns/default/at/large_io_activity/ai/{activity_id}/ri/null/" in keys[0] + assert "/wt/" not in keys[0] + + +async def test_s3_driver_standalone_activity_output_key( + env: WorkflowEnvironment, tmprl_client: Client, aioboto3_client: S3Client +) -> None: + if env.supports_time_skipping: + pytest.skip( + "Java test server: https://github.com/temporalio/sdk-java/issues/2741" + ) + activity_id = str(uuid.uuid4()) + task_queue = str(uuid.uuid4()) + async with new_worker( + tmprl_client, activities=[large_output_activity], task_queue=task_queue + ): + await tmprl_client.execute_activity( + large_output_activity, + id=activity_id, + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=5), + ) + keys = await _list_keys(aioboto3_client) + # Only the output is large; keyed under the activity. + assert len(keys) == 1 + assert f"/ns/default/at/large_output_activity/ai/{activity_id}/ri/null/" in keys[0] + assert "/wt/" not in keys[0] async def test_s3_driver_signal_arg_key( @@ -186,8 +256,12 @@ async def test_s3_driver_signal_arg_key( await handle.signal(SignalQueryUpdateWorkflow.finish, LARGE) await handle.result() keys = await _list_keys(aioboto3_client) - assert len(keys) == 1 - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[0] + # Signal arg + workflow result — two distinct keys (different wt and ri). + assert len(keys) == 2 + # Signal arg: client stores with wt=null (type not known) and ri=null. + assert any(f"/wt/null/wi/{workflow_id}/ri/null/" in k for k in keys) + # Workflow result: worker stores with real type and ri=run_id. + assert any(f"/wt/SignalQueryUpdateWorkflow/wi/{workflow_id}/" in k for k in keys) async def test_s3_driver_query_result_key( @@ -206,8 +280,12 @@ async def test_s3_driver_query_result_key( await handle.signal(SignalQueryUpdateWorkflow.finish, "done") await handle.result() keys = await _list_keys(aioboto3_client) - assert len(keys) == 1 - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[0] + # Query arg + (query result deduplicated with workflow result) — two distinct keys. + assert len(keys) == 2 + # Query arg: client stores with wt=null (type not known) and ri=null. + assert any(f"/wt/null/wi/{workflow_id}/ri/null/" in k for k in keys) + # Query result and workflow result are both LARGE and deduplicate to one key with ri=run_id. + assert any(f"/wt/SignalQueryUpdateWorkflow/wi/{workflow_id}/" in k for k in keys) async def test_s3_driver_update_result_key( @@ -226,8 +304,12 @@ async def test_s3_driver_update_result_key( await handle.signal(SignalQueryUpdateWorkflow.finish, "done") await handle.result() keys = await _list_keys(aioboto3_client) - assert len(keys) == 1 - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[0] + # Update arg + (update result deduplicated with workflow result) — two distinct keys. + assert len(keys) == 2 + # Update arg: client stores with wt=null (type not known) and ri=null. + assert any(f"/wt/null/wi/{workflow_id}/ri/null/" in k for k in keys) + # Update result and workflow result are both LARGE and deduplicate to one key with ri=run_id. + assert any(f"/wt/SignalQueryUpdateWorkflow/wi/{workflow_id}/" in k for k in keys) async def test_s3_driver_child_workflow_input_key( @@ -244,9 +326,11 @@ async def test_s3_driver_child_workflow_input_key( execution_timeout=timedelta(seconds=5), ) keys = await _list_keys(aioboto3_client) - assert len(keys) == 1 child_workflow_id = f"{workflow_id}-child" - assert f"/ns/default/wfi/{child_workflow_id}/d/sha256/" in keys[0] + # Child input is the only large payload — stored under the child's wi/ri. + assert len(keys) == 1 + # Keyed under the child: child input is stored in the child's context. + assert f"/ns/default/wt/ChildWorkflow/wi/{child_workflow_id}/ri/" in keys[0] async def test_s3_driver_identified_casing( @@ -264,10 +348,11 @@ async def test_s3_driver_identified_casing( execution_timeout=timedelta(seconds=5), ) keys = await _list_keys(aioboto3_client) - assert len(keys) == 1 - assert "/ns/default/" in keys[0], "Namespace segment should be present" - assert ( - f"/wfi/{workflow_id}/" in keys[0] + # Client start (ri=null) + worker stores (ri=run_id) — two objects. + assert len(keys) == 2 + # Workflow ID is percent-encoded but casing is preserved verbatim. + assert all( + f"/ns/default/wt/LargeIOWorkflow/wi/{workflow_id}/ri/" in k for k in keys ), "Workflow ID should preserve original case in the key" @@ -290,9 +375,14 @@ async def test_s3_driver_content_dedup( execution_timeout=timedelta(seconds=5), ) keys = await _list_keys(aioboto3_client) + # Two distinct content hashes (LARGE from download, LARGE_2 from extract) → two keys. assert len(keys) == 2 - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[0] - assert f"/ns/default/wfi/{workflow_id}/d/sha256/" in keys[1] + # Both are under the same workflow wi/ri prefix despite crossing activity boundaries. + assert all( + f"/ns/default/wt/DocumentIngestionWorkflow/wi/{workflow_id}/ri/" in k + for k in keys + ) + # The two keys differ by content hash only. assert keys[0] != keys[1] @@ -301,8 +391,7 @@ async def test_s3_driver_single_workflow_same_key_namespace( ) -> None: """A training job started with a large config, injected with large override parameters mid-run, and polled for large metrics — all produce S3 keys - under the same workflow ID prefix, regardless of which primitive carried - the payload.""" + containing the same workflow ID.""" workflow_id = str(uuid.uuid4()) async with new_worker(tmprl_client, ModelTrainingWorkflow) as worker: handle = await tmprl_client.start_workflow( @@ -320,19 +409,19 @@ async def test_s3_driver_single_workflow_same_key_namespace( await handle.signal(ModelTrainingWorkflow.complete) await handle.result() keys = await _list_keys(aioboto3_client) - # LARGE (input + signal arg) and LARGE_2 (metrics result) deduplicate to - # two distinct keys — both anchored under the same workflow ID prefix. - assert len(keys) == 2 - assert all(f"/ns/default/wfi/{workflow_id}/" in key for key in keys) + # Four distinct keys: client start, signal arg, update result, workflow result. + assert len(keys) == 4 + # All keys are anchored under the same workflow ID regardless of which primitive carried the payload. + assert all(f"/wi/{workflow_id}/" in k for k in keys) async def test_s3_driver_parent_child_independent_key_namespaces( tmprl_client: Client, aioboto3_client: S3Client ) -> None: - """An order fulfillment workflow spawns a child payment processor, passes it - a large order payload, and returns the child's large payment confirmation. - Each workflow accumulates S3 keys under its own workflow ID prefix — - parent and child key namespaces are fully independent.""" + """An order fulfillment workflow spawns a child payment processor and passes + it a large order payload. Child input is keyed under the parent (it lives in + the parent's history); child output is keyed under the parent (for lifecycle + resilience — the child result lives in the parent's completion history).""" workflow_id = str(uuid.uuid4()) payment_id = f"{workflow_id}-payment" async with new_worker( @@ -346,16 +435,15 @@ async def test_s3_driver_parent_child_independent_key_namespaces( execution_timeout=timedelta(seconds=5), ) keys = await _list_keys(aioboto3_client) - parent_prefix = f"/ns/default/wfi/{workflow_id}/d/" - child_prefix = f"/ns/default/wfi/{payment_id}/d/" - parent_keys = [k for k in keys if parent_prefix in k] - child_keys = [k for k in keys if child_prefix in k] - # The parent stores its input (LARGE) and the child's result propagated - # back (LARGE_2) under the parent's prefix → 2 keys. - # The child stores its input (LARGE) and its result (LARGE_2) under the - # child's prefix → 2 keys. - assert len(parent_keys) == 2 - assert len(child_keys) == 2 + parent_keys = [k for k in keys if f"/wi/{workflow_id}/" in k] + child_keys = [k for k in keys if f"/wi/{payment_id}/" in k] + # Parent accumulates 3 keys: + # 1. Client start stored in parent's key space (ri=null) + # 2. Child result stored in parent's key space + # 3. Parent's own workflow result + assert len(parent_keys) == 3 + # Child accumulates 1 key: its input from the parent + assert len(child_keys) == 1 async def test_s3_store_failure_surfaces_in_workflow_history( @@ -400,7 +488,6 @@ async def test_s3_store_failure_surfaces_in_workflow_history( large_payload = JSONPlainPayloadConverter().to_payload(LARGE) assert large_payload is not None expected_hash = hashlib.sha256(large_payload.SerializeToString()).hexdigest() - expected_key = f"v0/ns/default/wfi/{workflow_id}/d/sha256/{expected_hash}" assert isinstance(exc_info.value, WorkflowFailureError) activity_error = exc_info.value.__cause__ @@ -408,7 +495,8 @@ async def test_s3_store_failure_surfaces_in_workflow_history( app_error = activity_error.__cause__ assert isinstance(app_error, ApplicationError) assert app_error.type == "RuntimeError" - assert ( - app_error.message - == f"S3StorageDriver store failed [bucket={bad_bucket}, key={expected_key}]" - ) + # Key includes run_id which is only known at runtime; use substring checks. + msg = app_error.message + assert f"S3StorageDriver store failed [bucket={bad_bucket}, key=" in msg + assert f"/wt/LargeOutputNoRetryWorkflow/wi/{workflow_id}/ri/" in msg + assert f"/d/sha256/{expected_hash}]" in msg diff --git a/tests/test_serialization_context.py b/tests/test_serialization_context.py index 580926b4b..d3ce022f5 100644 --- a/tests/test_serialization_context.py +++ b/tests/test_serialization_context.py @@ -40,15 +40,10 @@ DefaultFailureConverter, DefaultPayloadConverter, EncodingPayloadConverter, - ExternalStorage, JSONPlainPayloadConverter, PayloadCodec, PayloadConverter, SerializationContext, - StorageDriver, - StorageDriverClaim, - StorageDriverRetrieveContext, - StorageDriverStoreContext, WithSerializationContext, WorkflowSerializationContext, ) @@ -1919,84 +1914,3 @@ async def test_user_customization_of_default_payload_converter( id=wf_id, task_queue=task_queue, ) - - -# Child workflow external storage context test - - -class ContextTrackingStorageDriver(StorageDriver): - """In-memory driver that records the serialization context on each store/retrieve.""" - - def __init__(self) -> None: - self._storage: dict[str, bytes] = {} - self.store_contexts: list[SerializationContext | None] = [] - - def name(self) -> str: - return "context-tracking" - - async def store( - self, - context: StorageDriverStoreContext, - payloads: Sequence[temporalio.api.common.v1.Payload], - ) -> list[StorageDriverClaim]: - self.store_contexts.append(context.serialization_context) - claims: list[StorageDriverClaim] = [] - for payload in payloads: - key = f"payload-{len(self._storage)}" - self._storage[key] = payload.SerializeToString() - claims.append(StorageDriverClaim(claim_data={"key": key})) - return claims - - async def retrieve( - self, - context: StorageDriverRetrieveContext, - claims: Sequence[StorageDriverClaim], - ) -> list[temporalio.api.common.v1.Payload]: - results: list[temporalio.api.common.v1.Payload] = [] - for claim in claims: - payload = temporalio.api.common.v1.Payload() - payload.ParseFromString(self._storage[claim.claim_data["key"]]) - results.append(payload) - return results - - -async def test_child_workflow_external_storage_with_context(client: Client): - """External storage should receive the child workflow's context, not the parent's.""" - workflow_id = str(uuid.uuid4()) - child_workflow_id = f"{workflow_id}-child" - task_queue = str(uuid.uuid4()) - - driver = ContextTrackingStorageDriver() - config = client.config() - config["data_converter"] = dataclasses.replace( - DataConverter.default, - external_storage=ExternalStorage( - drivers=[driver], - payload_size_threshold=0, - ), - ) - client = Client(**config) - - async with Worker( - client, - task_queue=task_queue, - workflows=[ChildWorkflowCodecTestWorkflow, EchoWorkflow], - workflow_runner=UnsandboxedWorkflowRunner(), - ): - await client.execute_workflow( - ChildWorkflowCodecTestWorkflow.run, - TraceData(), - id=workflow_id, - task_queue=task_queue, - ) - - child_context = WorkflowSerializationContext( - namespace=client.namespace, - workflow_id=child_workflow_id, - ) - # store_contexts[0]: parent input encode → parent context - # store_contexts[1]: child workflow input encode → child context - # store_contexts[2]: child workflow result encode → child context - # store_contexts[3]: parent result encode → parent context - child_context_count = sum(1 for c in driver.store_contexts if c == child_context) - assert child_context_count == 2 diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py index 8e2ee763a..5e5ebdac6 100644 --- a/tests/worker/test_extstore.py +++ b/tests/worker/test_extstore.py @@ -11,6 +11,7 @@ import temporalio import temporalio.bridge.client import temporalio.bridge.worker +import temporalio.client import temporalio.converter import temporalio.worker._workflow from temporalio import activity, workflow @@ -19,9 +20,12 @@ from temporalio.common import RetryPolicy from temporalio.converter import ( ExternalStorage, + StorageDriver, + StorageDriverActivityInfo, StorageDriverClaim, StorageDriverRetrieveContext, StorageDriverStoreContext, + StorageDriverWorkflowInfo, StorageWarning, ) from temporalio.exceptions import ActivityError, ApplicationError @@ -859,3 +863,491 @@ async def test_tmprl1104_with_extstore_download_and_upload( assert getattr(records[1], "payload_upload_count") == 1 assert getattr(records[1], "payload_upload_size") == expected_output_size assert getattr(records[1], "payload_upload_duration") > timedelta(0) + + +# --------------------------------------------------------------------------- +# Store-metadata context tests +# --------------------------------------------------------------------------- + + +class ContextTrackingStorageDriver(StorageDriver): + """In-memory driver that records the store context on each store/retrieve.""" + + def __init__(self) -> None: + self._storage: dict[str, bytes] = {} + self.store_contexts: list[StorageDriverStoreContext] = [] + + def name(self) -> str: + return "context-tracking" + + async def store( + self, + context: StorageDriverStoreContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + self.store_contexts.append(context) + claims: list[StorageDriverClaim] = [] + for payload in payloads: + key = f"payload-{len(self._storage)}" + self._storage[key] = payload.SerializeToString() + claims.append(StorageDriverClaim(claim_data={"key": key})) + return claims + + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + results: list[Payload] = [] + for claim in claims: + payload = Payload() + payload.ParseFromString(self._storage[claim.claim_data["key"]]) + results.append(payload) + return results + + +@workflow.defn +class SignalWaitWorkflow: + def __init__(self) -> None: + self._signal_data: str | None = None + + @workflow.run + async def run(self, _arg: str) -> str: + await workflow.wait_condition(lambda: self._signal_data is not None) + return self._signal_data # type: ignore + + @workflow.signal + async def my_signal(self, data: str) -> None: + self._signal_data = data + + +@workflow.defn +class EchoWorkflow: + @workflow.run + async def run(self, data: str) -> str: + return data + + +@workflow.defn +class ChildWorkflowStoreMetadataTestWorkflow: + @workflow.run + async def run(self, data: str) -> str: + return await workflow.execute_child_workflow( + EchoWorkflow.run, + data, + id=f"{workflow.info().workflow_id}-child", + ) + + +async def _make_tracking_client( + env: WorkflowEnvironment, +) -> tuple[Client, ContextTrackingStorageDriver]: + driver = ContextTrackingStorageDriver() + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=0, + ), + ), + ) + return client, driver + + +async def test_store_metadata_start_workflow(env: WorkflowEnvironment) -> None: + """start_workflow should set workflow id and type on store context.""" + client, driver = await _make_tracking_client(env) + workflow_id = str(uuid.uuid4()) + + async with new_worker(client, EchoWorkflow) as worker: + await client.execute_workflow( + EchoWorkflow.run, + "hello", + id=workflow_id, + task_queue=worker.task_queue, + ) + + assert len(driver.store_contexts) == 2 + + # [0] Workflow input arg + client_ctx = driver.store_contexts[0] + assert isinstance(client_ctx.target, StorageDriverWorkflowInfo) + assert client_ctx.target.namespace == client.namespace + assert client_ctx.target.id == workflow_id + assert client_ctx.target.type == "EchoWorkflow" + assert client_ctx.target.run_id is None + + # [1] Workflow result + worker_ctx = driver.store_contexts[1] + assert isinstance(worker_ctx.target, StorageDriverWorkflowInfo) + assert worker_ctx.target.namespace == client.namespace + assert worker_ctx.target.id == workflow_id + assert worker_ctx.target.type == "EchoWorkflow" + assert worker_ctx.target.run_id is not None + + +async def test_store_metadata_signal_with_start(env: WorkflowEnvironment) -> None: + """signal_with_start should set workflow metadata for signal arg encoding.""" + client, driver = await _make_tracking_client(env) + workflow_id = str(uuid.uuid4()) + + async with new_worker(client, SignalWaitWorkflow) as worker: + handle = await client.start_workflow( + SignalWaitWorkflow.run, + "hello", + id=workflow_id, + task_queue=worker.task_queue, + start_signal="my_signal", + start_signal_args=["signal-data"], + ) + await handle.result() + + assert len(driver.store_contexts) == 3 + + # [0] Workflow input arg + input_ctx = driver.store_contexts[0] + assert isinstance(input_ctx.target, StorageDriverWorkflowInfo) + assert input_ctx.target.id == workflow_id + assert input_ctx.target.type == "SignalWaitWorkflow" + assert input_ctx.target.run_id is None + + # [1] Signal arg + signal_ctx = driver.store_contexts[1] + assert isinstance(signal_ctx.target, StorageDriverWorkflowInfo) + assert signal_ctx.target.id == workflow_id + assert signal_ctx.target.type == "SignalWaitWorkflow" + assert signal_ctx.target.run_id is None + + # [2] Workflow result + result_ctx = driver.store_contexts[2] + assert isinstance(result_ctx.target, StorageDriverWorkflowInfo) + assert result_ctx.target.id == workflow_id + assert result_ctx.target.type == "SignalWaitWorkflow" + assert result_ctx.target.run_id is not None + + +async def test_store_metadata_signal_workflow(env: WorkflowEnvironment) -> None: + """signal_workflow should set workflow id on store context.""" + client, driver = await _make_tracking_client(env) + workflow_id = str(uuid.uuid4()) + + async with new_worker(client, SignalWaitWorkflow) as worker: + handle = await client.start_workflow( + SignalWaitWorkflow.run, + "hello", + id=workflow_id, + task_queue=worker.task_queue, + ) + # Signal separately (not signal-with-start) + await handle.signal(SignalWaitWorkflow.my_signal, "signal-data") + await handle.result() + + assert len(driver.store_contexts) == 3 + + # [0] Client starts workflow + start_ctx = driver.store_contexts[0] + assert isinstance(start_ctx.target, StorageDriverWorkflowInfo) + assert start_ctx.target.id == workflow_id + assert start_ctx.target.type == "SignalWaitWorkflow" + assert start_ctx.target.run_id is None + + # [1] Client sends signal: type and run_id are unknown at signal time + signal_ctx = driver.store_contexts[1] + assert isinstance(signal_ctx.target, StorageDriverWorkflowInfo) + assert signal_ctx.target.id == workflow_id + assert signal_ctx.target.type is None + assert signal_ctx.target.run_id is None + + # [2] Workflow worker returns result + result_ctx = driver.store_contexts[2] + assert isinstance(result_ctx.target, StorageDriverWorkflowInfo) + assert result_ctx.target.id == workflow_id + assert result_ctx.target.type == "SignalWaitWorkflow" + assert result_ctx.target.run_id is not None + + +async def test_store_metadata_schedule_action(env: WorkflowEnvironment) -> None: + """Schedule action _to_proto should set workflow metadata.""" + if env.supports_time_skipping: + pytest.skip("Java test server doesn't support schedules") + client, driver = await _make_tracking_client(env) + task_queue = str(uuid.uuid4()) + schedule_id = f"sched-{uuid.uuid4()}" + + try: + await client.create_schedule( + schedule_id, + temporalio.client.Schedule( + action=temporalio.client.ScheduleActionStartWorkflow( + EchoWorkflow.run, + "hello", + id=f"wf-{schedule_id}", + task_queue=task_queue, + ), + spec=temporalio.client.ScheduleSpec(), + ), + ) + + assert len(driver.store_contexts) == 1 + + # [0] Client encodes workflow args when creating the schedule action + ctx = driver.store_contexts[0] + assert isinstance(ctx.target, StorageDriverWorkflowInfo) + assert ctx.target.namespace == client.namespace + assert ctx.target.id == f"wf-{schedule_id}" + assert ctx.target.type == "EchoWorkflow" + assert ctx.target.run_id is None + finally: + try: + handle = client.get_schedule_handle(schedule_id) + await handle.delete() + except Exception: + pass + + +async def test_store_metadata_child_workflow(env: WorkflowEnvironment) -> None: + """External storage should receive the child workflow as the target when scheduling.""" + client, driver = await _make_tracking_client(env) + workflow_id = f"workflow-{uuid.uuid4()}" + child_workflow_id = f"{workflow_id}-child" + + async with new_worker( + client, + ChildWorkflowStoreMetadataTestWorkflow, + EchoWorkflow, + ) as worker: + await client.execute_workflow( + ChildWorkflowStoreMetadataTestWorkflow.run, + "hello", + id=workflow_id, + task_queue=worker.task_queue, + ) + + assert len(driver.store_contexts) == 4 + + # [0] Client starts parent workflow + client_ctx = driver.store_contexts[0] + assert isinstance(client_ctx.target, StorageDriverWorkflowInfo) + assert client_ctx.target.id == workflow_id + assert client_ctx.target.type == "ChildWorkflowStoreMetadataTestWorkflow" + assert client_ctx.target.run_id is None + + # [1] Parent schedules child: target = child workflow + start_child_ctx = driver.store_contexts[1] + assert isinstance(start_child_ctx.target, StorageDriverWorkflowInfo) + assert start_child_ctx.target.id == child_workflow_id + assert start_child_ctx.target.type == "EchoWorkflow" + assert start_child_ctx.target.run_id is None + + # [2] Child returns result: target = parent workflow (child results are + # stored in the parent's key space so they remain accessible during replay) + child_result_ctx = driver.store_contexts[2] + assert isinstance(child_result_ctx.target, StorageDriverWorkflowInfo) + assert child_result_ctx.target.id == workflow_id + # ParentInfo does not carry workflow type + assert child_result_ctx.target.type is None + assert child_result_ctx.target.run_id is not None + + # [3] Parent returns result: target = parent (current execution) + parent_result_ctx = driver.store_contexts[3] + assert isinstance(parent_result_ctx.target, StorageDriverWorkflowInfo) + assert parent_result_ctx.target.id == workflow_id + assert parent_result_ctx.target.type == "ChildWorkflowStoreMetadataTestWorkflow" + assert parent_result_ctx.target.run_id is not None + + +# Workflow definitions for gap tests + + +@activity.defn +async def echo_activity(input: str) -> str: + """Simple activity that returns its input.""" + return input + + +@workflow.defn +class ActivityScheduleMetadataWorkflow: + """Workflow that schedules an activity to test activity metadata on the store context.""" + + @workflow.run + async def run(self, data: str) -> str: + return await workflow.execute_activity( + echo_activity, + data, + activity_id="my-activity-id", + schedule_to_close_timeout=timedelta(seconds=10), + ) + + +@workflow.defn +class SignalExternalMetadataWorkflow: + """Workflow that signals another workflow.""" + + @workflow.run + async def run(self, target_workflow_id: str) -> None: + await workflow.get_external_workflow_handle(target_workflow_id).signal( + SignalWaitWorkflow.my_signal, "signal-from-workflow" + ) + + +async def test_store_metadata_activity_scheduling(env: WorkflowEnvironment) -> None: + """When a workflow schedules an activity, context.activity should be populated.""" + client, driver = await _make_tracking_client(env) + workflow_id = f"workflow-{uuid.uuid4()}" + + async with new_worker( + client, + ActivityScheduleMetadataWorkflow, + activities=[echo_activity], + ) as worker: + await client.execute_workflow( + ActivityScheduleMetadataWorkflow.run, + "hello", + id=workflow_id, + task_queue=worker.task_queue, + ) + + assert len(driver.store_contexts) == 4 + + # [0] Client starts workflow + client_ctx = driver.store_contexts[0] + assert isinstance(client_ctx.target, StorageDriverWorkflowInfo) + assert client_ctx.target.id == workflow_id + assert client_ctx.target.type == "ActivityScheduleMetadataWorkflow" + assert client_ctx.target.run_id is None + + # [1] Workflow worker schedules activity + schedule_ctx = driver.store_contexts[1] + assert isinstance(schedule_ctx.target, StorageDriverWorkflowInfo) + assert schedule_ctx.target.namespace == client.namespace + assert schedule_ctx.target.id == workflow_id + assert schedule_ctx.target.type == "ActivityScheduleMetadataWorkflow" + assert schedule_ctx.target.run_id is not None + + # [2] Activity worker completes + execute_ctx = driver.store_contexts[2] + assert isinstance(execute_ctx.target, StorageDriverWorkflowInfo) + assert execute_ctx.target.namespace == client.namespace + assert execute_ctx.target.id == workflow_id + assert execute_ctx.target.type == "ActivityScheduleMetadataWorkflow" + assert execute_ctx.target.run_id is not None + + # [3] Workflow returns result + result_ctx = driver.store_contexts[3] + assert isinstance(result_ctx.target, StorageDriverWorkflowInfo) + assert result_ctx.target.id == workflow_id + assert result_ctx.target.type == "ActivityScheduleMetadataWorkflow" + assert result_ctx.target.run_id is not None + + +async def test_store_metadata_signal_external_workflow( + env: WorkflowEnvironment, +) -> None: + """Signaling an external workflow should set workflow.id to the target.""" + client, driver = await _make_tracking_client(env) + target_workflow_id = f"target-{uuid.uuid4()}" + sender_workflow_id = f"sender-{uuid.uuid4()}" + + async with new_worker( + client, + SignalExternalMetadataWorkflow, + SignalWaitWorkflow, + ) as worker: + # Start the target workflow first + target_handle = await client.start_workflow( + SignalWaitWorkflow.run, + "waiting", + id=target_workflow_id, + task_queue=worker.task_queue, + ) + # Start the sender which will signal the target + await client.execute_workflow( + SignalExternalMetadataWorkflow.run, + target_workflow_id, + id=sender_workflow_id, + task_queue=worker.task_queue, + ) + await target_handle.result() + + assert len(driver.store_contexts) == 5 + + # [0] Client starts target workflow (SignalWaitWorkflow) + target_start_ctx = driver.store_contexts[0] + assert isinstance(target_start_ctx.target, StorageDriverWorkflowInfo) + assert target_start_ctx.target.id == target_workflow_id + assert target_start_ctx.target.type == "SignalWaitWorkflow" + assert target_start_ctx.target.run_id is None + + # [1] Client starts sender workflow (SignalExternalMetadataWorkflow) + sender_start_ctx = driver.store_contexts[1] + assert isinstance(sender_start_ctx.target, StorageDriverWorkflowInfo) + assert sender_start_ctx.target.id == sender_workflow_id + assert sender_start_ctx.target.type == "SignalExternalMetadataWorkflow" + assert sender_start_ctx.target.run_id is None + + # [2] Sender signals target: target = the workflow being signaled + signal_ctx = driver.store_contexts[2] + assert isinstance(signal_ctx.target, StorageDriverWorkflowInfo) + assert signal_ctx.target.id == target_workflow_id + assert signal_ctx.target.type is None + assert signal_ctx.target.run_id is None + + # [3] and [4] are the sender and target workflow completions in some order. + # The sender's WFT 2 (after signal resolution) and the target's WFT (after + # receiving the signal) are both scheduled by the server at nearly the same + # time, so the order of their completions is non-deterministic. + completion_ctxs = { + ctx.target.id: ctx + for ctx in driver.store_contexts[3:5] + if isinstance(ctx.target, StorageDriverWorkflowInfo) and ctx.target.id + } + assert sender_workflow_id in completion_ctxs + assert target_workflow_id in completion_ctxs + + sender_result_ctx = completion_ctxs[sender_workflow_id] + assert isinstance(sender_result_ctx.target, StorageDriverWorkflowInfo) + assert sender_result_ctx.target.run_id is not None + + target_result_ctx = completion_ctxs[target_workflow_id] + assert isinstance(target_result_ctx.target, StorageDriverWorkflowInfo) + assert target_result_ctx.target.run_id is not None + + +async def test_store_metadata_standalone_activity(env: WorkflowEnvironment) -> None: + """Standalone activity worker should use StorageDriverActivityInfo as target.""" + if env.supports_time_skipping: + pytest.skip( + "Java test server: https://github.com/temporalio/sdk-java/issues/2741" + ) + client, driver = await _make_tracking_client(env) + activity_id = f"activity-{uuid.uuid4()}" + + async with new_worker(client, activities=[echo_activity]) as worker: + await client.execute_activity( + echo_activity, + "hello", + id=activity_id, + task_queue=worker.task_queue, + schedule_to_close_timeout=timedelta(seconds=30), + ) + + assert len(driver.store_contexts) == 2 + + client_ctx = driver.store_contexts[0] + # [0] Client schedules standalone activity + assert isinstance(client_ctx.target, StorageDriverActivityInfo) + assert client_ctx.target.namespace == client.namespace + assert client_ctx.target.id == activity_id + assert client_ctx.target.type == "echo_activity" + assert client_ctx.target.run_id is None + + # [1] Activity worker completes: target = activity (no parent workflow) + execute_ctx = driver.store_contexts[1] + assert isinstance(execute_ctx.target, StorageDriverActivityInfo) + assert execute_ctx.target.namespace == client.namespace + assert execute_ctx.target.id == activity_id + assert execute_ctx.target.type == "echo_activity" + assert execute_ctx.target.run_id is None diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 4ef5c29fa..f123e5c61 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -38,6 +38,7 @@ import temporalio.api.sdk.v1 import temporalio.client import temporalio.converter +import temporalio.converter._extstore import temporalio.worker import temporalio.worker._command_aware_visitor import temporalio.workflow @@ -1626,6 +1627,12 @@ def get_serialization_context( ) -> temporalio.converter.SerializationContext | None: return self._unsandboxed.get_serialization_context(command_info) + def get_external_store_context( + self, + command_info: temporalio.worker._command_aware_visitor.CommandInfo | None, + ) -> temporalio.converter._extstore.StorageDriverStoreContext: + return self._unsandboxed.get_external_store_context(command_info) + async def test_workflow_with_custom_runner(client: Client): runner = CustomWorkflowRunner()