Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sdks/python/apache_beam/yaml/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,13 @@ def _build_pipeline_yaml_from_argv(argv):
argv = _preparse_jinja_flags(argv)
known_args, pipeline_args = _parse_arguments(argv)
pipeline_template = _pipeline_spec_from_args(known_args)

search_paths = []
if known_args.yaml_pipeline_file:
search_paths.append(FileSystems.split(known_args.yaml_pipeline_file)[0])

pipeline_yaml = yaml_transform.expand_jinja(
pipeline_template, known_args.jinja_variables or {})
pipeline_template, known_args.jinja_variables or {}, search_paths)
return known_args, pipeline_args, pipeline_template, pipeline_yaml


Expand Down
11 changes: 9 additions & 2 deletions sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,14 @@ def _with_extra_dependencies(self, dependencies: Iterable[str]):

@ExternalProvider.register_provider_type('yaml')
class YamlProvider(Provider):
def __init__(self, transforms: Mapping[str, Mapping[str, Any]]):
def __init__(
self,
transforms: Mapping[str, Mapping[str, Any]],
provider_base_path: Optional[str] = None):
if not isinstance(transforms, dict):
raise ValueError('Transform mapping must be a dict.')
self._transforms = transforms
self._provider_base_path = provider_base_path

def available(self):
return True
Expand Down Expand Up @@ -524,7 +528,10 @@ def create_transform(
else:
body_str = yaml.safe_dump(SafeLineLoader.strip_metadata(body))
# Now re-parse resolved templatization.
body = yaml.load(expand_jinja(body_str, args), Loader=SafeLineLoader)
search_paths = [FileSystems.split(self._provider_base_path)[0]
] if self._provider_base_path else []
body = yaml.load(
expand_jinja(body_str, args, search_paths), Loader=SafeLineLoader)
if (body.get('type') == 'chain' and 'input' not in body and
spec.get('requires_inputs', True)):
body['input'] = 'input'
Expand Down
56 changes: 50 additions & 6 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,19 +1391,63 @@ def validate_transform_references(spec):
return spec


def strip_leading_comments(source: str) -> str:
lines = source.splitlines(keepends=True)
stripped_lines = []
in_leading_comments = True
for line in lines:
stripped_line = line.lstrip()
if in_leading_comments:
if stripped_line.startswith('#') or not stripped_line:
continue
else:
in_leading_comments = False
stripped_lines.append(line)
return "".join(stripped_lines)
Comment thread
derrickaw marked this conversation as resolved.


class _BeamFileIOLoader(jinja2.BaseLoader):
def __init__(self, search_paths=()):
self.search_paths = list(search_paths)

def get_source(self, environment, path):
with FileSystems.open(path) as fin:
source = fin.read().decode()
return source, path, lambda: True
candidates = [path]
if FileSystems.get_scheme(path) is None and not os.path.isabs(path):
for search_path in self.search_paths:
candidates.append(FileSystems.join(search_path, path))

for candidate in candidates:
try:
exists = FileSystems.exists(candidate)
except Exception:
exists = False

if exists:
with FileSystems.open(candidate) as fin:
source = fin.read().decode()
return strip_leading_comments(source), candidate, lambda: True

raise jinja2.TemplateNotFound(path)


def expand_jinja(
jinja_template: str, jinja_variables: Mapping[str, Any]) -> str:
jinja_template: str,
jinja_variables: Mapping[str, Any],
search_paths: Iterable[str] = ()) -> str:
beam_root_dir = os.path.dirname(
os.path.dirname(os.path.abspath(beam.__file__)))

all_search_paths = list(search_paths)
if beam_root_dir not in all_search_paths:
all_search_paths.append(beam_root_dir)
if '.' not in all_search_paths:
all_search_paths.append('.')

return ( # keep formatting
jinja2.Environment(
undefined=jinja2.StrictUndefined, loader=_BeamFileIOLoader())
.from_string(jinja_template)
undefined=jinja2.StrictUndefined,
loader=_BeamFileIOLoader(all_search_paths))
.from_string(strip_leading_comments(jinja_template))
.render(datetime=datetime, **jinja_variables))


Expand Down
119 changes: 119 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@
import tempfile
import unittest

import yaml
Comment thread
derrickaw marked this conversation as resolved.

import apache_beam as beam
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.utils import python_callable
from apache_beam.yaml import yaml_provider
from apache_beam.yaml.yaml_transform import SafeLineLoader
from apache_beam.yaml.yaml_transform import YamlTransform
from apache_beam.yaml.yaml_transform import expand_jinja

try:
import jsonschema
Expand Down Expand Up @@ -1467,6 +1471,65 @@ def test_must_consume_error_output(self):
''',
providers=merged_providers)

def test_provider_with_jinja_imports(self):
# Create a macro file in the same temp directory as the provider
macro_path = os.path.join(self.temp_dir, 'my_macros.yaml')
with open(macro_path, 'w') as f:
f.write(
"""
{%- macro power_expr(var, n) -%}
{{ var }} ** {{ n }}
{%- endmacro -%}
""")

# Create a provider that imports and uses the macro
templated_provider_path = os.path.join(
self.temp_dir, 'templated_provider.yaml')
with open(templated_provider_path, 'w') as f:
f.write(
"""
- type: yaml
transforms:
CustomPower:
config_schema:
properties:
n: {type: integer}
body: |
type: MapToFields
config:
language: python
append: true
fields:
power: "{% import 'my_macros.yaml' as m %}{{ m.power_expr('element', n) }}"
""")

loaded_providers = yaml_provider.load_providers(templated_provider_path)
test_providers = yaml_provider.InlineProvider(TEST_PROVIDERS)
merged_providers = yaml_provider.merge_providers(
loaded_providers, [test_providers])

with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
results = p | YamlTransform(
'''
type: composite
transforms:
- type: Create
config:
elements: [2, 3]
- type: CustomPower
input: Create
config:
n: 3
output: CustomPower
''',
providers=merged_providers)

assert_that(
results,
equal_to(
[beam.Row(element=2, power=8), beam.Row(element=3, power=27)]))


@beam.transforms.ptransform.annotate_yaml
class LinearTransform(beam.PTransform):
Expand All @@ -1481,6 +1544,62 @@ def expand(self, pcoll):
return pcoll | beam.Map(lambda x: a * x.element + b)


class TestYamlExpandJinja(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
# Create a macro file with leading comments (license header)
self.macro_path = os.path.join(self.temp_dir, 'my_macros.yaml')
with open(self.macro_path, 'w') as f:
f.write(
"""# coding=utf-8
# Licensed to the Apache Software Foundation...
# Some leading comment line

{%- macro add_n(val, n) -%}
{{ val }} + {{ n }}
{%- endmacro -%}
""")

# Create a pipeline template that includes/imports the macro
self.pipeline_path = os.path.join(self.temp_dir, 'my_pipeline.yaml')
with open(self.pipeline_path, 'w') as f:
f.write(
"""# coding=utf-8
# Licensed to the Apache Software Foundation...

{% import 'my_macros.yaml' as macros %}
type: composite
transforms:
- type: Create
config:
elements: [1, 2, 3]
- type: MapToFields
config:
language: python
fields:
result: {{ macros.add_n('element', 10) }}
""")

def tearDown(self):
shutil.rmtree(self.temp_dir)

def test_expand_jinja_with_leading_comments_and_imports(self):
# Read the pipeline template
with open(self.pipeline_path, 'r') as f:
template_content = f.read()

# Expand the jinja using our temp_dir as a search path
expanded = expand_jinja(template_content, {}, [self.temp_dir])

# Parse the expanded YAML
parsed = yaml.load(expanded, Loader=SafeLineLoader)

# Verify the comment-stripping and import resolution was successful
self.assertEqual(parsed['type'], 'composite')
self.assertEqual(
parsed['transforms'][1]['config']['fields']['result'], 'element + 10')


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
Loading