From 0e182511db5e2b16360ad201f6877a0d95db6202 Mon Sep 17 00:00:00 2001 From: Derrick Williams Date: Tue, 19 May 2026 18:00:42 +0000 Subject: [PATCH 1/4] expand jinja --- sdks/python/apache_beam/yaml/main.py | 7 +- sdks/python/apache_beam/yaml/yaml_provider.py | 11 +- .../python/apache_beam/yaml/yaml_transform.py | 58 ++++++++- .../apache_beam/yaml/yaml_transform_test.py | 118 ++++++++++++++++++ 4 files changed, 185 insertions(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/yaml/main.py b/sdks/python/apache_beam/yaml/main.py index dc928dec7941..804798b82e02 100644 --- a/sdks/python/apache_beam/yaml/main.py +++ b/sdks/python/apache_beam/yaml/main.py @@ -235,8 +235,13 @@ def _build_pipeline_yaml_from_argv(argv): argv = _preparse_jinja_flags(argv) known_args, pipeline_args = _parse_arguments(argv) pipeline_template = _pipeline_spec_from_args(known_args) + + search_paths = [] + if known_args.yaml_pipeline_file: + search_paths.append(FileSystems.split(known_args.yaml_pipeline_file)[0]) + pipeline_yaml = yaml_transform.expand_jinja( - pipeline_template, known_args.jinja_variables or {}) + pipeline_template, known_args.jinja_variables or {}, search_paths) return known_args, pipeline_args, pipeline_template, pipeline_yaml diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index ce5a6e320589..324ae0c2e734 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -469,10 +469,14 @@ def _with_extra_dependencies(self, dependencies: Iterable[str]): @ExternalProvider.register_provider_type('yaml') class YamlProvider(Provider): - def __init__(self, transforms: Mapping[str, Mapping[str, Any]]): + def __init__( + self, + transforms: Mapping[str, Mapping[str, Any]], + provider_base_path: Optional[str] = None): if not isinstance(transforms, dict): raise ValueError('Transform mapping must be a dict.') self._transforms = transforms + self._provider_base_path = provider_base_path def available(self): return True @@ -524,7 +528,10 @@ def create_transform( else: body_str = yaml.safe_dump(SafeLineLoader.strip_metadata(body)) # Now re-parse resolved templatization. - body = yaml.load(expand_jinja(body_str, args), Loader=SafeLineLoader) + search_paths = [FileSystems.split(self._provider_base_path)[0] + ] if self._provider_base_path else [] + body = yaml.load( + expand_jinja(body_str, args, search_paths), Loader=SafeLineLoader) if (body.get('type') == 'chain' and 'input' not in body and spec.get('requires_inputs', True)): body['input'] = 'input' diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 2b745babad02..b8539cdb80e6 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -1391,19 +1391,65 @@ def validate_transform_references(spec): return spec +def strip_leading_comments(source: str) -> str: + lines = source.splitlines(keepends=True) + stripped_lines = [] + in_leading_comments = True + for line in lines: + stripped_line = line.lstrip() + if in_leading_comments: + if stripped_line.startswith('#') or not stripped_line: + continue + else: + in_leading_comments = False + stripped_lines.append(line) + return "".join(stripped_lines) + + class _BeamFileIOLoader(jinja2.BaseLoader): + def __init__(self, search_paths=()): + self.search_paths = list(search_paths) + def get_source(self, environment, path): - with FileSystems.open(path) as fin: - source = fin.read().decode() - return source, path, lambda: True + candidates = [] + if FileSystems.get_scheme(path) is not None or path.startswith('/'): + candidates.append(path) + else: + candidates.append(path) + for search_path in self.search_paths: + candidates.append(FileSystems.join(search_path, path)) + + for candidate in candidates: + try: + if FileSystems.exists(candidate): + with FileSystems.open(candidate) as fin: + source = fin.read().decode() + return strip_leading_comments(source), candidate, lambda: True + except Exception: + pass + + raise jinja2.TemplateNotFound(path) def expand_jinja( - jinja_template: str, jinja_variables: Mapping[str, Any]) -> str: + jinja_template: str, + jinja_variables: Mapping[str, Any], + search_paths: Iterable[str] = ()) -> str: + import apache_beam + beam_root_dir = os.path.dirname( + os.path.dirname(os.path.abspath(apache_beam.__file__))) + + all_search_paths = list(search_paths) + if beam_root_dir not in all_search_paths: + all_search_paths.append(beam_root_dir) + if '.' not in all_search_paths: + all_search_paths.append('.') + return ( # keep formatting jinja2.Environment( - undefined=jinja2.StrictUndefined, loader=_BeamFileIOLoader()) - .from_string(jinja_template) + undefined=jinja2.StrictUndefined, + loader=_BeamFileIOLoader(all_search_paths)) + .from_string(strip_leading_comments(jinja_template)) .render(datetime=datetime, **jinja_variables)) diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index bbb60b185c01..e6513b1525dd 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -22,13 +22,16 @@ import shutil import tempfile import unittest +import yaml import apache_beam as beam from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.utils import python_callable from apache_beam.yaml import yaml_provider +from apache_beam.yaml.yaml_transform import SafeLineLoader from apache_beam.yaml.yaml_transform import YamlTransform +from apache_beam.yaml.yaml_transform import expand_jinja try: import jsonschema @@ -1467,6 +1470,65 @@ def test_must_consume_error_output(self): ''', providers=merged_providers) + def test_provider_with_jinja_imports(self): + # Create a macro file in the same temp directory as the provider + macro_path = os.path.join(self.temp_dir, 'my_macros.yaml') + with open(macro_path, 'w') as f: + f.write( + """ +{%- macro power_expr(var, n) -%} +{{ var }} ** {{ n }} +{%- endmacro -%} +""") + + # Create a provider that imports and uses the macro + templated_provider_path = os.path.join( + self.temp_dir, 'templated_provider.yaml') + with open(templated_provider_path, 'w') as f: + f.write( + """ +- type: yaml + transforms: + CustomPower: + config_schema: + properties: + n: {type: integer} + body: | + type: MapToFields + config: + language: python + append: true + fields: + power: "{% import 'my_macros.yaml' as m %}{{ m.power_expr('element', n) }}" +""") + + loaded_providers = yaml_provider.load_providers(templated_provider_path) + test_providers = yaml_provider.InlineProvider(TEST_PROVIDERS) + merged_providers = yaml_provider.merge_providers( + loaded_providers, [test_providers]) + + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + results = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + config: + elements: [2, 3] + - type: CustomPower + input: Create + config: + n: 3 + output: CustomPower + ''', + providers=merged_providers) + + assert_that( + results, + equal_to( + [beam.Row(element=2, power=8), beam.Row(element=3, power=27)])) + @beam.transforms.ptransform.annotate_yaml class LinearTransform(beam.PTransform): @@ -1481,6 +1543,62 @@ def expand(self, pcoll): return pcoll | beam.Map(lambda x: a * x.element + b) +class TestYamlExpandJinja(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + # Create a macro file with leading comments (license header) + self.macro_path = os.path.join(self.temp_dir, 'my_macros.yaml') + with open(self.macro_path, 'w') as f: + f.write( + """# coding=utf-8 +# Licensed to the Apache Software Foundation... +# Some leading comment line + +{%- macro add_n(val, n) -%} +{{ val }} + {{ n }} +{%- endmacro -%} +""") + + # Create a pipeline template that includes/imports the macro + self.pipeline_path = os.path.join(self.temp_dir, 'my_pipeline.yaml') + with open(self.pipeline_path, 'w') as f: + f.write( + """# coding=utf-8 +# Licensed to the Apache Software Foundation... + +{% import 'my_macros.yaml' as macros %} +type: composite +transforms: + - type: Create + config: + elements: [1, 2, 3] + - type: MapToFields + config: + language: python + fields: + result: {{ macros.add_n('element', 10) }} +""") + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_expand_jinja_with_leading_comments_and_imports(self): + # Read the pipeline template + with open(self.pipeline_path, 'r') as f: + template_content = f.read() + + # Expand the jinja using our temp_dir as a search path + expanded = expand_jinja(template_content, {}, [self.temp_dir]) + + # Parse the expanded YAML + parsed = yaml.load(expanded, Loader=SafeLineLoader) + + # Verify the comment-stripping and import resolution was successful + self.assertEqual(parsed['type'], 'composite') + self.assertEqual( + parsed['transforms'][1]['config']['fields']['result'], 'element + 10') + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() From 72cee62c5e2e79ba41359ef06105b63b5467d466 Mon Sep 17 00:00:00 2001 From: Derrick Williams Date: Tue, 19 May 2026 18:13:42 +0000 Subject: [PATCH 2/4] minor tweaks --- sdks/python/apache_beam/yaml/yaml_transform.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index b8539cdb80e6..bd7c2fe3c388 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -1411,11 +1411,8 @@ def __init__(self, search_paths=()): self.search_paths = list(search_paths) def get_source(self, environment, path): - candidates = [] - if FileSystems.get_scheme(path) is not None or path.startswith('/'): - candidates.append(path) - else: - candidates.append(path) + candidates = [path] + if FileSystems.get_scheme(path) is None and not path.startswith('/'): for search_path in self.search_paths: candidates.append(FileSystems.join(search_path, path)) @@ -1435,9 +1432,8 @@ def expand_jinja( jinja_template: str, jinja_variables: Mapping[str, Any], search_paths: Iterable[str] = ()) -> str: - import apache_beam beam_root_dir = os.path.dirname( - os.path.dirname(os.path.abspath(apache_beam.__file__))) + os.path.dirname(os.path.abspath(beam.__file__))) all_search_paths = list(search_paths) if beam_root_dir not in all_search_paths: From 172d8c5ee9986735f5e3140c298426bb7fad75b3 Mon Sep 17 00:00:00 2001 From: Derrick Williams Date: Tue, 19 May 2026 19:00:28 +0000 Subject: [PATCH 3/4] fix lint --- sdks/python/apache_beam/yaml/yaml_transform_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index e6513b1525dd..53c68566f0c2 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -22,6 +22,7 @@ import shutil import tempfile import unittest + import yaml import apache_beam as beam From 7a9ce5013e54332442de77788f5047de7d069723 Mon Sep 17 00:00:00 2001 From: Derrick Williams Date: Wed, 20 May 2026 21:11:08 +0000 Subject: [PATCH 4/4] fix gemini comments --- sdks/python/apache_beam/yaml/yaml_transform.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index bd7c2fe3c388..a4bb9144f8e6 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -1412,18 +1412,20 @@ def __init__(self, search_paths=()): def get_source(self, environment, path): candidates = [path] - if FileSystems.get_scheme(path) is None and not path.startswith('/'): + if FileSystems.get_scheme(path) is None and not os.path.isabs(path): for search_path in self.search_paths: candidates.append(FileSystems.join(search_path, path)) for candidate in candidates: try: - if FileSystems.exists(candidate): - with FileSystems.open(candidate) as fin: - source = fin.read().decode() - return strip_leading_comments(source), candidate, lambda: True + exists = FileSystems.exists(candidate) except Exception: - pass + exists = False + + if exists: + with FileSystems.open(candidate) as fin: + source = fin.read().decode() + return strip_leading_comments(source), candidate, lambda: True raise jinja2.TemplateNotFound(path)