diff --git a/.gitignore b/.gitignore index 20d6890..10aea84 100644 --- a/.gitignore +++ b/.gitignore @@ -151,7 +151,5 @@ docs/temp/* /src-stats.yaml /config.yaml *.yaml.gz - -*_stories.py - /examples/opiates/* +./*_stories.py diff --git a/datafaker/install.py b/datafaker/install.py new file mode 100644 index 0000000..33cc9e1 --- /dev/null +++ b/datafaker/install.py @@ -0,0 +1,244 @@ +"""Functions to install Python file references in ``config.yaml``.""" +from collections.abc import Mapping, MutableMapping, Sequence +from inspect import Parameter, signature +from pathlib import Path +from typing import Any + +from datafaker.utils import import_file, logger + + +def _make_where_from_annotation( + query_def: Mapping[str, Any], + fn_name: str, + param_name: str, +) -> str: + """Make a where clause from ``query`` value from the annotation.""" + if "where" not in query_def: + return "" + w = query_def["where"] + if isinstance(w, str): + return f" WHERE {w}" + if isinstance(w, Sequence): + return " WHERE " + " AND ".join(f'"({clause})"' for clause in w) + logger.warning( + '"where" in the query annotation of parameter "%s" of function "%s"' + " needs to be a string or a list of strings", + param_name, + fn_name, + ) + return "" + + +def _make_vars_from_annotation( + query_def: Mapping[str, Any], + fn_name: str, + param_name: str, +) -> Mapping[str, Any]: + """Make a variables dict from ``query`` value from the annotation.""" + if "vars" not in query_def: + return {} + vars_def = query_def["vars"] + if isinstance(vars_def, Mapping): + return vars_def + if isinstance(vars_def, Sequence): + return {v: v for v in query_def["vars"]} + logger.warning( + '"vars" in the query annotation of parameter "%s" of function "%s"' + " needs to be a list of strings or a dict of strings to strings", + param_name, + fn_name, + ) + return {} + + +def _add_count_vars_from_annotation( + group_vars_out: MutableMapping[str, Any], + query_def: Mapping[str, Any], + fn_name: str, + param_name: str, +) -> None: + """Add ``GROUP BY`` clauses from ``count_vars``.""" + if "count_vars" not in query_def: + return + cntv = query_def["count_vars"] + if isinstance(cntv, Mapping): + group_vars_out.update({k: f"COUNT({v})" for k, v in cntv}) + return + logger.warning( + '"count_vars" needs to be a dict in the annotation for parameter %s of function %s', + param_name, + fn_name, + ) + + +def _add_ms_vars_from_annotation( + group_vars_out: MutableMapping[str, Any], + query_def: Mapping[str, Any], + fn_name: str, + param_name: str, +) -> None: + """Add ``GROUP BY`` clauses from ``ms_vars``.""" + if "ms_vars" not in query_def: + return + msv = query_def["ms_vars"] + if not isinstance(msv, Mapping): + logger.warning( + '"ms_vars" needs to be a dict in the annotation for parameter %s of function %s', + param_name, + fn_name, + ) + return + for k, v in msv.items(): + group_vars_out[k + "_count"] = f"COUNT({v})" + group_vars_out[k + "_mean"] = f"AVG({v})" + group_vars_out[k + "_stddev"] = f"STDDEV({v})" + + +def make_query_from_annotation( + annotation_data: Any, + fn_name: str, + param_name: str, +) -> str | None: + """ + Make new configuration items describing a query. + + The query's result will be passed as this parameter to this function. + + The annotation must be a dict with the following keys: + + ``comment``: A string describing the query in natural language. + + ``query``: Either a string containing the SQL query required, or + a dict containing the following keys: + + * ``table``: The table to query. Could be "tablename AS alias" if you like. + * ``vars`` (optional): Either a list of columns to extract from the table(s), + or a dict of keys (the names of the keys in the dict to be passed to the + annotated function) to values (the names of the columns to be extracted). + At least one of ``vars``, ``ms_vars``, ``count_vars`` must be present. + * ``where`` (optional): A SQL expression to filter the results. + * ``count_vars`` (optional): A dict of keys to be passed to the function + to values that are the names of the columns to be counted (could be + ``*``; if the name of a column the result will be the number of non-null + entries in that column). The query will be grouped by ``vars``. + * ``ms_vars`` (optional): A dict of value names to columns to be analysed. + The keys to be passed to the function will be name + ``_count`` for the + number of non-null values in that column, name + ``_mean`` for the + average value in that column and name + ``_stddev`` for the standard + deviation of values in that column. + + :param annotation_data: The ``Annotation`` attached to the parameter. + :param fn_name: The name of the function that the parameter is of. + :param param_name: The name of the parameter with the annotation. + :return: A mapping of new configuration items to add to the configuration, + if the annotation had a well-defined query and comment value; otherwise + an empty dict. + """ + if not isinstance(annotation_data, Sequence): + return None + ann = annotation_data[0] + if not isinstance(ann, Mapping) or "query" not in ann: + return None + if isinstance(ann["query"], str): + return ann["query"] + query_def = ann["query"] + if "table" not in query_def: + logger.warning( + '"table" needs to be a key in the annotation for' + ' the "query" value of parameter "%s" of function "%s"', + param_name, + fn_name, + ) + return None + table = query_def["table"] + nongroup_vars = _make_vars_from_annotation(query_def, fn_name, param_name) + where = _make_where_from_annotation(query_def, fn_name, param_name) + group_vars: dict[str, Any] = {} + _add_count_vars_from_annotation(group_vars, query_def, fn_name, param_name) + _add_ms_vars_from_annotation(group_vars, query_def, fn_name, param_name) + if group_vars and nongroup_vars: + group_by = " GROUP BY " + ", ".join(f'"{v}"' for v in nongroup_vars) + else: + group_by = "" + vars_exprs = ", ".join( + f'{v} AS "{k}"' for k, v in {**nongroup_vars, **group_vars}.items() + ) + return f"SELECT {vars_exprs} FROM {table}{group_by}{where}" + + +def _add_kwarg( + kwargs_out: dict[str, Any], fn_name: str, param: Parameter +) -> list[dict[str, Any]]: + """ + Add a kwargs configuration and return a ``src_stats`` query item. + + :param kwargs_out: The story generator's ``kwargs`` value to be updated. + :param fn_name: The name of the story generator function. + :param param: The parameter to specify. + :return: A list of configuration items to add to the ``src_stats`` config, for + all the queries this parameter requires. + """ + if param.annotation is Parameter.empty: + return [] + meta = param.annotation.__metadata__ + query = make_query_from_annotation( + param.annotation.__metadata__, fn_name, param.name + ) + if query is None: + return [] + stat_name = f"story_auto__{fn_name}__{param.name}" + if "comments" in meta[0]: + comments = [meta[0]["comment"]] + else: + comments = [] + ssc = { + "name": stat_name, + "query": query, + "comments": comments, + } + kwargs_out[param.name] = f'SRC_STATS["{stat_name}"]["results"]' + return [ssc] + + +def install_stories_from(config: MutableMapping[str, Any], story_file: Path) -> bool: + """ + Configure datafaker with the stories in a Python file. + + :param config: The contents of the configuration file, to be mutated. + :param story_file: Path to the Python file containing the story generators. + :return: True if the config was updated correctly, False if it was untouched + because problems were encountered. + """ + story_generators: list[Mapping[str, Any]] = [] + src_stats = [ + s + for s in config.get("src_stats", []) + if isinstance(s, Mapping) + and "name" in s + and not s["name"].startswith("story_auto__") + ] + story_module_name = story_file.stem + story_module = import_file(story_file, story_module_name) + for attr_name in dir(story_module): + attr = getattr(story_module, attr_name) + if ( + hasattr(attr, "__module__") + and attr.__module__ == story_module_name + and not attr_name.startswith("_") + and callable(attr) + ): + kwargs: dict[str, None] = {} + sig = signature(attr) + for param in sig.parameters.values(): + src_stats += _add_kwarg(kwargs, attr, param) + story_generators.append( + { + "name": f"{story_module_name}.{attr_name}", + "num_stories_per_pass": 1, + "kwargs": kwargs, + } + ) + config["story_generators_module"] = story_module_name + config["story_generators"] = story_generators + config["src-stats"] = src_stats + return True diff --git a/datafaker/main.py b/datafaker/main.py index baf2d49..a257ac5 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -24,13 +24,18 @@ TableWriter, get_parquet_table_writer, ) +from datafaker.install import install_stories_from from datafaker.interactive import ( update_config_generators, update_config_tables, update_missingness, ) from datafaker.interactive.base import DbCmd -from datafaker.make import make_src_stats, make_tables_file, make_vocabulary_tables +from datafaker.make import ( + make_src_stats, + make_tables_file, + make_vocabulary_tables, +) from datafaker.remove import remove_db_data, remove_db_tables, remove_db_vocab from datafaker.settings import ( SettingsError, @@ -265,7 +270,7 @@ def create_tables( $ datafaker create-tables """ logger.debug("Creating tables.") - config = read_config_file(config_file) if config_file is not None else {} + config = read_config_file(config_file) orm_metadata = load_metadata_for_output(orm_file, config) create_db_tables(orm_metadata) logger.debug("Tables created.") @@ -710,7 +715,7 @@ def remove_data( """Truncate non-vocabulary tables in the destination schema.""" if yes: logger.debug("Truncating non-vocabulary tables.") - config = read_config_file(config_file) if config_file is not None else {} + config = read_config_file(config_file) metadata = load_metadata_for_output(orm_file, config) remove_db_data(metadata, config) logger.debug("Non-vocabulary tables truncated.") @@ -737,7 +742,7 @@ def remove_vocab( """Truncate vocabulary tables in the destination schema.""" if yes: logger.debug("Truncating vocabulary tables.") - config = read_config_file(config_file) if config_file is not None else {} + config = read_config_file(config_file) meta_dict = load_metadata_config(orm_file, config) orm_metadata = dict_to_metadata(meta_dict, config) remove_db_vocab(orm_metadata, meta_dict, config) @@ -812,7 +817,7 @@ def list_tables( tables: TableType = Option(TableType.GENERATED, help="Which tables to list"), ) -> None: """List the names of tables described in the metadata file.""" - config = read_config_file(config_file) if config_file is not None else {} + config = read_config_file(config_file) orm_metadata = load_metadata(orm_file, config) all_table_names = set(orm_metadata.tables.keys()) vocab_table_names = { @@ -830,6 +835,26 @@ def list_tables( print(name) +@app.command() +def install_stories( + config_file: Path = Option(CONFIG_FILENAME, help="The configuration file"), + story_file: Path = Argument(help="The Python file containing stories"), +) -> None: + """Add the story file's name and any contained query to the configuration file.""" + config_file_path = Path(config_file) + config = {} + if config_file_path.exists(): + config = yaml.load( + config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + ) + if not install_stories_from(config, story_file): + logger.debug("Cancelled") + sys.exit(1) + content = yaml.dump(config) + config_file_path.write_text(content, encoding="utf-8") + logger.debug("Stories configured in %s.", config_file) + + @app.command() def version() -> None: """Display version information.""" diff --git a/datafaker/utils.py b/datafaker/utils.py index c8bcc97..dd02ef2 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -67,7 +67,7 @@ def read_config_file(path: Path) -> dict: return config -def import_file(file_path: str) -> ModuleType: +def import_file(file_path: str | Path, module_name: str = "df") -> ModuleType: """Import a file. This utility function returns file_path imported as a module. @@ -78,7 +78,7 @@ def import_file(file_path: str) -> ModuleType: Returns: ModuleType """ - spec = importlib.util.spec_from_file_location("df", file_path) + spec = importlib.util.spec_from_file_location(module_name, file_path) if spec is None or spec.loader is None: raise SettingsError(f"No loadable module '{file_path}'") module = importlib.util.module_from_spec(spec) diff --git a/examples/mimic_omop/README.md b/examples/mimic_omop/README.md index bd04883..7bce1a4 100644 --- a/examples/mimic_omop/README.md +++ b/examples/mimic_omop/README.md @@ -8,10 +8,6 @@ `poetry run datafaker create-tables --orm-file ./examples/mimic_omop/orm.yaml --config-file ./examples/mimic_omop/config.yaml` -1. Create generator table - -`poetry run datafaker create-generators --orm-file ./examples/mimic_omop/orm.yaml --config-file ./examples/mimic_omop/config.yaml --df-file ./examples/mimic_omop/df.py` - 1. Create data `poetry run datafaker create-data --orm-file ./examples/mimic_omop/orm.yaml --config-file ./examples/mimic_omop/config.yaml --df-file .\examples\mimic_omop\df.py` diff --git a/tests/examples/annotated_stories.py b/tests/examples/annotated_stories.py new file mode 100644 index 0000000..fcfce1e --- /dev/null +++ b/tests/examples/annotated_stories.py @@ -0,0 +1,28 @@ +"""Story generators which describe their own queries and can therefore be installed.""" +from collections.abc import Iterable +from typing import Annotated, Any + +def string_story_one_sd( + stats: Annotated[dict, { + "query": { + "ms_vars": {"freq": "frequency"}, + "table": "string", + }, + "comment": "Frequency mean and standard deviation", + }], +) -> Iterable[tuple[str, dict[str, Any]]]: + man = yield("manufacturer", {"name": "one"}) + mod = yield ("model", { + "name": "one_sd", + "manufacturer_id": man["id"] + }) + yield("string", { + "model_id": mod["id"], + "position": 0, + "frequency": stats[0]["freq_mean"] - stats[0]["freq_stddev"], + }) + yield("string", { + "model_id": mod["id"], + "position": stats[0]["freq_count"], + "frequency": stats[0]["freq_mean"] + stats[0]["freq_stddev"], + }) diff --git a/tests/examples/install_config.yaml b/tests/examples/install_config.yaml new file mode 100644 index 0000000..c155d2b --- /dev/null +++ b/tests/examples/install_config.yaml @@ -0,0 +1,3 @@ +tables: + string: + num_rows_per_pass: 0 diff --git a/tests/test_install_stories.py b/tests/test_install_stories.py new file mode 100644 index 0000000..9ed7bd3 --- /dev/null +++ b/tests/test_install_stories.py @@ -0,0 +1,253 @@ +"""Tests for installing stories into ``config.yaml``.""" +import os +import re +import shutil +import tempfile +from pathlib import Path +from typing import Any, Mapping + +import yaml +from sqlalchemy import Row, func, select, text +from typer.testing import CliRunner, Result + +from datafaker.main import app, install_stories +from tests.utils import GeneratesDBTestCase, create_db_engine, get_sync_engine + +# pylint: disable=subprocess-run-check + + +class InstallTestCase(GeneratesDBTestCase): + """End-to-end tests that require a database.""" + + dump_file_path = "instrument.sql" + database_name = "instrument" + schema_name = "public" + + examples_dir = Path("examples") + use_temporary_cwd = True + copy_files: list[str] = ["annotated_stories.py", "install_config.yaml"] + copy_from_directory = examples_dir + + orm_file_path = Path("orm.yaml") + + input_file_paths = [Path("annotated_stories.py"), Path("install_config.yaml")] + stats_file_path = Path("example_stats.yaml") + + src_stats_re = re.compile(r'SRC_STATS\["(.*)"\]\["results"\]') + + def setUp(self) -> None: + """Pre-test setup.""" + super().setUp() + self.env = { + "src_dsn": self.dsn, + "src_schema": self.schema_name, + "dst_dsn": self.dsn, + "dst_schema": "dstschema", + } + self.runner = CliRunner( + mix_stderr=False, + env=self.env, + ) + + def tearDown(self) -> None: + """Tear down post test.""" + os.chdir(self.start_dir) + super().tearDown() + + def assert_silent_success(self, completed_process: Result) -> None: + """Assert that the process completed successfully without producing output.""" + self.assertNoException(completed_process) + self.assertSuccess(completed_process) + self.assertEqual(completed_process.stderr, "") + self.assertEqual(completed_process.stdout, "") + + def test_install_stories_simple(self) -> None: + """Test story gets expected parameters after installation.""" + config_path = Path("config-iss.yaml") + config_path.write_text("{}", encoding="UTF-8") + + install_stories(config_path, Path("annotated_stories.py")) + + config = yaml.load( + config_path.read_text(encoding="UTF-8"), + Loader=yaml.SafeLoader, + ) + + # Module name configured + self.assertIn("story_generators_module", config) + self.assertEqual(config["story_generators_module"], "annotated_stories") + + # Generator added with parameter + self.assertIn("story_generators", config) + st_gen = config["story_generators"] + self.assertEqual(len(st_gen), 1) + self.assertIn("name", st_gen[0]) + self.assertEqual(st_gen[0]["name"], "annotated_stories.string_story_one_sd") + self.assertIn("kwargs", st_gen[0]) + self.assertIn("stats", st_gen[0]["kwargs"]) + stats_ref = st_gen[0]["kwargs"]["stats"] + stats_result = self.src_stats_re.match(stats_ref) + self.assertIsNotNone( + stats_result, f'parameter "{stats_ref}" is not a SRC_STATS reference' + ) + assert stats_result is not None + + # Source stats query + self.assertIn("src-stats", config) + src_stats = config["src-stats"] + assert src_stats is not None + self.assertEqual(len(src_stats), 1) + assert src_stats[0] is not None + self.assertIn("name", src_stats[0]) + self.assertEqual(src_stats[0]["name"], stats_result.group(1)) + self.assertIn("query", src_stats[0]) + query = src_stats[0]["query"] + (mean, stddev, _count) = self.get_string_stats() + + # Let's run the query and see what we get. + engine = get_sync_engine( + create_db_engine( + self.env["src_dsn"], + schema_name=self.env["src_schema"], + ) + ) + with engine.connect() as conn: + rows = conn.execute(text(query)).fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0].freq_mean, mean) + self.assertEqual(rows[0].freq_stddev, stddev) + + def test_install_stories_end_to_end(self) -> None: + """Test the stories run with the expected parameters after installation.""" + completed_process = self.invoke( + "make-tables", + "--force", + ) + self.assert_silent_success(completed_process) + + completed_process = self.invoke( + "install-stories", + "--config-file", + "install_config.yaml", + "annotated_stories.py", + ) + self.assert_silent_success(completed_process) + + completed_process = self.invoke( + "make-stats", + "--config-file", + "install_config.yaml", + "--force", + ) + self.assert_silent_success(completed_process) + + completed_process = self.invoke( + "remove-tables", + "--config-file", + "install_config.yaml", + "--yes", + ) + self.assert_silent_success(completed_process) + + completed_process = self.invoke( + "create-tables", + "--config-file", + "install_config.yaml", + ) + self.assert_silent_success(completed_process) + + completed_process = self.invoke( + "create-data", + "--config-file", + "install_config.yaml", + ) + self.assertNoException(completed_process) + self.assertEqual("", completed_process.stderr) + self.assertSuccess(completed_process) + self.assertEqual( + "Generating data for story 'annotated_stories.string_story_one_sd'\n", + completed_process.stdout, + ) + + (mean, stddev, count) = self.get_string_stats() + + model_table = self.metadata.tables["model"] + string_table = self.metadata.tables["string"] + engine = get_sync_engine( + create_db_engine( + self.env["dst_dsn"], + schema_name=self.env["dst_schema"], + ) + ) + with engine.connect() as conn: + row = conn.execute( + select(model_table.c.name, model_table.c.id).where( + model_table.c.name == "one_sd" + ) + ).fetchone() + assert row is not None + strs = conn.execute( + select(string_table).where(string_table.c.model_id == row.id) + ).fetchall() + lower = None + higher = None + for s in strs: + if s.position == 0: + self.assertIsNone( + lower, "Multiple one_sd strings with zero position" + ) + lower = s.frequency + else: + self.assertIsNone( + higher, "Multiple one_sd strings with non-zero position" + ) + self.assertEqual(s.position, count) + higher = s.frequency + assert lower is not None + assert higher is not None + self.assertAlmostEqual((higher + lower) / 2, mean) + self.assertAlmostEqual((higher - lower) / 2, stddev) + + def get_string_stats(self) -> tuple[float | None, float | None, int | None]: + """Get the mean, standard deviation and count of frequencies in the string table.""" + string_table = self.metadata.tables["string"] + engine = get_sync_engine( + create_db_engine( + self.env["src_dsn"], + schema_name=self.env["src_schema"], + ) + ) + with engine.connect() as conn: + results = conn.execute( + select( + func.count(), # pylint: disable=not-callable + func.avg(string_table.c.frequency), + func.stddev(string_table.c.frequency), + ) + ).fetchone() + if not isinstance(results, Row): + return None, None, None + return results.avg_1, results.stddev_1, results.count_1 + + def invoke( + self, + *args: Any, + expected_error: str | None = None, + env: Mapping[str, str] | None = None, + ) -> Result: + """ + Run datafaker with the given arguments and environment. + + :param args: Arguments to provide to datafaker. + :param expected_error: If None, will assert that the invocation + passes successfully without throwing an exception. Otherwise, + the suggested error must be present in the standard error stream. + :param env: The environment variables to be set during invocation. + """ + res = self.runner.invoke(app, args, env=env) + if expected_error is None: + self.assertNoException(res) + self.assertSuccess(res) + else: + self.assertIn(expected_error, res.stderr) + return res diff --git a/tests/test_main.py b/tests/test_main.py index 8d80ed6..a520253 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -23,6 +23,31 @@ class TestCLI(DatafakerTestCase): """Tests for the command-line interface.""" + @patch("datafaker.main.read_config_file") + @patch("datafaker.main.dict_to_metadata") + @patch("datafaker.main.load_metadata_config") + @patch("datafaker.main.create_db_vocab") + def test_create_vocab( + self, + mock_create: MagicMock, + mock_mdict: MagicMock, + mock_meta: MagicMock, + mock_config: MagicMock, + ) -> None: + """Test the create-vocab sub-command.""" + result = runner.invoke( + app, + [ + "create-vocab", + ], + catch_exceptions=False, + ) + + mock_create.assert_called_once_with( + mock_meta.return_value, mock_mdict.return_value, mock_config.return_value + ) + self.assertSuccess(result) + @patch("datafaker.main.create_db_tables") @patch("datafaker.main.read_config_file") @patch("datafaker.main.load_metadata_for_output") diff --git a/tests/utils.py b/tests/utils.py index 427ba67..0de8263 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -467,8 +467,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.generators_file_path = "" self.stats_fd = 0 - self.stats_file_path = "" - self.config_file_path = "" + self.stats_file_path = Path("") + self.config_file_path = Path("") self.config_fd = 0 self.dst_database: TestDatabaseBase | None = None @@ -476,7 +476,8 @@ def setUp(self) -> None: """Set up the test case with an actual orm.yaml file.""" super().setUp() # Generate the `orm.yaml` from the database - (self.orm_fd, self.orm_file_path) = mkstemp(".yaml", "orm_", text=True) + (self.orm_fd, orm_file_path) = mkstemp(".yaml", "orm_", text=True) + self.orm_file_path = Path(orm_file_path) with os.fdopen(self.orm_fd, "w", encoding="utf-8") as orm_fh: orm_fh.write(make_tables_file(self.dsn, self.schema_name)) # Create a separate empty destination database @@ -484,7 +485,8 @@ def setUp(self) -> None: def set_configuration(self, config: Mapping[str, Any]) -> None: """Accepts a configuration file, writes it out.""" - (self.config_fd, self.config_file_path) = mkstemp(".yaml", "config_", text=True) + (self.config_fd, config_file_path) = mkstemp(".yaml", "config_", text=True) + self.config_file_path = Path(config_file_path) with os.fdopen(self.config_fd, "w", encoding="utf-8") as config_fh: config_fh.write(yaml.dump(config)) @@ -499,9 +501,8 @@ def get_src_stats(self, config: Mapping[str, Any]) -> dict[str, Any]: make_src_stats(self.dsn, config, self.schema_name) ) loop.close() - (self.stats_fd, self.stats_file_path) = mkstemp( - ".yaml", "src_stats_", text=True - ) + (self.stats_fd, stats_file_path) = mkstemp(".yaml", "src_stats_", text=True) + self.stats_file_path = Path(stats_file_path) with os.fdopen(self.stats_fd, "w", encoding="utf-8") as stats_fh: stats_fh.write(yaml.dump(src_stats)) return src_stats