From bdbab753d572a40004210beb4ec50f2dbd44c249 Mon Sep 17 00:00:00 2001 From: Mingrui Han Date: Tue, 30 Jun 2026 14:45:34 -0400 Subject: [PATCH] Fix type safety on Chunk and added model test files --- ...nai-compatible-chat-language-model.test.ts | 246 ++++++++++++++++++ .../openai-compatible-chat-language-model.ts | 28 +- 2 files changed, 263 insertions(+), 11 deletions(-) create mode 100644 packages/llm-providers/src/openai-compatible/chat/openai-compatible-chat-language-model.test.ts diff --git a/packages/llm-providers/src/openai-compatible/chat/openai-compatible-chat-language-model.test.ts b/packages/llm-providers/src/openai-compatible/chat/openai-compatible-chat-language-model.test.ts new file mode 100644 index 0000000000..3d14cb84af --- /dev/null +++ b/packages/llm-providers/src/openai-compatible/chat/openai-compatible-chat-language-model.test.ts @@ -0,0 +1,246 @@ +import { describe, expect, it } from 'bun:test' + +import { OpenAICompatibleChatLanguageModel } from './openai-compatible-chat-language-model' + +import type { LanguageModelV2StreamPart } from '@ai-sdk/provider' +import type { FetchFunction } from '@ai-sdk/provider-utils' + +/** Build a fake SSE Response the way `createEventSourceResponseHandler` expects: + * each event is one `data: \n\n` line, finished with `data: [DONE]\n\n`. */ +function sseResponse(events: object[]): Response { + const encoder = new TextEncoder() + const stream = new ReadableStream({ + start(controller) { + for (const event of events) { + controller.enqueue( + encoder.encode(`data: ${JSON.stringify(event)}\n\n`), + ) + } + controller.enqueue(encoder.encode('data: [DONE]\n\n')) + controller.close() + }, + }) + return new Response(stream, { + status: 200, + headers: { 'content-type': 'text/event-stream' }, + }) +} + +async function readAllParts( + stream: ReadableStream, +): Promise { + const reader = stream.getReader() + const parts: LanguageModelV2StreamPart[] = [] + while (true) { + const { done, value } = await reader.read() + if (done) break + parts.push(value) + } + return parts +} + +/** Build a `typeof fetch`-shaped mock by attaching a no-op `preconnect`. + * Bun's `FetchFunction` resolves to `typeof globalThis.fetch`, whose TypeScript + * type also exposes static members like `preconnect` — a plain async function + * doesn't satisfy that shape under strict TS, hence the shim. */ +function mockFetch(events: object[]): FetchFunction { + const stub = ((..._args: Parameters) => + Promise.resolve(sseResponse(events))) as typeof fetch + return Object.assign(stub, { preconnect: () => {} }) as FetchFunction +} + +function makeModel(fetch: FetchFunction) { + return new OpenAICompatibleChatLanguageModel('test-model', { + provider: 'test.chat', + url: ({ modelId, path }) => + `https://example.test${path}?model=${encodeURIComponent(modelId)}`, + headers: () => ({}), + fetch, + }) +} + +describe('OpenAICompatibleChatLanguageModel.doStream', () => { + it('emits text-delta stream parts and a finish part for valid chunks', async () => { + const model = makeModel( + mockFetch([ + { + id: 'chunk-1', + model: 'test-model', + choices: [ + { + index: 0, + delta: { role: 'assistant', content: 'Hello' }, + finish_reason: null, + }, + ], + }, + { + id: 'chunk-2', + model: 'test-model', + choices: [ + { + index: 0, + delta: { content: ' world' }, + finish_reason: null, + }, + ], + }, + { + id: 'chunk-3', + model: 'test-model', + choices: [ + { + index: 0, + delta: {}, + finish_reason: 'stop', + }, + ], + usage: { + prompt_tokens: 1, + completion_tokens: 2, + total_tokens: 3, + }, + }, + ]), + ) + const { stream } = await model.doStream({ + prompt: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], + }) + + const parts = await readAllParts(stream) + + // First emitted stream part must be a stream-start. + expect(parts[0].type).toBe('stream-start') + + // Concatenate every text-delta — recovers the full text from a stream split + // across two chunks of delta.content. + const textDeltas = parts.filter( + (p): p is Extract => + p.type === 'text-delta', + ) + expect(textDeltas.map((p) => p.delta).join('')).toBe('Hello world') + + // The finish part must carry the OpenAI "stop" reason and the parsed usage. + const finishes = parts.filter( + (p): p is Extract => + p.type === 'finish', + ) + expect(finishes.length).toBe(1) + expect(finishes[0].finishReason).toBe('stop') + expect(finishes[0].usage.outputTokens).toBe(2) + expect(finishes[0].usage.inputTokens).toBe(1) + expect(finishes[0].usage.totalTokens).toBe(3) + }) + + it('emits an error stream part when the server sends an error chunk', async () => { + const model = makeModel( + mockFetch([ + { + error: { + message: 'rate limited', + type: 'rate_limit_error', + code: 'rate_limit', + }, + }, + ]), + ) + const { stream } = await model.doStream({ + prompt: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], + }) + + const parts = await readAllParts(stream) + + // The transform branch `'error' in value` should forward the message into + // an `{ type: 'error', error: }` stream part and set finishReason + // to 'error'. This is the second arm of the chunk-schema union whose + // type safety the original "MUST FIX" TODO was about. + const errorParts = parts.filter( + (p): p is Extract => + p.type === 'error', + ) + expect(errorParts.length).toBe(1) + expect(errorParts[0].error).toBe('rate limited') + + const finishes = parts.filter( + (p): p is Extract => + p.type === 'finish', + ) + expect(finishes.length).toBe(1) + expect(finishes[0].finishReason).toBe('error') + }) + + it('emits tool-call stream parts when the server sends a streaming tool call', async () => { + const model = makeModel( + mockFetch([ + { + id: 'chunk-1', + model: 'test-model', + choices: [ + { + index: 0, + delta: { + role: 'assistant', + tool_calls: [ + { + index: 0, + id: 'call_1', + function: { name: 'search', arguments: '{"q":' }, + }, + ], + }, + finish_reason: null, + }, + ], + }, + { + id: 'chunk-2', + model: 'test-model', + choices: [ + { + index: 0, + delta: { + tool_calls: [ + { + index: 0, + function: { arguments: '"codebuff"}' }, + }, + ], + }, + finish_reason: 'tool_calls', + }, + ], + }, + ]), + ) + const { stream } = await model.doStream({ + prompt: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], + }) + + const parts = await readAllParts(stream) + + // Two-part streaming arguments get merged into a single tool-call part + // when the second chunk closes the JSON argument. + const toolInputStarts = parts.filter( + (p): p is Extract => + p.type === 'tool-input-start', + ) + expect(toolInputStarts.length).toBe(1) + expect(toolInputStarts[0].toolName).toBe('search') + + const toolCalls = parts.filter( + (p): p is Extract => + p.type === 'tool-call', + ) + expect(toolCalls.length).toBe(1) + expect(toolCalls[0].toolName).toBe('search') + expect(toolCalls[0].input).toBe('{"q":"codebuff"}') + + // OpenAI's wire-level 'tool_calls' is mapped to AI SDK's 'tool-calls'. + const finishes = parts.filter( + (p): p is Extract => + p.type === 'finish', + ) + expect(finishes.length).toBe(1) + expect(finishes[0].finishReason).toBe('tool-calls') + }) +}) diff --git a/packages/llm-providers/src/openai-compatible/chat/openai-compatible-chat-language-model.ts b/packages/llm-providers/src/openai-compatible/chat/openai-compatible-chat-language-model.ts index 7e49bfcadc..460340f26e 100644 --- a/packages/llm-providers/src/openai-compatible/chat/openai-compatible-chat-language-model.ts +++ b/packages/llm-providers/src/openai-compatible/chat/openai-compatible-chat-language-model.ts @@ -19,7 +19,10 @@ import { defaultOpenAICompatibleErrorStructure } from '../openai-compatible-erro import { prepareTools } from './openai-compatible-prepare-tools' import type { OpenAICompatibleChatModelId } from './openai-compatible-chat-options' -import type { ProviderErrorStructure } from '../openai-compatible-error' +import type { + OpenAICompatibleErrorData, + ProviderErrorStructure, +} from '../openai-compatible-error' import type { MetadataExtractor } from './openai-compatible-metadata-extractor' import type { APICallError, @@ -63,8 +66,8 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV2 { readonly modelId: OpenAICompatibleChatModelId private readonly config: OpenAICompatibleChatConfig + private readonly errorStructure: ProviderErrorStructure private readonly failedResponseHandler: ResponseHandler - private readonly chunkSchema // type inferred via constructor constructor( modelId: OpenAICompatibleChatModelId, @@ -74,12 +77,11 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV2 { this.config = config // initialize error handling: - const errorStructure = + this.errorStructure = config.errorStructure ?? defaultOpenAICompatibleErrorStructure - this.chunkSchema = createOpenAICompatibleChatChunkSchema( - errorStructure.errorSchema, + this.failedResponseHandler = createJsonErrorResponseHandler( + this.errorStructure, ) - this.failedResponseHandler = createJsonErrorResponseHandler(errorStructure) this.supportsStructuredOutputs = config.supportsStructuredOutputs ?? false } @@ -327,6 +329,13 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV2 { const metadataExtractor = this.config.metadataExtractor?.createStreamExtractor() + // Build chunkSchema here (not in the constructor) so the ERROR_SCHEMA + // generic is bound at the use site; otherwise `z.infer` + // degenerates and the transform callback sees `chunk: any`. + const chunkSchema = createOpenAICompatibleChatChunkSchema( + this.errorStructure.errorSchema, + ) + const { responseHeaders, value: response } = await postJsonToApi({ url: this.config.url({ path: '/chat/completions', @@ -335,9 +344,7 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV2 { headers: combineHeaders(this.config.headers(), options.headers), body, failedResponseHandler: this.failedResponseHandler, - successfulResponseHandler: createEventSourceResponseHandler( - this.chunkSchema, - ), + successfulResponseHandler: createEventSourceResponseHandler(chunkSchema), abortSignal: options.abortSignal, fetch: this.config.fetch, }) @@ -386,14 +393,13 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV2 { return { stream: response.pipeThrough( new TransformStream< - ParseResult>, + ParseResult>, LanguageModelV2StreamPart >({ start(controller) { controller.enqueue({ type: 'stream-start', warnings }) }, - // TODO we lost type safety on Chunk, most likely due to the error schema. MUST FIX transform(chunk, controller) { // Emit raw chunk if requested (before anything else) if (options.includeRawChunks) {