diff --git a/pyproject.toml b/pyproject.toml index d1301d19..1bbf5442 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "plotly==5.24.1", "psycopg2-binary==2.9.10", "SQLAlchemy==2.0.36", - "streamlit==1.51.0", + "streamlit==1.57.0", "uvicorn==0.34.0", "geopandas==1.0.1", "garrett-streamlit-auth0==0.9", @@ -37,9 +37,7 @@ dependencies = [ "matplotlib>=3.8,<4.0", "dp-sdk", "aiocache", - "grpcio>=1.80.0", - "grpcio-tools>=1.50.0", - "grpc-requests>=0.1.17", + "grpclib==0.4.8", "betterproto>=2.0.0b7", "pytest-asyncio>=1.3.0", "testcontainers>=4.14.0", @@ -77,7 +75,7 @@ extra-index-url = ["https://pypi.org/simple"] [tool.uv.sources] -dp-sdk = { url = "https://github.com/openclimatefix/data-platform/releases/download/v0.28.0/dp_sdk-0.28.0-py3-none-any.whl" } +dp-sdk = { url = "https://github.com/openclimatefix/data-platform/releases/download/v0.30.1/dp_sdk-0.30.1-py3-none-any.whl" } [tool.pytest.ini_options] testpaths = ["tests"] python_files = ["test_*.py"] diff --git a/src/dataplatform/forecast/backend.py b/src/dataplatform/forecast/backend.py index 7b5ac14c..9752a67c 100644 --- a/src/dataplatform/forecast/backend.py +++ b/src/dataplatform/forecast/backend.py @@ -6,21 +6,22 @@ import pandas as pd import streamlit as st -from ocf import dp +from ocf.dp.dp_data import messages_pb2, service_pb2_grpc +from ocf.dp.dp import common_pb2 async def fetch_timeseries( - client: dp.DataPlatformDataServiceStub, + client: service_pb2_grpc.DataPlatformDataServiceStub, location_uuid: str, start_date: datetime.datetime, end_date: datetime.datetime, horizon_mins: int, - forecasters: list[dp.Forecaster], + forecasters: list[messages_pb2.Forecaster], init_times_utc: list[datetime.datetime] | None = None, ) -> pd.DataFrame: """Directly calls GetForecastAsTimeseries for selected models and init times.""" - time_window = dp.TimeWindow( + time_window = messages_pb2.TimeWindow( start_timestamp_utc=start_date, end_timestamp_utc=end_date ) time_windows = [] @@ -28,7 +29,7 @@ async def fetch_timeseries( while current_start < end_date: current_end = min(current_start + datetime.timedelta(days=7), end_date) time_windows.append( - dp.TimeWindow( + messages_pb2.TimeWindow( start_timestamp_utc=current_start, end_timestamp_utc=current_end ) @@ -38,13 +39,13 @@ async def fetch_timeseries( times_to_fetch = init_times_utc if init_times_utc else [None] async def fetch_one( - forecaster_obj: dp.Forecaster, - window: dp.TimeWindow, + forecaster_obj: messages_pb2.Forecaster, + window: messages_pb2.TimeWindow, init_time: datetime.datetime | None ): - req = dp.GetForecastAsTimeseriesRequest( + req = messages_pb2.GetForecastAsTimeseriesRequest( location_uuid=location_uuid, - energy_source=dp.EnergySource.SOLAR, + energy_source=common_pb2.EnergySource.ENERGY_SOURCE_SOLAR, horizon_mins=horizon_mins, time_window=window, forecaster=forecaster_obj, @@ -52,18 +53,18 @@ async def fetch_one( ) try: - resp = await client.get_forecast_as_timeseries(req) + resp = await client.GetForecastAsTimeseries(req) rows = [] for val in resp.values: row = { - "target_timestamp_utc": val.target_timestamp_utc, - "initialization_timestamp_utc": val.initialization_timestamp_utc, - "created_timestamp_utc": val.created_timestamp_utc, + "target_timestamp_utc": val.target_timestamp_utc.ToDatetime(tzinfo=datetime.UTC), + "initialization_timestamp_utc": val.initialization_timestamp_utc.ToDatetime(tzinfo=datetime.UTC), + "created_timestamp_utc": val.created_timestamp_utc.ToDatetime(tzinfo=datetime.UTC), "effective_capacity_watts": val.effective_capacity_watts, "forecaster_name": forecaster_obj.forecaster_name, "location_uuid": resp.location_uuid, "horizon_mins": ( - val.target_timestamp_utc - val.initialization_timestamp_utc + val.target_timestamp_utc.ToDatetime(tzinfo=datetime.UTC) - val.initialization_timestamp_utc.ToDatetime(tzinfo=datetime.UTC) ).total_seconds() // 60, "p50_watts": int( @@ -108,16 +109,16 @@ async def fetch_one( async def fetch_observations( - client: dp.DataPlatformDataServiceStub, + client: service_pb2_grpc.DataPlatformDataServiceStub, location_uuid: str, start_date: datetime.datetime, end_date: datetime.datetime, observers: list[str], - energy_source: dp.EnergySource = dp.EnergySource.SOLAR, + energy_source: common_pb2.EnergySource = common_pb2.EnergySource.ENERGY_SOURCE_SOLAR, ) -> pd.DataFrame: """Directly calls GetObservationsAsTimeseries for selected observers.""" - time_window = dp.TimeWindow( + time_window = messages_pb2.TimeWindow( start_timestamp_utc=start_date, end_timestamp_utc=end_date ) time_windows = [] @@ -125,7 +126,7 @@ async def fetch_observations( while current_start < end_date: current_end = min(current_start + datetime.timedelta(days=7), end_date) time_windows.append( - dp.TimeWindow( + messages_pb2.TimeWindow( start_timestamp_utc=current_start, end_timestamp_utc=current_end ) @@ -134,8 +135,8 @@ async def fetch_observations( # Run requests concurrently for all selected observers - async def fetch_one(obs_name: str, window: dp.TimeWindow): - req = dp.GetObservationsAsTimeseriesRequest( + async def fetch_one(obs_name: str, window: messages_pb2.TimeWindow): + req = messages_pb2.GetObservationsAsTimeseriesRequest( location_uuid=location_uuid, observer_name=obs_name, energy_source=energy_source, @@ -143,12 +144,12 @@ async def fetch_one(obs_name: str, window: dp.TimeWindow): ) try: - resp = await client.get_observations_as_timeseries(req) + resp = await client.GetObservationsAsTimeseries(req) rows = [] for val in resp.values: rows.append( { - "target_timestamp_utc": val.timestamp_utc, + "target_timestamp_utc": val.timestamp_utc.ToDatetime(tzinfo=datetime.UTC), "value_fraction": val.value_fraction, "effective_capacity_watts": val.effective_capacity_watts, "observer_name": obs_name, diff --git a/src/dataplatform/forecast/cache.py b/src/dataplatform/forecast/cache.py index 280ebac6..8b9628c8 100644 --- a/src/dataplatform/forecast/cache.py +++ b/src/dataplatform/forecast/cache.py @@ -2,7 +2,7 @@ from datetime import UTC, datetime, timedelta -from ocf import dp +from ocf.dp.dp_data import service_pb2_grpc from dataplatform.forecast.constant import cache_seconds @@ -11,7 +11,7 @@ def key_builder_remove_client(func: callable, *args: list, **kwargs: dict) -> st """Custom key builder that ignores the client argument for caching purposes.""" key = f"{func.__name__}:" for arg in args: - if not isinstance(arg, dp.DataPlatformDataServiceStub): + if not isinstance(arg, service_pb2_grpc.DataPlatformDataServiceStub): key += f"{arg}-" for k, v in kwargs.items(): diff --git a/src/dataplatform/forecast/data.py b/src/dataplatform/forecast/data.py index 0cf05f75..9792029f 100644 --- a/src/dataplatform/forecast/data.py +++ b/src/dataplatform/forecast/data.py @@ -5,7 +5,6 @@ import pandas as pd from aiocache import Cache, cached -from ocf import dp from google.protobuf.json_format import MessageToDict from dataplatform.forecast.cache import key_builder_remove_client @@ -16,10 +15,10 @@ async def get_forecast_data( dpc: service_pb2_grpc.DataPlatformDataServiceStub, - location: dp.ListLocationsResponseLocationSummary, + location: messages_pb2.ListLocationsResponse.LocationSummary, start_date: datetime, end_date: datetime, - selected_forecasters: list[dp.Forecaster], + selected_forecasters: list[messages_pb2.Forecaster], ) -> pd.DataFrame: """Get forecast data for the given location and time window.""" all_data_df = [] @@ -66,10 +65,10 @@ async def get_forecast_data( @cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client) async def get_forecast_data_one_forecaster( dpc: service_pb2_grpc.DataPlatformDataServiceStub, - location: dp.ListLocationsResponseLocationSummary, + location: messages_pb2.ListLocationsResponse.LocationSummary, start_date: datetime, end_date: datetime, - selected_forecaster: dp.Forecaster, + selected_forecaster: messages_pb2.Forecaster, ) -> pd.DataFrame | None: """Get forecast data for one forecaster for the given location and time window.""" all_data_list_dict = [] @@ -93,7 +92,7 @@ async def get_forecast_data_one_forecaster( forecasts = [] async for chunk in dpc.StreamForecastData(stream_forecast_data_request): - forecasts.append(chunk) + forecasts.extend(chunk.values) if len(forecasts) > 0: all_data_list_dict.extend(MessageToDict(f, always_print_fields_with_no_presence=True) for f in forecasts) @@ -149,7 +148,7 @@ async def get_forecast_data_one_forecaster( @cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client) async def get_all_observations( client: service_pb2_grpc.DataPlatformDataServiceStub, - location: dp.ListLocationsResponseLocationSummary, + location: messages_pb2.ListLocationsResponse.LocationSummary, start_date: datetime, end_date: datetime, ) -> pd.DataFrame: @@ -225,10 +224,10 @@ async def get_all_observations( async def get_all_data( client: service_pb2_grpc.DataPlatformDataServiceStub, - selected_location: dp.ListLocationsResponseLocationSummary, + selected_location: messages_pb2.ListLocationsResponse.LocationSummary, start_date: datetime, end_date: datetime, - selected_forecasters: list[dp.Forecaster], + selected_forecasters: list[messages_pb2.Forecaster], ) -> dict: """Get all forecast and observation data, and merge them.""" # get generation data diff --git a/src/dataplatform/forecast/main.py b/src/dataplatform/forecast/main.py index 5b0dd7e9..a8415e01 100644 --- a/src/dataplatform/forecast/main.py +++ b/src/dataplatform/forecast/main.py @@ -7,9 +7,9 @@ import pandas as pd import streamlit as st -from grpclib.client import Channel - -from ocf import dp +import grpc.aio +from ocf.dp.dp import common_pb2 +from ocf.dp.dp_data import messages_pb2, service_pb2_grpc from dataplatform.forecast.constant import metrics, observer_names from dataplatform.forecast.backend import fetch_observations, fetch_timeseries @@ -49,9 +49,10 @@ async def async_dp_forecast_page() -> None: st.title("Data Platform Forecast Page") st.write("This is the forecast page from the Data Platform module.") - async with Channel(host=data_platform_host, port=data_platform_port) as channel: - client = dp.DataPlatformDataServiceStub(channel) - + channel = grpc.aio.insecure_channel(f"{data_platform_host}:{data_platform_port}"): + client = service_pb2_grpc.DataPlatformDataServiceStub(channel) + + try: cfg = await setup_page(client) st.divider() st.subheader("1. Fetch Data") @@ -76,7 +77,7 @@ async def async_dp_forecast_page() -> None: start_date=cfg.start_date, end_date=cfg.end_date, observers=observer_names, - energy_source=dp.EnergySource.SOLAR, + energy_source=common_pb2.EnergySource.ENERGY_SOURCE_SOLAR, ) fetch_duration = (datetime.datetime.now() - start_time).total_seconds() @@ -242,3 +243,6 @@ async def async_dp_forecast_page() -> None: "Configure your filters in the sidebar and click 'Fetch Forecast & Observations' to begin." ) + finally: + await channel.close() + diff --git a/src/dataplatform/forecast/setup.py b/src/dataplatform/forecast/setup.py index fe85305e..c5c0d440 100644 --- a/src/dataplatform/forecast/setup.py +++ b/src/dataplatform/forecast/setup.py @@ -6,7 +6,8 @@ import pandas as pd import streamlit as st from aiocache import Cache, cached -from ocf import dp +from ocf.dp.dp import common_pb2 +from ocf.dp.dp_data import messages_pb2, service_pb2_grpc from dataplatform.forecast.cache import key_builder_remove_client from dataplatform.forecast.constant import cache_seconds, metrics @@ -14,11 +15,11 @@ @cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client) async def get_location_names( - client: dp.DataPlatformDataServiceStub, + client: service_pb2_grpc.DataPlatformDataServiceStub, ) -> dict: """Get location names.""" - list_locations_request = dp.ListLocationsRequest() - list_locations_response = await client.list_locations(list_locations_request) + list_locations_request = messages_pb2.ListLocationsRequest() + list_locations_response = await client.ListLocations(list_locations_request) all_locations = list_locations_response.locations location_names = {loc.location_name: loc for loc in all_locations} @@ -38,19 +39,19 @@ async def get_location_names( @cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client) async def get_forecasters( - client: dp.DataPlatformDataServiceStub, -) -> list[dp.Forecaster]: + client: service_pb2_grpc.DataPlatformDataServiceStub, +) -> list[messages_pb2.Forecaster]: """Get all forecasters.""" - get_forecasters_request = dp.ListForecastersRequest() - get_forecasters_response = await client.list_forecasters(get_forecasters_request) + get_forecasters_request = messages_pb2.ListForecastersRequest() + get_forecasters_response = await client.ListForecasters(get_forecasters_request) forecasters = get_forecasters_response.forecasters return forecasters @dataclasses.dataclass class PageConfig: - location: dp.ListLocationsResponseLocationSummary - forecasters: list[dp.Forecaster] + location: messages_pb2.ListLocationsResponse.LocationSummary + forecasters: list[messages_pb2.Forecaster] start_date: dt.datetime end_date: dt.datetime forecast_type: str @@ -62,7 +63,7 @@ class PageConfig: strict_horizon_filtering: bool -async def setup_page(client: dp.DataPlatformDataServiceStub) -> PageConfig: +async def setup_page(client: service_pb2_grpc.DataPlatformDataServiceStub) -> PageConfig: """Setup the Streamlit page with sidebar options.""" location_names = await get_location_names(client) selected_location_name = st.sidebar.selectbox( diff --git a/src/dataplatform/toolbox/location.py b/src/dataplatform/toolbox/location.py index 82aaa8ad..4ad03163 100644 --- a/src/dataplatform/toolbox/location.py +++ b/src/dataplatform/toolbox/location.py @@ -2,8 +2,10 @@ import streamlit as st import json -from ocf import dp +from ocf.dp.dp import common_pb2 +from ocf.dp.dp_data import messages_pb2 from grpclib.exceptions import GRPCError +from google.protobuf.json_format import MessageToDict async def locations_section(data_client): @@ -11,21 +13,21 @@ async def locations_section(data_client): # Energy source and location type mappings ENERGY_SOURCES = { - "UNSPECIFIED": dp.EnergySource.UNSPECIFIED, - "SOLAR": dp.EnergySource.SOLAR, - "WIND": dp.EnergySource.WIND, + "UNSPECIFIED": common_pb2.EnergySource.ENERGY_SOURCE_UNSPECIFIED, + "SOLAR": common_pb2.EnergySource.ENERGY_SOURCE_SOLAR, + "WIND": common_pb2.EnergySource.ENERGY_SOURCE_WIND, } LOCATION_TYPES = { - "UNSPECIFIED": dp.LocationType.UNSPECIFIED, - "SITE": dp.LocationType.SITE, - "GSP": dp.LocationType.GSP, - "DNO": dp.LocationType.DNO, - "NATION": dp.LocationType.NATION, - "STATE": dp.LocationType.STATE, - "COUNTY": dp.LocationType.COUNTY, - "CITY": dp.LocationType.CITY, - "PRIMARY SUBSTATION": dp.LocationType.PRIMARY_SUBSTATION, + "UNSPECIFIED": common_pb2.LocationType.LOCATION_TYPE_UNSPECIFIED, + "SITE": common_pb2.LocationType.LOCATION_TYPE_SITE, + "GSP": common_pb2.LocationType.LOCATION_TYPE_GSP, + "DNO": common_pb2.LocationType.LOCATION_TYPE_DNO, + "NATION": common_pb2.LocationType.LOCATION_TYPE_NATION, + "STATE": common_pb2.LocationType.LOCATION_TYPE_STATE, + "COUNTY": common_pb2.LocationType.LOCATION_TYPE_COUNTY, + "CITY": common_pb2.LocationType.LOCATION_TYPE_CITY, + "PRIMARY SUBSTATION": common_pb2.LocationType.LOCATION_TYPE_PRIMARY_SUBSTATION, } # List Locations @@ -50,7 +52,7 @@ async def locations_section(data_client): st.error("❌ Could not connect to Data Platform") else: try: - request = dp.ListLocationsRequest() + request = messages_pb2.ListLocationsRequest() if energy_source_filter != "UNSPECIFIED": request.energy_source_filter = ENERGY_SOURCES[energy_source_filter] if location_type_filter != "UNSPECIFIED": @@ -58,12 +60,12 @@ async def locations_section(data_client): if user_filter: request.user_oauth_id_filter = user_filter - response = await data_client.list_locations(request) + response = await data_client.ListLocations(request) locations = response.locations if locations: st.success(f"✅ Found {len(locations)} location(s)") - loc_dicts = [loc.to_dict() for loc in locations] + loc_dicts = [MessageToDict(loc) for loc in locations] st.write(loc_dicts) else: st.info("No locations found with the specified filters") @@ -92,14 +94,14 @@ async def locations_section(data_client): st.error("❌ Could not connect to Data Platform") else: try: - response = await data_client.get_location( - dp.GetLocationRequest( + response = await data_client.GetLocation( + messages_pb2.GetLocationRequest( location_uuid=loc_uuid, energy_source=ENERGY_SOURCES[loc_energy], include_geometry=include_geometry, ) ) - response_dict = response.to_dict() + response_dict = MessageToDict(response) st.success(f"✅ Found location: {loc_uuid}") st.write(response_dict) except GRPCError as e: @@ -154,8 +156,8 @@ async def locations_section(data_client): try: # Parse metadata JSON metadata = json.loads(loc_metadata) if loc_metadata.strip() else {} - response = await data_client.create_location( - dp.CreateLocationRequest( + response = await data_client.CreateLocation( + messages_pb2.CreateLocationRequest( location_name=loc_name, energy_source=ENERGY_SOURCES.get(loc_energy_src, 1), location_type=LOCATION_TYPES.get(loc_type, 1), @@ -164,7 +166,7 @@ async def locations_section(data_client): metadata=metadata, ) ) - response_dict = response.to_dict() + response_dict = MessageToDict(response) st.success(f"✅ Location '{loc_name}' created successfully!") st.write(response_dict) diff --git a/src/dataplatform/toolbox/main.py b/src/dataplatform/toolbox/main.py index 80d45bb2..f22041b3 100644 --- a/src/dataplatform/toolbox/main.py +++ b/src/dataplatform/toolbox/main.py @@ -1,14 +1,16 @@ """Data Platform Toolbox Streamlit Page Main Code.""" import asyncio -from grpclib.client import Channel +import grpc.aio import streamlit as st from dataplatform.toolbox.organisation import organisation_section from dataplatform.toolbox.users import users_section from dataplatform.toolbox.user_organisation import user_organisation_section from dataplatform.toolbox.location import locations_section from dataplatform.toolbox.policy import policies_section -from ocf import dp +from ocf.dp.dp import common_pb2 +from ocf.dp.dp_data import messages_pb2, service_pb2_grpc +from ocf.dp.dp_admin import messages_pb2 as dp_admin_messages_pb2, service_pb2_grpc as dp_admin_service_pb2_grpc import os # Color scheme (matching existing toolbox) @@ -28,9 +30,10 @@ async def async_dataplatform_toolbox_page(): """Async Main function for the Data Platform Toolbox Streamlit page.""" host = os.environ.get("DATA_PLATFORM_HOST", "localhost") port = os.environ.get("DATA_PLATFORM_PORT", "50051") - async with Channel(host=host, port=int(port)) as channel: - admin_client = dp.DataPlatformAdministrationServiceStub(channel) - data_client = dp.DataPlatformDataServiceStub(channel) + channel = grpc.aio.insecure_channel(f"{host}:{int(port)}") + try: + admin_client = dp_admin_service_pb2_grpc.DataPlatformAdministrationServiceStub(channel) + data_client = service_pb2_grpc.DataPlatformDataServiceStub(channel) st.markdown( '

Data Platform Toolbox

', @@ -63,6 +66,9 @@ async def async_dataplatform_toolbox_page(): with tab5: await policies_section(admin_client, data_client) + finally: + await channel.close() + # Required for the tests to run this as a script if __name__ == "__main__": diff --git a/src/dataplatform/toolbox/organisation.py b/src/dataplatform/toolbox/organisation.py index 5458ecd2..3cd7360e 100644 --- a/src/dataplatform/toolbox/organisation.py +++ b/src/dataplatform/toolbox/organisation.py @@ -2,8 +2,9 @@ import streamlit as st import json -from ocf import dp +from ocf.dp.dp_admin import messages_pb2, service_pb2_grpc from grpclib.exceptions import GRPCError +from google.protobuf.json_format import MessageToDict async def organisation_section(admin_client): @@ -22,10 +23,10 @@ async def organisation_section(admin_client): st.warning("⚠️ Please enter an organisation name") else: try: - response = await admin_client.get_organisation( - dp.GetOrganisationRequest(org_name=org_name) + response = await admin_client.GetOrganisation( + messages_pb2.GetOrganisationRequest(org_name=org_name) ) - response_dict = response.to_dict() + response_dict = MessageToDict(response) st.success(f"✅ Found organisation: {org_name}") st.write(response_dict) @@ -61,12 +62,12 @@ async def organisation_section(admin_client): metadata = ( json.loads(metadata_json) if metadata_json.strip() else {} ) - response = await admin_client.create_organisation( - dp.CreateOrganisationRequest( + response = await admin_client.CreateOrganisation( + messages_pb2.CreateOrganisationRequest( org_name=new_org_name, metadata=metadata ) ) - response_dict = response.to_dict() + response_dict = MessageToDict(response) st.success( f"✅ Organisation '{new_org_name}' created successfully!" ) @@ -104,8 +105,8 @@ async def organisation_section(admin_client): st.warning("⚠️ Please confirm deletion by checking the box above") else: try: - await admin_client.delete_organisation( - dp.DeleteOrganisationRequest(org_name=del_org_name) + await admin_client.DeleteOrganisation( + messages_pb2.DeleteOrganisationRequest(org_name=del_org_name) ) st.success( f"✅ Organisation '{del_org_name}' deleted successfully!" diff --git a/src/dataplatform/toolbox/policy.py b/src/dataplatform/toolbox/policy.py index f6938af7..9352ae50 100644 --- a/src/dataplatform/toolbox/policy.py +++ b/src/dataplatform/toolbox/policy.py @@ -2,7 +2,10 @@ import streamlit as st from grpclib.exceptions import GRPCError -from ocf import dp +from ocf.dp.dp import common_pb2 +from ocf.dp.dp_admin import messages_pb2, service_pb2_grpc +from ocf.dp.dp_data import messages_pb2 as data_messages_pb2 +from google.protobuf.json_format import MessageToDict async def policies_section(admin_client, data_client): @@ -10,13 +13,13 @@ async def policies_section(admin_client, data_client): # Permission mappings PERMISSIONS = { - "READ": dp.Permission.READ, - "WRITE": dp.Permission.WRITE, + "READ": common_pb2.Permission.PERMISSION_READ, + "WRITE": common_pb2.Permission.PERMISSION_WRITE, } ENERGY_SOURCES = { - "SOLAR": dp.EnergySource.SOLAR, - "WIND": dp.EnergySource.WIND, + "SOLAR": common_pb2.EnergySource.ENERGY_SOURCE_SOLAR, + "WIND": common_pb2.EnergySource.ENERGY_SOURCE_WIND, } # Create Location Policy Group @@ -35,10 +38,10 @@ async def policies_section(admin_client, data_client): st.warning("⚠️ Please enter a policy group name") else: try: - response = await admin_client.create_location_policy_group( - dp.CreateLocationPolicyGroupRequest(name=new_policy_group_name) + response = await admin_client.CreateLocationPolicyGroup( + messages_pb2.CreateLocationPolicyGroupRequest(name=new_policy_group_name) ) - response_dict = response.to_dict() + response_dict = MessageToDict(response) st.success(f"✅ Policy Group '{new_policy_group_name}' created!") st.write(response_dict) except GRPCError as e: @@ -61,12 +64,12 @@ async def policies_section(admin_client, data_client): st.warning("⚠️ Please enter a policy group name") else: try: - response = await admin_client.get_location_policy_group( - dp.GetLocationPolicyGroupRequest( + response = await admin_client.GetLocationPolicyGroup( + messages_pb2.GetLocationPolicyGroupRequest( location_policy_group_name=policy_group_name ) ) - response_dict = response.to_dict() + response_dict = MessageToDict(response) st.success(f"✅ Found policy group: {policy_group_name}") st.write(response_dict) @@ -88,8 +91,8 @@ async def policies_section(admin_client, data_client): if data_client: try: - response = await data_client.list_locations(dp.ListLocationsRequest()) - response_dict = response.to_dict() + response = await data_client.ListLocations(data_messages_pb2.ListLocationsRequest()) + response_dict = MessageToDict(response) locations = response_dict.get("locations", []) except Exception as e: st.error(f"❌ Failed to fetch locations: {e}") @@ -124,11 +127,11 @@ async def policies_section(admin_client, data_client): st.warning("⚠️ Please fill in all required fields") else: try: - await admin_client.add_location_policies_to_group( - dp.AddLocationPoliciesToGroupRequest( + await admin_client.AddLocationPoliciesToGroup( + messages_pb2.AddLocationPoliciesToGroupRequest( location_policy_group_name=add_policy_group, location_policies=[ - dp.LocationPolicy( + messages_pb2.LocationPolicy( location_id=add_location_id, energy_source=ENERGY_SOURCES[add_energy_source], permission=PERMISSIONS[add_permission], @@ -157,8 +160,8 @@ async def policies_section(admin_client, data_client): if data_client: try: - response = await data_client.list_locations(dp.ListLocationsRequest()) - response_dict = response.to_dict() + response = await data_client.ListLocations(data_messages_pb2.ListLocationsRequest()) + response_dict = MessageToDict(response) locations = response_dict.get("locations", []) except Exception as e: st.error(f"❌ Failed to fetch locations: {e}") @@ -194,11 +197,11 @@ async def policies_section(admin_client, data_client): st.warning("⚠️ Please fill in all required fields") else: try: - await admin_client.remove_location_policies_from_group( - dp.RemoveLocationPoliciesFromGroupRequest( + await admin_client.RemoveLocationPoliciesFromGroup( + messages_pb2.RemoveLocationPoliciesFromGroupRequest( location_policy_group_name=remove_policy_group, location_policies=[ - dp.LocationPolicy( + messages_pb2.LocationPolicy( location_id=remove_location_id, energy_source=ENERGY_SOURCES[remove_energy_source], permission=PERMISSIONS[remove_permission], @@ -228,8 +231,8 @@ async def policies_section(admin_client, data_client): st.warning("⚠️ Please fill in all fields") else: try: - await admin_client.add_location_policy_group_to_organisation( - dp.AddLocationPolicyGroupToOrganisationRequest( + await admin_client.AddLocationPolicyGroupToOrganisation( + messages_pb2.AddLocationPolicyGroupToOrganisationRequest( org_name=add_pg_org, location_policy_group_name=add_pg_name ) ) @@ -264,8 +267,8 @@ async def policies_section(admin_client, data_client): st.error("❌ Could not connect to Data Platform") else: try: - await admin_client.remove_location_policy_group_from_organisation( - dp.RemoveLocationPolicyGroupFromOrganisationRequest( + await admin_client.RemoveLocationPolicyGroupFromOrganisation( + messages_pb2.RemoveLocationPolicyGroupFromOrganisationRequest( org_name=remove_policy_group_org, location_policy_group_name=remove_policy_group_name, ) diff --git a/src/dataplatform/toolbox/user_organisation.py b/src/dataplatform/toolbox/user_organisation.py index cd9c8eae..f8ec9249 100644 --- a/src/dataplatform/toolbox/user_organisation.py +++ b/src/dataplatform/toolbox/user_organisation.py @@ -1,7 +1,7 @@ """User-Organisation relationship management section for the Data Platform Toolbox.""" import streamlit as st -from ocf import dp +from ocf.dp.dp_admin import messages_pb2 from grpclib.exceptions import GRPCError @@ -22,8 +22,8 @@ async def user_organisation_section(admin_client): st.warning("⚠️ Please fill in all fields") else: try: - await admin_client.add_user_to_organisation( - dp.AddUserToOrganisationRequest( + await admin_client.AddUserToOrganisation( + messages_pb2.AddUserToOrganisationRequest( org_name=add_org, user_oauth_id=add_user_oauth ) ) @@ -52,8 +52,8 @@ async def user_organisation_section(admin_client): st.warning("⚠️ Please fill in all fields") else: try: - await admin_client.remove_user_from_organisation( - dp.RemoveUserFromOrganisationRequest( + await admin_client.RemoveUserFromOrganisation( + messages_pb2.RemoveUserFromOrganisationRequest( org_name=remove_org, user_oauth_id=remove_user_oauth ) ) diff --git a/src/dataplatform/toolbox/users.py b/src/dataplatform/toolbox/users.py index 211087b6..731f5a79 100644 --- a/src/dataplatform/toolbox/users.py +++ b/src/dataplatform/toolbox/users.py @@ -2,9 +2,9 @@ import streamlit as st import json -from ocf import dp +from ocf.dp.dp_admin import messages_pb2 from grpclib.exceptions import GRPCError - +from google.protobuf.json_format import MessageToDict async def users_section(admin_client): """User management section.""" @@ -22,10 +22,10 @@ async def users_section(admin_client): st.warning("⚠️ Please enter an OAuth ID") else: try: - response = await admin_client.get_user( - dp.GetUserRequest(oauth_id=oauth_id) + response = await admin_client.GetUser( + messages_pb2.GetUserRequest(oauth_id=oauth_id) ) - response_dict = response.to_dict() + response_dict = MessageToDict(response) st.success(f"✅ Found user: {oauth_id}") st.write(response_dict) @@ -65,14 +65,14 @@ async def users_section(admin_client): metadata = ( json.loads(user_metadata) if user_metadata.strip() else {} ) - response = await admin_client.create_user( - dp.CreateUserRequest( + response = await admin_client.CreateUser( + messages_pb2.CreateUserRequest( oauth_id=new_oauth_id, organisation=user_org, metadata=metadata, ) ) - response_dict = response.to_dict() + response_dict = MessageToDict(response) st.success( f"✅ User '{new_oauth_id}' created in organisation '{user_org}'!" ) @@ -113,8 +113,8 @@ async def users_section(admin_client): else: try: # admin_client.DeleteUser({"user_id": del_user_id}) - await admin_client.delete_user( - dp.DeleteUserRequest(user_id=del_user_id) + await admin_client.DeleteUser( + messages_pb2.DeleteUserRequest(user_id=del_user_id) ) st.success(f"✅ User '{del_user_id}' deleted successfully!") diff --git a/src/plots/elexon_plots.py b/src/plots/elexon_plots.py index ea6edc65..b67cd1db 100644 --- a/src/plots/elexon_plots.py +++ b/src/plots/elexon_plots.py @@ -5,6 +5,7 @@ import streamlit as st from elexonpy.api_client import ApiClient from elexonpy.api.generation_forecast_api import GenerationForecastApi +from google.protobuf.json_format import MessageToDict def add_elexon_plot( @@ -105,7 +106,11 @@ def fetch_forecast_data( if not response.data: return pd.DataFrame() - df = pd.DataFrame([item.to_dict() for item in response.data]) + df = pd.DataFrame([MessageToDict( + item, + preserving_proto_field_name=True, + always_print_fields_with_no_presence=True + ) for item in response.data]) solar_df = df[df["business_type"] == "Solar generation"] solar_df["start_time"] = pd.to_datetime(solar_df["start_time"]) solar_df = solar_df.set_index("start_time") diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index dc8d0889..0ca602e3 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -9,8 +9,10 @@ from importlib.metadata import version import os from streamlit.testing.v1 import AppTest -from ocf import dp -from grpclib.client import Channel +from ocf.dp.dp import common_pb2 +from ocf.dp.dp_data import messages_pb2, service_pb2_grpc +from ocf.dp.dp_admin import messages_pb2 as dp_admin_messages_pb2, service_pb2_grpc as dp_admin_service_pb2_grpc +import grpc.aio DATA_PLATFORM_GRPC_PORT = 50051 DATA_PLATFORM_STARTUP_TIMEOUT_SECONDS = 60 @@ -64,19 +66,19 @@ async def dp_channel(): os.environ["DATA_PLATFORM_HOST"] = host os.environ["DATA_PLATFORM_PORT"] = str(port) - channel = Channel(host=host, port=port) + channel = grpc.aio.insecure_channel(f"{host}:{port}") yield channel - channel.close() + await channel.close() @pytest_asyncio.fixture(scope="session") async def admin_client(dp_channel): - return dp.DataPlatformAdministrationServiceStub(dp_channel) + return dp_admin_service_pb2_grpc.DataPlatformAdministrationServiceStub(dp_channel) @pytest_asyncio.fixture(scope="session") async def data_client(dp_channel): - return dp.DataPlatformDataServiceStub(dp_channel) + return service_pb2_grpc.DataPlatformDataServiceStub(dp_channel) @pytest.fixture @@ -103,41 +105,41 @@ def random_policy_name(): async def create_org_grpc(admin_client, org_name: str): - await admin_client.create_organisation( - dp.CreateOrganisationRequest(org_name=org_name, metadata={}) + await admin_client.CreateOrganisation( + dp_admin_messages_pb2.CreateOrganisationRequest(org_name=org_name, metadata={}) ) async def get_org_grpc(admin_client, org_name: str): - return await admin_client.get_organisation( - dp.GetOrganisationRequest(org_name=org_name) + return await admin_client.GetOrganisation( + dp_admin_messages_pb2.GetOrganisationRequest(org_name=org_name) ) async def create_user_grpc(admin_client, user_oauth_id: str, org_name: str): - await admin_client.create_user( - dp.CreateUserRequest(oauth_id=user_oauth_id, organisation=org_name, metadata={}) + await admin_client.CreateUser( + dp_admin_messages_pb2.CreateUserRequest(oauth_id=user_oauth_id, organisation=org_name, metadata={}) ) async def get_user_grpc(admin_client, user_oauth_id: str): - return await admin_client.get_user(dp.GetUserRequest(oauth_id=user_oauth_id)) + return await admin_client.GetUser(dp_admin_messages_pb2.GetUserRequest(oauth_id=user_oauth_id)) async def add_user_to_org_grpc(admin_client, user_oauth_id: str, org_name: str): - return await admin_client.add_user_to_organisation( - dp.AddUserToOrganisationRequest(org_name=org_name, user_oauth_id=user_oauth_id) + return await admin_client.AddUserToOrganisation( + dp_admin_messages_pb2.AddUserToOrganisationRequest(org_name=org_name, user_oauth_id=user_oauth_id) ) async def create_location_grpc( data_client, location_name: str, - energy_source=dp.EnergySource.SOLAR, - location_type=dp.LocationType.SITE, + energy_source=common_pb2.EnergySource.ENERGY_SOURCE_SOLAR, + location_type=common_pb2.LocationType.LOCATION_TYPE_SITE, ): - return await data_client.create_location( - dp.CreateLocationRequest( + return await data_client.CreateLocation( + messages_pb2.CreateLocationRequest( location_name=location_name, energy_source=energy_source, geometry_wkt="POINT(0 0)", @@ -149,18 +151,18 @@ async def create_location_grpc( async def list_locations_grpc(data_client): - return await data_client.list_locations(dp.ListLocationsRequest()) + return await data_client.ListLocations(messages_pb2.ListLocationsRequest()) async def create_policy_group_grpc(admin_client, policy_name: str): - await admin_client.create_location_policy_group( - dp.CreateLocationPolicyGroupRequest(name=policy_name) + await admin_client.CreateLocationPolicyGroup( + dp_admin_messages_pb2.CreateLocationPolicyGroupRequest(name=policy_name) ) async def get_policy_group_grpc(admin_client, policy_name: str): - return await admin_client.get_location_policy_group( - dp.GetLocationPolicyGroupRequest(location_policy_group_name=policy_name) + return await admin_client.GetLocationPolicyGroup( + dp_admin_messages_pb2.GetLocationPolicyGroupRequest(location_policy_group_name=policy_name) ) @@ -168,14 +170,14 @@ async def add_policy_to_group_grpc( admin_client, policy_name: str, location_uuid: str, - energy_source=dp.EnergySource.WIND, - permission=dp.Permission.WRITE, + energy_source=common_pb2.EnergySource.ENERGY_SOURCE_WIND, + permission=common_pb2.Permission.PERMISSION_WRITE, ): - await admin_client.add_location_policies_to_group( - dp.AddLocationPoliciesToGroupRequest( + await admin_client.AddLocationPoliciesToGroup( + dp_admin_messages_pb2.AddLocationPoliciesToGroupRequest( location_policy_group_name=policy_name, location_policies=[ - dp.LocationPolicy( + dp_admin_messages_pb2.LocationPolicy( location_id=location_uuid, energy_source=energy_source, permission=permission, @@ -186,8 +188,8 @@ async def add_policy_to_group_grpc( async def add_policy_to_org_grpc(admin_client, org_name, policy_name): - await admin_client.add_location_policy_group_to_organisation( - dp.AddLocationPolicyGroupToOrganisationRequest( + await admin_client.AddLocationPolicyGroupToOrganisation( + dp_admin_messages_pb2.AddLocationPolicyGroupToOrganisationRequest( org_name=org_name, location_policy_group_name=policy_name, ) diff --git a/tests/integration/test_locations_ui.py b/tests/integration/test_locations_ui.py index 1a78e4c0..8805dfaf 100644 --- a/tests/integration/test_locations_ui.py +++ b/tests/integration/test_locations_ui.py @@ -6,7 +6,8 @@ """ import pytest -from ocf import dp + +from ocf.dp.dp_data import service_pb2_grpc from tests.integration.conftest import ( create_location_grpc, @@ -17,7 +18,7 @@ @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") -async def test_list_locations_ui(app, data_client: dp.DataPlatformDataServiceStub): +async def test_list_locations_ui(app, data_client: service_pb2_grpc.DataPlatformDataServiceStub): """ - create some locations via grpc - fill in list locations form and submit @@ -51,7 +52,7 @@ async def test_list_locations_ui(app, data_client: dp.DataPlatformDataServiceStu @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") -async def test_get_location_ui(app, data_client: dp.DataPlatformDataServiceStub): +async def test_get_location_ui(app, data_client: service_pb2_grpc.DataPlatformDataServiceStub): """ - create a location via grpc - fill in get location form and submit @@ -74,7 +75,7 @@ async def test_get_location_ui(app, data_client: dp.DataPlatformDataServiceStub) @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") -async def test_create_location_ui(app, data_client: dp.DataPlatformDataServiceStub): +async def test_create_location_ui(app, data_client: service_pb2_grpc.DataPlatformDataServiceStub): """ - fill in create location form and submit - assert success message @@ -92,6 +93,9 @@ async def test_create_location_ui(app, data_client: dp.DataPlatformDataServiceSt app.button("create_location_button").click() app.run() + # assert no errors in app + assert len(app.error) == 0, f"Expected no errors, but got: {[e.value for e in app.error]}" + # Assert success message in UI assert any("created" in s.value.lower() for s in app.success) diff --git a/tests/integration/test_organisations_ui.py b/tests/integration/test_organisations_ui.py index 3d68b3b2..2588ccd7 100644 --- a/tests/integration/test_organisations_ui.py +++ b/tests/integration/test_organisations_ui.py @@ -6,7 +6,7 @@ """ import pytest -from ocf import dp +from ocf.dp.dp_admin import messages_pb2, service_pb2_grpc from tests.integration.conftest import create_org_grpc, random_org_name @@ -14,7 +14,7 @@ @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") async def test_create_organisation_ui( - app, admin_client: dp.DataPlatformAdministrationServiceStub + app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub ): """ - create random org name @@ -37,8 +37,8 @@ async def test_create_organisation_ui( # Assert success assert any("created" in s.value.lower() for s in app.success) - response = await admin_client.get_organisation( - dp.GetOrganisationRequest(org_name=org_name) + response = await admin_client.GetOrganisation( + messages_pb2.GetOrganisationRequest(org_name=org_name) ) assert response.org_name == org_name @@ -46,7 +46,7 @@ async def test_create_organisation_ui( @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") async def test_get_organisation_ui( - app, admin_client: dp.DataPlatformAdministrationServiceStub + app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub ): """ - create random org name via grpc @@ -66,7 +66,7 @@ async def test_get_organisation_ui( @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") async def test_delete_organisation_ui( - app, admin_client: dp.DataPlatformAdministrationServiceStub + app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub ): """ - create random org name via grpc @@ -87,6 +87,6 @@ async def test_delete_organisation_ui( # verify deletion via grpc with pytest.raises(Exception): - await admin_client.get_organisation( - dp.GetOrganisationRequest(org_name=org_name) + await admin_client.GetOrganisation( + messages_pb2.GetOrganisationRequest(org_name=org_name) ) diff --git a/tests/integration/test_policy_ui.py b/tests/integration/test_policy_ui.py index 44a51b89..2e018ddd 100644 --- a/tests/integration/test_policy_ui.py +++ b/tests/integration/test_policy_ui.py @@ -9,7 +9,8 @@ """ import pytest -from ocf import dp +from ocf.dp.dp_admin import messages_pb2, service_pb2_grpc +from ocf.dp.dp_data import messages_pb2 as data_messages_pb2, service_pb2_grpc as data_service_pb2_grpc from tests.integration.conftest import ( add_policy_to_group_grpc, @@ -28,7 +29,7 @@ @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") async def test_create_policy_ui( - app, admin_client: dp.DataPlatformAdministrationServiceStub + app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub ): """ - fill in create policy group form and submit @@ -56,7 +57,7 @@ async def test_create_policy_ui( @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") async def test_get_policy_ui( - app, admin_client: dp.DataPlatformAdministrationServiceStub + app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub ): """ - create a policy group via grpc @@ -77,8 +78,8 @@ async def test_get_policy_ui( @pytest.mark.asyncio(loop_scope="session") async def test_add_policy_to_group( app, - admin_client: dp.DataPlatformAdministrationServiceStub, - data_client: dp.DataPlatformDataServiceStub, + admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub, + data_client: data_service_pb2_grpc.DataPlatformDataServiceStub, ): """ - create a policy group via grpc @@ -108,8 +109,8 @@ async def test_add_policy_to_group( @pytest.mark.asyncio(loop_scope="session") async def test_remove_policy_from_group( app, - admin_client: dp.DataPlatformAdministrationServiceStub, - data_client: dp.DataPlatformDataServiceStub, + admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub, + data_client: data_service_pb2_grpc.DataPlatformDataServiceStub, ): """ - create a policy group via grpc @@ -145,7 +146,7 @@ async def test_remove_policy_from_group( @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") async def test_add_policy_to_org( - app, admin_client: dp.DataPlatformAdministrationServiceStub + app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub ): """ - create a policy group via grpc @@ -173,7 +174,7 @@ async def test_add_policy_to_org( @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") async def test_remove_policy_from_org( - app, admin_client: dp.DataPlatformAdministrationServiceStub + app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub ): """ - create a policy group via grpc diff --git a/tests/integration/test_user_org_ui.py b/tests/integration/test_user_org_ui.py index cea30eb8..58d502d4 100644 --- a/tests/integration/test_user_org_ui.py +++ b/tests/integration/test_user_org_ui.py @@ -5,7 +5,7 @@ """ import pytest -from ocf import dp +from ocf.dp.dp_admin import service_pb2_grpc from tests.integration.conftest import ( @@ -21,7 +21,7 @@ @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") async def test_add_user_org_ui( - app, admin_client: dp.DataPlatformAdministrationServiceStub + app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub ): """ - create two orgs and a user in one org @@ -47,7 +47,9 @@ async def test_add_user_org_ui( app.button("add_user_to_org_button").click() app.run() + # Assert success + assert len(app.error) == 0, f"Expected no errors, but got: {[s.value for s in app.error]}" assert any("added" in s.value.lower() for s in app.success) user = await get_user_grpc(admin_client, user_id) @@ -62,7 +64,7 @@ async def test_add_user_org_ui( @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") async def test_remove_user_org_ui( - app, admin_client: dp.DataPlatformAdministrationServiceStub + app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub ): """ - create two orgs and a user in both orgs diff --git a/tests/integration/test_users_ui.py b/tests/integration/test_users_ui.py index 347346b2..2bb762a2 100644 --- a/tests/integration/test_users_ui.py +++ b/tests/integration/test_users_ui.py @@ -6,7 +6,7 @@ """ import pytest -from ocf import dp +from ocf.dp.dp_admin import service_pb2_grpc from tests.integration.conftest import ( create_org_grpc, @@ -20,7 +20,7 @@ @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") async def test_create_user_ui( - app, admin_client: dp.DataPlatformAdministrationServiceStub + app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub ): """ - create an org for the user @@ -47,6 +47,7 @@ async def test_create_user_ui( app.run() # Assert success + assert len(app.error) == 0, f"Expected no errors, but got: {[e.value for e in app.error]}" assert any("created" in s.value.lower() for s in app.success) response = await get_user_grpc(admin_client, user_oauth_id) @@ -55,7 +56,7 @@ async def test_create_user_ui( @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") -async def test_get_user_ui(app, admin_client: dp.DataPlatformAdministrationServiceStub): +async def test_get_user_ui(app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub): """ - create random user via grpc - fill in get user form and submit @@ -70,13 +71,14 @@ async def test_get_user_ui(app, admin_client: dp.DataPlatformAdministrationServi app.button("get_user_button").click() app.run() + assert len(app.error) == 0, f"Expected no errors, but got: {[e.value for e in app.error]}" assert any(user_oauth_id in s.value for s in app.success) @pytest.mark.integration @pytest.mark.asyncio(loop_scope="session") async def test_delete_user_ui( - app, admin_client: dp.DataPlatformAdministrationServiceStub + app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub ): """ - create random user via grpc diff --git a/tests/test_elexon_plot.py b/tests/test_elexon_plot.py index 2b94df0e..d2c59467 100644 --- a/tests/test_elexon_plot.py +++ b/tests/test_elexon_plot.py @@ -74,6 +74,7 @@ def test_add_elexon_plot_no_data(mock_fetch): assert len(updated_fig.data) == 0, "Figure should have no traces added if no data is available" @pytest.mark.integration +@pytest.mark.skip(reason="Elexonpy not currently being used") def test_fetch_forecast_data_integration(): # Initialize the actual API client and the function to be tested api_client = ApiClient() @@ -92,4 +93,4 @@ def test_fetch_forecast_data_integration(): assert not result.empty, "DataFrame should not be empty" assert "start_time" in result.columns, "DataFrame should contain 'start_time' column" assert "quantity" in result.columns, "DataFrame should contain 'quantity' column" - assert result["quantity"].notna().all(), "Quantity values should not be NaN" \ No newline at end of file + assert result["quantity"].notna().all(), "Quantity values should not be NaN"