From bbd9af74275c4b2276bdeb46a08c86262f0ac1ed Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Tue, 19 May 2026 15:45:45 -0400 Subject: [PATCH] Support infer types involving dataclass fields --- sdks/python/apache_beam/typehints/opcodes.py | 6 ++++++ .../apache_beam/typehints/row_type_test.py | 21 +++++++++++++++++++ .../typehints/trivial_inference_test.py | 12 +++++++++++ 3 files changed, 39 insertions(+) diff --git a/sdks/python/apache_beam/typehints/opcodes.py b/sdks/python/apache_beam/typehints/opcodes.py index 8e5d7b1e40c8..963b5e0850b6 100644 --- a/sdks/python/apache_beam/typehints/opcodes.py +++ b/sdks/python/apache_beam/typehints/opcodes.py @@ -29,6 +29,7 @@ """ # pytype: skip-file +import dataclasses import inspect import logging import sys @@ -447,6 +448,11 @@ def _getattr(o, name): return Const(BoundMethod(func, o)) elif isinstance(o, row_type.RowTypeConstraint): return o.get_type_for(name) + elif inspect.isclass(o) and dataclasses.is_dataclass(o): + field = o.__dataclass_fields__.get(name) + if field is not None: + return field.type + return Any else: return Any diff --git a/sdks/python/apache_beam/typehints/row_type_test.py b/sdks/python/apache_beam/typehints/row_type_test.py index 54e64caf6fa7..30bda0cd98ba 100644 --- a/sdks/python/apache_beam/typehints/row_type_test.py +++ b/sdks/python/apache_beam/typehints/row_type_test.py @@ -172,6 +172,27 @@ class DerivedDataClass(BaseDataClass): getattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID)) self.assertNotEqual(schema_for_derived.id, schema_for_base.id) + def test_dataclass_map_typehints(self): + @beam.coders.typecoders.registry.register_row + @dataclass(frozen=True) + class MyDataClass: + id: int + name: str + + p = beam.Pipeline() + pa = (p | beam.Create([MyDataClass(1, "a"), MyDataClass(2, "b")])) + self.assertEqual(pa.element_type, MyDataClass) + + pb = ( + pa | beam.Map( + lambda x: beam.Row(id=x.id, name=x.name, name_hash=hash(x.name)))) + self.assertTrue( + isinstance(pb.element_type, row_type.GeneratedClassRowTypeConstraint)) + self.assertEqual( + pb.element_type, + row_type.GeneratedClassRowTypeConstraint( + fields=[('id', int), ('name', str), ('name_hash', int)])) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/typehints/trivial_inference_test.py b/sdks/python/apache_beam/typehints/trivial_inference_test.py index fe60974e2806..f421819bdcae 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference_test.py +++ b/sdks/python/apache_beam/typehints/trivial_inference_test.py @@ -19,6 +19,7 @@ # pytype: skip-file +import dataclasses import types import unittest @@ -487,6 +488,17 @@ def testPyCallable(self): python_callable.PythonCallableWithSource("lambda x: (x, str(x))"), [int]) + def testDataClassFields(self): + @dataclasses.dataclass + class MyDataClass: + id: int + name: str + + self.assertReturnType( + typehints.Tuple[int, str], + python_callable.PythonCallableWithSource("lambda x: (x.id, x.name)"), + [MyDataClass]) + if __name__ == '__main__': unittest.main()