diff --git a/pyproject.toml b/pyproject.toml
index d1301d19..1bbf5442 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -18,7 +18,7 @@ dependencies = [
"plotly==5.24.1",
"psycopg2-binary==2.9.10",
"SQLAlchemy==2.0.36",
- "streamlit==1.51.0",
+ "streamlit==1.57.0",
"uvicorn==0.34.0",
"geopandas==1.0.1",
"garrett-streamlit-auth0==0.9",
@@ -37,9 +37,7 @@ dependencies = [
"matplotlib>=3.8,<4.0",
"dp-sdk",
"aiocache",
- "grpcio>=1.80.0",
- "grpcio-tools>=1.50.0",
- "grpc-requests>=0.1.17",
+ "grpclib==0.4.8",
"betterproto>=2.0.0b7",
"pytest-asyncio>=1.3.0",
"testcontainers>=4.14.0",
@@ -77,7 +75,7 @@ extra-index-url = ["https://pypi.org/simple"]
[tool.uv.sources]
-dp-sdk = { url = "https://github.com/openclimatefix/data-platform/releases/download/v0.28.0/dp_sdk-0.28.0-py3-none-any.whl" }
+dp-sdk = { url = "https://github.com/openclimatefix/data-platform/releases/download/v0.30.1/dp_sdk-0.30.1-py3-none-any.whl" }
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
diff --git a/src/dataplatform/forecast/backend.py b/src/dataplatform/forecast/backend.py
index 7b5ac14c..9752a67c 100644
--- a/src/dataplatform/forecast/backend.py
+++ b/src/dataplatform/forecast/backend.py
@@ -6,21 +6,22 @@
import pandas as pd
import streamlit as st
-from ocf import dp
+from ocf.dp.dp_data import messages_pb2, service_pb2_grpc
+from ocf.dp.dp import common_pb2
async def fetch_timeseries(
- client: dp.DataPlatformDataServiceStub,
+ client: service_pb2_grpc.DataPlatformDataServiceStub,
location_uuid: str,
start_date: datetime.datetime,
end_date: datetime.datetime,
horizon_mins: int,
- forecasters: list[dp.Forecaster],
+ forecasters: list[messages_pb2.Forecaster],
init_times_utc: list[datetime.datetime] | None = None,
) -> pd.DataFrame:
"""Directly calls GetForecastAsTimeseries for selected models and init times."""
- time_window = dp.TimeWindow(
+ time_window = messages_pb2.TimeWindow(
start_timestamp_utc=start_date, end_timestamp_utc=end_date
)
time_windows = []
@@ -28,7 +29,7 @@ async def fetch_timeseries(
while current_start < end_date:
current_end = min(current_start + datetime.timedelta(days=7), end_date)
time_windows.append(
- dp.TimeWindow(
+ messages_pb2.TimeWindow(
start_timestamp_utc=current_start,
end_timestamp_utc=current_end
)
@@ -38,13 +39,13 @@ async def fetch_timeseries(
times_to_fetch = init_times_utc if init_times_utc else [None]
async def fetch_one(
- forecaster_obj: dp.Forecaster,
- window: dp.TimeWindow,
+ forecaster_obj: messages_pb2.Forecaster,
+ window: messages_pb2.TimeWindow,
init_time: datetime.datetime | None
):
- req = dp.GetForecastAsTimeseriesRequest(
+ req = messages_pb2.GetForecastAsTimeseriesRequest(
location_uuid=location_uuid,
- energy_source=dp.EnergySource.SOLAR,
+ energy_source=common_pb2.EnergySource.ENERGY_SOURCE_SOLAR,
horizon_mins=horizon_mins,
time_window=window,
forecaster=forecaster_obj,
@@ -52,18 +53,18 @@ async def fetch_one(
)
try:
- resp = await client.get_forecast_as_timeseries(req)
+ resp = await client.GetForecastAsTimeseries(req)
rows = []
for val in resp.values:
row = {
- "target_timestamp_utc": val.target_timestamp_utc,
- "initialization_timestamp_utc": val.initialization_timestamp_utc,
- "created_timestamp_utc": val.created_timestamp_utc,
+ "target_timestamp_utc": val.target_timestamp_utc.ToDatetime(tzinfo=datetime.UTC),
+ "initialization_timestamp_utc": val.initialization_timestamp_utc.ToDatetime(tzinfo=datetime.UTC),
+ "created_timestamp_utc": val.created_timestamp_utc.ToDatetime(tzinfo=datetime.UTC),
"effective_capacity_watts": val.effective_capacity_watts,
"forecaster_name": forecaster_obj.forecaster_name,
"location_uuid": resp.location_uuid,
"horizon_mins": (
- val.target_timestamp_utc - val.initialization_timestamp_utc
+ val.target_timestamp_utc.ToDatetime(tzinfo=datetime.UTC) - val.initialization_timestamp_utc.ToDatetime(tzinfo=datetime.UTC)
).total_seconds()
// 60,
"p50_watts": int(
@@ -108,16 +109,16 @@ async def fetch_one(
async def fetch_observations(
- client: dp.DataPlatformDataServiceStub,
+ client: service_pb2_grpc.DataPlatformDataServiceStub,
location_uuid: str,
start_date: datetime.datetime,
end_date: datetime.datetime,
observers: list[str],
- energy_source: dp.EnergySource = dp.EnergySource.SOLAR,
+ energy_source: common_pb2.EnergySource = common_pb2.EnergySource.ENERGY_SOURCE_SOLAR,
) -> pd.DataFrame:
"""Directly calls GetObservationsAsTimeseries for selected observers."""
- time_window = dp.TimeWindow(
+ time_window = messages_pb2.TimeWindow(
start_timestamp_utc=start_date, end_timestamp_utc=end_date
)
time_windows = []
@@ -125,7 +126,7 @@ async def fetch_observations(
while current_start < end_date:
current_end = min(current_start + datetime.timedelta(days=7), end_date)
time_windows.append(
- dp.TimeWindow(
+ messages_pb2.TimeWindow(
start_timestamp_utc=current_start,
end_timestamp_utc=current_end
)
@@ -134,8 +135,8 @@ async def fetch_observations(
# Run requests concurrently for all selected observers
- async def fetch_one(obs_name: str, window: dp.TimeWindow):
- req = dp.GetObservationsAsTimeseriesRequest(
+ async def fetch_one(obs_name: str, window: messages_pb2.TimeWindow):
+ req = messages_pb2.GetObservationsAsTimeseriesRequest(
location_uuid=location_uuid,
observer_name=obs_name,
energy_source=energy_source,
@@ -143,12 +144,12 @@ async def fetch_one(obs_name: str, window: dp.TimeWindow):
)
try:
- resp = await client.get_observations_as_timeseries(req)
+ resp = await client.GetObservationsAsTimeseries(req)
rows = []
for val in resp.values:
rows.append(
{
- "target_timestamp_utc": val.timestamp_utc,
+ "target_timestamp_utc": val.timestamp_utc.ToDatetime(tzinfo=datetime.UTC),
"value_fraction": val.value_fraction,
"effective_capacity_watts": val.effective_capacity_watts,
"observer_name": obs_name,
diff --git a/src/dataplatform/forecast/cache.py b/src/dataplatform/forecast/cache.py
index 280ebac6..8b9628c8 100644
--- a/src/dataplatform/forecast/cache.py
+++ b/src/dataplatform/forecast/cache.py
@@ -2,7 +2,7 @@
from datetime import UTC, datetime, timedelta
-from ocf import dp
+from ocf.dp.dp_data import service_pb2_grpc
from dataplatform.forecast.constant import cache_seconds
@@ -11,7 +11,7 @@ def key_builder_remove_client(func: callable, *args: list, **kwargs: dict) -> st
"""Custom key builder that ignores the client argument for caching purposes."""
key = f"{func.__name__}:"
for arg in args:
- if not isinstance(arg, dp.DataPlatformDataServiceStub):
+ if not isinstance(arg, service_pb2_grpc.DataPlatformDataServiceStub):
key += f"{arg}-"
for k, v in kwargs.items():
diff --git a/src/dataplatform/forecast/data.py b/src/dataplatform/forecast/data.py
index 0cf05f75..9792029f 100644
--- a/src/dataplatform/forecast/data.py
+++ b/src/dataplatform/forecast/data.py
@@ -5,7 +5,6 @@
import pandas as pd
from aiocache import Cache, cached
-from ocf import dp
from google.protobuf.json_format import MessageToDict
from dataplatform.forecast.cache import key_builder_remove_client
@@ -16,10 +15,10 @@
async def get_forecast_data(
dpc: service_pb2_grpc.DataPlatformDataServiceStub,
- location: dp.ListLocationsResponseLocationSummary,
+ location: messages_pb2.ListLocationsResponse.LocationSummary,
start_date: datetime,
end_date: datetime,
- selected_forecasters: list[dp.Forecaster],
+ selected_forecasters: list[messages_pb2.Forecaster],
) -> pd.DataFrame:
"""Get forecast data for the given location and time window."""
all_data_df = []
@@ -66,10 +65,10 @@ async def get_forecast_data(
@cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client)
async def get_forecast_data_one_forecaster(
dpc: service_pb2_grpc.DataPlatformDataServiceStub,
- location: dp.ListLocationsResponseLocationSummary,
+ location: messages_pb2.ListLocationsResponse.LocationSummary,
start_date: datetime,
end_date: datetime,
- selected_forecaster: dp.Forecaster,
+ selected_forecaster: messages_pb2.Forecaster,
) -> pd.DataFrame | None:
"""Get forecast data for one forecaster for the given location and time window."""
all_data_list_dict = []
@@ -93,7 +92,7 @@ async def get_forecast_data_one_forecaster(
forecasts = []
async for chunk in dpc.StreamForecastData(stream_forecast_data_request):
- forecasts.append(chunk)
+ forecasts.extend(chunk.values)
if len(forecasts) > 0:
all_data_list_dict.extend(MessageToDict(f, always_print_fields_with_no_presence=True) for f in forecasts)
@@ -149,7 +148,7 @@ async def get_forecast_data_one_forecaster(
@cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client)
async def get_all_observations(
client: service_pb2_grpc.DataPlatformDataServiceStub,
- location: dp.ListLocationsResponseLocationSummary,
+ location: messages_pb2.ListLocationsResponse.LocationSummary,
start_date: datetime,
end_date: datetime,
) -> pd.DataFrame:
@@ -225,10 +224,10 @@ async def get_all_observations(
async def get_all_data(
client: service_pb2_grpc.DataPlatformDataServiceStub,
- selected_location: dp.ListLocationsResponseLocationSummary,
+ selected_location: messages_pb2.ListLocationsResponse.LocationSummary,
start_date: datetime,
end_date: datetime,
- selected_forecasters: list[dp.Forecaster],
+ selected_forecasters: list[messages_pb2.Forecaster],
) -> dict:
"""Get all forecast and observation data, and merge them."""
# get generation data
diff --git a/src/dataplatform/forecast/main.py b/src/dataplatform/forecast/main.py
index 5b0dd7e9..a8415e01 100644
--- a/src/dataplatform/forecast/main.py
+++ b/src/dataplatform/forecast/main.py
@@ -7,9 +7,9 @@
import pandas as pd
import streamlit as st
-from grpclib.client import Channel
-
-from ocf import dp
+import grpc.aio
+from ocf.dp.dp import common_pb2
+from ocf.dp.dp_data import messages_pb2, service_pb2_grpc
from dataplatform.forecast.constant import metrics, observer_names
from dataplatform.forecast.backend import fetch_observations, fetch_timeseries
@@ -49,9 +49,10 @@ async def async_dp_forecast_page() -> None:
st.title("Data Platform Forecast Page")
st.write("This is the forecast page from the Data Platform module.")
- async with Channel(host=data_platform_host, port=data_platform_port) as channel:
- client = dp.DataPlatformDataServiceStub(channel)
-
+ channel = grpc.aio.insecure_channel(f"{data_platform_host}:{data_platform_port}"):
+ client = service_pb2_grpc.DataPlatformDataServiceStub(channel)
+
+ try:
cfg = await setup_page(client)
st.divider()
st.subheader("1. Fetch Data")
@@ -76,7 +77,7 @@ async def async_dp_forecast_page() -> None:
start_date=cfg.start_date,
end_date=cfg.end_date,
observers=observer_names,
- energy_source=dp.EnergySource.SOLAR,
+ energy_source=common_pb2.EnergySource.ENERGY_SOURCE_SOLAR,
)
fetch_duration = (datetime.datetime.now() - start_time).total_seconds()
@@ -242,3 +243,6 @@ async def async_dp_forecast_page() -> None:
"Configure your filters in the sidebar and click 'Fetch Forecast & Observations' to begin."
)
+ finally:
+ await channel.close()
+
diff --git a/src/dataplatform/forecast/setup.py b/src/dataplatform/forecast/setup.py
index fe85305e..c5c0d440 100644
--- a/src/dataplatform/forecast/setup.py
+++ b/src/dataplatform/forecast/setup.py
@@ -6,7 +6,8 @@
import pandas as pd
import streamlit as st
from aiocache import Cache, cached
-from ocf import dp
+from ocf.dp.dp import common_pb2
+from ocf.dp.dp_data import messages_pb2, service_pb2_grpc
from dataplatform.forecast.cache import key_builder_remove_client
from dataplatform.forecast.constant import cache_seconds, metrics
@@ -14,11 +15,11 @@
@cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client)
async def get_location_names(
- client: dp.DataPlatformDataServiceStub,
+ client: service_pb2_grpc.DataPlatformDataServiceStub,
) -> dict:
"""Get location names."""
- list_locations_request = dp.ListLocationsRequest()
- list_locations_response = await client.list_locations(list_locations_request)
+ list_locations_request = messages_pb2.ListLocationsRequest()
+ list_locations_response = await client.ListLocations(list_locations_request)
all_locations = list_locations_response.locations
location_names = {loc.location_name: loc for loc in all_locations}
@@ -38,19 +39,19 @@ async def get_location_names(
@cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client)
async def get_forecasters(
- client: dp.DataPlatformDataServiceStub,
-) -> list[dp.Forecaster]:
+ client: service_pb2_grpc.DataPlatformDataServiceStub,
+) -> list[messages_pb2.Forecaster]:
"""Get all forecasters."""
- get_forecasters_request = dp.ListForecastersRequest()
- get_forecasters_response = await client.list_forecasters(get_forecasters_request)
+ get_forecasters_request = messages_pb2.ListForecastersRequest()
+ get_forecasters_response = await client.ListForecasters(get_forecasters_request)
forecasters = get_forecasters_response.forecasters
return forecasters
@dataclasses.dataclass
class PageConfig:
- location: dp.ListLocationsResponseLocationSummary
- forecasters: list[dp.Forecaster]
+ location: messages_pb2.ListLocationsResponse.LocationSummary
+ forecasters: list[messages_pb2.Forecaster]
start_date: dt.datetime
end_date: dt.datetime
forecast_type: str
@@ -62,7 +63,7 @@ class PageConfig:
strict_horizon_filtering: bool
-async def setup_page(client: dp.DataPlatformDataServiceStub) -> PageConfig:
+async def setup_page(client: service_pb2_grpc.DataPlatformDataServiceStub) -> PageConfig:
"""Setup the Streamlit page with sidebar options."""
location_names = await get_location_names(client)
selected_location_name = st.sidebar.selectbox(
diff --git a/src/dataplatform/toolbox/location.py b/src/dataplatform/toolbox/location.py
index 82aaa8ad..4ad03163 100644
--- a/src/dataplatform/toolbox/location.py
+++ b/src/dataplatform/toolbox/location.py
@@ -2,8 +2,10 @@
import streamlit as st
import json
-from ocf import dp
+from ocf.dp.dp import common_pb2
+from ocf.dp.dp_data import messages_pb2
from grpclib.exceptions import GRPCError
+from google.protobuf.json_format import MessageToDict
async def locations_section(data_client):
@@ -11,21 +13,21 @@ async def locations_section(data_client):
# Energy source and location type mappings
ENERGY_SOURCES = {
- "UNSPECIFIED": dp.EnergySource.UNSPECIFIED,
- "SOLAR": dp.EnergySource.SOLAR,
- "WIND": dp.EnergySource.WIND,
+ "UNSPECIFIED": common_pb2.EnergySource.ENERGY_SOURCE_UNSPECIFIED,
+ "SOLAR": common_pb2.EnergySource.ENERGY_SOURCE_SOLAR,
+ "WIND": common_pb2.EnergySource.ENERGY_SOURCE_WIND,
}
LOCATION_TYPES = {
- "UNSPECIFIED": dp.LocationType.UNSPECIFIED,
- "SITE": dp.LocationType.SITE,
- "GSP": dp.LocationType.GSP,
- "DNO": dp.LocationType.DNO,
- "NATION": dp.LocationType.NATION,
- "STATE": dp.LocationType.STATE,
- "COUNTY": dp.LocationType.COUNTY,
- "CITY": dp.LocationType.CITY,
- "PRIMARY SUBSTATION": dp.LocationType.PRIMARY_SUBSTATION,
+ "UNSPECIFIED": common_pb2.LocationType.LOCATION_TYPE_UNSPECIFIED,
+ "SITE": common_pb2.LocationType.LOCATION_TYPE_SITE,
+ "GSP": common_pb2.LocationType.LOCATION_TYPE_GSP,
+ "DNO": common_pb2.LocationType.LOCATION_TYPE_DNO,
+ "NATION": common_pb2.LocationType.LOCATION_TYPE_NATION,
+ "STATE": common_pb2.LocationType.LOCATION_TYPE_STATE,
+ "COUNTY": common_pb2.LocationType.LOCATION_TYPE_COUNTY,
+ "CITY": common_pb2.LocationType.LOCATION_TYPE_CITY,
+ "PRIMARY SUBSTATION": common_pb2.LocationType.LOCATION_TYPE_PRIMARY_SUBSTATION,
}
# List Locations
@@ -50,7 +52,7 @@ async def locations_section(data_client):
st.error("❌ Could not connect to Data Platform")
else:
try:
- request = dp.ListLocationsRequest()
+ request = messages_pb2.ListLocationsRequest()
if energy_source_filter != "UNSPECIFIED":
request.energy_source_filter = ENERGY_SOURCES[energy_source_filter]
if location_type_filter != "UNSPECIFIED":
@@ -58,12 +60,12 @@ async def locations_section(data_client):
if user_filter:
request.user_oauth_id_filter = user_filter
- response = await data_client.list_locations(request)
+ response = await data_client.ListLocations(request)
locations = response.locations
if locations:
st.success(f"✅ Found {len(locations)} location(s)")
- loc_dicts = [loc.to_dict() for loc in locations]
+ loc_dicts = [MessageToDict(loc) for loc in locations]
st.write(loc_dicts)
else:
st.info("No locations found with the specified filters")
@@ -92,14 +94,14 @@ async def locations_section(data_client):
st.error("❌ Could not connect to Data Platform")
else:
try:
- response = await data_client.get_location(
- dp.GetLocationRequest(
+ response = await data_client.GetLocation(
+ messages_pb2.GetLocationRequest(
location_uuid=loc_uuid,
energy_source=ENERGY_SOURCES[loc_energy],
include_geometry=include_geometry,
)
)
- response_dict = response.to_dict()
+ response_dict = MessageToDict(response)
st.success(f"✅ Found location: {loc_uuid}")
st.write(response_dict)
except GRPCError as e:
@@ -154,8 +156,8 @@ async def locations_section(data_client):
try:
# Parse metadata JSON
metadata = json.loads(loc_metadata) if loc_metadata.strip() else {}
- response = await data_client.create_location(
- dp.CreateLocationRequest(
+ response = await data_client.CreateLocation(
+ messages_pb2.CreateLocationRequest(
location_name=loc_name,
energy_source=ENERGY_SOURCES.get(loc_energy_src, 1),
location_type=LOCATION_TYPES.get(loc_type, 1),
@@ -164,7 +166,7 @@ async def locations_section(data_client):
metadata=metadata,
)
)
- response_dict = response.to_dict()
+ response_dict = MessageToDict(response)
st.success(f"✅ Location '{loc_name}' created successfully!")
st.write(response_dict)
diff --git a/src/dataplatform/toolbox/main.py b/src/dataplatform/toolbox/main.py
index 80d45bb2..f22041b3 100644
--- a/src/dataplatform/toolbox/main.py
+++ b/src/dataplatform/toolbox/main.py
@@ -1,14 +1,16 @@
"""Data Platform Toolbox Streamlit Page Main Code."""
import asyncio
-from grpclib.client import Channel
+import grpc.aio
import streamlit as st
from dataplatform.toolbox.organisation import organisation_section
from dataplatform.toolbox.users import users_section
from dataplatform.toolbox.user_organisation import user_organisation_section
from dataplatform.toolbox.location import locations_section
from dataplatform.toolbox.policy import policies_section
-from ocf import dp
+from ocf.dp.dp import common_pb2
+from ocf.dp.dp_data import messages_pb2, service_pb2_grpc
+from ocf.dp.dp_admin import messages_pb2 as dp_admin_messages_pb2, service_pb2_grpc as dp_admin_service_pb2_grpc
import os
# Color scheme (matching existing toolbox)
@@ -28,9 +30,10 @@ async def async_dataplatform_toolbox_page():
"""Async Main function for the Data Platform Toolbox Streamlit page."""
host = os.environ.get("DATA_PLATFORM_HOST", "localhost")
port = os.environ.get("DATA_PLATFORM_PORT", "50051")
- async with Channel(host=host, port=int(port)) as channel:
- admin_client = dp.DataPlatformAdministrationServiceStub(channel)
- data_client = dp.DataPlatformDataServiceStub(channel)
+ channel = grpc.aio.insecure_channel(f"{host}:{int(port)}")
+ try:
+ admin_client = dp_admin_service_pb2_grpc.DataPlatformAdministrationServiceStub(channel)
+ data_client = service_pb2_grpc.DataPlatformDataServiceStub(channel)
st.markdown(
'
Data Platform Toolbox
',
@@ -63,6 +66,9 @@ async def async_dataplatform_toolbox_page():
with tab5:
await policies_section(admin_client, data_client)
+ finally:
+ await channel.close()
+
# Required for the tests to run this as a script
if __name__ == "__main__":
diff --git a/src/dataplatform/toolbox/organisation.py b/src/dataplatform/toolbox/organisation.py
index 5458ecd2..3cd7360e 100644
--- a/src/dataplatform/toolbox/organisation.py
+++ b/src/dataplatform/toolbox/organisation.py
@@ -2,8 +2,9 @@
import streamlit as st
import json
-from ocf import dp
+from ocf.dp.dp_admin import messages_pb2, service_pb2_grpc
from grpclib.exceptions import GRPCError
+from google.protobuf.json_format import MessageToDict
async def organisation_section(admin_client):
@@ -22,10 +23,10 @@ async def organisation_section(admin_client):
st.warning("⚠️ Please enter an organisation name")
else:
try:
- response = await admin_client.get_organisation(
- dp.GetOrganisationRequest(org_name=org_name)
+ response = await admin_client.GetOrganisation(
+ messages_pb2.GetOrganisationRequest(org_name=org_name)
)
- response_dict = response.to_dict()
+ response_dict = MessageToDict(response)
st.success(f"✅ Found organisation: {org_name}")
st.write(response_dict)
@@ -61,12 +62,12 @@ async def organisation_section(admin_client):
metadata = (
json.loads(metadata_json) if metadata_json.strip() else {}
)
- response = await admin_client.create_organisation(
- dp.CreateOrganisationRequest(
+ response = await admin_client.CreateOrganisation(
+ messages_pb2.CreateOrganisationRequest(
org_name=new_org_name, metadata=metadata
)
)
- response_dict = response.to_dict()
+ response_dict = MessageToDict(response)
st.success(
f"✅ Organisation '{new_org_name}' created successfully!"
)
@@ -104,8 +105,8 @@ async def organisation_section(admin_client):
st.warning("⚠️ Please confirm deletion by checking the box above")
else:
try:
- await admin_client.delete_organisation(
- dp.DeleteOrganisationRequest(org_name=del_org_name)
+ await admin_client.DeleteOrganisation(
+ messages_pb2.DeleteOrganisationRequest(org_name=del_org_name)
)
st.success(
f"✅ Organisation '{del_org_name}' deleted successfully!"
diff --git a/src/dataplatform/toolbox/policy.py b/src/dataplatform/toolbox/policy.py
index f6938af7..9352ae50 100644
--- a/src/dataplatform/toolbox/policy.py
+++ b/src/dataplatform/toolbox/policy.py
@@ -2,7 +2,10 @@
import streamlit as st
from grpclib.exceptions import GRPCError
-from ocf import dp
+from ocf.dp.dp import common_pb2
+from ocf.dp.dp_admin import messages_pb2, service_pb2_grpc
+from ocf.dp.dp_data import messages_pb2 as data_messages_pb2
+from google.protobuf.json_format import MessageToDict
async def policies_section(admin_client, data_client):
@@ -10,13 +13,13 @@ async def policies_section(admin_client, data_client):
# Permission mappings
PERMISSIONS = {
- "READ": dp.Permission.READ,
- "WRITE": dp.Permission.WRITE,
+ "READ": common_pb2.Permission.PERMISSION_READ,
+ "WRITE": common_pb2.Permission.PERMISSION_WRITE,
}
ENERGY_SOURCES = {
- "SOLAR": dp.EnergySource.SOLAR,
- "WIND": dp.EnergySource.WIND,
+ "SOLAR": common_pb2.EnergySource.ENERGY_SOURCE_SOLAR,
+ "WIND": common_pb2.EnergySource.ENERGY_SOURCE_WIND,
}
# Create Location Policy Group
@@ -35,10 +38,10 @@ async def policies_section(admin_client, data_client):
st.warning("⚠️ Please enter a policy group name")
else:
try:
- response = await admin_client.create_location_policy_group(
- dp.CreateLocationPolicyGroupRequest(name=new_policy_group_name)
+ response = await admin_client.CreateLocationPolicyGroup(
+ messages_pb2.CreateLocationPolicyGroupRequest(name=new_policy_group_name)
)
- response_dict = response.to_dict()
+ response_dict = MessageToDict(response)
st.success(f"✅ Policy Group '{new_policy_group_name}' created!")
st.write(response_dict)
except GRPCError as e:
@@ -61,12 +64,12 @@ async def policies_section(admin_client, data_client):
st.warning("⚠️ Please enter a policy group name")
else:
try:
- response = await admin_client.get_location_policy_group(
- dp.GetLocationPolicyGroupRequest(
+ response = await admin_client.GetLocationPolicyGroup(
+ messages_pb2.GetLocationPolicyGroupRequest(
location_policy_group_name=policy_group_name
)
)
- response_dict = response.to_dict()
+ response_dict = MessageToDict(response)
st.success(f"✅ Found policy group: {policy_group_name}")
st.write(response_dict)
@@ -88,8 +91,8 @@ async def policies_section(admin_client, data_client):
if data_client:
try:
- response = await data_client.list_locations(dp.ListLocationsRequest())
- response_dict = response.to_dict()
+ response = await data_client.ListLocations(data_messages_pb2.ListLocationsRequest())
+ response_dict = MessageToDict(response)
locations = response_dict.get("locations", [])
except Exception as e:
st.error(f"❌ Failed to fetch locations: {e}")
@@ -124,11 +127,11 @@ async def policies_section(admin_client, data_client):
st.warning("⚠️ Please fill in all required fields")
else:
try:
- await admin_client.add_location_policies_to_group(
- dp.AddLocationPoliciesToGroupRequest(
+ await admin_client.AddLocationPoliciesToGroup(
+ messages_pb2.AddLocationPoliciesToGroupRequest(
location_policy_group_name=add_policy_group,
location_policies=[
- dp.LocationPolicy(
+ messages_pb2.LocationPolicy(
location_id=add_location_id,
energy_source=ENERGY_SOURCES[add_energy_source],
permission=PERMISSIONS[add_permission],
@@ -157,8 +160,8 @@ async def policies_section(admin_client, data_client):
if data_client:
try:
- response = await data_client.list_locations(dp.ListLocationsRequest())
- response_dict = response.to_dict()
+ response = await data_client.ListLocations(data_messages_pb2.ListLocationsRequest())
+ response_dict = MessageToDict(response)
locations = response_dict.get("locations", [])
except Exception as e:
st.error(f"❌ Failed to fetch locations: {e}")
@@ -194,11 +197,11 @@ async def policies_section(admin_client, data_client):
st.warning("⚠️ Please fill in all required fields")
else:
try:
- await admin_client.remove_location_policies_from_group(
- dp.RemoveLocationPoliciesFromGroupRequest(
+ await admin_client.RemoveLocationPoliciesFromGroup(
+ messages_pb2.RemoveLocationPoliciesFromGroupRequest(
location_policy_group_name=remove_policy_group,
location_policies=[
- dp.LocationPolicy(
+ messages_pb2.LocationPolicy(
location_id=remove_location_id,
energy_source=ENERGY_SOURCES[remove_energy_source],
permission=PERMISSIONS[remove_permission],
@@ -228,8 +231,8 @@ async def policies_section(admin_client, data_client):
st.warning("⚠️ Please fill in all fields")
else:
try:
- await admin_client.add_location_policy_group_to_organisation(
- dp.AddLocationPolicyGroupToOrganisationRequest(
+ await admin_client.AddLocationPolicyGroupToOrganisation(
+ messages_pb2.AddLocationPolicyGroupToOrganisationRequest(
org_name=add_pg_org, location_policy_group_name=add_pg_name
)
)
@@ -264,8 +267,8 @@ async def policies_section(admin_client, data_client):
st.error("❌ Could not connect to Data Platform")
else:
try:
- await admin_client.remove_location_policy_group_from_organisation(
- dp.RemoveLocationPolicyGroupFromOrganisationRequest(
+ await admin_client.RemoveLocationPolicyGroupFromOrganisation(
+ messages_pb2.RemoveLocationPolicyGroupFromOrganisationRequest(
org_name=remove_policy_group_org,
location_policy_group_name=remove_policy_group_name,
)
diff --git a/src/dataplatform/toolbox/user_organisation.py b/src/dataplatform/toolbox/user_organisation.py
index cd9c8eae..f8ec9249 100644
--- a/src/dataplatform/toolbox/user_organisation.py
+++ b/src/dataplatform/toolbox/user_organisation.py
@@ -1,7 +1,7 @@
"""User-Organisation relationship management section for the Data Platform Toolbox."""
import streamlit as st
-from ocf import dp
+from ocf.dp.dp_admin import messages_pb2
from grpclib.exceptions import GRPCError
@@ -22,8 +22,8 @@ async def user_organisation_section(admin_client):
st.warning("⚠️ Please fill in all fields")
else:
try:
- await admin_client.add_user_to_organisation(
- dp.AddUserToOrganisationRequest(
+ await admin_client.AddUserToOrganisation(
+ messages_pb2.AddUserToOrganisationRequest(
org_name=add_org, user_oauth_id=add_user_oauth
)
)
@@ -52,8 +52,8 @@ async def user_organisation_section(admin_client):
st.warning("⚠️ Please fill in all fields")
else:
try:
- await admin_client.remove_user_from_organisation(
- dp.RemoveUserFromOrganisationRequest(
+ await admin_client.RemoveUserFromOrganisation(
+ messages_pb2.RemoveUserFromOrganisationRequest(
org_name=remove_org, user_oauth_id=remove_user_oauth
)
)
diff --git a/src/dataplatform/toolbox/users.py b/src/dataplatform/toolbox/users.py
index 211087b6..731f5a79 100644
--- a/src/dataplatform/toolbox/users.py
+++ b/src/dataplatform/toolbox/users.py
@@ -2,9 +2,9 @@
import streamlit as st
import json
-from ocf import dp
+from ocf.dp.dp_admin import messages_pb2
from grpclib.exceptions import GRPCError
-
+from google.protobuf.json_format import MessageToDict
async def users_section(admin_client):
"""User management section."""
@@ -22,10 +22,10 @@ async def users_section(admin_client):
st.warning("⚠️ Please enter an OAuth ID")
else:
try:
- response = await admin_client.get_user(
- dp.GetUserRequest(oauth_id=oauth_id)
+ response = await admin_client.GetUser(
+ messages_pb2.GetUserRequest(oauth_id=oauth_id)
)
- response_dict = response.to_dict()
+ response_dict = MessageToDict(response)
st.success(f"✅ Found user: {oauth_id}")
st.write(response_dict)
@@ -65,14 +65,14 @@ async def users_section(admin_client):
metadata = (
json.loads(user_metadata) if user_metadata.strip() else {}
)
- response = await admin_client.create_user(
- dp.CreateUserRequest(
+ response = await admin_client.CreateUser(
+ messages_pb2.CreateUserRequest(
oauth_id=new_oauth_id,
organisation=user_org,
metadata=metadata,
)
)
- response_dict = response.to_dict()
+ response_dict = MessageToDict(response)
st.success(
f"✅ User '{new_oauth_id}' created in organisation '{user_org}'!"
)
@@ -113,8 +113,8 @@ async def users_section(admin_client):
else:
try:
# admin_client.DeleteUser({"user_id": del_user_id})
- await admin_client.delete_user(
- dp.DeleteUserRequest(user_id=del_user_id)
+ await admin_client.DeleteUser(
+ messages_pb2.DeleteUserRequest(user_id=del_user_id)
)
st.success(f"✅ User '{del_user_id}' deleted successfully!")
diff --git a/src/plots/elexon_plots.py b/src/plots/elexon_plots.py
index ea6edc65..b67cd1db 100644
--- a/src/plots/elexon_plots.py
+++ b/src/plots/elexon_plots.py
@@ -5,6 +5,7 @@
import streamlit as st
from elexonpy.api_client import ApiClient
from elexonpy.api.generation_forecast_api import GenerationForecastApi
+from google.protobuf.json_format import MessageToDict
def add_elexon_plot(
@@ -105,7 +106,11 @@ def fetch_forecast_data(
if not response.data:
return pd.DataFrame()
- df = pd.DataFrame([item.to_dict() for item in response.data])
+ df = pd.DataFrame([MessageToDict(
+ item,
+ preserving_proto_field_name=True,
+ always_print_fields_with_no_presence=True
+ ) for item in response.data])
solar_df = df[df["business_type"] == "Solar generation"]
solar_df["start_time"] = pd.to_datetime(solar_df["start_time"])
solar_df = solar_df.set_index("start_time")
diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py
index dc8d0889..0ca602e3 100644
--- a/tests/integration/conftest.py
+++ b/tests/integration/conftest.py
@@ -9,8 +9,10 @@
from importlib.metadata import version
import os
from streamlit.testing.v1 import AppTest
-from ocf import dp
-from grpclib.client import Channel
+from ocf.dp.dp import common_pb2
+from ocf.dp.dp_data import messages_pb2, service_pb2_grpc
+from ocf.dp.dp_admin import messages_pb2 as dp_admin_messages_pb2, service_pb2_grpc as dp_admin_service_pb2_grpc
+import grpc.aio
DATA_PLATFORM_GRPC_PORT = 50051
DATA_PLATFORM_STARTUP_TIMEOUT_SECONDS = 60
@@ -64,19 +66,19 @@ async def dp_channel():
os.environ["DATA_PLATFORM_HOST"] = host
os.environ["DATA_PLATFORM_PORT"] = str(port)
- channel = Channel(host=host, port=port)
+ channel = grpc.aio.insecure_channel(f"{host}:{port}")
yield channel
- channel.close()
+ await channel.close()
@pytest_asyncio.fixture(scope="session")
async def admin_client(dp_channel):
- return dp.DataPlatformAdministrationServiceStub(dp_channel)
+ return dp_admin_service_pb2_grpc.DataPlatformAdministrationServiceStub(dp_channel)
@pytest_asyncio.fixture(scope="session")
async def data_client(dp_channel):
- return dp.DataPlatformDataServiceStub(dp_channel)
+ return service_pb2_grpc.DataPlatformDataServiceStub(dp_channel)
@pytest.fixture
@@ -103,41 +105,41 @@ def random_policy_name():
async def create_org_grpc(admin_client, org_name: str):
- await admin_client.create_organisation(
- dp.CreateOrganisationRequest(org_name=org_name, metadata={})
+ await admin_client.CreateOrganisation(
+ dp_admin_messages_pb2.CreateOrganisationRequest(org_name=org_name, metadata={})
)
async def get_org_grpc(admin_client, org_name: str):
- return await admin_client.get_organisation(
- dp.GetOrganisationRequest(org_name=org_name)
+ return await admin_client.GetOrganisation(
+ dp_admin_messages_pb2.GetOrganisationRequest(org_name=org_name)
)
async def create_user_grpc(admin_client, user_oauth_id: str, org_name: str):
- await admin_client.create_user(
- dp.CreateUserRequest(oauth_id=user_oauth_id, organisation=org_name, metadata={})
+ await admin_client.CreateUser(
+ dp_admin_messages_pb2.CreateUserRequest(oauth_id=user_oauth_id, organisation=org_name, metadata={})
)
async def get_user_grpc(admin_client, user_oauth_id: str):
- return await admin_client.get_user(dp.GetUserRequest(oauth_id=user_oauth_id))
+ return await admin_client.GetUser(dp_admin_messages_pb2.GetUserRequest(oauth_id=user_oauth_id))
async def add_user_to_org_grpc(admin_client, user_oauth_id: str, org_name: str):
- return await admin_client.add_user_to_organisation(
- dp.AddUserToOrganisationRequest(org_name=org_name, user_oauth_id=user_oauth_id)
+ return await admin_client.AddUserToOrganisation(
+ dp_admin_messages_pb2.AddUserToOrganisationRequest(org_name=org_name, user_oauth_id=user_oauth_id)
)
async def create_location_grpc(
data_client,
location_name: str,
- energy_source=dp.EnergySource.SOLAR,
- location_type=dp.LocationType.SITE,
+ energy_source=common_pb2.EnergySource.ENERGY_SOURCE_SOLAR,
+ location_type=common_pb2.LocationType.LOCATION_TYPE_SITE,
):
- return await data_client.create_location(
- dp.CreateLocationRequest(
+ return await data_client.CreateLocation(
+ messages_pb2.CreateLocationRequest(
location_name=location_name,
energy_source=energy_source,
geometry_wkt="POINT(0 0)",
@@ -149,18 +151,18 @@ async def create_location_grpc(
async def list_locations_grpc(data_client):
- return await data_client.list_locations(dp.ListLocationsRequest())
+ return await data_client.ListLocations(messages_pb2.ListLocationsRequest())
async def create_policy_group_grpc(admin_client, policy_name: str):
- await admin_client.create_location_policy_group(
- dp.CreateLocationPolicyGroupRequest(name=policy_name)
+ await admin_client.CreateLocationPolicyGroup(
+ dp_admin_messages_pb2.CreateLocationPolicyGroupRequest(name=policy_name)
)
async def get_policy_group_grpc(admin_client, policy_name: str):
- return await admin_client.get_location_policy_group(
- dp.GetLocationPolicyGroupRequest(location_policy_group_name=policy_name)
+ return await admin_client.GetLocationPolicyGroup(
+ dp_admin_messages_pb2.GetLocationPolicyGroupRequest(location_policy_group_name=policy_name)
)
@@ -168,14 +170,14 @@ async def add_policy_to_group_grpc(
admin_client,
policy_name: str,
location_uuid: str,
- energy_source=dp.EnergySource.WIND,
- permission=dp.Permission.WRITE,
+ energy_source=common_pb2.EnergySource.ENERGY_SOURCE_WIND,
+ permission=common_pb2.Permission.PERMISSION_WRITE,
):
- await admin_client.add_location_policies_to_group(
- dp.AddLocationPoliciesToGroupRequest(
+ await admin_client.AddLocationPoliciesToGroup(
+ dp_admin_messages_pb2.AddLocationPoliciesToGroupRequest(
location_policy_group_name=policy_name,
location_policies=[
- dp.LocationPolicy(
+ dp_admin_messages_pb2.LocationPolicy(
location_id=location_uuid,
energy_source=energy_source,
permission=permission,
@@ -186,8 +188,8 @@ async def add_policy_to_group_grpc(
async def add_policy_to_org_grpc(admin_client, org_name, policy_name):
- await admin_client.add_location_policy_group_to_organisation(
- dp.AddLocationPolicyGroupToOrganisationRequest(
+ await admin_client.AddLocationPolicyGroupToOrganisation(
+ dp_admin_messages_pb2.AddLocationPolicyGroupToOrganisationRequest(
org_name=org_name,
location_policy_group_name=policy_name,
)
diff --git a/tests/integration/test_locations_ui.py b/tests/integration/test_locations_ui.py
index 1a78e4c0..8805dfaf 100644
--- a/tests/integration/test_locations_ui.py
+++ b/tests/integration/test_locations_ui.py
@@ -6,7 +6,8 @@
"""
import pytest
-from ocf import dp
+
+from ocf.dp.dp_data import service_pb2_grpc
from tests.integration.conftest import (
create_location_grpc,
@@ -17,7 +18,7 @@
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
-async def test_list_locations_ui(app, data_client: dp.DataPlatformDataServiceStub):
+async def test_list_locations_ui(app, data_client: service_pb2_grpc.DataPlatformDataServiceStub):
"""
- create some locations via grpc
- fill in list locations form and submit
@@ -51,7 +52,7 @@ async def test_list_locations_ui(app, data_client: dp.DataPlatformDataServiceStu
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
-async def test_get_location_ui(app, data_client: dp.DataPlatformDataServiceStub):
+async def test_get_location_ui(app, data_client: service_pb2_grpc.DataPlatformDataServiceStub):
"""
- create a location via grpc
- fill in get location form and submit
@@ -74,7 +75,7 @@ async def test_get_location_ui(app, data_client: dp.DataPlatformDataServiceStub)
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
-async def test_create_location_ui(app, data_client: dp.DataPlatformDataServiceStub):
+async def test_create_location_ui(app, data_client: service_pb2_grpc.DataPlatformDataServiceStub):
"""
- fill in create location form and submit
- assert success message
@@ -92,6 +93,9 @@ async def test_create_location_ui(app, data_client: dp.DataPlatformDataServiceSt
app.button("create_location_button").click()
app.run()
+ # assert no errors in app
+ assert len(app.error) == 0, f"Expected no errors, but got: {[e.value for e in app.error]}"
+
# Assert success message in UI
assert any("created" in s.value.lower() for s in app.success)
diff --git a/tests/integration/test_organisations_ui.py b/tests/integration/test_organisations_ui.py
index 3d68b3b2..2588ccd7 100644
--- a/tests/integration/test_organisations_ui.py
+++ b/tests/integration/test_organisations_ui.py
@@ -6,7 +6,7 @@
"""
import pytest
-from ocf import dp
+from ocf.dp.dp_admin import messages_pb2, service_pb2_grpc
from tests.integration.conftest import create_org_grpc, random_org_name
@@ -14,7 +14,7 @@
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
async def test_create_organisation_ui(
- app, admin_client: dp.DataPlatformAdministrationServiceStub
+ app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub
):
"""
- create random org name
@@ -37,8 +37,8 @@ async def test_create_organisation_ui(
# Assert success
assert any("created" in s.value.lower() for s in app.success)
- response = await admin_client.get_organisation(
- dp.GetOrganisationRequest(org_name=org_name)
+ response = await admin_client.GetOrganisation(
+ messages_pb2.GetOrganisationRequest(org_name=org_name)
)
assert response.org_name == org_name
@@ -46,7 +46,7 @@ async def test_create_organisation_ui(
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
async def test_get_organisation_ui(
- app, admin_client: dp.DataPlatformAdministrationServiceStub
+ app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub
):
"""
- create random org name via grpc
@@ -66,7 +66,7 @@ async def test_get_organisation_ui(
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
async def test_delete_organisation_ui(
- app, admin_client: dp.DataPlatformAdministrationServiceStub
+ app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub
):
"""
- create random org name via grpc
@@ -87,6 +87,6 @@ async def test_delete_organisation_ui(
# verify deletion via grpc
with pytest.raises(Exception):
- await admin_client.get_organisation(
- dp.GetOrganisationRequest(org_name=org_name)
+ await admin_client.GetOrganisation(
+ messages_pb2.GetOrganisationRequest(org_name=org_name)
)
diff --git a/tests/integration/test_policy_ui.py b/tests/integration/test_policy_ui.py
index 44a51b89..2e018ddd 100644
--- a/tests/integration/test_policy_ui.py
+++ b/tests/integration/test_policy_ui.py
@@ -9,7 +9,8 @@
"""
import pytest
-from ocf import dp
+from ocf.dp.dp_admin import messages_pb2, service_pb2_grpc
+from ocf.dp.dp_data import messages_pb2 as data_messages_pb2, service_pb2_grpc as data_service_pb2_grpc
from tests.integration.conftest import (
add_policy_to_group_grpc,
@@ -28,7 +29,7 @@
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
async def test_create_policy_ui(
- app, admin_client: dp.DataPlatformAdministrationServiceStub
+ app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub
):
"""
- fill in create policy group form and submit
@@ -56,7 +57,7 @@ async def test_create_policy_ui(
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
async def test_get_policy_ui(
- app, admin_client: dp.DataPlatformAdministrationServiceStub
+ app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub
):
"""
- create a policy group via grpc
@@ -77,8 +78,8 @@ async def test_get_policy_ui(
@pytest.mark.asyncio(loop_scope="session")
async def test_add_policy_to_group(
app,
- admin_client: dp.DataPlatformAdministrationServiceStub,
- data_client: dp.DataPlatformDataServiceStub,
+ admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub,
+ data_client: data_service_pb2_grpc.DataPlatformDataServiceStub,
):
"""
- create a policy group via grpc
@@ -108,8 +109,8 @@ async def test_add_policy_to_group(
@pytest.mark.asyncio(loop_scope="session")
async def test_remove_policy_from_group(
app,
- admin_client: dp.DataPlatformAdministrationServiceStub,
- data_client: dp.DataPlatformDataServiceStub,
+ admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub,
+ data_client: data_service_pb2_grpc.DataPlatformDataServiceStub,
):
"""
- create a policy group via grpc
@@ -145,7 +146,7 @@ async def test_remove_policy_from_group(
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
async def test_add_policy_to_org(
- app, admin_client: dp.DataPlatformAdministrationServiceStub
+ app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub
):
"""
- create a policy group via grpc
@@ -173,7 +174,7 @@ async def test_add_policy_to_org(
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
async def test_remove_policy_from_org(
- app, admin_client: dp.DataPlatformAdministrationServiceStub
+ app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub
):
"""
- create a policy group via grpc
diff --git a/tests/integration/test_user_org_ui.py b/tests/integration/test_user_org_ui.py
index cea30eb8..58d502d4 100644
--- a/tests/integration/test_user_org_ui.py
+++ b/tests/integration/test_user_org_ui.py
@@ -5,7 +5,7 @@
"""
import pytest
-from ocf import dp
+from ocf.dp.dp_admin import service_pb2_grpc
from tests.integration.conftest import (
@@ -21,7 +21,7 @@
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
async def test_add_user_org_ui(
- app, admin_client: dp.DataPlatformAdministrationServiceStub
+ app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub
):
"""
- create two orgs and a user in one org
@@ -47,7 +47,9 @@ async def test_add_user_org_ui(
app.button("add_user_to_org_button").click()
app.run()
+
# Assert success
+ assert len(app.error) == 0, f"Expected no errors, but got: {[s.value for s in app.error]}"
assert any("added" in s.value.lower() for s in app.success)
user = await get_user_grpc(admin_client, user_id)
@@ -62,7 +64,7 @@ async def test_add_user_org_ui(
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
async def test_remove_user_org_ui(
- app, admin_client: dp.DataPlatformAdministrationServiceStub
+ app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub
):
"""
- create two orgs and a user in both orgs
diff --git a/tests/integration/test_users_ui.py b/tests/integration/test_users_ui.py
index 347346b2..2bb762a2 100644
--- a/tests/integration/test_users_ui.py
+++ b/tests/integration/test_users_ui.py
@@ -6,7 +6,7 @@
"""
import pytest
-from ocf import dp
+from ocf.dp.dp_admin import service_pb2_grpc
from tests.integration.conftest import (
create_org_grpc,
@@ -20,7 +20,7 @@
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
async def test_create_user_ui(
- app, admin_client: dp.DataPlatformAdministrationServiceStub
+ app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub
):
"""
- create an org for the user
@@ -47,6 +47,7 @@ async def test_create_user_ui(
app.run()
# Assert success
+ assert len(app.error) == 0, f"Expected no errors, but got: {[e.value for e in app.error]}"
assert any("created" in s.value.lower() for s in app.success)
response = await get_user_grpc(admin_client, user_oauth_id)
@@ -55,7 +56,7 @@ async def test_create_user_ui(
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
-async def test_get_user_ui(app, admin_client: dp.DataPlatformAdministrationServiceStub):
+async def test_get_user_ui(app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub):
"""
- create random user via grpc
- fill in get user form and submit
@@ -70,13 +71,14 @@ async def test_get_user_ui(app, admin_client: dp.DataPlatformAdministrationServi
app.button("get_user_button").click()
app.run()
+ assert len(app.error) == 0, f"Expected no errors, but got: {[e.value for e in app.error]}"
assert any(user_oauth_id in s.value for s in app.success)
@pytest.mark.integration
@pytest.mark.asyncio(loop_scope="session")
async def test_delete_user_ui(
- app, admin_client: dp.DataPlatformAdministrationServiceStub
+ app, admin_client: service_pb2_grpc.DataPlatformAdministrationServiceStub
):
"""
- create random user via grpc
diff --git a/tests/test_elexon_plot.py b/tests/test_elexon_plot.py
index 2b94df0e..d2c59467 100644
--- a/tests/test_elexon_plot.py
+++ b/tests/test_elexon_plot.py
@@ -74,6 +74,7 @@ def test_add_elexon_plot_no_data(mock_fetch):
assert len(updated_fig.data) == 0, "Figure should have no traces added if no data is available"
@pytest.mark.integration
+@pytest.mark.skip(reason="Elexonpy not currently being used")
def test_fetch_forecast_data_integration():
# Initialize the actual API client and the function to be tested
api_client = ApiClient()
@@ -92,4 +93,4 @@ def test_fetch_forecast_data_integration():
assert not result.empty, "DataFrame should not be empty"
assert "start_time" in result.columns, "DataFrame should contain 'start_time' column"
assert "quantity" in result.columns, "DataFrame should contain 'quantity' column"
- assert result["quantity"].notna().all(), "Quantity values should not be NaN"
\ No newline at end of file
+ assert result["quantity"].notna().all(), "Quantity values should not be NaN"