Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions src/spotforecast2_safe/processing/n2n_predict_with_covariates.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,17 +278,29 @@ def _get_day_night_features(

extended_index = pd.date_range(start=start, end=cov_end, freq=freq, tz=timezone)

sunrise_hour = [
sun(location.observer, date=date, tzinfo=location.timezone)["sunrise"]
for date in extended_index
]
sunset_hour = [
sun(location.observer, date=date, tzinfo=location.timezone)["sunset"]
for date in extended_index
]
# Cache sunrise and sunset times per unique calendar date to avoid
# recomputing them for every timestamp in the extended_index.
normalized_dates = extended_index.normalize()
unique_dates = normalized_dates.unique()

sunrise_map = {}
sunset_map = {}
for d in unique_dates:
s = sun(location.observer, date=d, tzinfo=location.timezone)
sunrise_map[d] = s["sunrise"]
sunset_map[d] = s["sunset"]

sunrise_series = pd.Series(
[sunrise_map[d] for d in normalized_dates],
index=extended_index,
)
sunset_series = pd.Series(
[sunset_map[d] for d in normalized_dates],
index=extended_index,
)

sunrise_hour = pd.Series(sunrise_hour, index=extended_index).dt.round("h").dt.hour
sunset_hour = pd.Series(sunset_hour, index=extended_index).dt.round("h").dt.hour
sunrise_hour = sunrise_series.dt.round("h").dt.hour
sunset_hour = sunset_series.dt.round("h").dt.hour

sun_light_features = pd.DataFrame(
{
Expand Down Expand Up @@ -444,7 +456,8 @@ def _create_interaction_features(

transformer_poly = PolynomialFeatures(
degree=degree, interaction_only=True, include_bias=False
).set_output(transform="pandas")
)
transformer_poly = transformer_poly.set_output(transform="pandas")

weather_window_cols = [
col
Expand Down Expand Up @@ -725,8 +738,7 @@ def n2n_predict_with_covariates(
# Set default model_dir if not provided
if model_dir is None:
from spotforecast2_safe.data.fetch_data import get_cache_home

model_dir = get_cache_home() / "forecasters"
model_dir = get_cache_home() / "forecasters"

# Input Validation
if forecast_horizon <= 0:
Expand Down Expand Up @@ -888,9 +900,11 @@ def n2n_predict_with_covariates(
axis=1,
)

assert (
sum(exogenous_features.isnull().sum()) == 0
), "Missing values in exogenous features"
missing_count = exogenous_features.isnull().sum().sum()
if missing_count != 0:
raise ValueError(
f"Missing values in exogenous features: {missing_count} missing entries"
)

# Apply cyclical encoding
exogenous_features = _apply_cyclical_encoding(
Expand Down
Loading