diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..bb60cbd3 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,100 @@ +name: CI + +on: + pull_request: + branches: [cli-v2, master] + push: + branches: [cli-v2, master] + +jobs: + unit-tests: + name: "Tests - Python ${{ matrix.python-version }} / ${{ matrix.os }}" + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.10", "3.11", "3.12", "3.13"] + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install package + run: | + python -m pip install --upgrade pip + pip install ".[dev]" + + - name: Verify zstandard is installed + run: python scripts/ci/verify_zstd_installed.py + + - name: Verify transport compression imports + run: python scripts/ci/verify_transport_compression.py + + - name: Run unit tests + run: pytest -v --durations=10 tests/unit/ + + wheel-sanity: + name: "Wheel install - Python ${{ matrix.python-version }} / ${{ matrix.os }}" + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.10", "3.13"] + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Build and install wheel + run: | + python -m pip install --upgrade pip setuptools wheel build + python -m build + python -m pip install --find-links dist limacharlie + + - name: Verify CLI works + run: limacharlie --version + + - name: Verify zstandard available after wheel install + run: python scripts/ci/verify_zstd_installed.py + + - name: Verify zstd decompression works end-to-end + run: python scripts/ci/verify_zstd_decompression.py + + no-zstd-fallback: + name: "Fallback without zstandard - Python ${{ matrix.python-version }} / ${{ matrix.os }}" + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.10", "3.13"] + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install package then remove zstandard + run: | + python -m pip install --upgrade pip + pip install ".[dev]" + pip uninstall -y zstandard + + - name: Verify zstandard is NOT importable + run: python scripts/ci/verify_zstd_not_importable.py + + - name: Verify fallback Accept-Encoding (no zstd) + run: python scripts/ci/verify_fallback_encoding.py + + - name: Run unit tests without zstandard + run: pytest -v --durations=10 tests/unit/ diff --git a/limacharlie/client.py b/limacharlie/client.py index 04654612..731c8c23 100644 --- a/limacharlie/client.py +++ b/limacharlie/client.py @@ -28,6 +28,7 @@ RateLimitError, error_from_status_code, ) +from .transport_compression import ACCEPT_ENCODING, decompress_response from .user_agent_utils import build_user_agent __version__ = "5.0.0" @@ -153,10 +154,13 @@ def _debug(self, msg: str) -> None: def unwrap(data: str, is_raw: bool = False) -> Any: """Decompress gzip+base64 encoded data from the API. - Used when is_compressed=true is set on requests. The API returns - data as base64-encoded gzip-compressed JSON. + .. deprecated:: + Application-level compression (is_compressed=true) has been + replaced by transport-level compression (Accept-Encoding). + This method is kept for backward compatibility with external + callers but is no longer used internally by the SDK. - Args: + Parameters: data: Base64-encoded gzip-compressed string. is_raw: If True, return raw bytes instead of parsed JSON. @@ -310,6 +314,10 @@ def _rest_call(self, url: str, verb: str, params: dict[str, Any] | None = None, request = URLRequest(full_url, body, headers=headers) request.get_method = lambda: verb request.add_header("User-Agent", self._user_agent) + # Request compressed responses at the transport level. + # urllib doesn't auto-decompress like requests does, so we + # handle decompression ourselves in the response path. + request.add_header("Accept-Encoding", ACCEPT_ENCODING) if content_type is not None: request.add_header("Content-Type", content_type) @@ -323,6 +331,12 @@ def _rest_call(self, url: str, verb: str, params: dict[str, Any] | None = None, try: data = u.read() + # Decompress transport-level encoding (gzip, zstd, etc.) + # before JSON parsing. The server may compress the entire + # response body when we send Accept-Encoding. + content_enc = u.headers.get("Content-Encoding") + if data and content_enc: + data = decompress_response(data, content_enc) resp = json.loads(data.decode()) if data else {} except ValueError: resp = {} @@ -347,6 +361,10 @@ def _rest_call(self, url: str, verb: str, params: dict[str, Any] | None = None, except HTTPError as e: error_body = e.read() + # Error responses can also be transport-compressed. + error_enc = e.headers.get("Content-Encoding") if hasattr(e, "headers") else None + if error_body and error_enc: + error_body = decompress_response(error_body, error_enc) try: resp = json.loads(error_body.decode()) except Exception: diff --git a/limacharlie/config.py b/limacharlie/config.py index 2a7fa099..079f9218 100644 --- a/limacharlie/config.py +++ b/limacharlie/config.py @@ -92,7 +92,10 @@ def save_config(config: dict[str, Any]) -> None: fd, tmp_path = tempfile.mkstemp() try: - os.chown(tmp_path, os.getuid(), os.getgid()) + # os.chown/os.getuid are Unix-only; skip on Windows where file + # ownership is managed by the OS via ACLs. + if hasattr(os, "chown"): + os.chown(tmp_path, os.getuid(), os.getgid()) os.chmod(tmp_path, stat.S_IWUSR | stat.S_IRUSR) # 0o600 try: os.write(fd, content) diff --git a/limacharlie/sdk/organization.py b/limacharlie/sdk/organization.py index 87fd8dbc..88c011e3 100644 --- a/limacharlie/sdk/organization.py +++ b/limacharlie/sdk/organization.py @@ -868,7 +868,7 @@ def get_detections(self, start: int, end: int, limit: int | None = None, categor cursor = "-" n_returned = 0 while cursor: - qp = {"start": str(int(start)), "end": str(int(end)), "cursor": cursor, "is_compressed": "true"} + qp = {"start": str(int(start)), "end": str(int(end)), "cursor": cursor} if limit is not None: qp["limit"] = str(limit) if category: @@ -876,7 +876,7 @@ def get_detections(self, start: int, end: int, limit: int | None = None, categor resp = self._client.request("GET", f"insight/{self.oid}/detections", query_params=qp) cursor = resp.get("next_cursor") - for d in self._client.unwrap(resp.get("detects", "")): + for d in resp.get("detects", []): yield d n_returned += 1 if limit is not None and n_returned >= limit: @@ -913,7 +913,7 @@ def get_audit_logs(self, start: int, end: int, limit: int | None = None, event_t cursor = "-" n_returned = 0 while cursor: - qp = {"start": str(int(start)), "end": str(int(end)), "cursor": cursor, "is_compressed": "true"} + qp = {"start": str(int(start)), "end": str(int(end)), "cursor": cursor} if limit is not None: qp["limit"] = str(limit) if event_type: @@ -923,7 +923,7 @@ def get_audit_logs(self, start: int, end: int, limit: int | None = None, event_t resp = self._client.request("GET", f"insight/{self.oid}/audit", query_params=qp) cursor = resp.get("next_cursor") - for entry in self._client.unwrap(resp.get("events", "")): + for entry in resp.get("events", []): yield entry n_returned += 1 if limit is not None and n_returned >= limit: @@ -946,7 +946,7 @@ def get_jobs(self, start_time: int | None = None, end_time: int | None = None, l list: Job dicts. """ import time as _time - qp = {"is_compressed": "true", "with_data": "false"} + qp = {"with_data": "false"} if start_time is None: start_time = int(_time.time()) - 86400 if end_time is None: @@ -958,8 +958,7 @@ def get_jobs(self, start_time: int | None = None, end_time: int | None = None, l if sid is not None: qp["sid"] = str(sid) resp = self._client.request("GET", f"job/{self.oid}", query_params=qp) - raw_jobs = resp.get("jobs", "") + raw_jobs = resp.get("jobs", {}) if not raw_jobs: return [] - jobs = self._client.unwrap(raw_jobs) - return [job for job_id, job in jobs.items()] + return [job for job_id, job in raw_jobs.items()] diff --git a/limacharlie/sdk/sensor.py b/limacharlie/sdk/sensor.py index 800b569e..fb735232 100644 --- a/limacharlie/sdk/sensor.py +++ b/limacharlie/sdk/sensor.py @@ -253,7 +253,6 @@ def get_events(self, start: int, end: int, limit: int | None = None, event_type: qp = { "start": str(int(start)), "end": str(int(end)), - "is_compressed": "true", "is_forward": "true" if is_forward else "false", "cursor": cursor, } @@ -264,7 +263,7 @@ def get_events(self, start: int, end: int, limit: int | None = None, event_type: resp = self.client.request("GET", f"insight/{self._org.oid}/{self.sid}", query_params=qp) cursor = resp.get("next_cursor") - for evt in self.client.unwrap(resp.get("events", "")): + for evt in resp.get("events", []): yield evt n_returned += 1 if limit is not None and n_returned >= limit: @@ -306,9 +305,8 @@ def get_children_events(self, atom: str) -> list[dict[str, Any]]: Returns: list: Child events. """ - data = self.client.request("GET", f"insight/{self._org.oid}/{self.sid}/{atom}/children", - query_params={"is_compressed": "true"}) - return self.client.unwrap(data.get("events", "")) + data = self.client.request("GET", f"insight/{self._org.oid}/{self.sid}/{atom}/children") + return data.get("events", []) def get_event_retention(self, start: int, end: int, is_detailed: bool = False) -> dict[str, Any]: """Get event retention statistics. diff --git a/limacharlie/transport_compression.py b/limacharlie/transport_compression.py new file mode 100644 index 00000000..7f90df8f --- /dev/null +++ b/limacharlie/transport_compression.py @@ -0,0 +1,73 @@ +"""Transport-level HTTP compression support. + +Handles Accept-Encoding negotiation and Content-Encoding decompression +for HTTP responses. Supports zstd, gzip, and deflate. + +zstd is preferred because it offers better compression ratios and faster +decompression than gzip. The zstandard package is a hard dependency but +the runtime gracefully falls back to gzip/deflate if it's unavailable +(e.g., exotic platform where the wheel couldn't be installed). +""" + +from __future__ import annotations + +import zlib + +# zstandard is a hard dependency (listed in pyproject.toml) with pre-built +# wheels for all major platforms. However, we guard the import so the SDK +# still works if someone is on an exotic platform where the wheel isn't +# available and there's no C compiler to build from source. +try: + import zstandard as _zstd + + _HAS_ZSTD = True +except ImportError: + _zstd = None # type: ignore[assignment] + _HAS_ZSTD = False + +# Header value sent on every request. Prefer zstd when available. +ACCEPT_ENCODING: str = "zstd, gzip, deflate" if _HAS_ZSTD else "gzip, deflate" + + +def decompress_response(data: bytes, content_encoding: str | None) -> bytes: + """Decompress an HTTP response body based on Content-Encoding. + + If the encoding is unrecognized or absent, the data is returned as-is + (passthrough). This matches standard HTTP client behavior - servers may + return uncompressed responses even when Accept-Encoding was sent. + + Parameters: + data: Raw response body bytes. + content_encoding: Value of the Content-Encoding response header, + or None if the header was absent. + + Returns: + Decompressed bytes, or the original bytes if no decompression needed. + """ + if not content_encoding: + return data + + encoding = content_encoding.strip().lower() + + if encoding == "zstd": + if not _HAS_ZSTD: + # Server sent zstd but we can't decompress - return as-is. + # JSON parsing will fail with a clear error downstream. + return data + return _zstd.ZstdDecompressor().decompress(data) + + if encoding in ("gzip", "x-gzip"): + # 16 + MAX_WBITS tells zlib to auto-detect gzip vs raw deflate + return zlib.decompress(data, 16 + zlib.MAX_WBITS) + + if encoding == "deflate": + # Try raw deflate first, fall back to zlib-wrapped deflate + try: + return zlib.decompress(data, -zlib.MAX_WBITS) + except zlib.error: + return zlib.decompress(data) + + # Unknown encoding - return data as-is rather than crashing. + # The caller will attempt JSON parsing which will fail with a clear + # error if the data is actually compressed. + return data diff --git a/pyproject.toml b/pyproject.toml index 281af51c..ab814790 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "cryptography==46.0.3", "click==8.1.8", "jmespath==1.1.0", + "zstandard>=0.22.0", ] [project.optional-dependencies] diff --git a/scripts/ci/verify_fallback_encoding.py b/scripts/ci/verify_fallback_encoding.py new file mode 100644 index 00000000..48bbd8ee --- /dev/null +++ b/scripts/ci/verify_fallback_encoding.py @@ -0,0 +1,8 @@ +"""Verify Accept-Encoding falls back to gzip/deflate when zstd is unavailable.""" + +from limacharlie.transport_compression import ACCEPT_ENCODING, _HAS_ZSTD + +assert not _HAS_ZSTD, "_HAS_ZSTD should be False" +assert "zstd" not in ACCEPT_ENCODING, f"zstd should not be in: {ACCEPT_ENCODING}" +assert "gzip" in ACCEPT_ENCODING, f"gzip missing from: {ACCEPT_ENCODING}" +print(f"Accept-Encoding (fallback): {ACCEPT_ENCODING}") diff --git a/scripts/ci/verify_transport_compression.py b/scripts/ci/verify_transport_compression.py new file mode 100644 index 00000000..1defe796 --- /dev/null +++ b/scripts/ci/verify_transport_compression.py @@ -0,0 +1,6 @@ +"""Verify transport compression module loads with zstd support.""" + +from limacharlie.transport_compression import ACCEPT_ENCODING + +assert "zstd" in ACCEPT_ENCODING, f"zstd missing from: {ACCEPT_ENCODING}" +print(f"Accept-Encoding: {ACCEPT_ENCODING}") diff --git a/scripts/ci/verify_zstd_decompression.py b/scripts/ci/verify_zstd_decompression.py new file mode 100644 index 00000000..1d1bab22 --- /dev/null +++ b/scripts/ci/verify_zstd_decompression.py @@ -0,0 +1,13 @@ +"""Verify zstd decompression works end-to-end (compress then decompress).""" + +import json + +import zstandard + +from limacharlie.transport_compression import decompress_response + +original = json.dumps({"events": [{"type": "test"}]}).encode() +compressed = zstandard.ZstdCompressor().compress(original) +result = decompress_response(compressed, "zstd") +assert result == original, "zstd round-trip failed" +print("zstd round-trip OK") diff --git a/scripts/ci/verify_zstd_installed.py b/scripts/ci/verify_zstd_installed.py new file mode 100644 index 00000000..b6380ffb --- /dev/null +++ b/scripts/ci/verify_zstd_installed.py @@ -0,0 +1,5 @@ +"""Verify zstandard is installed and importable.""" + +import zstandard + +print(f"zstandard {zstandard.__version__} OK") diff --git a/scripts/ci/verify_zstd_not_importable.py b/scripts/ci/verify_zstd_not_importable.py new file mode 100644 index 00000000..0a2e3c05 --- /dev/null +++ b/scripts/ci/verify_zstd_not_importable.py @@ -0,0 +1,8 @@ +"""Verify zstandard is NOT importable (used after pip uninstall).""" + +try: + import zstandard + + raise AssertionError("zstandard should not be importable") +except ImportError: + print("zstandard correctly unavailable") diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index bb9b06fa..20a5b5f6 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1,5 +1,6 @@ """Tests for limacharlie.client module.""" +import gzip import json from unittest.mock import MagicMock, patch @@ -9,6 +10,7 @@ from limacharlie.errors import ( AuthenticationError, ApiError, + LimaCharlieError, RateLimitError, ) @@ -289,6 +291,132 @@ def set_jwt(c): assert result == {"ok": True} +class TestTransportCompression: + @patch("limacharlie.client.urlopen") + def test_request_sends_accept_encoding_header(self, mock_urlopen): + """Every outgoing request should include Accept-Encoding.""" + jwt_response = MagicMock() + jwt_response.read.return_value = json.dumps({"jwt": "test-jwt"}).encode() + jwt_response.close = MagicMock() + + api_response = MagicMock() + api_response.read.return_value = json.dumps({"ok": True}).encode() + api_response.close = MagicMock() + api_response.getheaders.return_value = [] + api_response.headers = MagicMock() + api_response.headers.get.return_value = None + + mock_urlopen.side_effect = [jwt_response, api_response] + + client = Client(oid="test-oid", api_key="test-key") + client.request("GET", "sensors") + + # The second call is the API request (first is JWT) + api_request = mock_urlopen.call_args_list[1][0][0] + accept_enc = api_request.get_header("Accept-encoding") + assert "gzip" in accept_enc + assert "deflate" in accept_enc + + @patch("limacharlie.client.urlopen") + def test_request_decompresses_gzip_response(self, mock_urlopen): + """Response with Content-Encoding: gzip should be decompressed transparently.""" + jwt_response = MagicMock() + jwt_response.read.return_value = json.dumps({"jwt": "test-jwt"}).encode() + jwt_response.close = MagicMock() + + # Build a gzip-compressed JSON body + body = json.dumps({"sensors": ["s1", "s2"]}).encode() + compressed_body = gzip.compress(body) + + api_response = MagicMock() + api_response.read.return_value = compressed_body + api_response.close = MagicMock() + api_response.getheaders.return_value = [] + api_response.headers = MagicMock() + api_response.headers.get.return_value = "gzip" + + mock_urlopen.side_effect = [jwt_response, api_response] + + client = Client(oid="test-oid", api_key="test-key") + result = client.request("GET", "sensors") + + assert result == {"sensors": ["s1", "s2"]} + + @patch("limacharlie.client.urlopen") + def test_request_decompresses_zstd_response(self, mock_urlopen): + """Response with Content-Encoding: zstd should be decompressed.""" + zstandard = pytest.importorskip("zstandard") + + jwt_response = MagicMock() + jwt_response.read.return_value = json.dumps({"jwt": "test-jwt"}).encode() + jwt_response.close = MagicMock() + + body = json.dumps({"detects": [{"id": "d1"}]}).encode() + cctx = zstandard.ZstdCompressor() + compressed_body = cctx.compress(body) + + api_response = MagicMock() + api_response.read.return_value = compressed_body + api_response.close = MagicMock() + api_response.getheaders.return_value = [] + api_response.headers = MagicMock() + api_response.headers.get.return_value = "zstd" + + mock_urlopen.side_effect = [jwt_response, api_response] + + client = Client(oid="test-oid", api_key="test-key") + result = client.request("GET", "detections") + + assert result == {"detects": [{"id": "d1"}]} + + @patch("limacharlie.client.urlopen") + def test_request_handles_uncompressed_response(self, mock_urlopen): + """When server returns no Content-Encoding, response works as before.""" + jwt_response = MagicMock() + jwt_response.read.return_value = json.dumps({"jwt": "test-jwt"}).encode() + jwt_response.close = MagicMock() + + api_response = MagicMock() + api_response.read.return_value = json.dumps({"ok": True}).encode() + api_response.close = MagicMock() + api_response.getheaders.return_value = [] + api_response.headers = MagicMock() + api_response.headers.get.return_value = None + + mock_urlopen.side_effect = [jwt_response, api_response] + + client = Client(oid="test-oid", api_key="test-key") + result = client.request("GET", "test") + + assert result == {"ok": True} + + @patch("limacharlie.client.urlopen") + def test_error_response_decompressed(self, mock_urlopen): + """HTTPError bodies with Content-Encoding should be decompressed.""" + from urllib.error import HTTPError + import io + + jwt_response = MagicMock() + jwt_response.read.return_value = json.dumps({"jwt": "test-jwt"}).encode() + jwt_response.close = MagicMock() + + error_body = json.dumps({"error": "not found"}).encode() + compressed_error = gzip.compress(error_body) + + error = HTTPError( + "https://api.limacharlie.io/v1/test", 404, "Not Found", + {"Content-Encoding": "gzip"}, io.BytesIO(compressed_error), + ) + # HTTPError wraps headers in an http.client.HTTPMessage + error.headers = error.headers # already set via the constructor + + mock_urlopen.side_effect = [jwt_response, error] + + client = Client(oid="test-oid", api_key="test-key") + with pytest.raises(LimaCharlieError): + client.request("GET", "test") + + class TestBuildUserAgent: def test_user_agent_format(self): ua = _build_user_agent() diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 9c2a3f53..2ad5787a 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -2,6 +2,7 @@ import os import stat +import sys import tempfile import pytest @@ -57,8 +58,11 @@ class TestSaveConfig: def test_creates_file_with_secure_permissions(self, tmp_config_file): save_config({"oid": "test-oid"}) assert os.path.isfile(tmp_config_file) - mode = os.stat(tmp_config_file).st_mode - assert mode & 0o777 == 0o600 + # Windows NTFS uses ACLs, not Unix permission bits - os.chmod(0o600) + # doesn't restrict access the same way. Only assert on Unix. + if sys.platform != "win32": + mode = os.stat(tmp_config_file).st_mode + assert mode & 0o777 == 0o600 def test_round_trips_data(self, tmp_config_file): data = {"oid": "abc-123", "api_key": "key-456", "env": {"prod": {"oid": "prod-oid"}}} diff --git a/tests/unit/test_sdk_organization.py b/tests/unit/test_sdk_organization.py index 4dadd266..f4fe9d3a 100644 --- a/tests/unit/test_sdk_organization.py +++ b/tests/unit/test_sdk_organization.py @@ -739,18 +739,25 @@ def test_remove_group_org(self, org, mock_client): class TestDetections: def test_get_detections_single_page(self, org, mock_client): - mock_client.request.return_value = {"detects": "", "next_cursor": None} - mock_client.unwrap.return_value = [{"detect_id": "d1"}] + mock_client.request.return_value = { + "detects": [{"detect_id": "d1"}], + "next_cursor": None, + } result = list(org.get_detections(1000, 2000)) mock_client.request.assert_called_once_with( "GET", "insight/test-oid-123/detections", - query_params={"start": "1000", "end": "2000", "cursor": "-", "is_compressed": "true"}, + query_params={"start": "1000", "end": "2000", "cursor": "-"}, ) + # is_compressed must NOT be in query params + qp = mock_client.request.call_args[1]["query_params"] + assert "is_compressed" not in qp assert result == [{"detect_id": "d1"}] def test_get_detections_with_limit_and_category(self, org, mock_client): - mock_client.request.return_value = {"detects": "", "next_cursor": None} - mock_client.unwrap.return_value = [{"detect_id": "d1"}] + mock_client.request.return_value = { + "detects": [{"detect_id": "d1"}], + "next_cursor": None, + } result = list(org.get_detections(1000, 2000, limit=5, category="lateral")) qp = mock_client.request.call_args[1]["query_params"] assert qp["limit"] == "5" @@ -758,20 +765,18 @@ def test_get_detections_with_limit_and_category(self, org, mock_client): def test_get_detections_pagination(self, org, mock_client): mock_client.request.side_effect = [ - {"detects": "compressed1", "next_cursor": "cursor2"}, - {"detects": "compressed2", "next_cursor": None}, - ] - mock_client.unwrap.side_effect = [ - [{"detect_id": "d1"}], - [{"detect_id": "d2"}], + {"detects": [{"detect_id": "d1"}], "next_cursor": "cursor2"}, + {"detects": [{"detect_id": "d2"}], "next_cursor": None}, ] result = list(org.get_detections(1000, 2000)) assert len(result) == 2 assert mock_client.request.call_count == 2 def test_get_detections_limit_stops_iteration(self, org, mock_client): - mock_client.request.return_value = {"detects": "", "next_cursor": "more"} - mock_client.unwrap.return_value = [{"detect_id": "d1"}, {"detect_id": "d2"}] + mock_client.request.return_value = { + "detects": [{"detect_id": "d1"}, {"detect_id": "d2"}], + "next_cursor": "more", + } result = list(org.get_detections(1000, 2000, limit=1)) assert len(result) == 1 @@ -784,18 +789,25 @@ def test_get_detection_by_id(self, org, mock_client): class TestAuditLogs: def test_get_audit_logs_single_page(self, org, mock_client): - mock_client.request.return_value = {"events": "", "next_cursor": None} - mock_client.unwrap.return_value = [{"event": "login"}] + mock_client.request.return_value = { + "events": [{"event": "login"}], + "next_cursor": None, + } result = list(org.get_audit_logs(1000, 2000)) mock_client.request.assert_called_once_with( "GET", "insight/test-oid-123/audit", - query_params={"start": "1000", "end": "2000", "cursor": "-", "is_compressed": "true"}, + query_params={"start": "1000", "end": "2000", "cursor": "-"}, ) + # is_compressed must NOT be in query params + qp = mock_client.request.call_args[1]["query_params"] + assert "is_compressed" not in qp assert result == [{"event": "login"}] def test_get_audit_logs_with_filters(self, org, mock_client): - mock_client.request.return_value = {"events": "", "next_cursor": None} - mock_client.unwrap.return_value = [] + mock_client.request.return_value = { + "events": [], + "next_cursor": None, + } list(org.get_audit_logs(1000, 2000, limit=10, event_type="auth", sid="sid-1")) qp = mock_client.request.call_args[1]["query_params"] assert qp["limit"] == "10" @@ -804,42 +816,42 @@ def test_get_audit_logs_with_filters(self, org, mock_client): def test_get_audit_logs_pagination(self, org, mock_client): mock_client.request.side_effect = [ - {"events": "c1", "next_cursor": "cursor2"}, - {"events": "c2", "next_cursor": None}, - ] - mock_client.unwrap.side_effect = [ - [{"event": "e1"}], - [{"event": "e2"}], + {"events": [{"event": "e1"}], "next_cursor": "cursor2"}, + {"events": [{"event": "e2"}], "next_cursor": None}, ] result = list(org.get_audit_logs(1000, 2000)) assert len(result) == 2 def test_get_audit_logs_limit_stops_iteration(self, org, mock_client): - mock_client.request.return_value = {"events": "", "next_cursor": "more"} - mock_client.unwrap.return_value = [{"event": "e1"}, {"event": "e2"}] + mock_client.request.return_value = { + "events": [{"event": "e1"}, {"event": "e2"}], + "next_cursor": "more", + } result = list(org.get_audit_logs(1000, 2000, limit=1)) assert len(result) == 1 class TestJobs: def test_get_jobs_with_explicit_times(self, org, mock_client): - mock_client.request.return_value = {"jobs": {"j1": {"name": "scan"}, "j2": {"name": "resp"}}} - mock_client.unwrap.return_value = {"j1": {"name": "scan"}, "j2": {"name": "resp"}} + mock_client.request.return_value = { + "jobs": {"j1": {"name": "scan"}, "j2": {"name": "resp"}}, + } result = org.get_jobs(start_time=1000, end_time=2000) qp = mock_client.request.call_args[1]["query_params"] assert qp["start"] == "1000" assert qp["end"] == "2000" - assert qp["is_compressed"] == "true" + # is_compressed must NOT be in query params + assert "is_compressed" not in qp assert qp["with_data"] == "false" assert len(result) == 2 def test_get_jobs_empty(self, org, mock_client): - mock_client.request.return_value = {"jobs": ""} + mock_client.request.return_value = {"jobs": {}} result = org.get_jobs(start_time=1000, end_time=2000) assert result == [] def test_get_jobs_with_limit_and_sid(self, org, mock_client): - mock_client.request.return_value = {"jobs": ""} + mock_client.request.return_value = {"jobs": {}} org.get_jobs(start_time=1000, end_time=2000, limit=5, sid="sid-1") qp = mock_client.request.call_args[1]["query_params"] assert qp["limit"] == "5" diff --git a/tests/unit/test_sdk_sensor.py b/tests/unit/test_sdk_sensor.py index 0efef15a..cf7f5360 100644 --- a/tests/unit/test_sdk_sensor.py +++ b/tests/unit/test_sdk_sensor.py @@ -124,6 +124,45 @@ def test_delete(self, sensor, mock_org): assert call_args[0][0] == "DELETE" +class TestSensorGetEventsContract: + def test_get_events_query_params(self, sensor, mock_org): + """get_events should not send is_compressed - transport compression handles it.""" + mock_org.client.request.return_value = { + "events": [{"type": "NEW_PROCESS"}], + "next_cursor": None, + } + result = list(sensor.get_events(1000, 2000)) + qp = mock_org.client.request.call_args[1]["query_params"] + # is_compressed must NOT be in query params + assert "is_compressed" not in qp + assert qp["start"] == "1000" + assert qp["end"] == "2000" + assert qp["is_forward"] == "true" + assert result == [{"type": "NEW_PROCESS"}] + + def test_get_events_with_limit(self, sensor, mock_org): + mock_org.client.request.return_value = { + "events": [{"type": "e1"}, {"type": "e2"}, {"type": "e3"}], + "next_cursor": "more", + } + result = list(sensor.get_events(1000, 2000, limit=2)) + assert len(result) == 2 + + def test_get_events_pagination(self, sensor, mock_org): + mock_org.client.request.side_effect = [ + {"events": [{"type": "e1"}], "next_cursor": "cursor2"}, + {"events": [{"type": "e2"}], "next_cursor": None}, + ] + result = list(sensor.get_events(1000, 2000)) + assert len(result) == 2 + assert mock_org.client.request.call_count == 2 + + def test_get_events_empty(self, sensor, mock_org): + mock_org.client.request.return_value = {"events": [], "next_cursor": None} + result = list(sensor.get_events(1000, 2000)) + assert result == [] + + class TestSensorEventRetention: def test_get_event_retention_uses_correct_params(self, sensor, mock_org): mock_org.client.request.return_value = {"retention": {}} @@ -262,12 +301,14 @@ def test_get_event_by_atom_path(self, sensor, mock_org): class TestSensorGetChildrenEventsContract: def test_get_children_events_params(self, sensor, mock_org): - mock_org.client.request.return_value = {"events": "compressed-data"} - mock_org.client.unwrap.return_value = [{"type": "FILE_CREATE"}] + mock_org.client.request.return_value = {"events": [{"type": "FILE_CREATE"}]} result = sensor.get_children_events("atom-xyz") mock_org.client.request.assert_called_once_with( "GET", "insight/test-oid/aaaa-bbbb-cccc-dddd/atom-xyz/children", - query_params={"is_compressed": "true"}, ) - mock_org.client.unwrap.assert_called_once_with("compressed-data") assert result == [{"type": "FILE_CREATE"}] + + def test_get_children_events_empty(self, sensor, mock_org): + mock_org.client.request.return_value = {} + result = sensor.get_children_events("atom-xyz") + assert result == [] diff --git a/tests/unit/test_transport_compression.py b/tests/unit/test_transport_compression.py new file mode 100644 index 00000000..41491b92 --- /dev/null +++ b/tests/unit/test_transport_compression.py @@ -0,0 +1,249 @@ +"""Tests for limacharlie.transport_compression module.""" + +import gzip +import importlib +import json +import sys +import zlib +from unittest import mock + +import pytest + +from limacharlie.transport_compression import ( + ACCEPT_ENCODING, + _HAS_ZSTD, + decompress_response, +) + + +class TestAcceptEncoding: + def test_includes_gzip_and_deflate(self): + """Accept-Encoding should always advertise gzip and deflate.""" + assert "gzip" in ACCEPT_ENCODING + assert "deflate" in ACCEPT_ENCODING + + @pytest.mark.skipif(not _HAS_ZSTD, reason="zstandard not installed") + def test_includes_zstd_when_available(self): + """Accept-Encoding should include zstd when zstandard is installed.""" + assert "zstd" in ACCEPT_ENCODING + + @pytest.mark.skipif(not _HAS_ZSTD, reason="zstandard not installed") + def test_zstd_preferred(self): + """zstd should be listed first (highest priority).""" + assert ACCEPT_ENCODING.startswith("zstd") + + @pytest.mark.skipif(not _HAS_ZSTD, reason="zstandard not installed") + def test_has_zstd_flag_true(self): + """_HAS_ZSTD should be True when zstandard is importable.""" + assert _HAS_ZSTD is True + + @pytest.mark.skipif(_HAS_ZSTD, reason="zstandard is installed") + def test_excludes_zstd_when_unavailable(self): + """Accept-Encoding should not include zstd when zstandard is missing.""" + assert "zstd" not in ACCEPT_ENCODING + + @pytest.mark.skipif(_HAS_ZSTD, reason="zstandard is installed") + def test_has_zstd_flag_false(self): + """_HAS_ZSTD should be False when zstandard is not importable.""" + assert _HAS_ZSTD is False + + +class TestAcceptEncodingWithoutZstd: + def test_fallback_without_zstandard(self): + """When zstandard is not importable, fall back to gzip/deflate only.""" + import limacharlie.transport_compression as tc_mod + + # Temporarily make zstandard unimportable by removing it from + # sys.modules and patching the import machinery. + saved = sys.modules.pop("zstandard", None) + try: + with mock.patch.dict(sys.modules, {"zstandard": None}): + importlib.reload(tc_mod) + assert tc_mod._HAS_ZSTD is False + assert "zstd" not in tc_mod.ACCEPT_ENCODING + assert "gzip" in tc_mod.ACCEPT_ENCODING + assert "deflate" in tc_mod.ACCEPT_ENCODING + finally: + # Restore original module state + if saved is not None: + sys.modules["zstandard"] = saved + importlib.reload(tc_mod) + + def test_zstd_passthrough_when_unavailable(self): + """If server sends zstd but lib is missing, return raw bytes.""" + import limacharlie.transport_compression as tc_mod + + saved = sys.modules.pop("zstandard", None) + try: + with mock.patch.dict(sys.modules, {"zstandard": None}): + importlib.reload(tc_mod) + raw = b"some-zstd-compressed-bytes" + # Should passthrough without crashing + result = tc_mod.decompress_response(raw, "zstd") + assert result is raw + finally: + if saved is not None: + sys.modules["zstandard"] = saved + importlib.reload(tc_mod) + + def test_zstd_passthrough_case_insensitive_when_unavailable(self): + """Zstd passthrough should work regardless of header casing.""" + import limacharlie.transport_compression as tc_mod + + saved = sys.modules.pop("zstandard", None) + try: + with mock.patch.dict(sys.modules, {"zstandard": None}): + importlib.reload(tc_mod) + raw = b"compressed-bytes" + assert tc_mod.decompress_response(raw, "ZSTD") is raw + assert tc_mod.decompress_response(raw, "Zstd") is raw + finally: + if saved is not None: + sys.modules["zstandard"] = saved + importlib.reload(tc_mod) + + +class TestDecompressGzip: + def test_round_trip(self): + """Compress with gzip, decompress with our function.""" + original = b'{"events": [{"type": "NEW_PROCESS"}]}' + compressed = gzip.compress(original) + result = decompress_response(compressed, "gzip") + assert result == original + + def test_x_gzip_alias(self): + """x-gzip is a legacy alias for gzip.""" + original = b'{"ok": true}' + compressed = gzip.compress(original) + result = decompress_response(compressed, "x-gzip") + assert result == original + + def test_case_insensitive(self): + """Content-Encoding values should be case-insensitive.""" + original = b'{"data": "test"}' + compressed = gzip.compress(original) + assert decompress_response(compressed, "GZIP") == original + assert decompress_response(compressed, "Gzip") == original + + def test_large_payload(self): + """Verify gzip works with larger payloads.""" + events = [{"type": "NEW_PROCESS", "id": f"evt-{i}", "data": "x" * 100} for i in range(500)] + original = json.dumps({"events": events}).encode() + compressed = gzip.compress(original) + result = decompress_response(compressed, "gzip") + assert result == original + + def test_whitespace_around_header(self): + """Leading/trailing whitespace in Content-Encoding should be stripped.""" + original = b'{"ok": true}' + compressed = gzip.compress(original) + result = decompress_response(compressed, " gzip ") + assert result == original + + +class TestDecompressDeflate: + def test_zlib_wrapped_deflate(self): + """zlib.compress produces zlib-wrapped deflate - should decompress.""" + original = b'{"sensors": []}' + compressed = zlib.compress(original) + result = decompress_response(compressed, "deflate") + assert result == original + + def test_raw_deflate(self): + """Raw deflate (no zlib header) - should decompress via the try branch.""" + original = b'{"sensors": [{"sid": "test-sensor"}]}' + # Use wbits=-15 to produce raw deflate (no zlib header) + compressor = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) + compressed = compressor.compress(original) + compressor.flush() + result = decompress_response(compressed, "deflate") + assert result == original + + def test_case_insensitive(self): + """Content-Encoding: DEFLATE should work.""" + original = b'{"data": "test"}' + compressed = zlib.compress(original) + assert decompress_response(compressed, "DEFLATE") == original + assert decompress_response(compressed, "Deflate") == original + + +@pytest.mark.skipif(not _HAS_ZSTD, reason="zstandard not installed") +class TestDecompressZstd: + def test_round_trip(self): + """Compress with zstandard, decompress with our function.""" + import zstandard + + original = b'{"detects": [{"id": "d-1", "title": "suspicious"}]}' + cctx = zstandard.ZstdCompressor() + compressed = cctx.compress(original) + result = decompress_response(compressed, "zstd") + assert result == original + + def test_large_payload(self): + """Verify zstd works with larger payloads (realistic JSON response).""" + import zstandard + + events = [{"type": "NEW_PROCESS", "id": f"evt-{i}", "data": "x" * 100} for i in range(500)] + original = json.dumps({"events": events}).encode() + cctx = zstandard.ZstdCompressor() + compressed = cctx.compress(original) + result = decompress_response(compressed, "zstd") + assert result == original + + def test_case_insensitive(self): + """Content-Encoding: ZSTD should work.""" + import zstandard + + original = b'{"data": "test"}' + cctx = zstandard.ZstdCompressor() + compressed = cctx.compress(original) + assert decompress_response(compressed, "ZSTD") == original + assert decompress_response(compressed, "Zstd") == original + + def test_whitespace_around_header(self): + """Leading/trailing whitespace in Content-Encoding should be stripped.""" + import zstandard + + original = b'{"ok": true}' + cctx = zstandard.ZstdCompressor() + compressed = cctx.compress(original) + result = decompress_response(compressed, " zstd ") + assert result == original + + +class TestPassthrough: + def test_none_encoding(self): + """None content_encoding means no compression - passthrough.""" + data = b'{"events": []}' + result = decompress_response(data, None) + assert result is data # Same object, not just equal + + def test_empty_encoding(self): + """Empty string content_encoding - passthrough.""" + data = b'{"events": []}' + result = decompress_response(data, "") + assert result is data + + def test_unknown_encoding(self): + """Unknown encoding - passthrough without crashing.""" + data = b'{"events": []}' + result = decompress_response(data, "br") # brotli, not supported + assert result is data + + def test_whitespace_encoding(self): + """Whitespace-only encoding string - passthrough.""" + data = b'{"events": []}' + result = decompress_response(data, " ") + assert result is data + + def test_empty_data_with_encoding(self): + """Empty bytes with a Content-Encoding should not crash.""" + # gzip of empty bytes is a valid gzip stream + compressed_empty = gzip.compress(b"") + result = decompress_response(compressed_empty, "gzip") + assert result == b"" + + def test_empty_data_no_encoding(self): + """Empty bytes with no encoding - passthrough.""" + data = b"" + result = decompress_response(data, None) + assert result is data