Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,4 @@ jobs:
- uses: ./.github/workflows/install_deps

- name: Run formatter
run: poetry run black --check --diff .

- name: Run isort
run: poetry run isort --check-only --diff .

run: poetry run ruff format --check .
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ jobs:
- uses: ./.github/workflows/install_deps

- name: Run linter
run: poetry run pylint .
run: poetry run ruff check .
4 changes: 2 additions & 2 deletions armis_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from armis_sdk.core.armis_sdk import ArmisSdk # noqa: F401
from armis_sdk.core.client_credentials import ClientCredentials # noqa: F401
from armis_sdk.core.armis_sdk import ArmisSdk
from armis_sdk.core.client_credentials import ClientCredentials
74 changes: 35 additions & 39 deletions armis_sdk/clients/assets_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import datetime
from typing import AsyncIterator
from typing import Optional
from typing import Type
from typing import Type # noqa: UP035 # TODO: fix UP035 (deprecated import, use updated module)
from typing import TYPE_CHECKING
from typing import Union

import universalasync
Expand All @@ -18,9 +20,12 @@
from armis_sdk.types.asset_id_source import AssetIdSource


if TYPE_CHECKING:
from collections.abc import AsyncIterator


@universalasync.wrap
class AssetsClient(BaseEntityClient): # pylint: disable=too-few-public-methods
# pylint: disable=line-too-long
class AssetsClient(BaseEntityClient):
"""
A client for interacting with assets.

Expand All @@ -31,10 +36,10 @@ class AssetsClient(BaseEntityClient): # pylint: disable=too-few-public-methods

async def list_by_asset_id(
self,
asset_class: Type[AssetT],
asset_ids: Union[list[int], list[str]],
asset_class: type[AssetT],
asset_ids: list[int] | list[str],
asset_id_source: AssetIdSource = "ASSET_ID",
fields: Optional[list[str]] = None,
fields: list[str] | None = None,
) -> AsyncIterator[AssetT]:
"""List assets by asset ID or other identifiers.

Expand All @@ -54,6 +59,7 @@ async def list_by_asset_id(
from armis_sdk.clients.assets_client import AssetsClient
from armis_sdk.entities.device import Device


async def main():
assets_client = AssetsClient()

Expand All @@ -68,6 +74,7 @@ async def main():
async for device in assets_client.list_by_asset_id(Device, ipv4_addresses, asset_id_source="IPV4_ADDRESS"):
print(device)


asyncio.run(main())
```
"""
Expand All @@ -81,9 +88,9 @@ async def main():

async def list_by_last_seen(
self,
asset_class: Type[AssetT],
last_seen: Union[datetime.datetime, datetime.timedelta],
fields: Optional[list[str]] = None,
asset_class: type[AssetT],
last_seen: datetime.datetime | datetime.timedelta,
fields: list[str] | None = None,
) -> AsyncIterator[AssetT]:
"""List assets by last seen timestamp.

Expand All @@ -106,6 +113,7 @@ async def list_by_last_seen(
from armis_sdk.clients.assets_client import AssetsClient
from armis_sdk.entities.device import Device


async def main():
assets_client = AssetsClient()

Expand All @@ -117,10 +125,11 @@ async def main():
async for device in assets_client.list_by_last_seen(Device, datetime.datetime(2025, 12, 8)):
print(device)


asyncio.run(main())
```
"""
filter_: dict[str, Union[str, int]] = {"filter_criteria": "LAST_SEEN"}
filter_: dict[str, str | int] = {"filter_criteria": "LAST_SEEN"}

if isinstance(last_seen, datetime.datetime):
filter_["last_seen_ge"] = last_seen.isoformat()
Expand All @@ -132,9 +141,7 @@ async def main():
async for item in self._list_assets(asset_class, fields, filter_):
yield item

async def list_fields(
self, asset_class: Type[AssetT]
) -> AsyncIterator[AssetFieldDescription]:
async def list_fields(self, asset_class: type[AssetT]) -> AsyncIterator[AssetFieldDescription]:
"""List all available fields for a given asset class.

Args:
Expand All @@ -150,12 +157,14 @@ async def list_fields(
from armis_sdk.clients.assets_client import AssetsClient
from armis_sdk.entities.device import Device


async def main():
assets_client = AssetsClient()

async for field in assets_client.list_fields(Device):
print(f"{field.name}: {field.type}")


asyncio.run(main())
```
"""
Expand All @@ -174,7 +183,6 @@ async def update(
fields: list[str],
asset_id_source: AssetIdSource = "ASSET_ID",
) -> None:
# pylint: disable=line-too-long
"""Bulk update assets.

Args:
Expand Down Expand Up @@ -204,6 +212,7 @@ async def main():
# Update based on the explicit source "IPV4_ADDRESS"
await assets_client.update([device], ["custom.MyField"], asset_id_source="IPV4_ADDRESS")


asyncio.run(main())
```
"""
Expand Down Expand Up @@ -244,7 +253,7 @@ async def main():
def _create_bulk_update_request(
cls,
asset: Asset,
asset_id: Union[str, int],
asset_id: str | int,
field: str,
):
request = {"asset_id": asset_id, "key": field}
Expand All @@ -266,7 +275,7 @@ def _get_asset_id(
asset: Asset,
index: int,
asset_id_source: AssetIdSource,
) -> Union[str, int]:
) -> str | int:
if isinstance(asset, Device):
return cls._get_device_asset_id(asset, index, asset_id_source)

Expand All @@ -286,30 +295,22 @@ def _get_device_asset_id(

if asset_id_source == "MAC_ADDRESS":
if device.mac_addresses is None or len(device.mac_addresses) != 1:
raise ArmisError(
f"Device at index {index} doesn't have exactly one mac address"
)
raise ArmisError(f"Device at index {index} doesn't have exactly one mac address")
return device.mac_addresses[0]

if asset_id_source == "IPV4_ADDRESS":
if device.ipv4_addresses is None or len(device.ipv4_addresses) != 1:
raise ArmisError(
f"Device at index {index} doesn't have exactly one IPv4 address"
)
raise ArmisError(f"Device at index {index} doesn't have exactly one IPv4 address")
return device.ipv4_addresses[0]

if asset_id_source == "IPV6_ADDRESS":
if device.ipv6_addresses is None or len(device.ipv6_addresses) != 1:
raise ArmisError(
f"Device at index {index} doesn't have exactly one IPv6 address"
)
raise ArmisError(f"Device at index {index} doesn't have exactly one IPv6 address")
return device.ipv6_addresses[0]

if asset_id_source == "SERIAL_NUMBER":
if device.serial_numbers is None or len(device.serial_numbers) != 1:
raise ArmisError(
f"Device at index {index} doesn't have exactly one serial number"
)
raise ArmisError(f"Device at index {index} doesn't have exactly one serial number")
return device.serial_numbers[0]

raise ArmisError(f"Can't get {asset_id_source!r} of device at index {index}")
Expand All @@ -324,8 +325,8 @@ def _is_integration_field(cls, field: str) -> bool:

async def _list_assets(
self,
asset_class: Type[AssetT],
fields: Optional[list[str]],
asset_class: type[AssetT],
fields: list[str] | None,
filter_: dict,
) -> AsyncIterator[AssetT]:
fields = fields or sorted(asset_class.all_fields())
Expand All @@ -345,15 +346,12 @@ def _validate_asset_class(cls, assets: list[AssetT]):
asset_types = {type(asset) for asset in assets}
if len(asset_types) > 1:
asset_types_str = ", ".join(sorted(repr(at.__name__) for at in asset_types))
raise ArmisError(
"All assets must be of the same type, "
f"got {len(asset_types)} types: {asset_types_str}"
)
raise ArmisError(f"All assets must be of the same type, got {len(asset_types)} types: {asset_types_str}")

@classmethod
def _validate_fields(
cls,
asset_class: Type[AssetT],
asset_class: type[AssetT],
fields: list[str],
allow_model_members=True,
):
Expand All @@ -373,6 +371,4 @@ def _validate_fields(

if invalid_fields:
fields_str = ", ".join(map(repr, invalid_fields))
raise ArmisError(
f"The following fields are not supported with this operation: {fields_str}"
)
raise ArmisError(f"The following fields are not supported with this operation: {fields_str}")
24 changes: 13 additions & 11 deletions armis_sdk/clients/collectors_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import contextlib
from typing import IO
from typing import AsyncIterator
from typing import Generator
from typing import TYPE_CHECKING
from typing import Union

import httpx
Expand All @@ -14,9 +15,13 @@
from armis_sdk.types.collector_image_type import CollectorImageType


if TYPE_CHECKING:
from collections.abc import AsyncIterator
from collections.abc import Generator


@universalasync.wrap
class CollectorsClient(BaseEntityClient):
# pylint: disable=line-too-long
"""
A client for interacting with Armis collectors.

Expand All @@ -25,7 +30,7 @@ class CollectorsClient(BaseEntityClient):

async def download_image(
self,
destination: Union[str, IO[bytes]],
destination: str | IO[bytes],
image_type: CollectorImageType = "OVA",
) -> AsyncIterator[DownloadProgress]:
"""Download a collector image to a specified destination path / file.
Expand Down Expand Up @@ -56,6 +61,7 @@ async def main():
async for progress in armis_sdk.collectors.download_image(file):
print(progress.percent)


asyncio.run(main())
```
Will output:
Expand All @@ -71,7 +77,6 @@ async def main():
async with client.stream("GET", collector_image.url) as response:
response.raise_for_status()
total_size = int(response.headers.get("Content-Length", "0"))
# pylint: disable-next=contextmanager-generator-missing-cleanup
with self.open_file(destination) as file:
async for chunk in response.aiter_bytes():
file.write(chunk)
Expand All @@ -97,6 +102,7 @@ async def main():
collectors_client = CollectorsClient()
print(await collectors_client.get_image(image_type="OVA"))


asyncio.run(main())
```
Will output:
Expand All @@ -105,17 +111,13 @@ async def main():
```
"""
async with self._armis_client.client() as client:
response = await client.get(
"/v3/collectors/_image", params={"image_type": image_type}
)
response = await client.get("/v3/collectors/_image", params={"image_type": image_type})
data = response_utils.get_data_dict(response)
return CollectorImage.model_validate(data)

@classmethod
@contextlib.contextmanager
def open_file(
cls, destination: Union[str, IO[bytes]]
) -> Generator[IO[bytes], None, None]:
def open_file(cls, destination: str | IO[bytes]) -> Generator[IO[bytes], None, None]:
if isinstance(destination, str):
with open(destination, "wb") as file:
yield file
Expand Down
Loading
Loading