diff --git a/simvue/api/objects/alert/base.py b/simvue/api/objects/alert/base.py index 04fa03ad..5816845b 100644 --- a/simvue/api/objects/alert/base.py +++ b/simvue/api/objects/alert/base.py @@ -249,7 +249,9 @@ def get_status(self, run_id: str) -> typing.Literal["ok", "critical"]: ) _url: URL = self.url / f"status/{run_id}" - _response = sv_get(url=f"{_url}", headers=self._headers) + _response = sv_get( + url=f"{_url}", headers=self._headers, verify=self._user_config.server_verify + ) _json_response = get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK], diff --git a/simvue/api/objects/alert/fetch.py b/simvue/api/objects/alert/fetch.py index 7599fddb..6f5a65b2 100644 --- a/simvue/api/objects/alert/fetch.py +++ b/simvue/api/objects/alert/fetch.py @@ -181,6 +181,7 @@ def get( f"{_url}", headers=_config.headers, params=_params | kwargs, + verify=_config.server_verify, ) _label: str = cls.__name__.lower() diff --git a/simvue/api/objects/alert/user.py b/simvue/api/objects/alert/user.py index 97573a5a..e4787e74 100644 --- a/simvue/api/objects/alert/user.py +++ b/simvue/api/objects/alert/user.py @@ -159,6 +159,7 @@ def set_status(self, run_id: str, status: typing.Literal["ok", "critical"]) -> N url=self.url / "status" / run_id, data={"status": status}, headers=self._headers, + verify=self._user_config.server_verify, ) get_json_from_response( diff --git a/simvue/api/objects/artifact/base.py b/simvue/api/objects/artifact/base.py index db76e5ff..813fc096 100644 --- a/simvue/api/objects/artifact/base.py +++ b/simvue/api/objects/artifact/base.py @@ -8,6 +8,7 @@ import http import io import logging +import pathlib import typing import pydantic @@ -109,6 +110,7 @@ def attach_to_run(self, run_id: str, category: Category) -> None: url=f"{_run_artifacts_url}", headers=self._headers, json={"category": category}, + verify=self._user_config.server_verify, ) get_json_from_response( @@ -157,6 +159,7 @@ def _upload(self, file: io.BytesIO, timeout: int, file_size: int) -> None: params={}, is_json=False, timeout=timeout, + verify=self.storage_ca_cert, files={"file": file}, data=_fields, ) @@ -168,6 +171,7 @@ def _upload(self, file: io.BytesIO, timeout: int, file_size: int) -> None: headers={}, is_json=False, timeout=timeout, + verify=self.storage_ca_cert, data=file, ) @@ -203,6 +207,15 @@ def _get( **kwargs, ) + @property + def storage_ca_cert(self) -> str | bool: + """Return current storage CA certificate.""" + _ca_cert: pathlib.Path | bool = ( + self._user_config.server.certificates.storage_ca_cert + ) + + return f"{_ca_cert}" if isinstance(_ca_cert, pathlib.Path) else _ca_cert + @property def checksum(self) -> str: """Retrieve the checksum for this artifact. @@ -357,7 +370,9 @@ def get_category(self, run_id: str) -> Category: URL(self._user_config.server.url) / f"runs/{run_id}/artifacts/{self._identifier}" ) - _response = sv_get(url=_run_url, header=self._headers) + _response = sv_get( + url=_run_url, header=self._headers, verify=self._user_config.server_verify + ) _json_response = get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK, http.HTTPStatus.NOT_FOUND], @@ -406,6 +421,7 @@ def download_content(self) -> Generator[bytes]: _response = sv_get( f"{self.download_url}", timeout=_timeout, + verify=self.storage_ca_cert, headers=None, ) diff --git a/simvue/api/objects/artifact/fetch.py b/simvue/api/objects/artifact/fetch.py index 109a366c..7e37c933 100644 --- a/simvue/api/objects/artifact/fetch.py +++ b/simvue/api/objects/artifact/fetch.py @@ -142,6 +142,7 @@ def from_run( url=f"{_url}", params={"category": category}, headers=_config.headers, + verify=_config.server_verify, ) _json_response = get_json_from_response( expected_type=list, @@ -212,6 +213,7 @@ def from_name( url=f"{_url}", params={"name": name}, headers=_config.headers, + verify=_config.server_verify, ) _json_response = get_json_from_response( expected_type=list, @@ -296,6 +298,7 @@ def get( _url, headers=_config.headers, params=_params | kwargs, + verify=_config.server_verify, ) _label: str = cls.__name__.lower() _label = _label.replace("base", "") diff --git a/simvue/api/objects/base.py b/simvue/api/objects/base.py index 6a7533f1..457ba1ce 100644 --- a/simvue/api/objects/base.py +++ b/simvue/api/objects/base.py @@ -720,6 +720,7 @@ def _post_batch( headers=self._headers | {"Content-Type": "application/msgpack"}, params=self._params or {}, data=batch_data, + verify=self._user_config.server_verify, is_json=True, ) @@ -767,6 +768,7 @@ def _post_single( params=self._params or {}, data=data or kwargs, is_json=is_json, + verify=self._user_config.server_verify, ) if _response.status_code == http.HTTPStatus.FORBIDDEN: @@ -809,6 +811,7 @@ def _put(self, **kwargs) -> dict[str, typing.Any]: headers=self._headers, data=kwargs, is_json=True, + verify=self._user_config.server_verify, ) if _response.status_code == http.HTTPStatus.FORBIDDEN: @@ -843,7 +846,12 @@ def delete(self, **kwargs) -> dict[str, typing.Any]: if not self.url: raise RuntimeError(f"Identifier for instance of {self.label()} Unknown") - _response = sv_delete(url=f"{self.url}", headers=self._headers, params=kwargs) + _response = sv_delete( + url=f"{self.url}", + headers=self._headers, + params=kwargs, + verify=self._user_config.server_verify, + ) _json_response = get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK, http.HTTPStatus.NO_CONTENT], @@ -870,6 +878,7 @@ def _get( url=f"{url or self.url}", headers=self._headers, params=kwargs, + verify=self._user_config.server_verify, ) if _response.status_code == http.HTTPStatus.NOT_FOUND: diff --git a/simvue/api/objects/events.py b/simvue/api/objects/events.py index 6b6b543b..4e31afac 100644 --- a/simvue/api/objects/events.py +++ b/simvue/api/objects/events.py @@ -217,6 +217,7 @@ def histogram( _response = sv_get( url=_url, headers=self._headers, + verify=self._user_config.server_verify, params={ "run": self._run_id, "window": window, diff --git a/simvue/api/objects/folder.py b/simvue/api/objects/folder.py index 4ae6b126..8147696b 100644 --- a/simvue/api/objects/folder.py +++ b/simvue/api/objects/folder.py @@ -314,6 +314,7 @@ def _set_favourite(self, *, starred: bool) -> dict: f"{_url}", headers=self._user_config.headers, data={"starred": starred}, + verify=self._user_config.server_verify, ) return get_json_from_response( expected_status=[http.HTTPStatus.OK], diff --git a/simvue/api/objects/grids.py b/simvue/api/objects/grids.py index 208b8bcb..46eb2f01 100644 --- a/simvue/api/objects/grids.py +++ b/simvue/api/objects/grids.py @@ -123,6 +123,7 @@ def attach_metric_for_run(self, run_id: str, metric_name: str) -> None: _response = sv_put( url=f"{self.run_data_url(run_id)}", headers=self._headers, + verify=self._user_config.server_verify, json={"metric": metric_name}, ) @@ -279,6 +280,7 @@ def get_run_metric_values( _response = sv_get( url=f"{self.run_metric_url(run_id, metric_name) / 'values'}", headers=self._headers, + verify=self._user_config.server_verify, params={"step": step}, ) @@ -311,6 +313,7 @@ def get_run_metric_span(self, *, run_id: str, metric_name: str) -> dict: """ _response = sv_get( url=f"{self.run_metric_url(run_id, metric_name) / 'span'}", + verify=self._user_config.server_verify, headers=self._headers, ) @@ -529,6 +532,7 @@ def _log_values(self, metrics: list[GridMetricSet]) -> None: url=f"{self._user_config.server.url}/{self.run_grids_endpoint(self._run_id)}", headers=self._headers | {"Content-Type": "application/msgpack"}, data=msgpack.packb(metrics, use_bin_type=True), + verify=self._user_config.server_verify, is_json=False, params={}, ) diff --git a/simvue/api/objects/metrics.py b/simvue/api/objects/metrics.py index 0fedbc31..a1fb3ce2 100644 --- a/simvue/api/objects/metrics.py +++ b/simvue/api/objects/metrics.py @@ -168,9 +168,14 @@ def get( @pydantic.validate_call def span(self, run_ids: list[str]) -> dict[str, int | float]: - """Returns the metrics span for the given runs.""" - _url = self.base_url / "span" - _response = sv_get(url=f"{_url}", headers=self._headers, json=run_ids) + """Returns the metrics span for the given runs""" + _url = self._base_url / "span" + _response = sv_get( + url=f"{_url}", + headers=self._headers, + json=run_ids, + verify=self._user_config.server_verify, + ) return get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK], @@ -185,6 +190,7 @@ def names(self, run_ids: list[str]) -> list[str]: url=f"{_url}", headers=self._headers, params={"runs": json.dumps(run_ids)}, + verify=self._user_config.server_verify, ) return get_json_from_response( response=_response, diff --git a/simvue/api/objects/run.py b/simvue/api/objects/run.py index 2ef5eac9..e0a369c1 100644 --- a/simvue/api/objects/run.py +++ b/simvue/api/objects/run.py @@ -638,8 +638,9 @@ def _set_favourite(self, *, starred: bool) -> dict: _url = self.url / "starred" _response = sv_put( f"{_url}", - headers=self.user_config.headers, + headers=self._user_config.headers, data={"starred": starred}, + verify=self._user_config.server_verify, ) return get_json_from_response( expected_status=[http.HTTPStatus.OK], @@ -730,7 +731,12 @@ def send_heartbeat(self) -> dict[str, typing.Any] | None: _url = self.base_url _url /= f"{self._identifier}/heartbeat" - _response = sv_put(f"{_url}", headers=self._headers, data={}) + _response = sv_put( + f"{_url}", + headers=self._headers, + data={}, + verify=self._user_config.server_verify, + ) return get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK], @@ -770,7 +776,11 @@ def abort_trigger(self) -> bool: if self._offline or not self._identifier: return False - _response = sv_get(f"{self._abort_url}", headers=self._headers) + _response = sv_get( + f"{self._abort_url}", + headers=self._headers, + verify=self._user_config.server_verify, + ) _json_response = get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK], @@ -791,7 +801,11 @@ def artifacts(self) -> list[dict[str, typing.Any]]: if self._offline or not self._artifact_url: return [] - _response = sv_get(url=self._artifact_url, headers=self._headers) + _response = sv_get( + url=self._artifact_url, + headers=self._headers, + verify=self._user_config.server_verify, + ) return get_json_from_response( response=_response, @@ -813,7 +827,11 @@ def grids(self) -> list[dict[str, str]]: if self._offline or not self._grid_url: return [] - _response = sv_get(url=self._grid_url, headers=self._headers) + _response = sv_get( + url=self._grid_url, + headers=self._headers, + verify=self._user_config.server_verify, + ) return get_json_from_response( response=_response, @@ -849,6 +867,7 @@ def abort(self, reason: str) -> dict[str, typing.Any]: f"{self._abort_url}", headers=self._headers, data={"reason": reason}, + verify=self._user_config.server_verify, ) return get_json_from_response( diff --git a/simvue/api/objects/stats.py b/simvue/api/objects/stats.py index de4099f2..9cf96dd5 100644 --- a/simvue/api/objects/stats.py +++ b/simvue/api/objects/stats.py @@ -148,7 +148,9 @@ def whoami(self) -> dict[str, str]: """ _url: URL = URL(self._user_config.server.url) / "whoami" - _response = sv_get(url=f"{_url}", headers=self._headers) + _response = sv_get( + url=f"{_url}", headers=self._headers, verify=self._user_config.server_verify + ) return get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK], diff --git a/simvue/api/objects/storage/fetch.py b/simvue/api/objects/storage/fetch.py index 006228b6..5a4b221f 100644 --- a/simvue/api/objects/storage/fetch.py +++ b/simvue/api/objects/storage/fetch.py @@ -122,6 +122,7 @@ def get( _url, headers=_class_instance.user_config.headers, params={"start": offset, "count": count} | kwargs, + verify=_class_instance._user_config.server_verify, ) _label: str = _class_instance.__class__.__name__.lower() _label = _label.replace("base", "") diff --git a/simvue/api/request.py b/simvue/api/request.py index 668021c3..567cb66e 100644 --- a/simvue/api/request.py +++ b/simvue/api/request.py @@ -66,6 +66,7 @@ def post( data: typing.Any, is_json: bool = True, timeout: int | None = None, + verify: str | bool = True, files: dict[str, typing.Any] | None = None, ) -> requests.Response: """HTTP POST with retries. @@ -106,6 +107,7 @@ def post( data=data_sent, timeout=timeout, files=files, + verify=verify, ) if response.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY: @@ -142,6 +144,7 @@ def put( data: dict[str, typing.Any] | None = None, json: dict[str, typing.Any] | None = None, is_json: bool = True, + verify: bool | str = True, timeout: int = DEFAULT_API_TIMEOUT, ) -> requests.Response: """HTTP PUT with retries. @@ -181,6 +184,7 @@ def put( data=data_sent, timeout=timeout, json=json, + verify=verify, ) if response.status_code in RETRY_STATUSES: @@ -208,6 +212,7 @@ def get( headers: dict[str, str] | None = None, params: dict[str, str | int | float | None] | None = None, timeout: int = DEFAULT_API_TIMEOUT, + verify: str | bool = True, json: dict[str, typing.Any] | None = None, ) -> requests.Response: """HTTP GET. @@ -238,6 +243,7 @@ def get( timeout=timeout, params=params, json=json, + verify=verify ) if response.status_code in RETRY_STATUSES: @@ -264,6 +270,7 @@ def delete( url: str, headers: dict[str, str], timeout: int = DEFAULT_API_TIMEOUT, + verify: str | bool = True, params: dict[str, typing.Any] | None = None, ) -> requests.Response: """HTTP DELETE. @@ -286,7 +293,7 @@ def delete( """ logger.debug("DELETE: %s\n\tparams=%s", url, params) - response = requests.delete(url, headers=headers, timeout=timeout, params=params) + response = requests.delete(url, headers=headers, timeout=timeout, params=params, verify=verify) if response.status_code in RETRY_STATUSES: raise RetryableHTTPError( diff --git a/simvue/config/parameters.py b/simvue/config/parameters.py index 2407410c..92dbd36c 100644 --- a/simvue/config/parameters.py +++ b/simvue/config/parameters.py @@ -19,6 +19,29 @@ logger = logging.getLogger(__name__) +class CertificateSpecifications(pydantic.BaseModel): + storage_ca_cert: pydantic.FilePath | bool = True + server_ca_cert: pydantic.FilePath | bool = True + client_cert: pydantic.FilePath | None = None + client_key: pydantic.SecretStr | None = None + + @pydantic.model_validator(mode="before") + @classmethod + def check_for_cert_env( + cls, values: dict[str, pathlib.Path | None | str] + ) -> dict[str, pathlib.Path | None | str]: + """Check for CA certificate for storage specification in environment.""" + if ( + _env_ca_cert := os.environ.get("SIMVUE_STORAGE_CA_CERTIFICATE") + ) is not None: + values["storage_ca_cert"] = _env_ca_cert + if (_env_ca_cert := os.environ.get("SIMVUE_SERVER_CA_CERTIFICATE")) is not None: + values["server_ca_cert"] = _env_ca_cert + if _env_client_cert := os.environ.get("SIMVUE_SERVER_CLIENT_CERTIFICATE"): + values["client_cert"] = _env_ca_cert + return values + + class ServerSpecifications(pydantic.BaseModel): model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict( extra="forbid", @@ -27,6 +50,9 @@ class ServerSpecifications(pydantic.BaseModel): url: pydantic.AnyHttpUrl | None token: pydantic.SecretStr | None env: dict[str, str] | None = None + certificates: CertificateSpecifications = pydantic.Field( + default_factory=CertificateSpecifications + ) @pydantic.field_validator("url") @classmethod @@ -74,7 +100,7 @@ class DefaultRunSpecifications(pydantic.BaseModel): name: str | None = None description: str | None = None tags: list[str] | None = None - folder: str = pydantic.Field("/", pattern=sv_models.FOLDER_REGEX) + folder: str = pydantic.Field(default="/", pattern=sv_models.FOLDER_REGEX) metadata: dict[str, str | int | float | bool] | None = None mode: typing.Literal["offline", "disabled", "online"] = "online" record_shell_vars: list[str] | None = None diff --git a/simvue/config/user.py b/simvue/config/user.py index 6c53647c..22305ff1 100644 --- a/simvue/config/user.py +++ b/simvue/config/user.py @@ -71,6 +71,12 @@ class SimvueConfiguration(pydantic.BaseModel): current_profile: str | None = None _server_version: semver.Version | None = None + @property + def server_verify(self) -> str | bool: + """Return current server CA certificate.""" + _ca_cert: pathlib.Path | bool = self.server.certificates.server_ca_cert + return f"{_ca_cert}" if isinstance(_ca_cert, pathlib.Path) else _ca_cert + @property def server_version(self) -> semver.Version: """Retrieve current Server version.""" @@ -118,6 +124,7 @@ def _check_server( cls, token: str, url: str, + verify: str | bool, mode: typing.Literal["offline", "online", "disabled"], ) -> semver.Version | None: if mode in {"offline", "disabled"}: @@ -129,7 +136,7 @@ def _check_server( } try: _url = URL(url) / "version" - _response = sv_get(f"{_url}", headers) + _response = sv_get(f"{_url}", headers, verify=verify) if _response.status_code == http.HTTPStatus.UNAUTHORIZED: raise AssertionError("Unauthorised token") @@ -176,9 +183,10 @@ def check_valid_server(self) -> Self: raise ValueError("No token provided.") self._server_version = self._check_server( - self.server.token.get_secret_value(), - self.server.url, - self.run.mode, + token=self.server.token.get_secret_value(), + url=self.server.url, + verify=self.server_verify, + mode=self.run.mode, ) return self diff --git a/simvue/sender/actions.py b/simvue/sender/actions.py index 53cd4f66..7513f943 100644 --- a/simvue/sender/actions.py +++ b/simvue/sender/actions.py @@ -1005,6 +1005,7 @@ def _single_item_upload( _response: requests.Response = sv_put( url=f"{_local_config.server.url}/runs/{_online_id}/heartbeat", headers=_local_config.headers, + verify=_local_config.server_verify, ) try: