diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..a2aff78 --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,184 @@ +""" +Tests for client.chat — complete() and stream() +""" + +import pytest +import respx +import httpx + +from tests.conftest import CHAT_RESPONSE, STREAM_CHUNKS +import kitefishai +from kitefishai.types import ChatCompletion, StreamChunk + + +BASE = "https://api.kitefishai.com/v1" + + +class TestChatComplete: + + @respx.mock + def test_basic_complete(self, client): + respx.post(f"{BASE}/chat").mock( + return_value=httpx.Response(200, json=CHAT_RESPONSE) + ) + response = client.chat.complete( + model="kf-reasoning-10b", + messages=[{"role": "user", "content": "Hello"}], + ) + assert isinstance(response, ChatCompletion) + assert response.choices[0].message.content == "Hello from KiteFish!" + assert response.choices[0].message.role == "assistant" + assert response.choices[0].finish_reason == "stop" + + @respx.mock + def test_usage_parsed(self, client): + respx.post(f"{BASE}/chat").mock( + return_value=httpx.Response(200, json=CHAT_RESPONSE) + ) + response = client.chat.complete( + model="kf-reasoning-10b", + messages=[{"role": "user", "content": "Hi"}], + ) + assert response.usage.prompt_tokens == 10 + assert response.usage.completion_tokens == 5 + assert response.usage.total_tokens == 15 + + @respx.mock + def test_system_prompt_prepended(self, client): + captured = {} + + def capture(request): + import json + captured["body"] = json.loads(request.content) + return httpx.Response(200, json=CHAT_RESPONSE) + + respx.post(f"{BASE}/chat").mock(side_effect=capture) + client.chat.complete( + model="kf-reasoning-10b", + system="You are a compliance assistant.", + messages=[{"role": "user", "content": "Hello"}], + ) + messages = captured["body"]["messages"] + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "You are a compliance assistant." + assert messages[1]["role"] == "user" + + @respx.mock + def test_stream_false_by_default(self, client): + captured = {} + + def capture(request): + import json + captured["body"] = json.loads(request.content) + return httpx.Response(200, json=CHAT_RESPONSE) + + respx.post(f"{BASE}/chat").mock(side_effect=capture) + client.chat.complete( + model="kf-reasoning-10b", + messages=[{"role": "user", "content": "Hi"}], + ) + assert captured["body"]["stream"] is False + + @respx.mock + def test_model_sent_in_payload(self, client): + captured = {} + + def capture(request): + import json + captured["body"] = json.loads(request.content) + return httpx.Response(200, json=CHAT_RESPONSE) + + respx.post(f"{BASE}/chat").mock(side_effect=capture) + client.chat.complete( + model="kf-reasoning-10b", + messages=[{"role": "user", "content": "Hi"}], + ) + assert captured["body"]["model"] == "kf-reasoning-10b" + + @respx.mock + def test_extra_params_forwarded(self, client): + captured = {} + + def capture(request): + import json + captured["body"] = json.loads(request.content) + return httpx.Response(200, json=CHAT_RESPONSE) + + respx.post(f"{BASE}/chat").mock(side_effect=capture) + client.chat.complete( + model="kf-reasoning-10b", + messages=[{"role": "user", "content": "Hi"}], + extra={"frequency_penalty": 0.5}, + ) + assert captured["body"]["frequency_penalty"] == 0.5 + + @respx.mock + def test_auth_error_on_401(self, client): + respx.post(f"{BASE}/chat").mock( + return_value=httpx.Response(401, json={"error": {"message": "Unauthorized"}}) + ) + with pytest.raises(kitefishai.AuthenticationError): + client.chat.complete( + model="kf-reasoning-10b", + messages=[{"role": "user", "content": "Hi"}], + ) + + @respx.mock + def test_rate_limit_error_on_429(self, client): + respx.post(f"{BASE}/chat").mock( + return_value=httpx.Response(429, json={"error": {"message": "Rate limited"}}) + ) + with pytest.raises(kitefishai.RateLimitError): + client.chat.complete( + model="kf-reasoning-10b", + messages=[{"role": "user", "content": "Hi"}], + ) + + @respx.mock + def test_api_error_on_500(self, client): + respx.post(f"{BASE}/chat").mock( + return_value=httpx.Response(500, json={"error": {"message": "Server error"}}) + ) + with pytest.raises(kitefishai.APIError) as exc_info: + client.chat.complete( + model="kf-reasoning-10b", + messages=[{"role": "user", "content": "Hi"}], + ) + assert exc_info.value.status_code == 500 + + +class TestChatStream: + + def test_stream_returns_chat_stream(self, client): + from kitefishai.types import ChatStream + result = client.chat.stream( + model="kf-reasoning-10b", + messages=[{"role": "user", "content": "Hi"}], + ) + assert isinstance(result, ChatStream) + + def test_stream_chunk_parsing(self): + import json + from kitefishai.types import StreamChunk + + raw = '{"id":"r1","model":"kf-reasoning-10b","choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}' + chunk = StreamChunk.from_dict(json.loads(raw)) + assert chunk.delta == "Hello" + assert chunk.finish_reason is None + assert chunk.model == "kf-reasoning-10b" + + def test_stream_chunk_finish_reason(self): + import json + from kitefishai.types import StreamChunk + + raw = '{"id":"r1","model":"kf-reasoning-10b","choices":[{"delta":{"content":""},"finish_reason":"stop"}]}' + chunk = StreamChunk.from_dict(json.loads(raw)) + assert chunk.finish_reason == "stop" + + def test_stream_empty_delta(self): + import json + from kitefishai.types import StreamChunk + + raw = '{"id":"r1","model":"kf-reasoning-10b","choices":[{"delta":{},"finish_reason":null}]}' + chunk = StreamChunk.from_dict(json.loads(raw)) + assert chunk.delta == "" diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..ffa4a5a --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,70 @@ +""" +Tests for KiteFishAI Client initialisation and configuration. +""" + +import pytest +import os +import kitefishai +from kitefishai._exceptions import AuthenticationError + + +class TestClientInit: + + def test_requires_api_key(self): + env = os.environ.pop("KITEFISH_API_KEY", None) + try: + with pytest.raises(AuthenticationError, match="No API key"): + kitefishai.Client() + finally: + if env: + os.environ["KITEFISH_API_KEY"] = env + + def test_accepts_api_key_arg(self): + client = kitefishai.Client(api_key="kf-test") + assert client.api_key == "kf-test" + + def test_reads_api_key_from_env(self, monkeypatch): + monkeypatch.setenv("KITEFISH_API_KEY", "kf-from-env") + client = kitefishai.Client() + assert client.api_key == "kf-from-env" + + def test_default_base_url(self): + client = kitefishai.Client(api_key="kf-test") + assert client.base_url == "https://api.kitefishai.com/v1" + + def test_custom_base_url(self): + client = kitefishai.Client(api_key="kf-test", base_url="https://internal/v1") + assert client.base_url == "https://internal/v1" + + def test_base_url_strips_trailing_slash(self): + client = kitefishai.Client(api_key="kf-test", base_url="https://internal/v1/") + assert client.base_url == "https://internal/v1" + + def test_reads_base_url_from_env(self, monkeypatch): + monkeypatch.setenv("KITEFISH_BASE_URL", "https://onprem/v1") + client = kitefishai.Client(api_key="kf-test") + assert client.base_url == "https://onprem/v1" + + def test_default_timeout(self): + client = kitefishai.Client(api_key="kf-test") + assert client.timeout == 60.0 + + def test_custom_timeout(self): + client = kitefishai.Client(api_key="kf-test", timeout=120.0) + assert client.timeout == 120.0 + + def test_default_max_retries(self): + client = kitefishai.Client(api_key="kf-test") + assert client.max_retries == 2 + + def test_resources_attached(self): + client = kitefishai.Client(api_key="kf-test") + assert hasattr(client, "chat") + assert hasattr(client, "embeddings") + + def test_context_manager(self): + with kitefishai.Client(api_key="kf-test") as client: + assert client.api_key == "kf-test" + + def test_version_exported(self): + assert kitefishai.__version__ == "0.1.0" diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py new file mode 100644 index 0000000..cae8ab1 --- /dev/null +++ b/tests/test_embeddings.py @@ -0,0 +1,148 @@ +""" +Tests for client.embeddings — create() +""" + +import pytest +import respx +import httpx + +from tests.conftest import EMBEDDING_RESPONSE +import kitefishai +from kitefishai.types import EmbeddingResponse, Embedding + + +BASE = "https://api.kitefishai.com/v1" + + +class TestEmbeddings: + + @respx.mock + def test_basic_create(self, client): + respx.post(f"{BASE}/embeddings").mock( + return_value=httpx.Response(200, json=EMBEDDING_RESPONSE) + ) + result = client.embeddings.create( + model="minnow-em-v1", + input=["query: hello", "passage: world"], + ) + assert isinstance(result, EmbeddingResponse) + assert len(result.data) == 2 + assert isinstance(result.data[0], Embedding) + + @respx.mock + def test_embedding_values(self, client): + respx.post(f"{BASE}/embeddings").mock( + return_value=httpx.Response(200, json=EMBEDDING_RESPONSE) + ) + result = client.embeddings.create( + model="minnow-em-v1", + input=["query: test"], + ) + assert result.data[0].embedding == [0.1, 0.2, 0.3] + assert result.data[1].embedding == [0.4, 0.5, 0.6] + + @respx.mock + def test_single_string_input_wrapped(self, client): + captured = {} + + def capture(request): + import json + captured["body"] = json.loads(request.content) + return httpx.Response(200, json=EMBEDDING_RESPONSE) + + respx.post(f"{BASE}/embeddings").mock(side_effect=capture) + client.embeddings.create( + model="minnow-em-v1", + input="query: single string", + ) + assert isinstance(captured["body"]["input"], list) + assert captured["body"]["input"] == ["query: single string"] + + @respx.mock + def test_dimensions_forwarded(self, client): + captured = {} + + def capture(request): + import json + captured["body"] = json.loads(request.content) + return httpx.Response(200, json=EMBEDDING_RESPONSE) + + respx.post(f"{BASE}/embeddings").mock(side_effect=capture) + client.embeddings.create( + model="minnow-em-v1", + input=["query: test"], + dimensions=256, + ) + assert captured["body"]["dimensions"] == 256 + + @respx.mock + def test_dimensions_omitted_by_default(self, client): + captured = {} + + def capture(request): + import json + captured["body"] = json.loads(request.content) + return httpx.Response(200, json=EMBEDDING_RESPONSE) + + respx.post(f"{BASE}/embeddings").mock(side_effect=capture) + client.embeddings.create( + model="minnow-em-v1", + input=["query: test"], + ) + assert "dimensions" not in captured["body"] + + @respx.mock + def test_model_sent_in_payload(self, client): + captured = {} + + def capture(request): + import json + captured["body"] = json.loads(request.content) + return httpx.Response(200, json=EMBEDDING_RESPONSE) + + respx.post(f"{BASE}/embeddings").mock(side_effect=capture) + client.embeddings.create(model="minnow-em-v1", input=["test"]) + assert captured["body"]["model"] == "minnow-em-v1" + + @respx.mock + def test_usage_parsed(self, client): + respx.post(f"{BASE}/embeddings").mock( + return_value=httpx.Response(200, json=EMBEDDING_RESPONSE) + ) + result = client.embeddings.create(model="minnow-em-v1", input=["test"]) + assert result.usage.prompt_tokens == 8 + assert result.usage.total_tokens == 8 + + @respx.mock + def test_index_on_each_embedding(self, client): + respx.post(f"{BASE}/embeddings").mock( + return_value=httpx.Response(200, json=EMBEDDING_RESPONSE) + ) + result = client.embeddings.create(model="minnow-em-v1", input=["a", "b"]) + assert result.data[0].index == 0 + assert result.data[1].index == 1 + + @respx.mock + def test_auth_error_on_401(self, client): + respx.post(f"{BASE}/embeddings").mock( + return_value=httpx.Response(401, json={"error": {"message": "Unauthorized"}}) + ) + with pytest.raises(kitefishai.AuthenticationError): + client.embeddings.create(model="minnow-em-v1", input=["test"]) + + @respx.mock + def test_extra_params_forwarded(self, client): + captured = {} + + def capture(request): + import json + captured["body"] = json.loads(request.content) + return httpx.Response(200, json=EMBEDDING_RESPONSE) + + respx.post(f"{BASE}/embeddings").mock(side_effect=capture) + client.embeddings.create( + model="minnow-em-v1", + input=["test"], + extra={"truncation": True}, + ) + assert captured["body"]["truncation"] is True diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 0000000..dd1bca9 --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,141 @@ +""" +Tests for exceptions and response type parsing. +""" + +import pytest +from kitefishai._exceptions import ( + KiteFishAIError, + AuthenticationError, + RateLimitError, + NotFoundError, + APIError, +) +from kitefishai.types import ( + ChatCompletion, + EmbeddingResponse, + StreamChunk, + Usage, + Message, + Choice, + Embedding, +) + + +class TestExceptions: + + def test_hierarchy(self): + assert issubclass(AuthenticationError, KiteFishAIError) + assert issubclass(RateLimitError, KiteFishAIError) + assert issubclass(NotFoundError, KiteFishAIError) + assert issubclass(APIError, KiteFishAIError) + + def test_api_error_status_code(self): + err = APIError("Something failed", status_code=503) + assert err.status_code == 503 + assert "503" in repr(err) + + def test_message_accessible(self): + err = KiteFishAIError("test message") + assert err.message == "test message" + assert str(err) == "test message" + + +class TestChatCompletionParsing: + + def test_full_response(self): + data = { + "id": "req_1", + "model": "kf-reasoning-10b", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Hello!"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 3, "total_tokens": 8}, + "created": 1700000000, + } + cc = ChatCompletion.from_dict(data) + assert cc.id == "req_1" + assert cc.model == "kf-reasoning-10b" + assert len(cc.choices) == 1 + assert cc.choices[0].message.content == "Hello!" + assert cc.choices[0].message.role == "assistant" + assert cc.choices[0].finish_reason == "stop" + assert cc.usage.total_tokens == 8 + assert cc.created == 1700000000 + + def test_missing_usage(self): + data = { + "id": "req_2", + "model": "kf-reasoning-10b", + "choices": [ + {"index": 0, "message": {"role": "assistant", "content": "Hi"}, "finish_reason": None} + ], + } + cc = ChatCompletion.from_dict(data) + assert cc.usage is None + + def test_empty_choices(self): + cc = ChatCompletion.from_dict({"id": "", "model": "", "choices": []}) + assert cc.choices == [] + + +class TestStreamChunkParsing: + + def test_delta_extracted(self): + data = { + "id": "r1", + "model": "kf-reasoning-10b", + "choices": [{"delta": {"content": "Hello"}, "finish_reason": None}], + } + chunk = StreamChunk.from_dict(data) + assert chunk.delta == "Hello" + assert chunk.finish_reason is None + + def test_empty_delta_on_missing_content(self): + data = { + "id": "r1", + "model": "kf-reasoning-10b", + "choices": [{"delta": {}, "finish_reason": "stop"}], + } + chunk = StreamChunk.from_dict(data) + assert chunk.delta == "" + assert chunk.finish_reason == "stop" + + def test_no_choices(self): + data = {"id": "r1", "model": "kf-reasoning-10b", "choices": []} + chunk = StreamChunk.from_dict(data) + assert chunk.delta == "" + + +class TestEmbeddingResponseParsing: + + def test_full_response(self): + data = { + "model": "minnow-em-v1", + "data": [ + {"index": 0, "embedding": [0.1, 0.2], "object": "embedding"}, + {"index": 1, "embedding": [0.3, 0.4], "object": "embedding"}, + ], + "usage": {"prompt_tokens": 4, "total_tokens": 4}, + } + er = EmbeddingResponse.from_dict(data) + assert er.model == "minnow-em-v1" + assert len(er.data) == 2 + assert er.data[0].embedding == [0.1, 0.2] + assert er.data[1].index == 1 + assert er.usage.prompt_tokens == 4 + + def test_missing_usage(self): + data = { + "model": "minnow-em-v1", + "data": [{"index": 0, "embedding": [0.1]}], + } + er = EmbeddingResponse.from_dict(data) + assert er.usage is None + + def test_empty_data(self): + er = EmbeddingResponse.from_dict({"model": "minnow-em-v1", "data": []}) + assert er.data == []