[codex] add typed generate_text provider adapters#145
[codex] add typed generate_text provider adapters#145Hynek Kydlíček (hynky1999) wants to merge 4 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new generate_text API for typed, multimodal inference, adding native support for Google Gemini, Anthropic Messages, and OpenAI Responses APIs. The changes include a message conversion layer to translate canonical Refiner messages into provider-specific formats, along with new client implementations and updated documentation. Feedback identifies a potential issue where sending providerOptions as a top-level key to OpenAI-compatible endpoints could cause request failures. Additionally, a suggestion was made to reduce code duplication in the Google client by using a shared HTTP helper function.
| if providerOptions is not None and not isinstance( | ||
| provider, | ||
| GoogleEndpointProvider | ||
| | AnthropicEndpointProvider | ||
| | OpenAIResponsesProvider, | ||
| ): | ||
| payload["providerOptions"] = providerOptions |
There was a problem hiding this comment.
For OpenAIEndpointProvider and VLLMProvider, including providerOptions as a top-level key in the request payload is likely to cause 400 Bad Request errors from most OpenAI-compatible endpoints, as they typically do not recognize this field. Since the relevant options (like reasoningEffort) are already extracted and normalized into the payload in previous steps (lines 93-94, 116-117), this assignment should be removed for these providers.
| async def generate_text(self, payload: Mapping[str, Any]) -> InferenceResponse: | ||
| response_json = await self._post_json( | ||
| f"{_google_model_path(self.model)}:generateContent", | ||
| payload, | ||
| operation="google generation", | ||
| ) | ||
| if not isinstance(response_json, Mapping): | ||
| raise RuntimeError("google generation response must be a JSON object") | ||
| return _parse_google_inference_response(response_json) | ||
|
|
||
| async def _post_json( | ||
| self, | ||
| endpoint_path: str, | ||
| payload: Mapping[str, Any], | ||
| *, | ||
| operation: str, | ||
| ) -> Any: | ||
| client = self._ensure_client() | ||
| for attempt in range(_OPENAI_ENDPOINT_MAX_RETRIES): | ||
| try: | ||
| response = await client.post(endpoint_path, json=dict(payload)) | ||
| break | ||
| except ( | ||
| ConnectionError, | ||
| OSError, | ||
| asyncio.TimeoutError, | ||
| httpx.NetworkError, | ||
| httpx.TimeoutException, | ||
| ) as err: | ||
| if attempt + 1 >= _OPENAI_ENDPOINT_MAX_RETRIES: | ||
| message = ( | ||
| f"{operation} request failed after " | ||
| f"{_OPENAI_ENDPOINT_MAX_RETRIES} attempts: " | ||
| f"{type(err).__name__}: {err}" | ||
| ) | ||
| raise RuntimeError(message) from err | ||
| await asyncio.sleep(_retry_delay_seconds(attempt)) | ||
| else: | ||
| raise RuntimeError(f"{operation} request failed without a response") | ||
| try: | ||
| response.raise_for_status() | ||
| except httpx.HTTPStatusError as err: | ||
| detail = "" | ||
| try: | ||
| detail = str(err.response.json()) | ||
| except ValueError: | ||
| detail = err.response.text.strip() | ||
| message = f"{operation} request failed with HTTP {err.response.status_code}" | ||
| if detail: | ||
| message = f"{message}: {detail}" | ||
| raise RuntimeError(message) from err | ||
| return response.json() |
There was a problem hiding this comment.
The _post_json method in _GoogleEndpointClient is identical to the _post_json_with_retries helper function defined later in this file. To improve maintainability and reduce code duplication, _GoogleEndpointClient should use the helper function.
| async def generate_text(self, payload: Mapping[str, Any]) -> InferenceResponse: | |
| response_json = await self._post_json( | |
| f"{_google_model_path(self.model)}:generateContent", | |
| payload, | |
| operation="google generation", | |
| ) | |
| if not isinstance(response_json, Mapping): | |
| raise RuntimeError("google generation response must be a JSON object") | |
| return _parse_google_inference_response(response_json) | |
| async def _post_json( | |
| self, | |
| endpoint_path: str, | |
| payload: Mapping[str, Any], | |
| *, | |
| operation: str, | |
| ) -> Any: | |
| client = self._ensure_client() | |
| for attempt in range(_OPENAI_ENDPOINT_MAX_RETRIES): | |
| try: | |
| response = await client.post(endpoint_path, json=dict(payload)) | |
| break | |
| except ( | |
| ConnectionError, | |
| OSError, | |
| asyncio.TimeoutError, | |
| httpx.NetworkError, | |
| httpx.TimeoutException, | |
| ) as err: | |
| if attempt + 1 >= _OPENAI_ENDPOINT_MAX_RETRIES: | |
| message = ( | |
| f"{operation} request failed after " | |
| f"{_OPENAI_ENDPOINT_MAX_RETRIES} attempts: " | |
| f"{type(err).__name__}: {err}" | |
| ) | |
| raise RuntimeError(message) from err | |
| await asyncio.sleep(_retry_delay_seconds(attempt)) | |
| else: | |
| raise RuntimeError(f"{operation} request failed without a response") | |
| try: | |
| response.raise_for_status() | |
| except httpx.HTTPStatusError as err: | |
| detail = "" | |
| try: | |
| detail = str(err.response.json()) | |
| except ValueError: | |
| detail = err.response.text.strip() | |
| message = f"{operation} request failed with HTTP {err.response.status_code}" | |
| if detail: | |
| message = f"{message}: {detail}" | |
| raise RuntimeError(message) from err | |
| return response.json() | |
| async def generate_text(self, payload: Mapping[str, Any]) -> InferenceResponse: | |
| response_json = await _post_json_with_retries( | |
| self._ensure_client(), | |
| f"{_google_model_path(self.model)}:generateContent", | |
| payload, | |
| operation="google generation", | |
| ) | |
| if not isinstance(response_json, Mapping): | |
| raise RuntimeError("google generation response must be a JSON object") | |
| return _parse_google_inference_response(response_json) |
Purpose
Add a Vercel-style typed
generate_textinference surface while preserving the existing rawgenerateescape hatch.Changes
TypedDictmessage/content-part types for text, image, and file inputs.inlineDatafor video bytes.providerOptionsusage.Validation
uv run pytest tests/test_inference.pyuv run ty checkuv run ruff check --force-exclude src/refiner/inference tests/test_inference.py docs/inference.mduv run pytest(644 passed)