Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,32 +72,36 @@ def __init__(self):
metrics_namespace=self.metrics_namespace,
is_streaming=False,
pcollection='RunInference/BeamML_RunInference_Postprocess-0.out0')
self.is_streaming = ((self.pipeline.get_option('mode') or
'batch') == 'streaming')
self.opts = self.pipeline.get_pipeline_options().view_as(
TableRowInferenceOptions)
mode = self.opts.mode or 'batch'
self.is_streaming = mode == 'streaming'
if self.is_streaming:
self.subscription = self.pipeline.get_option('input_subscription')
self.subscription = self.opts.input_subscription

def test(self):
"""Execute the table row inference pipeline for benchmarking."""
extra_opts = {}

mode = self.pipeline.get_option('mode') or 'batch'
extra_opts['mode'] = mode
mode = self.opts.mode or 'batch'
extra_opts = {'mode': mode}

if mode == 'streaming':
extra_opts['input_subscription'] = self.pipeline.get_option(
'input_subscription')
extra_opts['window_size_sec'] = int(
self.pipeline.get_option('window_size_sec') or 60)
extra_opts['trigger_interval_sec'] = int(
self.pipeline.get_option('trigger_interval_sec') or 30)
else:
extra_opts['input_file'] = self.pipeline.get_option('input_file')

for opt in ['output_table', 'model_path', 'feature_columns']:
val = self.pipeline.get_option(opt)
if val:
extra_opts[opt] = val
if self.opts.input_subscription:
extra_opts['input_subscription'] = self.opts.input_subscription
extra_opts['window_size_sec'] = (
self.opts.window_size_sec
if self.opts.window_size_sec is not None else 60)
extra_opts['trigger_interval_sec'] = (
self.opts.trigger_interval_sec
if self.opts.trigger_interval_sec is not None else 30)
elif self.opts.input_file:
extra_opts['input_file'] = self.opts.input_file

if self.opts.output_table:
extra_opts['output_table'] = self.opts.output_table
if self.opts.model_path:
extra_opts['model_path'] = self.opts.model_path
if self.opts.feature_columns:
extra_opts['feature_columns'] = self.opts.feature_columns

self.result = table_row_inference.run(
self.pipeline.get_full_options_as_args(**extra_opts),
Expand Down
4 changes: 3 additions & 1 deletion sdks/python/apache_beam/testing/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def get_option(self, opt_name, bool_option=False):
None if option is not found in existing option list which is generated
by parsing value of argument `test-pipeline-options`.
"""
parser = argparse.ArgumentParser()
# Parse one flag at a time; disable prefix matching so e.g. --mode does
# not satisfy --model_path when both appear in options_list.
parser = argparse.ArgumentParser(allow_abbrev=False)
opt_name = opt_name[:2] if opt_name[:2] == '--' else opt_name
# Option name should start with '--' when it's used for parsing.
if bool_option:
Expand Down
Loading