diff --git a/src/spotforecast2_safe/processing/n2n_predict_with_covariates.py b/src/spotforecast2_safe/processing/n2n_predict_with_covariates.py index dffbcda13..a575a031e 100644 --- a/src/spotforecast2_safe/processing/n2n_predict_with_covariates.py +++ b/src/spotforecast2_safe/processing/n2n_predict_with_covariates.py @@ -278,17 +278,29 @@ def _get_day_night_features( extended_index = pd.date_range(start=start, end=cov_end, freq=freq, tz=timezone) - sunrise_hour = [ - sun(location.observer, date=date, tzinfo=location.timezone)["sunrise"] - for date in extended_index - ] - sunset_hour = [ - sun(location.observer, date=date, tzinfo=location.timezone)["sunset"] - for date in extended_index - ] + # Cache sunrise and sunset times per unique calendar date to avoid + # recomputing them for every timestamp in the extended_index. + normalized_dates = extended_index.normalize() + unique_dates = normalized_dates.unique() + + sunrise_map = {} + sunset_map = {} + for d in unique_dates: + s = sun(location.observer, date=d, tzinfo=location.timezone) + sunrise_map[d] = s["sunrise"] + sunset_map[d] = s["sunset"] + + sunrise_series = pd.Series( + [sunrise_map[d] for d in normalized_dates], + index=extended_index, + ) + sunset_series = pd.Series( + [sunset_map[d] for d in normalized_dates], + index=extended_index, + ) - sunrise_hour = pd.Series(sunrise_hour, index=extended_index).dt.round("h").dt.hour - sunset_hour = pd.Series(sunset_hour, index=extended_index).dt.round("h").dt.hour + sunrise_hour = sunrise_series.dt.round("h").dt.hour + sunset_hour = sunset_series.dt.round("h").dt.hour sun_light_features = pd.DataFrame( { @@ -444,7 +456,8 @@ def _create_interaction_features( transformer_poly = PolynomialFeatures( degree=degree, interaction_only=True, include_bias=False - ).set_output(transform="pandas") + ) + transformer_poly = transformer_poly.set_output(transform="pandas") weather_window_cols = [ col @@ -725,8 +738,7 @@ def n2n_predict_with_covariates( # Set default model_dir if not provided if model_dir is None: from spotforecast2_safe.data.fetch_data import get_cache_home - - model_dir = get_cache_home() / "forecasters" + model_dir = get_cache_home() / "forecasters" # Input Validation if forecast_horizon <= 0: @@ -888,9 +900,11 @@ def n2n_predict_with_covariates( axis=1, ) - assert ( - sum(exogenous_features.isnull().sum()) == 0 - ), "Missing values in exogenous features" + missing_count = exogenous_features.isnull().sum().sum() + if missing_count != 0: + raise ValueError( + f"Missing values in exogenous features: {missing_count} missing entries" + ) # Apply cyclical encoding exogenous_features = _apply_cyclical_encoding(