From 91aaddc1888474ffcfe6fd4cd6605c9c1c7259e5 Mon Sep 17 00:00:00 2001 From: Tanya Verma Date: Wed, 22 Apr 2026 23:05:39 -0700 Subject: [PATCH] Add constructor timeout support --- src/tinfoil/__init__.py | 22 +++++- tests/test_constructor_timeout.py | 113 ++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 3 deletions(-) create mode 100644 tests/test_constructor_timeout.py diff --git a/src/tinfoil/__init__.py b/src/tinfoil/__init__.py index 29ded91..81a500d 100644 --- a/src/tinfoil/__init__.py +++ b/src/tinfoil/__init__.py @@ -1,5 +1,5 @@ from typing import Optional -from openai import OpenAI, AsyncOpenAI +from openai import OpenAI, AsyncOpenAI, NOT_GIVEN, NotGiven from openai.resources.chat import Chat as OpenAIChat from openai.resources.embeddings import Embeddings as OpenAIEmbeddings from openai.resources.audio import Audio as OpenAIAudio @@ -14,7 +14,14 @@ class TinfoilAI: api_key: str enclave: str - def __init__(self, enclave: str = "", repo: str = "tinfoilsh/confidential-model-router", api_key: str = "tinfoil", measurement: Optional[dict] = None): + def __init__( + self, + enclave: str = "", + repo: str = "tinfoilsh/confidential-model-router", + api_key: str = "tinfoil", + measurement: Optional[dict] = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ): if measurement is not None: repo = "" @@ -33,6 +40,7 @@ def __init__(self, enclave: str = "", repo: str = "tinfoilsh/confidential-model- self.client = OpenAI( base_url=f"https://{enclave}/v1/", api_key=api_key, + timeout=timeout, http_client=secure_http, ) self.chat = self.client.chat @@ -53,7 +61,14 @@ class AsyncTinfoilAI: api_key: str enclave: str - def __init__(self, enclave: str = "", repo: str = "tinfoilsh/confidential-model-router", api_key: str = "tinfoil", measurement: Optional[dict] = None): + def __init__( + self, + enclave: str = "", + repo: str = "tinfoilsh/confidential-model-router", + api_key: str = "tinfoil", + measurement: Optional[dict] = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ): if measurement is not None: repo = "" @@ -73,6 +88,7 @@ def __init__(self, enclave: str = "", repo: str = "tinfoilsh/confidential-model- self.client = AsyncOpenAI( base_url=f"https://{enclave}/v1/", api_key=api_key, + timeout=timeout, http_client=async_http, ) self.chat = self.client.chat diff --git a/tests/test_constructor_timeout.py b/tests/test_constructor_timeout.py new file mode 100644 index 0000000..3bb7a0a --- /dev/null +++ b/tests/test_constructor_timeout.py @@ -0,0 +1,113 @@ +import httpx +from openai import NOT_GIVEN + +import tinfoil + + +class FakeSecureClient: + def __init__(self, enclave, repo, measurement): + self.enclave = enclave + self.repo = repo + self.measurement = measurement + + def make_secure_http_client(self): + return "sync-http-client" + + def make_secure_async_http_client(self): + return "async-http-client" + + +class FakeOpenAIClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.chat = object() + self.embeddings = object() + self.audio = object() + + +def test_tinfoilai_forwards_timeout_to_openai(monkeypatch): + captured = {} + + def fake_openai(**kwargs): + captured.update(kwargs) + return FakeOpenAIClient(**kwargs) + + monkeypatch.setattr(tinfoil, "SecureClient", FakeSecureClient) + monkeypatch.setattr(tinfoil, "OpenAI", fake_openai) + + timeout = httpx.Timeout(12.5) + client = tinfoil.TinfoilAI( + enclave="router.test", + repo="tinfoilsh/confidential-model-router", + api_key="test-key", + timeout=timeout, + ) + + assert client.enclave == "router.test" + assert captured["base_url"] == "https://router.test/v1/" + assert captured["api_key"] == "test-key" + assert captured["timeout"] is timeout + assert captured["http_client"] == "sync-http-client" + + +def test_tinfoilai_uses_openai_default_timeout_when_unspecified(monkeypatch): + captured = {} + + def fake_openai(**kwargs): + captured.update(kwargs) + return FakeOpenAIClient(**kwargs) + + monkeypatch.setattr(tinfoil, "SecureClient", FakeSecureClient) + monkeypatch.setattr(tinfoil, "OpenAI", fake_openai) + + tinfoil.TinfoilAI( + enclave="router.test", + repo="tinfoilsh/confidential-model-router", + api_key="test-key", + ) + + assert captured["timeout"] is NOT_GIVEN + + +def test_async_tinfoilai_forwards_timeout_to_async_openai(monkeypatch): + captured = {} + + def fake_async_openai(**kwargs): + captured.update(kwargs) + return FakeOpenAIClient(**kwargs) + + monkeypatch.setattr(tinfoil, "SecureClient", FakeSecureClient) + monkeypatch.setattr(tinfoil, "AsyncOpenAI", fake_async_openai) + + timeout = httpx.Timeout(8.0) + client = tinfoil.AsyncTinfoilAI( + enclave="router.test", + repo="tinfoilsh/confidential-model-router", + api_key="test-key", + timeout=timeout, + ) + + assert client.enclave == "router.test" + assert captured["base_url"] == "https://router.test/v1/" + assert captured["api_key"] == "test-key" + assert captured["timeout"] is timeout + assert captured["http_client"] == "async-http-client" + + +def test_async_tinfoilai_uses_openai_default_timeout_when_unspecified(monkeypatch): + captured = {} + + def fake_async_openai(**kwargs): + captured.update(kwargs) + return FakeOpenAIClient(**kwargs) + + monkeypatch.setattr(tinfoil, "SecureClient", FakeSecureClient) + monkeypatch.setattr(tinfoil, "AsyncOpenAI", fake_async_openai) + + tinfoil.AsyncTinfoilAI( + enclave="router.test", + repo="tinfoilsh/confidential-model-router", + api_key="test-key", + ) + + assert captured["timeout"] is NOT_GIVEN