diff --git a/sdks/python/apache_beam/testing/benchmarks/inference/table_row_inference_benchmark.py b/sdks/python/apache_beam/testing/benchmarks/inference/table_row_inference_benchmark.py index b8591a0fea83..e3de24574391 100644 --- a/sdks/python/apache_beam/testing/benchmarks/inference/table_row_inference_benchmark.py +++ b/sdks/python/apache_beam/testing/benchmarks/inference/table_row_inference_benchmark.py @@ -72,32 +72,36 @@ def __init__(self): metrics_namespace=self.metrics_namespace, is_streaming=False, pcollection='RunInference/BeamML_RunInference_Postprocess-0.out0') - self.is_streaming = ((self.pipeline.get_option('mode') or - 'batch') == 'streaming') + self.opts = self.pipeline.get_pipeline_options().view_as( + TableRowInferenceOptions) + mode = self.opts.mode or 'batch' + self.is_streaming = mode == 'streaming' if self.is_streaming: - self.subscription = self.pipeline.get_option('input_subscription') + self.subscription = self.opts.input_subscription def test(self): """Execute the table row inference pipeline for benchmarking.""" - extra_opts = {} - - mode = self.pipeline.get_option('mode') or 'batch' - extra_opts['mode'] = mode + mode = self.opts.mode or 'batch' + extra_opts = {'mode': mode} if mode == 'streaming': - extra_opts['input_subscription'] = self.pipeline.get_option( - 'input_subscription') - extra_opts['window_size_sec'] = int( - self.pipeline.get_option('window_size_sec') or 60) - extra_opts['trigger_interval_sec'] = int( - self.pipeline.get_option('trigger_interval_sec') or 30) - else: - extra_opts['input_file'] = self.pipeline.get_option('input_file') - - for opt in ['output_table', 'model_path', 'feature_columns']: - val = self.pipeline.get_option(opt) - if val: - extra_opts[opt] = val + if self.opts.input_subscription: + extra_opts['input_subscription'] = self.opts.input_subscription + extra_opts['window_size_sec'] = ( + self.opts.window_size_sec + if self.opts.window_size_sec is not None else 60) + extra_opts['trigger_interval_sec'] = ( + self.opts.trigger_interval_sec + if self.opts.trigger_interval_sec is not None else 30) + elif self.opts.input_file: + extra_opts['input_file'] = self.opts.input_file + + if self.opts.output_table: + extra_opts['output_table'] = self.opts.output_table + if self.opts.model_path: + extra_opts['model_path'] = self.opts.model_path + if self.opts.feature_columns: + extra_opts['feature_columns'] = self.opts.feature_columns self.result = table_row_inference.run( self.pipeline.get_full_options_as_args(**extra_opts), diff --git a/sdks/python/apache_beam/testing/test_pipeline.py b/sdks/python/apache_beam/testing/test_pipeline.py index 712da8636234..1fe2d86e35b3 100644 --- a/sdks/python/apache_beam/testing/test_pipeline.py +++ b/sdks/python/apache_beam/testing/test_pipeline.py @@ -209,7 +209,9 @@ def get_option(self, opt_name, bool_option=False): None if option is not found in existing option list which is generated by parsing value of argument `test-pipeline-options`. """ - parser = argparse.ArgumentParser() + # Parse one flag at a time; disable prefix matching so e.g. --mode does + # not satisfy --model_path when both appear in options_list. + parser = argparse.ArgumentParser(allow_abbrev=False) opt_name = opt_name[:2] if opt_name[:2] == '--' else opt_name # Option name should start with '--' when it's used for parsing. if bool_option: