diff --git a/armis_sdk/clients/assets_client.py b/armis_sdk/clients/assets_client.py index f3ec9f2..f9a0d92 100644 --- a/armis_sdk/clients/assets_client.py +++ b/armis_sdk/clients/assets_client.py @@ -501,7 +501,7 @@ async def _list_assets( "fields": fields, "filter": filter_, } - async for item in self._armis_client.list("/v3/assets/_search", body=body): + async for item in self._armis_client.list("/v3/assets/_search", body=body, after_location="filter"): yield asset_class.from_search_result(item) @classmethod diff --git a/armis_sdk/core/armis_client.py b/armis_sdk/core/armis_client.py index 287b26c..9d1d4de 100644 --- a/armis_sdk/core/armis_client.py +++ b/armis_sdk/core/armis_client.py @@ -89,12 +89,13 @@ def client(self, retries: int | None = None, backoff: float | None = None): trust_env=True, ) - async def list(self, url: str, body: dict | None = None) -> AsyncIterator[dict]: + async def list(self, url: str, body: dict | None = None, after_location: str | None = None) -> AsyncIterator[dict]: """List all items from a paginated endpoint. Args: url (str): The relative endpoint URL. body (dict): Payload to send as POST request. + after_location (str): The nested object location to use for pagination. Returns: An (async) iterator of `dict`s. @@ -133,7 +134,10 @@ async def main(): for item in items: yield item if next_ := data.get("next"): - params["after"] = next_ + if after_location: + params[after_location]["after"] = next_ + else: + params["after"] = next_ else: break diff --git a/pyproject.toml b/pyproject.toml index f0071fa..5f48443 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "armis_sdk" -version = "1.2.0" +version = "1.2.1" description = "The Armis SDK is a package that encapsulates common use-cases for interacting with the Armis platform." authors = [ { name = "Shai Lachmanovich", email = "shai@armis.com" }, diff --git a/tests/armis_sdk/clients/assets_client_test.py b/tests/armis_sdk/clients/assets_client_test.py index 0048cdc..640af4a 100644 --- a/tests/armis_sdk/clients/assets_client_test.py +++ b/tests/armis_sdk/clients/assets_client_test.py @@ -493,6 +493,37 @@ async def test_update_with_validation_errors(assets, fields, expected_error): await assets_client.update(assets, fields) +async def test_list_assets_pagination(monkeypatch, httpx_mock: pytest_httpx.HTTPXMock): + monkeypatch.setenv("ARMIS_PAGE_SIZE", "1") + httpx_mock.add_response( + url="https://api.armis.com/v3/assets/_search", + method="POST", + match_json={ + "limit": 1, + "asset_type": "DEVICE", + "fields": assets_test_data.ALL_DEVICE_FIELDS, + "filter": {"filter_criteria": "LAST_SEEN", "last_seen_seconds": 3600}, + }, + json={"next": 2, "items": [{"asset_id": 1, "fields": assets_test_data.MOCK_DEVICE_FULL_RAW_DATA}]}, + ) + httpx_mock.add_response( + url="https://api.armis.com/v3/assets/_search", + method="POST", + match_json={ + "limit": 1, + "asset_type": "DEVICE", + "fields": assets_test_data.ALL_DEVICE_FIELDS, + "filter": {"filter_criteria": "LAST_SEEN", "last_seen_seconds": 3600, "after": 2}, + }, + json={"next": None, "items": [{"asset_id": 2, "fields": assets_test_data.MOCK_DEVICE_FULL_RAW_DATA}]}, + ) + + assets_client = AssetsClient() + devices = [device async for device in assets_client.list_by_last_seen(Device, datetime.timedelta(hours=1))] + + assert len(devices) == 2 + + async def test_list_fields(httpx_mock: pytest_httpx.HTTPXMock): httpx_mock.add_response( url="https://api.armis.com/v3/assets/_search/fields?asset_type=DEVICE",