Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"plotly==5.24.1",
"psycopg2-binary==2.9.10",
"SQLAlchemy==2.0.36",
"streamlit==1.51.0",
"streamlit==1.57.0",
"uvicorn==0.34.0",
"geopandas==1.0.1",
"garrett-streamlit-auth0==0.9",
Expand All @@ -37,9 +37,7 @@ dependencies = [
"matplotlib>=3.8,<4.0",
"dp-sdk",
"aiocache",
"grpcio>=1.80.0",
"grpcio-tools>=1.50.0",
"grpc-requests>=0.1.17",
"grpclib==0.4.8",
"betterproto>=2.0.0b7",
"pytest-asyncio>=1.3.0",
"testcontainers>=4.14.0",
Expand Down Expand Up @@ -77,7 +75,7 @@ extra-index-url = ["https://pypi.org/simple"]

[tool.uv.sources]

dp-sdk = { url = "https://github.com/openclimatefix/data-platform/releases/download/v0.28.0/dp_sdk-0.28.0-py3-none-any.whl" }
dp-sdk = { url = "https://github.com/openclimatefix/data-platform/releases/download/v0.30.1/dp_sdk-0.30.1-py3-none-any.whl" }
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
Expand Down
45 changes: 23 additions & 22 deletions src/dataplatform/forecast/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,30 @@
import pandas as pd
import streamlit as st

from ocf import dp
from ocf.dp.dp_data import messages_pb2, service_pb2_grpc
from ocf.dp.dp import common_pb2


async def fetch_timeseries(
client: dp.DataPlatformDataServiceStub,
client: service_pb2_grpc.DataPlatformDataServiceStub,
location_uuid: str,
start_date: datetime.datetime,
end_date: datetime.datetime,
horizon_mins: int,
forecasters: list[dp.Forecaster],
forecasters: list[messages_pb2.Forecaster],
init_times_utc: list[datetime.datetime] | None = None,
) -> pd.DataFrame:
"""Directly calls GetForecastAsTimeseries for selected models and init times."""

time_window = dp.TimeWindow(
time_window = messages_pb2.TimeWindow(
start_timestamp_utc=start_date, end_timestamp_utc=end_date
)
time_windows = []
current_start = start_date
while current_start < end_date:
current_end = min(current_start + datetime.timedelta(days=7), end_date)
time_windows.append(
dp.TimeWindow(
messages_pb2.TimeWindow(
start_timestamp_utc=current_start,
end_timestamp_utc=current_end
)
Expand All @@ -38,32 +39,32 @@ async def fetch_timeseries(
times_to_fetch = init_times_utc if init_times_utc else [None]

async def fetch_one(
forecaster_obj: dp.Forecaster,
window: dp.TimeWindow,
forecaster_obj: messages_pb2.Forecaster,
window: messages_pb2.TimeWindow,
init_time: datetime.datetime | None
):
req = dp.GetForecastAsTimeseriesRequest(
req = messages_pb2.GetForecastAsTimeseriesRequest(
location_uuid=location_uuid,
energy_source=dp.EnergySource.SOLAR,
energy_source=common_pb2.EnergySource.ENERGY_SOURCE_SOLAR,
horizon_mins=horizon_mins,
time_window=window,
forecaster=forecaster_obj,
initialization_timestamp_utc=init_time,
)

try:
resp = await client.get_forecast_as_timeseries(req)
resp = await client.GetForecastAsTimeseries(req)
rows = []
for val in resp.values:
row = {
"target_timestamp_utc": val.target_timestamp_utc,
"initialization_timestamp_utc": val.initialization_timestamp_utc,
"created_timestamp_utc": val.created_timestamp_utc,
"target_timestamp_utc": val.target_timestamp_utc.ToDatetime(tzinfo=datetime.UTC),
"initialization_timestamp_utc": val.initialization_timestamp_utc.ToDatetime(tzinfo=datetime.UTC),
"created_timestamp_utc": val.created_timestamp_utc.ToDatetime(tzinfo=datetime.UTC),
"effective_capacity_watts": val.effective_capacity_watts,
"forecaster_name": forecaster_obj.forecaster_name,
"location_uuid": resp.location_uuid,
"horizon_mins": (
val.target_timestamp_utc - val.initialization_timestamp_utc
val.target_timestamp_utc.ToDatetime(tzinfo=datetime.UTC) - val.initialization_timestamp_utc.ToDatetime(tzinfo=datetime.UTC)
).total_seconds()
// 60,
"p50_watts": int(
Expand Down Expand Up @@ -108,24 +109,24 @@ async def fetch_one(


async def fetch_observations(
client: dp.DataPlatformDataServiceStub,
client: service_pb2_grpc.DataPlatformDataServiceStub,
location_uuid: str,
start_date: datetime.datetime,
end_date: datetime.datetime,
observers: list[str],
energy_source: dp.EnergySource = dp.EnergySource.SOLAR,
energy_source: common_pb2.EnergySource = common_pb2.EnergySource.ENERGY_SOURCE_SOLAR,
) -> pd.DataFrame:
"""Directly calls GetObservationsAsTimeseries for selected observers."""

time_window = dp.TimeWindow(
time_window = messages_pb2.TimeWindow(
start_timestamp_utc=start_date, end_timestamp_utc=end_date
)
time_windows = []
current_start = start_date
while current_start < end_date:
current_end = min(current_start + datetime.timedelta(days=7), end_date)
time_windows.append(
dp.TimeWindow(
messages_pb2.TimeWindow(
start_timestamp_utc=current_start,
end_timestamp_utc=current_end
)
Expand All @@ -134,21 +135,21 @@ async def fetch_observations(


# Run requests concurrently for all selected observers
async def fetch_one(obs_name: str, window: dp.TimeWindow):
req = dp.GetObservationsAsTimeseriesRequest(
async def fetch_one(obs_name: str, window: messages_pb2.TimeWindow):
req = messages_pb2.GetObservationsAsTimeseriesRequest(
location_uuid=location_uuid,
observer_name=obs_name,
energy_source=energy_source,
time_window=window,
)

try:
resp = await client.get_observations_as_timeseries(req)
resp = await client.GetObservationsAsTimeseries(req)
rows = []
for val in resp.values:
rows.append(
{
"target_timestamp_utc": val.timestamp_utc,
"target_timestamp_utc": val.timestamp_utc.ToDatetime(tzinfo=datetime.UTC),
"value_fraction": val.value_fraction,
"effective_capacity_watts": val.effective_capacity_watts,
"observer_name": obs_name,
Expand Down
4 changes: 2 additions & 2 deletions src/dataplatform/forecast/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from datetime import UTC, datetime, timedelta

from ocf import dp
from ocf.dp.dp_data import service_pb2_grpc

from dataplatform.forecast.constant import cache_seconds

Expand All @@ -11,7 +11,7 @@ def key_builder_remove_client(func: callable, *args: list, **kwargs: dict) -> st
"""Custom key builder that ignores the client argument for caching purposes."""
key = f"{func.__name__}:"
for arg in args:
if not isinstance(arg, dp.DataPlatformDataServiceStub):
if not isinstance(arg, service_pb2_grpc.DataPlatformDataServiceStub):
key += f"{arg}-"

for k, v in kwargs.items():
Expand Down
17 changes: 8 additions & 9 deletions src/dataplatform/forecast/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import pandas as pd
from aiocache import Cache, cached
from ocf import dp
from google.protobuf.json_format import MessageToDict

from dataplatform.forecast.cache import key_builder_remove_client
Expand All @@ -16,10 +15,10 @@

async def get_forecast_data(
dpc: service_pb2_grpc.DataPlatformDataServiceStub,
location: dp.ListLocationsResponseLocationSummary,
location: messages_pb2.ListLocationsResponse.LocationSummary,
start_date: datetime,
end_date: datetime,
selected_forecasters: list[dp.Forecaster],
selected_forecasters: list[messages_pb2.Forecaster],
) -> pd.DataFrame:
"""Get forecast data for the given location and time window."""
all_data_df = []
Expand Down Expand Up @@ -66,10 +65,10 @@ async def get_forecast_data(
@cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client)
async def get_forecast_data_one_forecaster(
dpc: service_pb2_grpc.DataPlatformDataServiceStub,
location: dp.ListLocationsResponseLocationSummary,
location: messages_pb2.ListLocationsResponse.LocationSummary,
start_date: datetime,
end_date: datetime,
selected_forecaster: dp.Forecaster,
selected_forecaster: messages_pb2.Forecaster,
) -> pd.DataFrame | None:
"""Get forecast data for one forecaster for the given location and time window."""
all_data_list_dict = []
Expand All @@ -93,7 +92,7 @@ async def get_forecast_data_one_forecaster(

forecasts = []
async for chunk in dpc.StreamForecastData(stream_forecast_data_request):
forecasts.append(chunk)
forecasts.extend(chunk.values)

if len(forecasts) > 0:
all_data_list_dict.extend(MessageToDict(f, always_print_fields_with_no_presence=True) for f in forecasts)
Expand Down Expand Up @@ -149,7 +148,7 @@ async def get_forecast_data_one_forecaster(
@cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client)
async def get_all_observations(
client: service_pb2_grpc.DataPlatformDataServiceStub,
location: dp.ListLocationsResponseLocationSummary,
location: messages_pb2.ListLocationsResponse.LocationSummary,
start_date: datetime,
end_date: datetime,
) -> pd.DataFrame:
Expand Down Expand Up @@ -225,10 +224,10 @@ async def get_all_observations(

async def get_all_data(
client: service_pb2_grpc.DataPlatformDataServiceStub,
selected_location: dp.ListLocationsResponseLocationSummary,
selected_location: messages_pb2.ListLocationsResponse.LocationSummary,
start_date: datetime,
end_date: datetime,
selected_forecasters: list[dp.Forecaster],
selected_forecasters: list[messages_pb2.Forecaster],
) -> dict:
"""Get all forecast and observation data, and merge them."""
# get generation data
Expand Down
18 changes: 11 additions & 7 deletions src/dataplatform/forecast/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import pandas as pd
import streamlit as st
from grpclib.client import Channel

from ocf import dp
import grpc.aio
from ocf.dp.dp import common_pb2
from ocf.dp.dp_data import messages_pb2, service_pb2_grpc

from dataplatform.forecast.constant import metrics, observer_names
from dataplatform.forecast.backend import fetch_observations, fetch_timeseries
Expand Down Expand Up @@ -49,9 +49,10 @@ async def async_dp_forecast_page() -> None:
st.title("Data Platform Forecast Page")
st.write("This is the forecast page from the Data Platform module.")

async with Channel(host=data_platform_host, port=data_platform_port) as channel:
client = dp.DataPlatformDataServiceStub(channel)

channel = grpc.aio.insecure_channel(f"{data_platform_host}:{data_platform_port}"):
client = service_pb2_grpc.DataPlatformDataServiceStub(channel)

try:
cfg = await setup_page(client)
st.divider()
st.subheader("1. Fetch Data")
Expand All @@ -76,7 +77,7 @@ async def async_dp_forecast_page() -> None:
start_date=cfg.start_date,
end_date=cfg.end_date,
observers=observer_names,
energy_source=dp.EnergySource.SOLAR,
energy_source=common_pb2.EnergySource.ENERGY_SOURCE_SOLAR,
)

fetch_duration = (datetime.datetime.now() - start_time).total_seconds()
Expand Down Expand Up @@ -242,3 +243,6 @@ async def async_dp_forecast_page() -> None:
"Configure your filters in the sidebar and click 'Fetch Forecast & Observations' to begin."
)

finally:
await channel.close()

23 changes: 12 additions & 11 deletions src/dataplatform/forecast/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@
import pandas as pd
import streamlit as st
from aiocache import Cache, cached
from ocf import dp
from ocf.dp.dp import common_pb2
from ocf.dp.dp_data import messages_pb2, service_pb2_grpc

from dataplatform.forecast.cache import key_builder_remove_client
from dataplatform.forecast.constant import cache_seconds, metrics


@cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client)
async def get_location_names(
client: dp.DataPlatformDataServiceStub,
client: service_pb2_grpc.DataPlatformDataServiceStub,
) -> dict:
"""Get location names."""
list_locations_request = dp.ListLocationsRequest()
list_locations_response = await client.list_locations(list_locations_request)
list_locations_request = messages_pb2.ListLocationsRequest()
list_locations_response = await client.ListLocations(list_locations_request)
all_locations = list_locations_response.locations

location_names = {loc.location_name: loc for loc in all_locations}
Expand All @@ -38,19 +39,19 @@ async def get_location_names(

@cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client)
async def get_forecasters(
client: dp.DataPlatformDataServiceStub,
) -> list[dp.Forecaster]:
client: service_pb2_grpc.DataPlatformDataServiceStub,
) -> list[messages_pb2.Forecaster]:
"""Get all forecasters."""
get_forecasters_request = dp.ListForecastersRequest()
get_forecasters_response = await client.list_forecasters(get_forecasters_request)
get_forecasters_request = messages_pb2.ListForecastersRequest()
get_forecasters_response = await client.ListForecasters(get_forecasters_request)
forecasters = get_forecasters_response.forecasters
return forecasters


@dataclasses.dataclass
class PageConfig:
location: dp.ListLocationsResponseLocationSummary
forecasters: list[dp.Forecaster]
location: messages_pb2.ListLocationsResponse.LocationSummary
forecasters: list[messages_pb2.Forecaster]
start_date: dt.datetime
end_date: dt.datetime
forecast_type: str
Expand All @@ -62,7 +63,7 @@ class PageConfig:
strict_horizon_filtering: bool


async def setup_page(client: dp.DataPlatformDataServiceStub) -> PageConfig:
async def setup_page(client: service_pb2_grpc.DataPlatformDataServiceStub) -> PageConfig:
"""Setup the Streamlit page with sidebar options."""
location_names = await get_location_names(client)
selected_location_name = st.sidebar.selectbox(
Expand Down
Loading
Loading