Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions temporalio/bridge/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,11 @@ async def decode_activation(
"""
metrics = temporalio.converter._extstore.StorageOperationMetrics()
with metrics.track():
# Always retrieve headers from external storage regardless of
# decode_headers — external storage is orthogonal to codec encoding.
await CommandAwarePayloadVisitor(
skip_search_attributes=True,
skip_headers=not decode_headers,
skip_headers=False,
concurrency_limit=storage_concurrency_limit,
).visit(
_Visitor(data_converter._external_retrieve_payload_sequence), activation
Expand Down Expand Up @@ -353,9 +355,11 @@ async def _store_and_validate(

metrics = temporalio.converter._extstore.StorageOperationMetrics()
with metrics.track():
# Always store and validate headers regardless of encode_headers —
# external storage is orthogonal to codec encoding.
await CommandAwarePayloadVisitor(
skip_search_attributes=True,
skip_headers=not encode_headers,
skip_headers=False,
concurrency_limit=storage_concurrency_limit,
).visit(_Visitor(_store_and_validate), completion)

Expand Down
11 changes: 8 additions & 3 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9183,9 +9183,14 @@ async def _apply_headers(
) -> None:
if source is None:
return
if encode_headers:
for payload in source.values():
payload.CopyFrom(await data_converter._transform_outbound_payload(payload))
# Always run headers through external storage and validation, but only
# codec-encode when encode_headers is True.
for payload in source.values():
payload.CopyFrom(
await data_converter._transform_outbound_payload(
payload, encode=encode_headers
)
)
temporalio.common._apply_headers(source, dest)


Expand Down
14 changes: 10 additions & 4 deletions temporalio/converter/_data_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,12 @@ async def _encode_memo_existing(
)

async def _transform_outbound_payload(
self, payload: temporalio.api.common.v1.Payload
self,
payload: temporalio.api.common.v1.Payload,
*,
encode: bool = True,
) -> temporalio.api.common.v1.Payload:
if self.payload_codec:
if encode and self.payload_codec:
payload = (await self.payload_codec.encode([payload]))[0]
if self.external_storage:
payload = await self.external_storage._store_payload(payload)
Expand All @@ -273,11 +276,14 @@ async def _transform_outbound_payloads(
self._validate_payload_limits(payloads.payloads)

async def _transform_inbound_payload(
self, payload: temporalio.api.common.v1.Payload
self,
payload: temporalio.api.common.v1.Payload,
*,
decode: bool = True,
) -> temporalio.api.common.v1.Payload:
if self.external_storage:
payload = await self.external_storage._retrieve_payload(payload)
if self.payload_codec:
if decode and self.payload_codec:
payload = (await self.payload_codec.decode([payload]))[0]
return payload

Expand Down
11 changes: 7 additions & 4 deletions temporalio/worker/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,11 +629,14 @@ async def _execute_activity(
else None,
)

if self._encode_headers:
for payload in start.header_fields.values():
payload.CopyFrom(
await data_converter._transform_inbound_payload(payload)
# Always retrieve headers from external storage regardless of
# encode_headers — external storage is orthogonal to codec encoding.
for payload in start.header_fields.values():
payload.CopyFrom(
await data_converter._transform_inbound_payload(
payload, decode=self._encode_headers
)
)

running_activity.info = info
input = ExecuteActivityInput(
Expand Down
64 changes: 63 additions & 1 deletion tests/test_extstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,5 +679,67 @@ def test_duplicate_driver_names_raises(self):
)


class TestHeadersAlwaysExternalStorage:
"""Headers must always pass through external storage and validation,
regardless of the encode_headers/decode_headers flag. Codec encoding
and decoding remain gated by that flag."""

async def test_transform_outbound_payload_encode_false_still_stores(
self,
):
"""encode=False skips codec but still stores externally."""
driver = InMemoryTestDriver()
dc = DataConverter(
payload_codec=RecordingPayloadCodec("test-codec"),
external_storage=ExternalStorage(
drivers=[driver],
payload_size_threshold=0,
),
)

payload = Payload(data=b"x" * 100)

# encode=True: codec encodes AND external storage stores
result_encode = await dc._transform_outbound_payload(
payload, encode=True
)
assert driver._store_calls == 1

# encode=False: codec does NOT encode but external storage still stores
driver._store_calls = 0
result_no_encode = await dc._transform_outbound_payload(
payload, encode=False
)
assert driver._store_calls == 1

async def test_transform_inbound_payload_decode_false_still_retrieves(
self,
):
"""decode=False skips codec but still retrieves from external storage."""
driver = InMemoryTestDriver()
dc = DataConverter(
payload_codec=RecordingPayloadCodec("test-codec"),
external_storage=ExternalStorage(
drivers=[driver],
payload_size_threshold=0,
),
)

# First store a payload externally so there is something to retrieve
payload = Payload(data=b"x" * 100)
stored = await dc._transform_outbound_payload(payload, encode=True)
assert driver._store_calls == 1

# decode=True: retrieve AND decode
driver._store_calls = 0
await dc._transform_inbound_payload(stored, decode=True)
assert driver._retrieve_calls == 1

# decode=False: retrieve but do NOT decode
driver._retrieve_calls = 0
await dc._transform_inbound_payload(stored, decode=False)
assert driver._retrieve_calls == 1


if __name__ == "__main__":
pytest.main([__file__, "-v"])
pytest.main([__file], "-v")