Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/tinfoil/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Optional
from openai import OpenAI, AsyncOpenAI
from openai import OpenAI, AsyncOpenAI, NOT_GIVEN, NotGiven
from openai.resources.chat import Chat as OpenAIChat
from openai.resources.embeddings import Embeddings as OpenAIEmbeddings
from openai.resources.audio import Audio as OpenAIAudio
Expand All @@ -14,7 +14,14 @@ class TinfoilAI:
api_key: str
enclave: str

def __init__(self, enclave: str = "", repo: str = "tinfoilsh/confidential-model-router", api_key: str = "tinfoil", measurement: Optional[dict] = None):
def __init__(
self,
enclave: str = "",
repo: str = "tinfoilsh/confidential-model-router",
api_key: str = "tinfoil",
measurement: Optional[dict] = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
):
if measurement is not None:
repo = ""

Expand All @@ -33,6 +40,7 @@ def __init__(self, enclave: str = "", repo: str = "tinfoilsh/confidential-model-
self.client = OpenAI(
base_url=f"https://{enclave}/v1/",
api_key=api_key,
timeout=timeout,
http_client=secure_http,
)
self.chat = self.client.chat
Expand All @@ -53,7 +61,14 @@ class AsyncTinfoilAI:
api_key: str
enclave: str

def __init__(self, enclave: str = "", repo: str = "tinfoilsh/confidential-model-router", api_key: str = "tinfoil", measurement: Optional[dict] = None):
def __init__(
self,
enclave: str = "",
repo: str = "tinfoilsh/confidential-model-router",
api_key: str = "tinfoil",
measurement: Optional[dict] = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
):
if measurement is not None:
repo = ""

Expand All @@ -73,6 +88,7 @@ def __init__(self, enclave: str = "", repo: str = "tinfoilsh/confidential-model-
self.client = AsyncOpenAI(
base_url=f"https://{enclave}/v1/",
api_key=api_key,
timeout=timeout,
http_client=async_http,
)
self.chat = self.client.chat
Expand Down
113 changes: 113 additions & 0 deletions tests/test_constructor_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import httpx
from openai import NOT_GIVEN

import tinfoil


class FakeSecureClient:
def __init__(self, enclave, repo, measurement):
self.enclave = enclave
self.repo = repo
self.measurement = measurement

def make_secure_http_client(self):
return "sync-http-client"

def make_secure_async_http_client(self):
return "async-http-client"


class FakeOpenAIClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.chat = object()
self.embeddings = object()
self.audio = object()


def test_tinfoilai_forwards_timeout_to_openai(monkeypatch):
captured = {}

def fake_openai(**kwargs):
captured.update(kwargs)
return FakeOpenAIClient(**kwargs)

monkeypatch.setattr(tinfoil, "SecureClient", FakeSecureClient)
monkeypatch.setattr(tinfoil, "OpenAI", fake_openai)

timeout = httpx.Timeout(12.5)
client = tinfoil.TinfoilAI(
enclave="router.test",
repo="tinfoilsh/confidential-model-router",
api_key="test-key",
timeout=timeout,
)

assert client.enclave == "router.test"
assert captured["base_url"] == "https://router.test/v1/"
assert captured["api_key"] == "test-key"
assert captured["timeout"] is timeout
assert captured["http_client"] == "sync-http-client"


def test_tinfoilai_uses_openai_default_timeout_when_unspecified(monkeypatch):
captured = {}

def fake_openai(**kwargs):
captured.update(kwargs)
return FakeOpenAIClient(**kwargs)

monkeypatch.setattr(tinfoil, "SecureClient", FakeSecureClient)
monkeypatch.setattr(tinfoil, "OpenAI", fake_openai)

tinfoil.TinfoilAI(
enclave="router.test",
repo="tinfoilsh/confidential-model-router",
api_key="test-key",
)

assert captured["timeout"] is NOT_GIVEN


def test_async_tinfoilai_forwards_timeout_to_async_openai(monkeypatch):
captured = {}

def fake_async_openai(**kwargs):
captured.update(kwargs)
return FakeOpenAIClient(**kwargs)

monkeypatch.setattr(tinfoil, "SecureClient", FakeSecureClient)
monkeypatch.setattr(tinfoil, "AsyncOpenAI", fake_async_openai)

timeout = httpx.Timeout(8.0)
client = tinfoil.AsyncTinfoilAI(
enclave="router.test",
repo="tinfoilsh/confidential-model-router",
api_key="test-key",
timeout=timeout,
)

assert client.enclave == "router.test"
assert captured["base_url"] == "https://router.test/v1/"
assert captured["api_key"] == "test-key"
assert captured["timeout"] is timeout
assert captured["http_client"] == "async-http-client"


def test_async_tinfoilai_uses_openai_default_timeout_when_unspecified(monkeypatch):
captured = {}

def fake_async_openai(**kwargs):
captured.update(kwargs)
return FakeOpenAIClient(**kwargs)

monkeypatch.setattr(tinfoil, "SecureClient", FakeSecureClient)
monkeypatch.setattr(tinfoil, "AsyncOpenAI", fake_async_openai)

tinfoil.AsyncTinfoilAI(
enclave="router.test",
repo="tinfoilsh/confidential-model-router",
api_key="test-key",
)

assert captured["timeout"] is NOT_GIVEN
Loading