diff --git a/cadence/client.py b/cadence/client.py index a75d7b5..8a0622b 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -2,7 +2,7 @@ import socket import uuid from datetime import timedelta -from typing import TypedDict, Unpack, Any, cast, Union, Callable +from typing import TypedDict, Unpack, Any, cast, Union from grpc import ChannelCredentials, Compression from google.protobuf.duration_pb2 import Duration @@ -17,11 +17,14 @@ from cadence.api.v1.service_workflow_pb2 import ( StartWorkflowExecutionRequest, StartWorkflowExecutionResponse, + SignalWithStartWorkflowExecutionRequest, + SignalWithStartWorkflowExecutionResponse, ) from cadence.api.v1.common_pb2 import WorkflowType, WorkflowExecution from cadence.api.v1.tasklist_pb2 import TaskList from cadence.data_converter import DataConverter, DefaultDataConverter from cadence.metrics import MetricsEmitter, NoOpMetricsEmitter +from cadence.workflow import WorkflowDefinition class StartWorkflowOptions(TypedDict, total=False): @@ -132,7 +135,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: def _build_start_workflow_request( self, - workflow: Union[str, Callable], + workflow: Union[str, WorkflowDefinition], args: tuple[Any, ...], options: StartWorkflowOptions, ) -> StartWorkflowExecutionRequest: @@ -144,8 +147,8 @@ def _build_start_workflow_request( if isinstance(workflow, str): workflow_type_name = workflow else: - # For callable, use function name or __name__ attribute - workflow_type_name = getattr(workflow, "__name__", str(workflow)) + # For WorkflowDefinition, use the name property + workflow_type_name = workflow.name # Encode input arguments input_payload = None @@ -186,7 +189,7 @@ def _build_start_workflow_request( async def start_workflow( self, - workflow: Union[str, Callable], + workflow: Union[str, WorkflowDefinition], *args, **options_kwargs: Unpack[StartWorkflowOptions], ) -> WorkflowExecution: @@ -194,7 +197,7 @@ async def start_workflow( Start a workflow execution asynchronously. Args: - workflow: Workflow function or workflow type name string + workflow: WorkflowDefinition or workflow type name string *args: Arguments to pass to the workflow **options_kwargs: StartWorkflowOptions as keyword arguments @@ -229,6 +232,69 @@ async def start_workflow( except Exception: raise + async def signal_with_start_workflow( + self, + workflow: Union[str, WorkflowDefinition], + signal_name: str, + signal_args: list[Any], + *workflow_args: Any, + **options_kwargs: Unpack[StartWorkflowOptions], + ) -> WorkflowExecution: + """ + Signal a workflow execution, starting it if it is not already running. + + Args: + workflow: WorkflowDefinition or workflow type name string + signal_name: Name of the signal + signal_args: List of arguments to pass to the signal handler + *workflow_args: Arguments to pass to the workflow if it needs to be started + **options_kwargs: StartWorkflowOptions as keyword arguments + + Returns: + WorkflowExecution with workflow_id and run_id + + Raises: + ValueError: If required parameters are missing or invalid + Exception: If the gRPC call fails + """ + # Convert kwargs to StartWorkflowOptions and validate + options = _validate_and_apply_defaults(StartWorkflowOptions(**options_kwargs)) + + # Build the start workflow request + start_request = self._build_start_workflow_request( + workflow, workflow_args, options + ) + + # Encode signal input + signal_payload = None + if signal_args: + try: + signal_payload = self.data_converter.to_data(signal_args) + except Exception as e: + raise ValueError(f"Failed to encode signal input: {e}") + + # Build the SignalWithStartWorkflowExecution request + request = SignalWithStartWorkflowExecutionRequest( + start_request=start_request, + signal_name=signal_name, + ) + + if signal_payload: + request.signal_input.CopyFrom(signal_payload) + + # Execute the gRPC call + try: + response: SignalWithStartWorkflowExecutionResponse = ( + await self.workflow_stub.SignalWithStartWorkflowExecution(request) + ) + + execution = WorkflowExecution() + execution.workflow_id = start_request.workflow_id + execution.run_id = response.run_id + return execution + except Exception: + raise + def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions: if "target" not in options: diff --git a/cadence/workflow.py b/cadence/workflow.py index a8b257f..b6ebbbd 100644 --- a/cadence/workflow.py +++ b/cadence/workflow.py @@ -4,16 +4,16 @@ from dataclasses import dataclass from datetime import timedelta from typing import ( + Iterator, Callable, + TypeVar, + TypedDict, + Type, cast, + Any, Optional, Union, - Iterator, - TypedDict, - TypeVar, - Type, Unpack, - Any, Generic, ) import inspect diff --git a/tests/cadence/test_client_workflow.py b/tests/cadence/test_client_workflow.py index cdf7a2c..acb1a98 100644 --- a/tests/cadence/test_client_workflow.py +++ b/tests/cadence/test_client_workflow.py @@ -10,6 +10,7 @@ ) from cadence.client import Client, StartWorkflowOptions, _validate_and_apply_defaults from cadence.data_converter import DefaultDataConverter +from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions @pytest.fixture @@ -96,11 +97,17 @@ async def test_build_request_with_string_workflow(self, mock_client): uuid.UUID(request.request_id) # This will raise if not valid UUID @pytest.mark.asyncio - async def test_build_request_with_callable_workflow(self, mock_client): - """Test building request with callable workflow.""" + async def test_build_request_with_workflow_definition(self, mock_client): + """Test building request with WorkflowDefinition.""" + from cadence import workflow - def test_workflow(): - pass + class TestWorkflow: + @workflow.run + async def run(self): + pass + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(TestWorkflow, workflow_opts) client = Client(domain="test-domain", target="localhost:7933") @@ -110,7 +117,7 @@ def test_workflow(): task_start_to_close_timeout=timedelta(seconds=30), ) - request = client._build_start_workflow_request(test_workflow, (), options) + request = client._build_start_workflow_request(workflow_definition, (), options) assert request.workflow_type.name == "test_workflow" diff --git a/tests/integration_tests/test_client.py b/tests/integration_tests/test_client.py index 5b4e785..746100c 100644 --- a/tests/integration_tests/test_client.py +++ b/tests/integration_tests/test_client.py @@ -7,7 +7,10 @@ ) from cadence.error import EntityNotExistsError from tests.integration_tests.helper import CadenceHelper, DOMAIN_NAME -from cadence.api.v1.service_workflow_pb2 import DescribeWorkflowExecutionRequest +from cadence.api.v1.service_workflow_pb2 import ( + DescribeWorkflowExecutionRequest, + GetWorkflowExecutionHistoryRequest, +) from cadence.api.v1.common_pb2 import WorkflowExecution @@ -135,3 +138,66 @@ async def test_workflow_stub_start_and_describe(helper: CadenceHelper): assert task_timeout_seconds == task_timeout.total_seconds(), ( f"task_start_to_close_timeout mismatch: expected {task_timeout.total_seconds()}s, got {task_timeout_seconds}s" ) + + +@pytest.mark.usefixtures("helper") +async def test_signal_with_start_workflow(helper: CadenceHelper): + """Test signal_with_start_workflow method. + + This integration test verifies: + 1. Starting a workflow via signal_with_start_workflow + 2. Sending a signal to the workflow + 3. Signal appears in the workflow's history with correct name and payload + """ + async with helper.client() as client: + workflow_type = "test-workflow-signal-with-start" + task_list_name = "test-task-list-signal-with-start" + workflow_id = "test-workflow-signal-with-start-123" + execution_timeout = timedelta(minutes=5) + signal_name = "test-signal" + signal_arg = {"data": "test-signal-data"} + + execution = await client.signal_with_start_workflow( + workflow_type, + signal_name, + [signal_arg], + "arg1", + "arg2", + task_list=task_list_name, + execution_start_to_close_timeout=execution_timeout, + workflow_id=workflow_id, + ) + + assert execution is not None + assert execution.workflow_id == workflow_id + assert execution.run_id is not None + assert execution.run_id != "" + + # Fetch workflow history to verify signal was recorded + history_response = await client.workflow_stub.GetWorkflowExecutionHistory( + GetWorkflowExecutionHistoryRequest( + domain=DOMAIN_NAME, + workflow_execution=execution, + skip_archival=True, + ) + ) + + # Verify signal event appears in history with correct name and payload + signal_events = [ + event + for event in history_response.history.events + if event.HasField("workflow_execution_signaled_event_attributes") + ] + + assert len(signal_events) == 1, "Expected exactly one signal event in history" + signal_event = signal_events[0] + assert ( + signal_event.workflow_execution_signaled_event_attributes.signal_name + == signal_name + ), f"Expected signal name '{signal_name}'" + + # Verify signal payload matches what we sent + signal_payload_data = signal_event.workflow_execution_signaled_event_attributes.input.data.decode() + assert signal_arg["data"] in signal_payload_data, ( + f"Expected signal payload to contain '{signal_arg['data']}'" + )