From c735db05dba54233d75b5965da0f07368e694fb7 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 27 May 2025 23:37:31 -0400 Subject: [PATCH 1/5] feat: Redo data setting in simulation --- policyengine/constants.py | 27 ------- policyengine/simulation.py | 107 +++++++++++++++++++--------- policyengine/utils/data/datasets.py | 48 +++++++++++++ tests/fixtures/simulation.py | 61 ++++++++++++++++ tests/test_simulation.py | 55 ++++++++++++++ 5 files changed, 239 insertions(+), 59 deletions(-) delete mode 100644 policyengine/constants.py create mode 100644 policyengine/utils/data/datasets.py create mode 100644 tests/fixtures/simulation.py create mode 100644 tests/test_simulation.py diff --git a/policyengine/constants.py b/policyengine/constants.py deleted file mode 100644 index de9b6799..00000000 --- a/policyengine/constants.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Mainly simulation options and parameters.""" - -from policyengine_core.data import Dataset -from policyengine.utils.data_download import download -from typing import Tuple, Optional - -EFRS_2022 = "gcs://policyengine-uk-data-private/enhanced_frs_2022_23.h5" -FRS_2022 = "gcs://policyengine-uk-data-private/frs_2022_23.h5" -CPS_2023_POOLED = "gcs://policyengine-us-data/pooled_3_year_cps_2023.h5" -CPS_2023 = "gcs://policyengine-us-data/cps_2023.h5" -ECPS_2024 = "gcs://policyengine-us-data/ecps_2024.h5" - - -def get_default_dataset( - country: str, region: str, version: Optional[str] = None -) -> str: - if country == "uk": - return EFRS_2022 - elif country == "us": - if region is not None and region != "us": - return CPS_2023_POOLED - else: - return CPS_2023 - - raise ValueError( - f"Unable to select a default dataset for country {country} and region {region}." - ) diff --git a/policyengine/simulation.py b/policyengine/simulation.py index 670cff67..67a9cca6 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -1,8 +1,9 @@ """Simulate tax-benefit policy and derive society-level output statistics.""" +import sys from pydantic import BaseModel, Field from typing import Literal -from .constants import get_default_dataset +from .utils.data.datasets import get_default_dataset, process_gs_path, POLICYENGINE_DATASETS, DATASET_TIME_PERIODS from policyengine_core.simulations import Simulation as CountrySimulation from policyengine_core.simulations import ( Microsimulation as CountryMicrosimulation, @@ -31,8 +32,8 @@ CountryType = Literal["uk", "us"] ScopeType = Literal["household", "macro"] DataType = ( - str | dict | Any | None -) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason. + str | Dataset | None +) TimePeriodType = int ReformType = ParametricReform | Type[StructuralReform] | None RegionType = Optional[str] @@ -72,6 +73,10 @@ class SimulationOptions(BaseModel): description="The version of the data used in the simulation. If not provided, the current data version will be used. If provided, this package will throw an error if the data version does not match. Use this as an extra safety check.", ) + model_config = { + "arbitrary_types_allowed": True, + } + class Simulation: """Simulate tax-benefit policy and derive society-level output statistics.""" @@ -89,7 +94,8 @@ class Simulation: def __init__(self, **options: SimulationOptions): self.options = SimulationOptions(**options) self.check_model_version() - self._set_data() + if not isinstance(self.options.data, Dataset): + self._set_data(self.options.data) self._initialise_simulations() self.check_data_version() self._add_output_functions() @@ -125,39 +131,42 @@ def _add_output_functions(self): wrapped_func, ) - def _set_data(self): - if self.options.data is None: - self.options.data = get_default_dataset( + def _set_data(self, file_address: str | None = None) -> None: + + # filename refers to file's unique name + extension; + # file_address refers to URI + filename + + # If None is passed, user wants default dataset; get URL, then continue initializing. + if file_address is None: + file_address = get_default_dataset( country=self.options.country, - region=self.options.region, + region=self.options.region + ) + print( + f"No data provided, using default dataset: {file_address}", + file=sys.stderr, ) - if isinstance(self.options.data, str): - filename = self.options.data - if self.options.data[:6] == "gcs://": - bucket, filename = self.options.data.split("://")[-1].split( - "/" - ) - version = self.options.data_version + if file_address not in POLICYENGINE_DATASETS: + # If it's a local file, no URI present and unable to infer version. + filename = file_address + version = None - file_path, version = download( - filepath=filename, - gcs_bucket=bucket, - version=version, - return_version=True, - ) - self.data_version = version - filename = str(Path(file_path)) - else: - # If it's a local file, we can't infer the version. - version = None - if "cps_2023" in filename: - time_period = 2023 - else: - time_period = None - self.options.data = Dataset.from_file( - filename, time_period=time_period + else: + # All official PolicyEngine datasets are stored in GCS; + # load accordingly + filename, version = self._set_data_from_gs( + file_address ) + self.data_version = version + + time_period = self._set_data_time_period( + file_address + ) + + self.options.data = Dataset.from_file( + filename, time_period=time_period + ) def _initialise_simulations(self): self.baseline_simulation = self._initialise_simulation( @@ -361,3 +370,37 @@ def check_data_version(self) -> None: raise ValueError( f"Data version {self.data_version} does not match expected version {self.options.data_version}." ) + + def _set_data_time_period(self, file_address: str) -> Optional[int]: + """ + Set the time period based on the file address. + If the file address is a PE dataset, return the time period from the dataset. + If it's a local file, return None. + """ + if file_address in DATASET_TIME_PERIODS: + return DATASET_TIME_PERIODS[file_address] + else: + # Local file, no time period available + return None + + def _set_data_from_gs( + self, file_address: str + ) -> tuple[str, str | None]: + """ + Set the data from a GCS path and return the filename and version. + """ + + bucket, filename = process_gs_path(file_address) + version = self.options.data_version + + print(f"Downloading {filename} from bucket {bucket}", file=sys.stderr) + + filepath, version = download( + filepath=filename, + gcs_bucket=bucket, + version=version, + return_version=True, + ) + + return filename, version + \ No newline at end of file diff --git a/policyengine/utils/data/datasets.py b/policyengine/utils/data/datasets.py new file mode 100644 index 00000000..079f1ee1 --- /dev/null +++ b/policyengine/utils/data/datasets.py @@ -0,0 +1,48 @@ +"""Mainly simulation options and parameters.""" + +from typing import Tuple, Optional + +EFRS_2022 = "gs://policyengine-uk-data-private/enhanced_frs_2022_23.h5" +FRS_2022 = "gs://policyengine-uk-data-private/frs_2022_23.h5" +CPS_2023 = "gs://policyengine-us-data/cps_2023.h5" +CPS_2023_POOLED = "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" +ECPS_2024 = "gs://policyengine-us-data/ecps_2024.h5" + +POLICYENGINE_DATASETS = [ + EFRS_2022, + FRS_2022, + CPS_2023, + CPS_2023_POOLED, + ECPS_2024, +] + +# Contains datasets that map to particular time_period values +DATASET_TIME_PERIODS = { + CPS_2023: 2023, + CPS_2023_POOLED: 2023, + ECPS_2024: 2023, +} + +def get_default_dataset( + country: str, region: str, version: Optional[str] = None +) -> str: + if country == "uk": + return EFRS_2022 + elif country == "us": + if region is not None and region != "us": + return CPS_2023_POOLED + else: + return CPS_2023 + + raise ValueError( + f"Unable to select a default dataset for country {country} and region {region}." + ) + +def process_gs_path(path: str) -> Tuple[str, str]: + """Process a GS path to return bucket and object.""" + if not path.startswith("gs://"): + raise ValueError(f"Invalid GS path: {path}") + + path = path[5:] # Remove 'gs://' + bucket, obj = path.split("/", 1) + return bucket, obj \ No newline at end of file diff --git a/tests/fixtures/simulation.py b/tests/fixtures/simulation.py new file mode 100644 index 00000000..0dbb9554 --- /dev/null +++ b/tests/fixtures/simulation.py @@ -0,0 +1,61 @@ +from policyengine.simulation import SimulationOptions +from unittest.mock import patch, Mock +import pytest +from policyengine.utils.data.datasets import CPS_2023 + +non_data_uk_sim_options = { + "country": "uk", + "scope": "macro", + "region": "uk", + "time_period": 2025, + "reform": None, + "baseline": None, +} + +non_data_us_sim_options = { + "country": "us", + "scope": "macro", + "region": "us", + "time_period": 2025, + "reform": None, + "baseline": None, +} + +uk_sim_options_no_data = SimulationOptions.model_validate({ + **non_data_uk_sim_options, + "data": None, +}) + +us_sim_options_cps_dataset = SimulationOptions.model_validate({ + **non_data_us_sim_options, + "data": CPS_2023 +}) + +SAMPLE_DATASET_FILENAME = "sample_value.h5" +SAMPLE_DATASET_BUCKET_NAME = "policyengine-uk-data-private" +SAMPLE_DATASET_URI_PREFIX = "gs://" +SAMPLE_DATASET_FILE_ADDRESS = f"{SAMPLE_DATASET_URI_PREFIX}{SAMPLE_DATASET_BUCKET_NAME}/{SAMPLE_DATASET_FILENAME}" + +uk_sim_options_pe_dataset = SimulationOptions.model_validate({ + **non_data_uk_sim_options, + "data": SAMPLE_DATASET_FILE_ADDRESS +}) + +@pytest.fixture +def mock_get_default_dataset(): + with patch( + "policyengine.simulation.get_default_dataset", + return_value=SAMPLE_DATASET_FILE_ADDRESS + ) as mock_get_default_dataset: + yield mock_get_default_dataset + +@pytest.fixture +def mock_dataset(): + """Simple Dataset mock fixture""" + with patch('policyengine.simulation.Dataset') as mock_dataset_class: + mock_instance = Mock() + # Set file_path to mimic Dataset's behavior of clipping URI and bucket name from GCS paths + mock_instance.from_file = Mock() + mock_instance.file_path = SAMPLE_DATASET_FILENAME + mock_dataset_class.from_file.return_value = mock_instance + yield mock_instance \ No newline at end of file diff --git a/tests/test_simulation.py b/tests/test_simulation.py new file mode 100644 index 00000000..e5b05b05 --- /dev/null +++ b/tests/test_simulation.py @@ -0,0 +1,55 @@ +from .fixtures.simulation import ( + uk_sim_options_no_data, + uk_sim_options_pe_dataset, + us_sim_options_cps_dataset, + mock_get_default_dataset, + mock_dataset, + SAMPLE_DATASET_FILENAME +) +import sys +from copy import deepcopy + +from policyengine import Simulation + +class TestSimulation: + class TestSetData: + def test__given_no_data_option__sets_default_dataset(self, mock_get_default_dataset, mock_dataset): + + # Don't run entire init script + sim = object.__new__(Simulation) + sim.options = deepcopy(uk_sim_options_no_data) + sim._set_data(uk_sim_options_no_data.data) + + assert str(sim.options.data.file_path) == SAMPLE_DATASET_FILENAME + def test__given_pe_dataset__sets_data_option_to_dataset(self, mock_dataset): + + sim = object.__new__(Simulation) + sim.options = deepcopy(uk_sim_options_pe_dataset) + sim._set_data(uk_sim_options_pe_dataset.data) + + assert str(sim.options.data.file_path) == SAMPLE_DATASET_FILENAME + def test__given_cps_2023_in_filename__sets_time_period_to_2023(self, mock_dataset): + from policyengine import Simulation + + sim = object.__new__(Simulation) + sim.options = deepcopy(us_sim_options_cps_dataset) + sim._set_data(us_sim_options_cps_dataset.data) + + assert mock_dataset.from_file.called_with( + us_sim_options_cps_dataset.data, + time_period=2023 + ) + class TestSetDataTimePeriod: + def test__given_dataset_with_time_period__sets_time_period(self): + from policyengine import Simulation + + sim = object.__new__(Simulation) + + print("Dataset:", us_sim_options_cps_dataset.data, file=sys.stderr) + assert sim._set_data_time_period(us_sim_options_cps_dataset.data) == 2023 + + def test__given_dataset_without_time_period__does_not_set_time_period(self): + from policyengine import Simulation + + sim = object.__new__(Simulation) + assert sim._set_data_time_period(uk_sim_options_pe_dataset.data) == None From 1e10d5607d0561e6973463629fdabd3cc7e4894f Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 28 May 2025 16:38:28 -0400 Subject: [PATCH 2/5] chore: Lint and changelog --- changelog_entry.yaml | 8 +++++++ policyengine/simulation.py | 33 ++++++++++++--------------- policyengine/utils/data/datasets.py | 6 +++-- tests/fixtures/simulation.py | 34 +++++++++++++++------------- tests/test_simulation.py | 35 +++++++++++++++++++++-------- 5 files changed, 70 insertions(+), 46 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..760395ca 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,8 @@ +- bump: patch + changes: + changed: + - Disambiguated filepath management in Simulation._set_data() + - Refactored Simulation._set_data() to divide functionality into smaller methods + - Prevented passage of non-Path URIs to Dataset.from_file() at end of Simulation._set_data() execution + added: + - Tests for Simulation._set_data() \ No newline at end of file diff --git a/policyengine/simulation.py b/policyengine/simulation.py index 67a9cca6..70286fa0 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -3,7 +3,12 @@ import sys from pydantic import BaseModel, Field from typing import Literal -from .utils.data.datasets import get_default_dataset, process_gs_path, POLICYENGINE_DATASETS, DATASET_TIME_PERIODS +from .utils.data.datasets import ( + get_default_dataset, + process_gs_path, + POLICYENGINE_DATASETS, + DATASET_TIME_PERIODS, +) from policyengine_core.simulations import Simulation as CountrySimulation from policyengine_core.simulations import ( Microsimulation as CountryMicrosimulation, @@ -31,9 +36,7 @@ CountryType = Literal["uk", "us"] ScopeType = Literal["household", "macro"] -DataType = ( - str | Dataset | None -) +DataType = str | Dataset | None TimePeriodType = int ReformType = ParametricReform | Type[StructuralReform] | None RegionType = Optional[str] @@ -95,7 +98,7 @@ def __init__(self, **options: SimulationOptions): self.options = SimulationOptions(**options) self.check_model_version() if not isinstance(self.options.data, Dataset): - self._set_data(self.options.data) + self._set_data(self.options.data) self._initialise_simulations() self.check_data_version() self._add_output_functions() @@ -139,8 +142,7 @@ def _set_data(self, file_address: str | None = None) -> None: # If None is passed, user wants default dataset; get URL, then continue initializing. if file_address is None: file_address = get_default_dataset( - country=self.options.country, - region=self.options.region + country=self.options.country, region=self.options.region ) print( f"No data provided, using default dataset: {file_address}", @@ -155,15 +157,11 @@ def _set_data(self, file_address: str | None = None) -> None: else: # All official PolicyEngine datasets are stored in GCS; # load accordingly - filename, version = self._set_data_from_gs( - file_address - ) + filename, version = self._set_data_from_gs(file_address) self.data_version = version - time_period = self._set_data_time_period( - file_address - ) - + time_period = self._set_data_time_period(file_address) + self.options.data = Dataset.from_file( filename, time_period=time_period ) @@ -370,7 +368,7 @@ def check_data_version(self) -> None: raise ValueError( f"Data version {self.data_version} does not match expected version {self.options.data_version}." ) - + def _set_data_time_period(self, file_address: str) -> Optional[int]: """ Set the time period based on the file address. @@ -383,9 +381,7 @@ def _set_data_time_period(self, file_address: str) -> Optional[int]: # Local file, no time period available return None - def _set_data_from_gs( - self, file_address: str - ) -> tuple[str, str | None]: + def _set_data_from_gs(self, file_address: str) -> tuple[str, str | None]: """ Set the data from a GCS path and return the filename and version. """ @@ -403,4 +399,3 @@ def _set_data_from_gs( ) return filename, version - \ No newline at end of file diff --git a/policyengine/utils/data/datasets.py b/policyengine/utils/data/datasets.py index 079f1ee1..4dcd8af6 100644 --- a/policyengine/utils/data/datasets.py +++ b/policyengine/utils/data/datasets.py @@ -23,6 +23,7 @@ ECPS_2024: 2023, } + def get_default_dataset( country: str, region: str, version: Optional[str] = None ) -> str: @@ -38,11 +39,12 @@ def get_default_dataset( f"Unable to select a default dataset for country {country} and region {region}." ) + def process_gs_path(path: str) -> Tuple[str, str]: """Process a GS path to return bucket and object.""" if not path.startswith("gs://"): raise ValueError(f"Invalid GS path: {path}") - + path = path[5:] # Remove 'gs://' bucket, obj = path.split("/", 1) - return bucket, obj \ No newline at end of file + return bucket, obj diff --git a/tests/fixtures/simulation.py b/tests/fixtures/simulation.py index 0dbb9554..7cc51720 100644 --- a/tests/fixtures/simulation.py +++ b/tests/fixtures/simulation.py @@ -1,7 +1,7 @@ from policyengine.simulation import SimulationOptions from unittest.mock import patch, Mock import pytest -from policyengine.utils.data.datasets import CPS_2023 +from policyengine.utils.data.datasets import CPS_2023 non_data_uk_sim_options = { "country": "uk", @@ -21,41 +21,43 @@ "baseline": None, } -uk_sim_options_no_data = SimulationOptions.model_validate({ - **non_data_uk_sim_options, - "data": None, -}) +uk_sim_options_no_data = SimulationOptions.model_validate( + { + **non_data_uk_sim_options, + "data": None, + } +) -us_sim_options_cps_dataset = SimulationOptions.model_validate({ - **non_data_us_sim_options, - "data": CPS_2023 -}) +us_sim_options_cps_dataset = SimulationOptions.model_validate( + {**non_data_us_sim_options, "data": CPS_2023} +) SAMPLE_DATASET_FILENAME = "sample_value.h5" SAMPLE_DATASET_BUCKET_NAME = "policyengine-uk-data-private" SAMPLE_DATASET_URI_PREFIX = "gs://" SAMPLE_DATASET_FILE_ADDRESS = f"{SAMPLE_DATASET_URI_PREFIX}{SAMPLE_DATASET_BUCKET_NAME}/{SAMPLE_DATASET_FILENAME}" -uk_sim_options_pe_dataset = SimulationOptions.model_validate({ - **non_data_uk_sim_options, - "data": SAMPLE_DATASET_FILE_ADDRESS -}) +uk_sim_options_pe_dataset = SimulationOptions.model_validate( + {**non_data_uk_sim_options, "data": SAMPLE_DATASET_FILE_ADDRESS} +) + @pytest.fixture def mock_get_default_dataset(): with patch( "policyengine.simulation.get_default_dataset", - return_value=SAMPLE_DATASET_FILE_ADDRESS + return_value=SAMPLE_DATASET_FILE_ADDRESS, ) as mock_get_default_dataset: yield mock_get_default_dataset + @pytest.fixture def mock_dataset(): """Simple Dataset mock fixture""" - with patch('policyengine.simulation.Dataset') as mock_dataset_class: + with patch("policyengine.simulation.Dataset") as mock_dataset_class: mock_instance = Mock() # Set file_path to mimic Dataset's behavior of clipping URI and bucket name from GCS paths mock_instance.from_file = Mock() mock_instance.file_path = SAMPLE_DATASET_FILENAME mock_dataset_class.from_file.return_value = mock_instance - yield mock_instance \ No newline at end of file + yield mock_instance diff --git a/tests/test_simulation.py b/tests/test_simulation.py index e5b05b05..a3697bc2 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -4,16 +4,19 @@ us_sim_options_cps_dataset, mock_get_default_dataset, mock_dataset, - SAMPLE_DATASET_FILENAME + SAMPLE_DATASET_FILENAME, ) import sys from copy import deepcopy from policyengine import Simulation + class TestSimulation: class TestSetData: - def test__given_no_data_option__sets_default_dataset(self, mock_get_default_dataset, mock_dataset): + def test__given_no_data_option__sets_default_dataset( + self, mock_get_default_dataset, mock_dataset + ): # Don't run entire init script sim = object.__new__(Simulation) @@ -21,14 +24,20 @@ def test__given_no_data_option__sets_default_dataset(self, mock_get_default_data sim._set_data(uk_sim_options_no_data.data) assert str(sim.options.data.file_path) == SAMPLE_DATASET_FILENAME - def test__given_pe_dataset__sets_data_option_to_dataset(self, mock_dataset): + + def test__given_pe_dataset__sets_data_option_to_dataset( + self, mock_dataset + ): sim = object.__new__(Simulation) sim.options = deepcopy(uk_sim_options_pe_dataset) sim._set_data(uk_sim_options_pe_dataset.data) assert str(sim.options.data.file_path) == SAMPLE_DATASET_FILENAME - def test__given_cps_2023_in_filename__sets_time_period_to_2023(self, mock_dataset): + + def test__given_cps_2023_in_filename__sets_time_period_to_2023( + self, mock_dataset + ): from policyengine import Simulation sim = object.__new__(Simulation) @@ -36,9 +45,9 @@ def test__given_cps_2023_in_filename__sets_time_period_to_2023(self, mock_datase sim._set_data(us_sim_options_cps_dataset.data) assert mock_dataset.from_file.called_with( - us_sim_options_cps_dataset.data, - time_period=2023 + us_sim_options_cps_dataset.data, time_period=2023 ) + class TestSetDataTimePeriod: def test__given_dataset_with_time_period__sets_time_period(self): from policyengine import Simulation @@ -46,10 +55,18 @@ def test__given_dataset_with_time_period__sets_time_period(self): sim = object.__new__(Simulation) print("Dataset:", us_sim_options_cps_dataset.data, file=sys.stderr) - assert sim._set_data_time_period(us_sim_options_cps_dataset.data) == 2023 + assert ( + sim._set_data_time_period(us_sim_options_cps_dataset.data) + == 2023 + ) - def test__given_dataset_without_time_period__does_not_set_time_period(self): + def test__given_dataset_without_time_period__does_not_set_time_period( + self, + ): from policyengine import Simulation sim = object.__new__(Simulation) - assert sim._set_data_time_period(uk_sim_options_pe_dataset.data) == None + assert ( + sim._set_data_time_period(uk_sim_options_pe_dataset.data) + == None + ) From 23d3fab211071a98abadd9eb3655f11d31bd4954 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 28 May 2025 17:04:07 -0400 Subject: [PATCH 3/5] fix: Reallow arbitrary dict passage to Simulation(data) --- policyengine/simulation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/policyengine/simulation.py b/policyengine/simulation.py index 70286fa0..47466479 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -30,13 +30,15 @@ import pandas as pd from typing import Type, Optional from functools import wraps, partial -from typing import Dict, Any, Callable +from typing import Callable import importlib from policyengine.utils.data_download import download CountryType = Literal["uk", "us"] ScopeType = Literal["household", "macro"] -DataType = str | Dataset | None +DataType = ( + str | dict | Dataset | None +) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason. TimePeriodType = int ReformType = ParametricReform | Type[StructuralReform] | None RegionType = Optional[str] From bc16cb4ba81d018c98a2082bfda03c9a062002e5 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 28 May 2025 17:26:39 -0400 Subject: [PATCH 4/5] fix: Properly check for dict type before running _set_data --- policyengine/simulation.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/policyengine/simulation.py b/policyengine/simulation.py index 47466479..24752aaf 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -28,7 +28,7 @@ import h5py from pathlib import Path import pandas as pd -from typing import Type, Optional +from typing import Type, Any, Optional from functools import wraps, partial from typing import Callable import importlib @@ -37,7 +37,7 @@ CountryType = Literal["uk", "us"] ScopeType = Literal["household", "macro"] DataType = ( - str | dict | Dataset | None + str | dict[Any, Any] | Dataset | None ) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason. TimePeriodType = int ReformType = ParametricReform | Type[StructuralReform] | None @@ -99,7 +99,10 @@ class Simulation: def __init__(self, **options: SimulationOptions): self.options = SimulationOptions(**options) self.check_model_version() - if not isinstance(self.options.data, Dataset): + if not isinstance(self.options.data, dict) and not isinstance( + self.options.data, Dataset + ): + print(type(self.options.data), sys.stderr) self._set_data(self.options.data) self._initialise_simulations() self.check_data_version() From f1c94348666e933582d77d344e5e78b530d44867 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 30 May 2025 14:39:04 -0400 Subject: [PATCH 5/5] chore: Remove unneeded print statement --- policyengine/simulation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/policyengine/simulation.py b/policyengine/simulation.py index 24752aaf..a8a818fb 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -102,7 +102,6 @@ def __init__(self, **options: SimulationOptions): if not isinstance(self.options.data, dict) and not isinstance( self.options.data, Dataset ): - print(type(self.options.data), sys.stderr) self._set_data(self.options.data) self._initialise_simulations() self.check_data_version()