diff --git a/sdk-endpoints.txt b/sdk-endpoints.txt index 9da2578..acf2b2a 100644 --- a/sdk-endpoints.txt +++ b/sdk-endpoints.txt @@ -16,6 +16,7 @@ POST /v1/chat/completions POST /v1/responses POST /v1/messages +POST /v1/messages/count_tokens POST /v1/embeddings POST /v1/moderations POST /v1/rerank diff --git a/src/otari/_client/__init__.py b/src/otari/_client/__init__.py index c5d554b..0749d2d 100644 --- a/src/otari/_client/__init__.py +++ b/src/otari/_client/__init__.py @@ -91,6 +91,8 @@ "Content8", "Content9Inner", "ContentAnyOfInner", + "CountTokensRequest", + "CountTokensResponse", "CreateBatchRequest", "CreateBudgetRequest", "CreateEmbeddingResponse", @@ -195,6 +197,7 @@ "SetPricingRequest", "Source", "System", + "System1", "ToolCallsInner", "ToolChoice", "UpdateBudgetRequest", @@ -284,6 +287,8 @@ from otari._client.models.content8 import Content8 as Content8 from otari._client.models.content9_inner import Content9Inner as Content9Inner from otari._client.models.content_any_of_inner import ContentAnyOfInner as ContentAnyOfInner +from otari._client.models.count_tokens_request import CountTokensRequest as CountTokensRequest +from otari._client.models.count_tokens_response import CountTokensResponse as CountTokensResponse from otari._client.models.create_batch_request import CreateBatchRequest as CreateBatchRequest from otari._client.models.create_budget_request import CreateBudgetRequest as CreateBudgetRequest from otari._client.models.create_embedding_response import CreateEmbeddingResponse as CreateEmbeddingResponse @@ -388,6 +393,7 @@ from otari._client.models.set_pricing_request import SetPricingRequest as SetPricingRequest from otari._client.models.source import Source as Source from otari._client.models.system import System as System +from otari._client.models.system1 import System1 as System1 from otari._client.models.tool_calls_inner import ToolCallsInner as ToolCallsInner from otari._client.models.tool_choice import ToolChoice as ToolChoice from otari._client.models.update_budget_request import UpdateBudgetRequest as UpdateBudgetRequest diff --git a/src/otari/_client/api/messages_api.py b/src/otari/_client/api/messages_api.py index 6b65bff..0d7dd03 100644 --- a/src/otari/_client/api/messages_api.py +++ b/src/otari/_client/api/messages_api.py @@ -15,6 +15,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union from typing_extensions import Annotated +from otari._client.models.count_tokens_request import CountTokensRequest +from otari._client.models.count_tokens_response import CountTokensResponse from otari._client.models.message_response import MessageResponse from otari._client.models.messages_request import MessagesRequest @@ -36,6 +38,282 @@ def __init__(self, api_client=None) -> None: self.api_client = api_client + @validate_call + def count_message_tokens_v1_messages_count_tokens_post( + self, + count_tokens_request: CountTokensRequest, + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + Tuple[ + Annotated[StrictFloat, Field(gt=0)], + Annotated[StrictFloat, Field(gt=0)] + ] + ] = None, + _request_auth: Optional[Dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[Dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> CountTokensResponse: + """Count Message Tokens + + Anthropic ``/v1/messages/count_tokens``-compatible endpoint. Returns ``{\"input_tokens\": N}`` without contacting an upstream provider: counting is local, so there is no budget reservation, pricing, or usage logging. Authentication mirrors :func:`create_message` — platform mode resolves the caller's token against the platform, standalone mode validates the API key — so the endpoint is not an open token-counting oracle. + + :param count_tokens_request: (required) + :type count_tokens_request: CountTokensRequest + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + + _param = self._count_message_tokens_v1_messages_count_tokens_post_serialize( + count_tokens_request=count_tokens_request, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index + ) + + _response_types_map: Dict[str, Optional[str]] = { + '200': "CountTokensResponse", + '422': "HTTPValidationError", + } + response_data = self.api_client.call_api( + *_param, + _request_timeout=_request_timeout + ) + response_data.read() + return self.api_client.response_deserialize( + response_data=response_data, + response_types_map=_response_types_map, + ).data + + + @validate_call + def count_message_tokens_v1_messages_count_tokens_post_with_http_info( + self, + count_tokens_request: CountTokensRequest, + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + Tuple[ + Annotated[StrictFloat, Field(gt=0)], + Annotated[StrictFloat, Field(gt=0)] + ] + ] = None, + _request_auth: Optional[Dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[Dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> ApiResponse[CountTokensResponse]: + """Count Message Tokens + + Anthropic ``/v1/messages/count_tokens``-compatible endpoint. Returns ``{\"input_tokens\": N}`` without contacting an upstream provider: counting is local, so there is no budget reservation, pricing, or usage logging. Authentication mirrors :func:`create_message` — platform mode resolves the caller's token against the platform, standalone mode validates the API key — so the endpoint is not an open token-counting oracle. + + :param count_tokens_request: (required) + :type count_tokens_request: CountTokensRequest + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + + _param = self._count_message_tokens_v1_messages_count_tokens_post_serialize( + count_tokens_request=count_tokens_request, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index + ) + + _response_types_map: Dict[str, Optional[str]] = { + '200': "CountTokensResponse", + '422': "HTTPValidationError", + } + response_data = self.api_client.call_api( + *_param, + _request_timeout=_request_timeout + ) + response_data.read() + return self.api_client.response_deserialize( + response_data=response_data, + response_types_map=_response_types_map, + ) + + + @validate_call + def count_message_tokens_v1_messages_count_tokens_post_without_preload_content( + self, + count_tokens_request: CountTokensRequest, + _request_timeout: Union[ + None, + Annotated[StrictFloat, Field(gt=0)], + Tuple[ + Annotated[StrictFloat, Field(gt=0)], + Annotated[StrictFloat, Field(gt=0)] + ] + ] = None, + _request_auth: Optional[Dict[StrictStr, Any]] = None, + _content_type: Optional[StrictStr] = None, + _headers: Optional[Dict[StrictStr, Any]] = None, + _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0, + ) -> RESTResponseType: + """Count Message Tokens + + Anthropic ``/v1/messages/count_tokens``-compatible endpoint. Returns ``{\"input_tokens\": N}`` without contacting an upstream provider: counting is local, so there is no budget reservation, pricing, or usage logging. Authentication mirrors :func:`create_message` — platform mode resolves the caller's token against the platform, standalone mode validates the API key — so the endpoint is not an open token-counting oracle. + + :param count_tokens_request: (required) + :type count_tokens_request: CountTokensRequest + :param _request_timeout: timeout setting for this request. If one + number provided, it will be total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + :type _request_timeout: int, tuple(int, int), optional + :param _request_auth: set to override the auth_settings for an a single + request; this effectively ignores the + authentication in the spec for a single request. + :type _request_auth: dict, optional + :param _content_type: force content-type for the request. + :type _content_type: str, Optional + :param _headers: set to override the headers for a single + request; this effectively ignores the headers + in the spec for a single request. + :type _headers: dict, optional + :param _host_index: set to override the host_index for a single + request; this effectively ignores the host_index + in the spec for a single request. + :type _host_index: int, optional + :return: Returns the result object. + """ # noqa: E501 + + _param = self._count_message_tokens_v1_messages_count_tokens_post_serialize( + count_tokens_request=count_tokens_request, + _request_auth=_request_auth, + _content_type=_content_type, + _headers=_headers, + _host_index=_host_index + ) + + _response_types_map: Dict[str, Optional[str]] = { + '200': "CountTokensResponse", + '422': "HTTPValidationError", + } + response_data = self.api_client.call_api( + *_param, + _request_timeout=_request_timeout + ) + return response_data.response + + + def _count_message_tokens_v1_messages_count_tokens_post_serialize( + self, + count_tokens_request, + _request_auth, + _content_type, + _headers, + _host_index, + ) -> RequestSerialized: + + _host = None + + _collection_formats: Dict[str, str] = { + } + + _path_params: Dict[str, str] = {} + _query_params: List[Tuple[str, str]] = [] + _header_params: Dict[str, Optional[str]] = _headers or {} + _form_params: List[Tuple[str, str]] = [] + _files: Dict[ + str, Union[str, bytes, List[str], List[bytes], List[Tuple[str, bytes]]] + ] = {} + _body_params: Optional[bytes] = None + + # process the path parameters + # process the query parameters + # process the header parameters + # process the form parameters + # process the body parameter + if count_tokens_request is not None: + _body_params = count_tokens_request + + + # set the HTTP header `Accept` + if 'Accept' not in _header_params: + _header_params['Accept'] = self.api_client.select_header_accept( + [ + 'application/json' + ] + ) + + # set the HTTP header `Content-Type` + if _content_type: + _header_params['Content-Type'] = _content_type + else: + _default_content_type = ( + self.api_client.select_header_content_type( + [ + 'application/json' + ] + ) + ) + if _default_content_type is not None: + _header_params['Content-Type'] = _default_content_type + + # authentication setting + _auth_settings: List[str] = [ + ] + + return self.api_client.param_serialize( + method='POST', + resource_path='/v1/messages/count_tokens', + path_params=_path_params, + query_params=_query_params, + header_params=_header_params, + body=_body_params, + post_params=_form_params, + files=_files, + auth_settings=_auth_settings, + collection_formats=_collection_formats, + _host=_host, + _request_auth=_request_auth + ) + + + + @validate_call def create_message_v1_messages_post( self, diff --git a/src/otari/_client/models/__init__.py b/src/otari/_client/models/__init__.py index d886a76..f3be3b6 100644 --- a/src/otari/_client/models/__init__.py +++ b/src/otari/_client/models/__init__.py @@ -61,6 +61,8 @@ from otari._client.models.content8 import Content8 from otari._client.models.content9_inner import Content9Inner from otari._client.models.content_any_of_inner import ContentAnyOfInner +from otari._client.models.count_tokens_request import CountTokensRequest +from otari._client.models.count_tokens_response import CountTokensResponse from otari._client.models.create_batch_request import CreateBatchRequest from otari._client.models.create_budget_request import CreateBudgetRequest from otari._client.models.create_embedding_response import CreateEmbeddingResponse @@ -165,6 +167,7 @@ from otari._client.models.set_pricing_request import SetPricingRequest from otari._client.models.source import Source from otari._client.models.system import System +from otari._client.models.system1 import System1 from otari._client.models.tool_calls_inner import ToolCallsInner from otari._client.models.tool_choice import ToolChoice from otari._client.models.update_budget_request import UpdateBudgetRequest diff --git a/src/otari/_client/models/count_tokens_request.py b/src/otari/_client/models/count_tokens_request.py new file mode 100644 index 0000000..74f9ac3 --- /dev/null +++ b/src/otari/_client/models/count_tokens_request.py @@ -0,0 +1,137 @@ +# coding: utf-8 + +""" + otari-gateway + + A clean FastAPI gateway for otari with API key management + + The version of the OpenAPI document: 0.0.0-dev + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" # noqa: E501 + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + +from pydantic import BaseModel, ConfigDict, Field, StrictStr +from typing import Any, ClassVar, Dict, List, Optional +from typing_extensions import Annotated +from otari._client.models.system import System +from typing import Optional, Set +from typing_extensions import Self +from pydantic_core import to_jsonable_python + +class CountTokensRequest(BaseModel): + """ + Anthropic ``/v1/messages/count_tokens`` request. A subset of :class:`MessagesRequest`: the input fields that affect the token count, minus ``max_tokens`` and the streaming/sampling controls, since the endpoint only counts input tokens. Clients such as Claude Code call this on every turn to keep their prompt within the model's context window. + """ # noqa: E501 + cache_control: Optional[Dict[str, Any]] = None + messages: Annotated[List[Optional[Dict[str, Any]]], Field(min_length=1)] + metadata: Optional[Dict[str, Any]] = None + model: StrictStr + system: Optional[System] = None + thinking: Optional[Dict[str, Any]] = None + tool_choice: Optional[Dict[str, Any]] = None + tools: Optional[List[Dict[str, Any]]] = None + __properties: ClassVar[List[str]] = ["cache_control", "messages", "metadata", "model", "system", "thinking", "tool_choice", "tools"] + + model_config = ConfigDict( + validate_by_name=True, + validate_by_alias=True, + validate_assignment=True, + protected_namespaces=(), + ) + + + def to_str(self) -> str: + """Returns the string representation of the model using alias""" + return pprint.pformat(self.model_dump(by_alias=True)) + + def to_json(self) -> str: + """Returns the JSON representation of the model using alias""" + return json.dumps(to_jsonable_python(self.to_dict())) + + @classmethod + def from_json(cls, json_str: str) -> Optional[Self]: + """Create an instance of CountTokensRequest from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self) -> Dict[str, Any]: + """Return the dictionary representation of the model using alias. + + This has the following differences from calling pydantic's + `self.model_dump(by_alias=True)`: + + * `None` is only added to the output dict for nullable fields that + were set at model initialization. Other fields with value `None` + are ignored. + """ + excluded_fields: Set[str] = set([ + ]) + + _dict = self.model_dump( + by_alias=True, + exclude=excluded_fields, + exclude_none=True, + ) + # override the default output from pydantic by calling `to_dict()` of system + if self.system: + _dict['system'] = self.system.to_dict() + # set to None if cache_control (nullable) is None + # and model_fields_set contains the field + if self.cache_control is None and "cache_control" in self.model_fields_set: + _dict['cache_control'] = None + + # set to None if metadata (nullable) is None + # and model_fields_set contains the field + if self.metadata is None and "metadata" in self.model_fields_set: + _dict['metadata'] = None + + # set to None if system (nullable) is None + # and model_fields_set contains the field + if self.system is None and "system" in self.model_fields_set: + _dict['system'] = None + + # set to None if thinking (nullable) is None + # and model_fields_set contains the field + if self.thinking is None and "thinking" in self.model_fields_set: + _dict['thinking'] = None + + # set to None if tool_choice (nullable) is None + # and model_fields_set contains the field + if self.tool_choice is None and "tool_choice" in self.model_fields_set: + _dict['tool_choice'] = None + + # set to None if tools (nullable) is None + # and model_fields_set contains the field + if self.tools is None and "tools" in self.model_fields_set: + _dict['tools'] = None + + return _dict + + @classmethod + def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]: + """Create an instance of CountTokensRequest from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return cls.model_validate(obj) + + _obj = cls.model_validate({ + "cache_control": obj.get("cache_control"), + "messages": obj.get("messages"), + "metadata": obj.get("metadata"), + "model": obj.get("model"), + "system": System.from_dict(obj["system"]) if obj.get("system") is not None else None, + "thinking": obj.get("thinking"), + "tool_choice": obj.get("tool_choice"), + "tools": obj.get("tools") + }) + return _obj + + diff --git a/src/otari/_client/models/count_tokens_response.py b/src/otari/_client/models/count_tokens_response.py new file mode 100644 index 0000000..50f1347 --- /dev/null +++ b/src/otari/_client/models/count_tokens_response.py @@ -0,0 +1,88 @@ +# coding: utf-8 + +""" + otari-gateway + + A clean FastAPI gateway for otari with API key management + + The version of the OpenAPI document: 0.0.0-dev + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" # noqa: E501 + + +from __future__ import annotations +import pprint +import re # noqa: F401 +import json + +from pydantic import BaseModel, ConfigDict, StrictInt +from typing import Any, ClassVar, Dict, List +from typing import Optional, Set +from typing_extensions import Self +from pydantic_core import to_jsonable_python + +class CountTokensResponse(BaseModel): + """ + Anthropic ``/v1/messages/count_tokens`` response. + """ # noqa: E501 + input_tokens: StrictInt + __properties: ClassVar[List[str]] = ["input_tokens"] + + model_config = ConfigDict( + validate_by_name=True, + validate_by_alias=True, + validate_assignment=True, + protected_namespaces=(), + ) + + + def to_str(self) -> str: + """Returns the string representation of the model using alias""" + return pprint.pformat(self.model_dump(by_alias=True)) + + def to_json(self) -> str: + """Returns the JSON representation of the model using alias""" + return json.dumps(to_jsonable_python(self.to_dict())) + + @classmethod + def from_json(cls, json_str: str) -> Optional[Self]: + """Create an instance of CountTokensResponse from a JSON string""" + return cls.from_dict(json.loads(json_str)) + + def to_dict(self) -> Dict[str, Any]: + """Return the dictionary representation of the model using alias. + + This has the following differences from calling pydantic's + `self.model_dump(by_alias=True)`: + + * `None` is only added to the output dict for nullable fields that + were set at model initialization. Other fields with value `None` + are ignored. + """ + excluded_fields: Set[str] = set([ + ]) + + _dict = self.model_dump( + by_alias=True, + exclude=excluded_fields, + exclude_none=True, + ) + return _dict + + @classmethod + def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]: + """Create an instance of CountTokensResponse from a dict""" + if obj is None: + return None + + if not isinstance(obj, dict): + return cls.model_validate(obj) + + _obj = cls.model_validate({ + "input_tokens": obj.get("input_tokens") + }) + return _obj + + diff --git a/src/otari/_client/models/messages_request.py b/src/otari/_client/models/messages_request.py index 8f9b4c8..1d9e9d7 100644 --- a/src/otari/_client/models/messages_request.py +++ b/src/otari/_client/models/messages_request.py @@ -23,7 +23,7 @@ from uuid import UUID from otari._client.models.guardrail_config import GuardrailConfig from otari._client.models.mcp_server_config import McpServerConfig -from otari._client.models.system import System +from otari._client.models.system1 import System1 from typing import Optional, Set from typing_extensions import Self from pydantic_core import to_jsonable_python @@ -38,12 +38,12 @@ class MessagesRequest(BaseModel): max_tool_iterations: Optional[Annotated[int, Field(le=25, strict=True, ge=1)]] = None mcp_server_ids: Optional[List[UUID]] = None mcp_servers: Optional[List[McpServerConfig]] = None - messages: Annotated[List[Optional[Dict[str, Any]]], Field(min_length=1)] + messages: Annotated[List[Dict[str, Any]], Field(min_length=1)] metadata: Optional[Dict[str, Any]] = None model: StrictStr stop_sequences: Optional[List[StrictStr]] = None stream: Optional[StrictBool] = False - system: Optional[System] = None + system: Optional[System1] = None temperature: Optional[Union[StrictFloat, StrictInt]] = None thinking: Optional[Dict[str, Any]] = None tool_choice: Optional[Dict[str, Any]] = None @@ -207,7 +207,7 @@ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]: "model": obj.get("model"), "stop_sequences": obj.get("stop_sequences"), "stream": obj.get("stream") if obj.get("stream") is not None else False, - "system": System.from_dict(obj["system"]) if obj.get("system") is not None else None, + "system": System1.from_dict(obj["system"]) if obj.get("system") is not None else None, "temperature": obj.get("temperature"), "thinking": obj.get("thinking"), "tool_choice": obj.get("tool_choice"), diff --git a/src/otari/_client/models/system1.py b/src/otari/_client/models/system1.py new file mode 100644 index 0000000..d44f94d --- /dev/null +++ b/src/otari/_client/models/system1.py @@ -0,0 +1,144 @@ +# coding: utf-8 + +""" + otari-gateway + + A clean FastAPI gateway for otari with API key management + + The version of the OpenAPI document: 0.0.0-dev + Generated by OpenAPI Generator (https://openapi-generator.tech) + + Do not edit the class manually. +""" # noqa: E501 + + +from __future__ import annotations +from inspect import getfullargspec +import json +import pprint +import re # noqa: F401 +from pydantic import BaseModel, ConfigDict, Field, StrictStr, ValidationError, field_validator +from typing import Any, Dict, List, Optional +from typing import Union, Any, List, Set, TYPE_CHECKING, Optional, Dict +from typing_extensions import Literal, Self +from pydantic import Field + +SYSTEM1_ANY_OF_SCHEMAS = ["List[Dict[str, object]]", "str"] + +class System1(BaseModel): + """ + System1 + """ + + # data type: str + anyof_schema_1_validator: Optional[StrictStr] = None + # data type: List[Dict[str, object]] + anyof_schema_2_validator: Optional[List[Dict[str, Any]]] = None + if TYPE_CHECKING: + actual_instance: Optional[Union[List[Dict[str, object]], str]] = None + else: + actual_instance: Any = None + any_of_schemas: Set[str] = { "List[Dict[str, object]]", "str" } + + model_config = { + "validate_assignment": True, + "protected_namespaces": (), + } + + def __init__(self, *args, **kwargs) -> None: + if args: + if len(args) > 1: + raise ValueError("If a position argument is used, only 1 is allowed to set `actual_instance`") + if kwargs: + raise ValueError("If a position argument is used, keyword arguments cannot be used.") + super().__init__(actual_instance=args[0]) + else: + super().__init__(**kwargs) + + @field_validator('actual_instance') + def actual_instance_must_validate_anyof(cls, v): + if v is None: + return v + + instance = System1.model_construct() + error_messages = [] + # validate data type: str + try: + instance.anyof_schema_1_validator = v + return v + except (ValidationError, ValueError) as e: + error_messages.append(str(e)) + # validate data type: List[Dict[str, object]] + try: + instance.anyof_schema_2_validator = v + return v + except (ValidationError, ValueError) as e: + error_messages.append(str(e)) + if error_messages: + # no match + raise ValueError("No match found when setting the actual_instance in System1 with anyOf schemas: List[Dict[str, object]], str. Details: " + ", ".join(error_messages)) + else: + return v + + @classmethod + def from_dict(cls, obj: Dict[str, Any]) -> Self: + return cls.from_json(json.dumps(obj)) + + @classmethod + def from_json(cls, json_str: str) -> Self: + """Returns the object represented by the json string""" + instance = cls.model_construct() + if json_str is None: + return instance + + error_messages = [] + # deserialize data into str + try: + # validation + instance.anyof_schema_1_validator = json.loads(json_str) + # assign value to actual_instance + instance.actual_instance = instance.anyof_schema_1_validator + return instance + except (ValidationError, ValueError) as e: + error_messages.append(str(e)) + # deserialize data into List[Dict[str, object]] + try: + # validation + instance.anyof_schema_2_validator = json.loads(json_str) + # assign value to actual_instance + instance.actual_instance = instance.anyof_schema_2_validator + return instance + except (ValidationError, ValueError) as e: + error_messages.append(str(e)) + + if error_messages: + # no match + raise ValueError("No match found when deserializing the JSON string into System1 with anyOf schemas: List[Dict[str, object]], str. Details: " + ", ".join(error_messages)) + else: + return instance + + def to_json(self) -> str: + """Returns the JSON representation of the actual instance""" + if self.actual_instance is None: + return "null" + + if hasattr(self.actual_instance, "to_json") and callable(self.actual_instance.to_json): + return self.actual_instance.to_json() + else: + return json.dumps(self.actual_instance) + + def to_dict(self) -> Optional[Union[Dict[str, Any], List[Dict[str, object]], str]]: + """Returns the dict representation of the actual instance""" + if self.actual_instance is None: + return None + + if hasattr(self.actual_instance, "to_dict") and callable(self.actual_instance.to_dict): + return self.actual_instance.to_dict() + else: + return self.actual_instance + + def to_str(self) -> str: + """Returns the string representation of the actual instance""" + return pprint.pformat(self.model_dump()) + + diff --git a/src/otari/async_client.py b/src/otari/async_client.py index 6ac0874..992a0dd 100644 --- a/src/otari/async_client.py +++ b/src/otari/async_client.py @@ -43,6 +43,7 @@ from otari._client.api.responses_api import ResponsesApi from otari._client.exceptions import ApiException from otari._client.models.chat_completion_request import ChatCompletionRequest +from otari._client.models.count_tokens_request import CountTokensRequest from otari._client.models.create_batch_request import CreateBatchRequest from otari._client.models.embedding_request import EmbeddingRequest from otari._client.models.messages_request import MessagesRequest @@ -57,6 +58,7 @@ from otari._client.models.chat_completion import ChatCompletion from otari._client.models.chat_completion_chunk import ChatCompletionChunk + from otari._client.models.count_tokens_response import CountTokensResponse from otari._client.models.create_embedding_response import CreateEmbeddingResponse from otari._client.models.model_object import ModelObject from otari._client.models.moderation_response import ModerationResponse @@ -231,6 +233,25 @@ async def message( request = build_request(MessagesRequest, body) return await self._call(lambda: self._messages.create_message_v1_messages_post(request)) + async def count_tokens( + self, + *, + model: str, + messages: list[dict[str, Any]], + **kwargs: Any, + ) -> CountTokensResponse: + """Count input tokens for an Anthropic-style message request. + + Calls the gateway ``/v1/messages/count_tokens`` endpoint, which counts + the tokens a ``/messages`` request would consume without generating a + response. Returns a typed ``CountTokensResponse``. + """ + request = build_request(CountTokensRequest, {"model": model, "messages": messages, **kwargs}) + result = await self._call( + lambda: self._messages.count_message_tokens_v1_messages_count_tokens_post(request) + ) + return cast("CountTokensResponse", result) + # -- Embeddings --------------------------------------------------------- async def embedding( diff --git a/src/otari/client.py b/src/otari/client.py index 5dc31db..1b3f995 100644 --- a/src/otari/client.py +++ b/src/otari/client.py @@ -43,6 +43,7 @@ from otari._client.api.responses_api import ResponsesApi from otari._client.exceptions import ApiException from otari._client.models.chat_completion_request import ChatCompletionRequest +from otari._client.models.count_tokens_request import CountTokensRequest from otari._client.models.create_batch_request import CreateBatchRequest from otari._client.models.embedding_request import EmbeddingRequest from otari._client.models.messages_request import MessagesRequest @@ -57,6 +58,7 @@ from otari._client.models.chat_completion import ChatCompletion from otari._client.models.chat_completion_chunk import ChatCompletionChunk + from otari._client.models.count_tokens_response import CountTokensResponse from otari._client.models.create_embedding_response import CreateEmbeddingResponse from otari._client.models.model_object import ModelObject from otari._client.models.moderation_response import ModerationResponse @@ -254,6 +256,32 @@ def message( request = build_request(MessagesRequest, body) return self._call(lambda: self._messages.create_message_v1_messages_post(request)) + def count_tokens( + self, + *, + model: str, + messages: list[dict[str, Any]], + **kwargs: Any, + ) -> CountTokensResponse: + """Count input tokens for an Anthropic-style message request. + + Calls the gateway ``/v1/messages/count_tokens`` endpoint, which counts + the tokens a ``/messages`` request would consume without generating a + response. Returns a typed + :class:`~otari._client.models.count_tokens_response.CountTokensResponse`. + + Args: + model: Model identifier (e.g. ``"anthropic:claude-3-5-sonnet"``). + messages: Anthropic-style message list. + **kwargs: Additional count-tokens parameters (``system``, ``tools``, + ``tool_choice``, ``thinking``, ...). + """ + request = build_request(CountTokensRequest, {"model": model, "messages": messages, **kwargs}) + result = self._call( + lambda: self._messages.count_message_tokens_v1_messages_count_tokens_post(request), + ) + return cast("CountTokensResponse", result) + # -- Embeddings --------------------------------------------------------- def embedding( diff --git a/tests/unit/test_async_client.py b/tests/unit/test_async_client.py index cf5bd96..d7c1245 100644 --- a/tests/unit/test_async_client.py +++ b/tests/unit/test_async_client.py @@ -28,6 +28,7 @@ ) from tests.unit.test_client import ( CHAT_RESPONSE, + COUNT_TOKENS_RESPONSE, EMBEDDING_RESPONSE, MESSAGE_RESPONSE, MODELS_RESPONSE, @@ -87,6 +88,15 @@ async def test_message_returns_typed(self, mock_rest: Any) -> None: assert result.id == "msg-1" assert mock.last.url.endswith("/v1/messages") + async def test_count_tokens_returns_typed(self, mock_rest: Any) -> None: + mock = mock_rest(status=200, body=COUNT_TOKENS_RESPONSE) + client = AsyncOtariClient(api_base="http://localhost:8000", api_key="vk") + result = await client.count_tokens( + model="anthropic:claude", messages=[{"role": "user", "content": "Hi"}] + ) + assert result.input_tokens == 42 + assert mock.last.url.endswith("/v1/messages/count_tokens") + async def test_list_models_returns_typed(self, mock_rest: Any) -> None: mock_rest(status=200, body=MODELS_RESPONSE) client = AsyncOtariClient(api_base="http://localhost:8000", api_key="vk") diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index e1a839e..bf0027a 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -64,6 +64,8 @@ "usage": {"input_tokens": 1, "output_tokens": 1}, } +COUNT_TOKENS_RESPONSE: dict[str, Any] = {"input_tokens": 42} + MODERATION_RESPONSE: dict[str, Any] = { "id": "modr-1", "model": "openai:omni-moderation-latest", @@ -213,6 +215,19 @@ def test_returns_typed_message_response(self, mock_rest: Any) -> None: assert body["max_tokens"] == 64 assert body["model"] == "anthropic:claude-3-5-sonnet" + def test_count_tokens_returns_typed_response(self, mock_rest: Any) -> None: + mock = mock_rest(status=200, body=COUNT_TOKENS_RESPONSE) + client = OtariClient(api_base="http://localhost:8000", api_key="vk") + result = client.count_tokens( + model="anthropic:claude-3-5-sonnet", + messages=[{"role": "user", "content": "Hi"}], + ) + assert result.input_tokens == 42 + assert mock.last.url.endswith("/v1/messages/count_tokens") + body = mock.last.json_body + assert body["model"] == "anthropic:claude-3-5-sonnet" + assert "max_tokens" not in body + class TestModeration: def test_returns_typed_moderation(self, mock_rest: Any) -> None: