diff --git a/src/policyengine/outputs/aggregate.py b/src/policyengine/outputs/aggregate.py index d014b06c..3f314d5f 100644 --- a/src/policyengine/outputs/aggregate.py +++ b/src/policyengine/outputs/aggregate.py @@ -1,7 +1,8 @@ +from difflib import get_close_matches from enum import Enum from typing import Any, Optional -from policyengine.core import Output, Simulation +from policyengine.core import Output, Simulation, Variable class AggregateType(str, Enum): @@ -10,6 +11,72 @@ class AggregateType(str, Enum): COUNT = "count" +def get_aggregate_variable( + simulation: Simulation, + variable: str, + context: str, +) -> Variable: + """Return a model variable with an aggregation-specific error message.""" + model_version = simulation.tax_benefit_model_version + try: + return model_version.get_variable(variable) + except ValueError as exc: + candidates = sorted(model_version.variables_by_name) + suggestions = get_close_matches(variable, candidates, n=3, cutoff=0.65) + suggestion_text = ( + f" Did you mean: {', '.join(repr(name) for name in suggestions)}?" + if suggestions + else "" + ) + raise ValueError( + f"{context} references missing variable '{variable}' in " + f"{model_version.model.id} version {model_version.version}." + f"{suggestion_text}" + ) from exc + + +def get_output_entity_data( + simulation: Simulation, + entity: str, + context: str, +) -> Any: + """Return output data for an entity with a clear error if it is unavailable.""" + if simulation.output_dataset is None or simulation.output_dataset.data is None: + raise ValueError( + f"{context} requires simulation '{simulation.id}' to have an " + "output dataset before aggregation." + ) + + try: + return getattr(simulation.output_dataset.data, entity) + except AttributeError as exc: + raise ValueError( + f"{context} references entity '{entity}', but simulation " + f"'{simulation.id}' has no output data for that entity." + ) from exc + + +def require_output_column( + data: Any, + variable: str, + entity: str, + simulation: Simulation, + context: str, +) -> None: + """Raise a descriptive error when a known variable was not materialized.""" + if variable in data.columns: + return + + model_version = simulation.tax_benefit_model_version + raise ValueError( + f"{context} variable '{variable}' exists in {model_version.model.id} " + f"version {model_version.version}, but is not present in simulation " + f"'{simulation.id}' output data for entity '{entity}'. Add '{variable}' " + f"to {model_version.__class__.__name__}.entity_variables or pass it via " + "Simulation.extra_variables before running the simulation." + ) + + class Aggregate(Output): simulation: Simulation variable: str @@ -47,35 +114,61 @@ def run(self): elif self.quantile_geq is not None: self.filter_variable_geq = (self.quantile_geq - 1) / self.quantile - # Get variable object - var_obj = next( - v - for v in self.simulation.tax_benefit_model_version.variables - if v.name == self.variable + var_obj = get_aggregate_variable( + self.simulation, self.variable, "Aggregate.variable" ) # Get the target entity data target_entity = self.entity or var_obj.entity - data = getattr(self.simulation.output_dataset.data, target_entity) + data = get_output_entity_data( + self.simulation, target_entity, "Aggregate.entity" + ) # Map variable to target entity if needed if var_obj.entity != target_entity: + source_data = get_output_entity_data( + self.simulation, var_obj.entity, "Aggregate.variable" + ) + require_output_column( + source_data, + self.variable, + var_obj.entity, + self.simulation, + "Aggregate.variable", + ) mapped = self.simulation.output_dataset.data.map_to_entity( var_obj.entity, target_entity, columns=[self.variable] ) series = mapped[self.variable] else: + require_output_column( + data, + self.variable, + target_entity, + self.simulation, + "Aggregate.variable", + ) series = data[self.variable] # Apply filters if self.filter_variable is not None: - filter_var_obj = next( - v - for v in self.simulation.tax_benefit_model_version.variables - if v.name == self.filter_variable + filter_var_obj = get_aggregate_variable( + self.simulation, self.filter_variable, "Aggregate.filter_variable" ) if filter_var_obj.entity != target_entity: + filter_source_data = get_output_entity_data( + self.simulation, + filter_var_obj.entity, + "Aggregate.filter_variable", + ) + require_output_column( + filter_source_data, + self.filter_variable, + filter_var_obj.entity, + self.simulation, + "Aggregate.filter_variable", + ) filter_mapped = self.simulation.output_dataset.data.map_to_entity( filter_var_obj.entity, target_entity, @@ -83,6 +176,13 @@ def run(self): ) filter_series = filter_mapped[self.filter_variable] else: + require_output_column( + data, + self.filter_variable, + target_entity, + self.simulation, + "Aggregate.filter_variable", + ) filter_series = data[self.filter_variable] if self.filter_variable_describes_quantiles: diff --git a/src/policyengine/outputs/change_aggregate.py b/src/policyengine/outputs/change_aggregate.py index 87d2e0d9..f9ea6502 100644 --- a/src/policyengine/outputs/change_aggregate.py +++ b/src/policyengine/outputs/change_aggregate.py @@ -2,6 +2,11 @@ from typing import Any, Optional from policyengine.core import Output, Simulation +from policyengine.outputs.aggregate import ( + get_aggregate_variable, + get_output_entity_data, + require_output_column, +) class ChangeAggregateType(str, Enum): @@ -59,34 +64,75 @@ def run(self): elif self.quantile_geq is not None: self.filter_variable_geq = (self.quantile_geq - 1) / self.quantile - # Get variable object - var_obj = next( - v - for v in self.baseline_simulation.tax_benefit_model_version.variables - if v.name == self.variable + var_obj = get_aggregate_variable( + self.baseline_simulation, self.variable, "ChangeAggregate.variable" ) # Get the target entity data target_entity = self.entity or var_obj.entity - baseline_data = getattr( - self.baseline_simulation.output_dataset.data, target_entity + baseline_data = get_output_entity_data( + self.baseline_simulation, + target_entity, + "ChangeAggregate.baseline_entity", + ) + reform_data = get_output_entity_data( + self.reform_simulation, + target_entity, + "ChangeAggregate.reform_entity", ) - reform_data = getattr(self.reform_simulation.output_dataset.data, target_entity) # Map variable to target entity if needed if var_obj.entity != target_entity: + baseline_source_data = get_output_entity_data( + self.baseline_simulation, + var_obj.entity, + "ChangeAggregate.variable", + ) + reform_source_data = get_output_entity_data( + self.reform_simulation, + var_obj.entity, + "ChangeAggregate.variable", + ) + require_output_column( + baseline_source_data, + self.variable, + var_obj.entity, + self.baseline_simulation, + "ChangeAggregate.variable", + ) + require_output_column( + reform_source_data, + self.variable, + var_obj.entity, + self.reform_simulation, + "ChangeAggregate.variable", + ) baseline_mapped = ( self.baseline_simulation.output_dataset.data.map_to_entity( - var_obj.entity, target_entity + var_obj.entity, target_entity, columns=[self.variable] ) ) baseline_series = baseline_mapped[self.variable] reform_mapped = self.reform_simulation.output_dataset.data.map_to_entity( - var_obj.entity, target_entity + var_obj.entity, target_entity, columns=[self.variable] ) reform_series = reform_mapped[self.variable] else: + require_output_column( + baseline_data, + self.variable, + target_entity, + self.baseline_simulation, + "ChangeAggregate.variable", + ) + require_output_column( + reform_data, + self.variable, + target_entity, + self.reform_simulation, + "ChangeAggregate.variable", + ) baseline_series = baseline_data[self.variable] reform_series = reform_data[self.variable] @@ -124,20 +170,41 @@ def run(self): # Apply filter_variable filters if self.filter_variable is not None: - filter_var_obj = next( - v - for v in self.baseline_simulation.tax_benefit_model_version.variables - if v.name == self.filter_variable + filter_var_obj = get_aggregate_variable( + self.baseline_simulation, + self.filter_variable, + "ChangeAggregate.filter_variable", ) if filter_var_obj.entity != target_entity: + filter_source_data = get_output_entity_data( + self.baseline_simulation, + filter_var_obj.entity, + "ChangeAggregate.filter_variable", + ) + require_output_column( + filter_source_data, + self.filter_variable, + filter_var_obj.entity, + self.baseline_simulation, + "ChangeAggregate.filter_variable", + ) filter_mapped = ( self.baseline_simulation.output_dataset.data.map_to_entity( - filter_var_obj.entity, target_entity + filter_var_obj.entity, + target_entity, + columns=[self.filter_variable], ) ) filter_series = filter_mapped[self.filter_variable] else: + require_output_column( + baseline_data, + self.filter_variable, + target_entity, + self.baseline_simulation, + "ChangeAggregate.filter_variable", + ) filter_series = baseline_data[self.filter_variable] if self.filter_variable_describes_quantiles: diff --git a/src/policyengine/outputs/program_statistics.py b/src/policyengine/outputs/program_statistics.py index a48ff8a8..ccb4f1e1 100644 --- a/src/policyengine/outputs/program_statistics.py +++ b/src/policyengine/outputs/program_statistics.py @@ -13,7 +13,12 @@ class ProgramStatistics(Output): - """Single program's statistics from a policy reform - represents one database row.""" + """Single program's statistics from a policy reform. + + Count fields are reported in the configured entity's units. For example, + a tax-unit variable reports tax-unit recipient/winner/loser counts, while + a person variable reports person counts. + """ model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/src/policyengine/tax_benefit_models/us/analysis.py b/src/policyengine/tax_benefit_models/us/analysis.py index 8b3eefc8..7bc1cd52 100644 --- a/src/policyengine/tax_benefit_models/us/analysis.py +++ b/src/policyengine/tax_benefit_models/us/analysis.py @@ -26,6 +26,21 @@ Poverty, calculate_us_poverty_rates, ) +from policyengine.utils.errors import format_conditional_error_detail + +US_PROGRAMS = { + "income_tax": {"entity": "tax_unit", "is_tax": True}, + "employee_payroll_tax": {"entity": "tax_unit", "is_tax": True}, + "state_income_tax": {"entity": "tax_unit", "is_tax": True}, + "snap": {"entity": "spm_unit", "is_tax": False}, + "tanf": {"entity": "spm_unit", "is_tax": False}, + "ssi": {"entity": "person", "is_tax": False}, + "social_security": {"entity": "person", "is_tax": False}, + "medicare_cost": {"entity": "person", "is_tax": False}, + "medicaid": {"entity": "person", "is_tax": False}, + "eitc": {"entity": "tax_unit", "is_tax": False}, + "ctc": {"entity": "tax_unit", "is_tax": False}, +} class PolicyReformAnalysis(BaseModel): @@ -39,6 +54,71 @@ class PolicyReformAnalysis(BaseModel): reform_inequality: Inequality +def _format_missing_program_variables(missing_variables: set[str]) -> str | None: + """Format the optional missing-variable detail for program statistics.""" + return format_conditional_error_detail( + "Missing model variables", + missing_variables, + ) + + +def _program_statistics_config_error_message( + missing_variables: set[str], + missing_outputs: set[tuple[str, str]], +) -> str: + lines = ["US program statistics config is invalid:"] + + missing_variables_message = _format_missing_program_variables(missing_variables) + if missing_variables_message is not None: + lines.append(missing_variables_message) + + if missing_outputs: + formatted = ", ".join( + f"{program_name} on {entity}" + for program_name, entity in sorted(missing_outputs) + ) + lines.append("Variables not materialized in simulation outputs: " + formatted) + lines.append( + "Add them to the model version's entity_variables or pass them " + "via Simulation.extra_variables before running the simulation." + ) + + return "\n".join(lines) + + +def _validate_program_statistics_config( + baseline_simulation: Simulation, + reform_simulation: Simulation, +) -> None: + """Validate US program-stat variables before running simulations.""" + missing_variables: set[str] = set() + missing_outputs: set[tuple[str, str]] = set() + + simulations = (baseline_simulation, reform_simulation) + for program_name, program_info in US_PROGRAMS.items(): + for simulation in simulations: + model_version = simulation.tax_benefit_model_version + try: + variable = model_version.get_variable(program_name) + except ValueError: + missing_variables.add(program_name) + continue + + resolved_variables = model_version.resolve_entity_variables(simulation) + if program_name not in resolved_variables.get(variable.entity, []): + missing_outputs.add((program_name, variable.entity)) + + if not missing_variables and not missing_outputs: + return + + raise ValueError( + _program_statistics_config_error_message( + missing_variables, + missing_outputs, + ), + ) + + def economic_impact_analysis( baseline_simulation: Simulation, reform_simulation: Simulation, @@ -55,6 +135,8 @@ def economic_impact_analysis( ``PolicyReformAnalysis`` with decile impacts, program statistics, baseline and reform poverty, and inequality. """ + _validate_program_statistics_config(baseline_simulation, reform_simulation) + baseline_simulation.ensure() reform_simulation.ensure() @@ -71,22 +153,8 @@ def economic_impact_analysis( income_variable="household_net_income", ) - programs = { - "income_tax": {"entity": "tax_unit", "is_tax": True}, - "payroll_tax": {"entity": "person", "is_tax": True}, - "state_income_tax": {"entity": "tax_unit", "is_tax": True}, - "snap": {"entity": "spm_unit", "is_tax": False}, - "tanf": {"entity": "spm_unit", "is_tax": False}, - "ssi": {"entity": "person", "is_tax": False}, - "social_security": {"entity": "person", "is_tax": False}, - "medicare": {"entity": "person", "is_tax": False}, - "medicaid": {"entity": "person", "is_tax": False}, - "eitc": {"entity": "tax_unit", "is_tax": False}, - "ctc": {"entity": "tax_unit", "is_tax": False}, - } - program_statistics = [] - for program_name, program_info in programs.items(): + for program_name, program_info in US_PROGRAMS.items(): stats = ProgramStatistics( baseline_simulation=baseline_simulation, reform_simulation=reform_simulation, diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index 184dd110..655e05d6 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -65,6 +65,7 @@ class PolicyEngineUSLatest(MicrosimulationModelVersion): # Benefits "ssi", "social_security", + "medicare_cost", "medicaid", "unemployment_compensation", ], @@ -91,6 +92,7 @@ class PolicyEngineUSLatest(MicrosimulationModelVersion): "tax_unit_weight", "income_tax", "employee_payroll_tax", + "state_income_tax", "household_state_income_tax", "eitc", "ctc", diff --git a/src/policyengine/utils/__init__.py b/src/policyengine/utils/__init__.py index bfbfe10b..8cee3ff2 100644 --- a/src/policyengine/utils/__init__.py +++ b/src/policyengine/utils/__init__.py @@ -1,5 +1,6 @@ from .dates import parse_safe_date as parse_safe_date from .design import COLORS as COLORS +from .errors import format_conditional_error_detail as format_conditional_error_detail from .parameter_labels import build_scale_lookup as build_scale_lookup from .parameter_labels import ( generate_label_for_parameter as generate_label_for_parameter, diff --git a/src/policyengine/utils/errors.py b/src/policyengine/utils/errors.py new file mode 100644 index 00000000..34213b59 --- /dev/null +++ b/src/policyengine/utils/errors.py @@ -0,0 +1,17 @@ +"""Shared helpers for constructing consistent errors.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Optional + + +def format_conditional_error_detail( + label: str, + values: Iterable[str], +) -> Optional[str]: + """Build a labelled error detail line when ``values`` is non-empty.""" + sorted_values = sorted(set(values)) + if not sorted_values: + return None + return f"{label}: {', '.join(sorted_values)}" diff --git a/tests/fixtures/household_calculator_snapshots/us_married_two_kids_high_income.json b/tests/fixtures/household_calculator_snapshots/us_married_two_kids_high_income.json index 1d5e98ca..0a8c2662 100644 --- a/tests/fixtures/household_calculator_snapshots/us_married_two_kids_high_income.json +++ b/tests/fixtures/household_calculator_snapshots/us_married_two_kids_high_income.json @@ -21,6 +21,7 @@ "person[0].is_male": 1.0, "person[0].marital_unit_id": 0.0, "person[0].medicaid": 0.0, + "person[0].medicare_cost": 14500.0, "person[0].person_id": 0.0, "person[0].person_weight": 1.0, "person[0].race": 3.0, @@ -38,6 +39,7 @@ "person[1].is_male": 1.0, "person[1].marital_unit_id": 0.0, "person[1].medicaid": 0.0, + "person[1].medicare_cost": 14500.0, "person[1].person_id": 1.0, "person[1].person_weight": 1.0, "person[1].race": 3.0, @@ -55,6 +57,7 @@ "person[2].is_male": 1.0, "person[2].marital_unit_id": 0.0, "person[2].medicaid": 0.0, + "person[2].medicare_cost": 14500.0, "person[2].person_id": 2.0, "person[2].person_weight": 1.0, "person[2].race": 3.0, @@ -72,6 +75,7 @@ "person[3].is_male": 1.0, "person[3].marital_unit_id": 0.0, "person[3].medicaid": 0.0, + "person[3].medicare_cost": 14500.0, "person[3].person_id": 3.0, "person[3].person_weight": 1.0, "person[3].race": 3.0, @@ -92,6 +96,7 @@ "tax_unit.employee_payroll_tax": 21480.0, "tax_unit.household_state_income_tax": 12690.07, "tax_unit.income_tax": 30740.0, + "tax_unit.state_income_tax": 12690.07, "tax_unit.tax_unit_id": 0.0, "tax_unit.tax_unit_weight": 1.0 } diff --git a/tests/fixtures/household_calculator_snapshots/us_single_adult_employment_income.json b/tests/fixtures/household_calculator_snapshots/us_single_adult_employment_income.json index d94660a9..8284c6fc 100644 --- a/tests/fixtures/household_calculator_snapshots/us_single_adult_employment_income.json +++ b/tests/fixtures/household_calculator_snapshots/us_single_adult_employment_income.json @@ -21,6 +21,7 @@ "person[0].is_male": 1.0, "person[0].marital_unit_id": 0.0, "person[0].medicaid": 0.0, + "person[0].medicare_cost": 14500.0, "person[0].person_id": 0.0, "person[0].person_weight": 1.0, "person[0].race": 3.0, @@ -41,6 +42,7 @@ "tax_unit.employee_payroll_tax": 5370.0, "tax_unit.household_state_income_tax": 1602.86, "tax_unit.income_tax": 5020.0, + "tax_unit.state_income_tax": 1602.86, "tax_unit.tax_unit_id": 0.0, "tax_unit.tax_unit_weight": 1.0 } diff --git a/tests/fixtures/household_calculator_snapshots/us_single_adult_no_income.json b/tests/fixtures/household_calculator_snapshots/us_single_adult_no_income.json index 258db6f1..b77b54f4 100644 --- a/tests/fixtures/household_calculator_snapshots/us_single_adult_no_income.json +++ b/tests/fixtures/household_calculator_snapshots/us_single_adult_no_income.json @@ -21,6 +21,7 @@ "person[0].is_male": 1.0, "person[0].marital_unit_id": 0.0, "person[0].medicaid": 6439.11, + "person[0].medicare_cost": 14500.0, "person[0].person_id": 0.0, "person[0].person_weight": 1.0, "person[0].race": 3.0, @@ -41,6 +42,7 @@ "tax_unit.employee_payroll_tax": 0.0, "tax_unit.household_state_income_tax": 0.0, "tax_unit.income_tax": 0.0, + "tax_unit.state_income_tax": 0.0, "tax_unit.tax_unit_id": 0.0, "tax_unit.tax_unit_weight": 1.0 } diff --git a/tests/fixtures/household_calculator_snapshots/us_single_parent_one_child.json b/tests/fixtures/household_calculator_snapshots/us_single_parent_one_child.json index 78ba7237..46504931 100644 --- a/tests/fixtures/household_calculator_snapshots/us_single_parent_one_child.json +++ b/tests/fixtures/household_calculator_snapshots/us_single_parent_one_child.json @@ -21,6 +21,7 @@ "person[0].is_male": 1.0, "person[0].marital_unit_id": 0.0, "person[0].medicaid": 0.0, + "person[0].medicare_cost": 14500.0, "person[0].person_id": 0.0, "person[0].person_weight": 1.0, "person[0].race": 3.0, @@ -38,6 +39,7 @@ "person[1].is_male": 1.0, "person[1].marital_unit_id": 0.0, "person[1].medicaid": 3258.31, + "person[1].medicare_cost": 14500.0, "person[1].person_id": 1.0, "person[1].person_weight": 1.0, "person[1].race": 3.0, @@ -58,6 +60,7 @@ "tax_unit.employee_payroll_tax": 3580.0, "tax_unit.household_state_income_tax": 0.0, "tax_unit.income_tax": -2467.62, + "tax_unit.state_income_tax": 0.0, "tax_unit.tax_unit_id": 0.0, "tax_unit.tax_unit_weight": 1.0 } diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py index 5b4e8b27..28c29928 100644 --- a/tests/test_aggregate.py +++ b/tests/test_aggregate.py @@ -478,8 +478,10 @@ def test_aggregate_invalid_variable(): variable="nonexistent_variable", aggregate_type=AggregateType.SUM, ) - with pytest.raises(StopIteration): + with pytest.raises(ValueError) as exc_info: agg.run() + assert "nonexistent_variable" in str(exc_info.value) + assert "references missing variable" in str(exc_info.value) # Invalid filter variable name should raise error on run() agg = Aggregate( @@ -488,5 +490,7 @@ def test_aggregate_invalid_variable(): aggregate_type=AggregateType.SUM, filter_variable="nonexistent_filter", ) - with pytest.raises(StopIteration): + with pytest.raises(ValueError) as exc_info: agg.run() + assert "nonexistent_filter" in str(exc_info.value) + assert "references missing variable" in str(exc_info.value) diff --git a/tests/test_change_aggregate.py b/tests/test_change_aggregate.py index ea900db6..0728b880 100644 --- a/tests/test_change_aggregate.py +++ b/tests/test_change_aggregate.py @@ -2,6 +2,7 @@ import tempfile import pandas as pd +import pytest from microdf import MicroDataFrame from policyengine.core import ( @@ -18,6 +19,120 @@ ) +def _make_change_aggregate_simulations(tmp_path): + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2], + "benunit_id": [1, 2], + "household_id": [1, 2], + "age": [30, 40], + "employment_income": [50000, 60000], + "person_weight": [1.0, 1.0], + } + ), + weights="person_weight", + ) + reform_person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2], + "benunit_id": [1, 2], + "household_id": [1, 2], + "age": [30, 40], + "employment_income": [51000, 61000], + "person_weight": [1.0, 1.0], + } + ), + weights="person_weight", + ) + benunit_df = MicroDataFrame( + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", + ) + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", + ) + + baseline_dataset = PolicyEngineUKDataset( + name="Baseline", + description="Baseline dataset", + filepath=str(tmp_path / "baseline.h5"), + year=2024, + data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + ) + reform_dataset = PolicyEngineUKDataset( + name="Reform", + description="Reform dataset", + filepath=str(tmp_path / "reform.h5"), + year=2024, + data=UKYearData( + person=reform_person_df, + benunit=benunit_df, + household=household_df, + ), + ) + + baseline_sim = Simulation( + dataset=baseline_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=baseline_dataset, + ) + reform_sim = Simulation( + dataset=reform_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=reform_dataset, + ) + return baseline_sim, reform_sim + + +def test_change_aggregate_invalid_variable(tmp_path): + baseline_sim, reform_sim = _make_change_aggregate_simulations(tmp_path) + + agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="not_a_variable", + aggregate_type=ChangeAggregateType.COUNT, + ) + + with pytest.raises(ValueError) as exc_info: + agg.run() + + assert "not_a_variable" in str(exc_info.value) + assert "references missing variable" in str(exc_info.value) + + +def test_change_aggregate_invalid_filter_variable(tmp_path): + baseline_sim, reform_sim = _make_change_aggregate_simulations(tmp_path) + + agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.COUNT, + filter_variable="not_a_filter_variable", + filter_variable_geq=0, + ) + + with pytest.raises(ValueError) as exc_info: + agg.run() + + assert "not_a_filter_variable" in str(exc_info.value) + assert "references missing variable" in str(exc_info.value) + + def test_change_aggregate_count(): """Test counting people with any change.""" person_df = MicroDataFrame( diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 00000000..81803b40 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,9 @@ +from policyengine.utils.errors import format_conditional_error_detail + + +def test_format_conditional_error_detail(): + assert ( + format_conditional_error_detail("Missing model variables", {"beta", "alpha"}) + == "Missing model variables: alpha, beta" + ) + assert format_conditional_error_detail("Missing model variables", set()) is None diff --git a/tests/test_us_program_statistics.py b/tests/test_us_program_statistics.py new file mode 100644 index 00000000..2c5044f8 --- /dev/null +++ b/tests/test_us_program_statistics.py @@ -0,0 +1,146 @@ +import pandas as pd +import pytest +from microdf import MicroDataFrame + +from policyengine.core import Simulation +from policyengine.outputs import ProgramStatistics +from policyengine.tax_benefit_models.us.analysis import ( + US_PROGRAMS, + _validate_program_statistics_config, +) +from policyengine.tax_benefit_models.us.datasets import ( + PolicyEngineUSDataset, + USYearData, +) +from policyengine.tax_benefit_models.us.model import us_latest + + +def _microdf(data: dict, weights: str) -> MicroDataFrame: + return MicroDataFrame(pd.DataFrame(data), weights=weights) + + +def _make_us_output_simulation(tmp_path, simulation_id: str, multiplier: float): + data = USYearData( + person=_microdf( + { + "person_id": [1, 2], + "household_id": [1, 2], + "marital_unit_id": [1, 2], + "family_id": [1, 2], + "spm_unit_id": [1, 2], + "tax_unit_id": [1, 2], + "person_weight": [1.0, 2.0], + "ssi": [100.0 * multiplier, 0.0], + "social_security": [0.0, 200.0 * multiplier], + "medicare_cost": [300.0 * multiplier, 0.0], + "medicaid": [0.0, 400.0 * multiplier], + }, + "person_weight", + ), + marital_unit=_microdf( + { + "marital_unit_id": [1, 2], + "marital_unit_weight": [1.0, 2.0], + }, + "marital_unit_weight", + ), + family=_microdf( + { + "family_id": [1, 2], + "family_weight": [1.0, 2.0], + }, + "family_weight", + ), + spm_unit=_microdf( + { + "spm_unit_id": [1, 2], + "spm_unit_weight": [1.0, 2.0], + "snap": [500.0 * multiplier, 0.0], + "tanf": [0.0, 600.0 * multiplier], + }, + "spm_unit_weight", + ), + tax_unit=_microdf( + { + "tax_unit_id": [1, 2], + "tax_unit_weight": [1.0, 2.0], + "income_tax": [700.0 * multiplier, 0.0], + "employee_payroll_tax": [0.0, 800.0 * multiplier], + "state_income_tax": [900.0 * multiplier, 0.0], + "eitc": [0.0, 1_000.0 * multiplier], + "ctc": [1_100.0 * multiplier, 0.0], + }, + "tax_unit_weight", + ), + household=_microdf( + { + "household_id": [1, 2], + "household_weight": [1.0, 2.0], + }, + "household_weight", + ), + ) + dataset = PolicyEngineUSDataset( + id=simulation_id, + name=f"{simulation_id} output", + description="Mocked US output dataset for program statistics", + filepath=str(tmp_path / f"{simulation_id}.h5"), + year=2026, + is_output_dataset=True, + data=data, + ) + return Simulation( + id=simulation_id, + dataset=dataset, + tax_benefit_model_version=us_latest, + output_dataset=dataset, + ) + + +def test_us_program_statistics_config_runs_against_mocked_outputs(tmp_path): + baseline = _make_us_output_simulation(tmp_path, "baseline", 1.0) + reform = _make_us_output_simulation(tmp_path, "reform", 2.0) + + _validate_program_statistics_config(baseline, reform) + + results = {} + for program_name, program_info in US_PROGRAMS.items(): + stats = ProgramStatistics( + baseline_simulation=baseline, + reform_simulation=reform, + program_name=program_name, + entity=program_info["entity"], + is_tax=program_info["is_tax"], + ) + stats.run() + results[program_name] = stats + + assert set(results) == set(US_PROGRAMS) + assert results["employee_payroll_tax"].baseline_total == 1_600.0 + assert results["medicare_cost"].baseline_total == 300.0 + assert results["state_income_tax"].baseline_total == 900.0 + + +def test_us_program_statistics_config_fails_before_simulation_run( + tmp_path, monkeypatch +): + baseline = _make_us_output_simulation(tmp_path, "baseline", 1.0) + reform = _make_us_output_simulation(tmp_path, "reform", 2.0) + + entity_variables = { + entity: list(variables) + for entity, variables in us_latest.entity_variables.items() + } + entity_variables["person"].remove("medicare_cost") + monkeypatch.setattr( + baseline.tax_benefit_model_version, + "entity_variables", + entity_variables, + ) + + with pytest.raises( + ValueError, match="US program statistics config is invalid" + ) as exc_info: + _validate_program_statistics_config(baseline, reform) + + assert "medicare_cost" in str(exc_info.value)