diff --git a/src/Fim/PendingRequest.php b/src/Fim/PendingRequest.php new file mode 100644 index 000000000..bc30efedc --- /dev/null +++ b/src/Fim/PendingRequest.php @@ -0,0 +1,113 @@ + */ + protected array $stop = []; + + public function withPrompt(string $prompt): self + { + $this->prompt = $prompt; + + return $this; + } + + public function withSuffix(?string $suffix): self + { + $this->suffix = $suffix; + + return $this; + } + + public function withMaxTokens(int $maxTokens): self + { + $this->maxTokens = $maxTokens; + + return $this; + } + + public function withTemperature(int|float $temperature): self + { + $this->temperature = $temperature; + + return $this; + } + + public function withTopP(int|float $topP): self + { + $this->topP = $topP; + + return $this; + } + + /** + * @param string|array $stop + */ + public function withStop(string|array $stop): self + { + $this->stop = is_string($stop) ? [$stop] : $stop; + + return $this; + } + + public function asText(): Response + { + $request = $this->toRequest(); + + try { + return $this->provider->fim($request); + } catch (RequestException $e) { + $this->provider->handleRequestException($request->model(), $e); + } + } + + /** + * @deprecated Use `asText` instead. + */ + public function generate(): Response + { + return $this->asText(); + } + + public function toRequest(): Request + { + return new Request( + model: $this->model, + providerKey: $this->providerKey(), + prompt: $this->prompt, + suffix: $this->suffix, + maxTokens: $this->maxTokens, + temperature: $this->temperature, + topP: $this->topP, + stop: $this->stop, + clientOptions: $this->clientOptions, + clientRetry: $this->clientRetry, + providerOptions: $this->providerOptions, + ); + } +} diff --git a/src/Fim/Request.php b/src/Fim/Request.php new file mode 100644 index 000000000..4e45f7942 --- /dev/null +++ b/src/Fim/Request.php @@ -0,0 +1,95 @@ + $stop + * @param array $clientOptions + * @param array $clientRetry + * @param array $providerOptions + */ + public function __construct( + protected string $model, + protected string $providerKey, + protected string $prompt, + protected ?string $suffix, + protected ?int $maxTokens, + protected int|float|null $temperature, + protected int|float|null $topP, + protected array $stop, + protected array $clientOptions, + protected array $clientRetry, + array $providerOptions = [], + ) { + $this->providerOptions = $providerOptions; + } + + public function model(): string + { + return $this->model; + } + + public function provider(): string + { + return $this->providerKey; + } + + public function prompt(): string + { + return $this->prompt; + } + + public function suffix(): ?string + { + return $this->suffix; + } + + public function maxTokens(): ?int + { + return $this->maxTokens; + } + + public function temperature(): int|float|null + { + return $this->temperature; + } + + public function topP(): int|float|null + { + return $this->topP; + } + + /** + * @return array + */ + public function stop(): array + { + return $this->stop; + } + + /** + * @return array + */ + public function clientOptions(): array + { + return $this->clientOptions; + } + + /** + * @return array + */ + public function clientRetry(): array + { + return $this->clientRetry; + } +} diff --git a/src/Fim/Response.php b/src/Fim/Response.php new file mode 100644 index 000000000..82cd528d2 --- /dev/null +++ b/src/Fim/Response.php @@ -0,0 +1,37 @@ + + */ +readonly class Response implements Arrayable +{ + public function __construct( + public string $text, + public FinishReason $finishReason, + public Usage $usage, + public Meta $meta, + ) {} + + /** + * @return array + */ + #[\Override] + public function toArray(): array + { + return [ + 'text' => $this->text, + 'finish_reason' => $this->finishReason->value, + 'usage' => $this->usage->toArray(), + 'meta' => $this->meta->toArray(), + ]; + } +} diff --git a/src/Prism.php b/src/Prism.php index 7ed3a4953..b070f2b90 100644 --- a/src/Prism.php +++ b/src/Prism.php @@ -7,6 +7,7 @@ use Illuminate\Support\Traits\Macroable; use Prism\Prism\Audio\PendingRequest as PendingAudioRequest; use Prism\Prism\Embeddings\PendingRequest as PendingEmbeddingRequest; +use Prism\Prism\Fim\PendingRequest as PendingFimRequest; use Prism\Prism\Images\PendingRequest as PendingImageRequest; use Prism\Prism\Moderation\PendingRequest as PendingModerationRequest; use Prism\Prism\Structured\PendingRequest as PendingStructuredRequest; @@ -45,4 +46,9 @@ public function moderation(): PendingModerationRequest { return new PendingModerationRequest; } + + public function fim(): PendingFimRequest + { + return new PendingFimRequest; + } } diff --git a/src/Providers/Mistral/Handlers/Fim.php b/src/Providers/Mistral/Handlers/Fim.php new file mode 100644 index 000000000..5a8531102 --- /dev/null +++ b/src/Providers/Mistral/Handlers/Fim.php @@ -0,0 +1,79 @@ +sendRequest($request); + + $this->validateResponse($response); + + $data = $response->json(); + + return new Response( + text: data_get($data, 'choices.0.message.content', ''), + finishReason: $this->mapFinishReason($data), + usage: new Usage( + data_get($data, 'usage.prompt_tokens', 0), + data_get($data, 'usage.completion_tokens', 0), + ), + meta: new Meta( + id: data_get($data, 'id', ''), + model: data_get($data, 'model', ''), + rateLimits: $this->processRateLimits($response), + ) + ); + } + + protected function sendRequest(Request $request): ClientResponse + { + /** @var ClientResponse $response */ + $response = $this->client->post( + 'fim/completions', + array_merge([ + 'model' => $request->model(), + 'prompt' => $request->prompt(), + ], Arr::whereNotNull([ + 'suffix' => $request->suffix(), + 'max_tokens' => $request->maxTokens(), + 'temperature' => $request->temperature(), + 'top_p' => $request->topP(), + 'stop' => empty($request->stop()) ? null : $request->stop(), + ])) + ); + + return $response; + } + + /** + * @param array $data + */ + protected function mapFinishReason(array $data): FinishReason + { + return match (data_get($data, 'choices.0.finish_reason')) { + 'stop' => FinishReason::Stop, + 'length' => FinishReason::Length, + default => FinishReason::Unknown, + }; + } +} diff --git a/src/Providers/Mistral/Mistral.php b/src/Providers/Mistral/Mistral.php index 9287e454e..fb75d1f81 100644 --- a/src/Providers/Mistral/Mistral.php +++ b/src/Providers/Mistral/Mistral.php @@ -17,9 +17,12 @@ use Prism\Prism\Exceptions\PrismProviderOverloadedException; use Prism\Prism\Exceptions\PrismRateLimitedException; use Prism\Prism\Exceptions\PrismRequestTooLargeException; +use Prism\Prism\Fim\Request as FimRequest; +use Prism\Prism\Fim\Response as FimResponse; use Prism\Prism\Providers\Mistral\Concerns\ProcessRateLimits; use Prism\Prism\Providers\Mistral\Handlers\Audio; use Prism\Prism\Providers\Mistral\Handlers\Embeddings; +use Prism\Prism\Providers\Mistral\Handlers\Fim; use Prism\Prism\Providers\Mistral\Handlers\OCR; use Prism\Prism\Providers\Mistral\Handlers\Stream; use Prism\Prism\Providers\Mistral\Handlers\Structured; @@ -66,6 +69,19 @@ public function structured(StructuredRequest $request): StructuredResponse return $handler->handle($request); } + #[\Override] + public function fim(FimRequest $request): FimResponse + { + $handler = new Fim( + $this->client( + $request->clientOptions(), + $request->clientRetry() + ) + ); + + return $handler->handle($request); + } + #[\Override] public function embeddings(EmbeddingRequest $request): EmbeddingResponse { diff --git a/src/Providers/Provider.php b/src/Providers/Provider.php index 3f70a1c9a..59724bf30 100644 --- a/src/Providers/Provider.php +++ b/src/Providers/Provider.php @@ -16,6 +16,8 @@ use Prism\Prism\Exceptions\PrismProviderOverloadedException; use Prism\Prism\Exceptions\PrismRateLimitedException; use Prism\Prism\Exceptions\PrismRequestTooLargeException; +use Prism\Prism\Fim\Request as FimRequest; +use Prism\Prism\Fim\Response as FimResponse; use Prism\Prism\Images\Request as ImagesRequest; use Prism\Prism\Images\Response as ImagesResponse; use Prism\Prism\Moderation\Request as ModerationRequest; @@ -63,6 +65,11 @@ public function speechToText(SpeechToTextRequest $request): SpeechToTextResponse throw PrismException::unsupportedProviderAction('speechToText', class_basename($this)); } + public function fim(FimRequest $request): FimResponse + { + throw PrismException::unsupportedProviderAction('fim', class_basename($this)); + } + /** * @return Generator */ diff --git a/tests/Fixtures/mistral/fim-completion-1.json b/tests/Fixtures/mistral/fim-completion-1.json new file mode 100644 index 000000000..d9fb2b367 --- /dev/null +++ b/tests/Fixtures/mistral/fim-completion-1.json @@ -0,0 +1,23 @@ +{ + "id": "447e3e0d457e42e98248b5d2ef52a2a3", + "object": "chat.completion", + "model": "codestral-2405", + "usage": { + "prompt_tokens": 8, + "completion_tokens": 91, + "total_tokens": 99 + }, + "created": 1759496862, + "choices": [ + { + "index": 0, + "message": { + "content": "return a+b", + "tool_calls": null, + "prefix": false, + "role": "assistant" + }, + "finish_reason": "stop" + } + ] +} diff --git a/tests/Providers/Mistral/MistralFimTest.php b/tests/Providers/Mistral/MistralFimTest.php new file mode 100644 index 000000000..e4c1a578d --- /dev/null +++ b/tests/Providers/Mistral/MistralFimTest.php @@ -0,0 +1,70 @@ +set('prism.providers.mistral.api_key', env('MISTRAL_API_KEY', 'sk-1234')); +}); + +describe('Fim completion', function (): void { + it('can generate fim completion', function (): void { + FixtureResponse::fakeResponseSequence('v1/fim/completions', 'mistral/fim-completion'); + + $response = Prism::fim() + ->using(Provider::Mistral, 'codestral-2405') + ->withPrompt('def add(a, b):') + ->withSuffix(' print("Done")') + ->asText(); + + expect($response->usage->promptTokens)->toBe(8); + expect($response->usage->completionTokens)->toBe(91); + expect($response->meta->id)->toBe('447e3e0d457e42e98248b5d2ef52a2a3'); + expect($response->meta->model)->toBe('codestral-2405'); + expect($response->text)->toBe('return a+b'); + expect($response->finishReason)->toBe(FinishReason::Stop); + + Http::assertSent(function (Request $request): true { + $data = $request->data(); + + expect($data['model'])->toBe('codestral-2405'); + expect($data['prompt'])->toBe('def add(a, b):'); + expect($data['suffix'])->toBe(' print("Done")'); + + return true; + }); + }); + + it('sets the rate limits on meta', function (): void { + $this->freezeTime(function (Carbon $time): void { + $time = $time->toImmutable(); + + FixtureResponse::fakeResponseSequence('v1/fim/completions', 'mistral/fim-completion', [ + 'ratelimitbysize-limit' => 500000, + 'ratelimitbysize-remaining' => 499900, + 'ratelimitbysize-reset' => 28, + ]); + + $response = Prism::fim() + ->using(Provider::Mistral, 'codestral-2405') + ->withPrompt('def mul(a, b):') + ->asText(); + + expect($response->meta->rateLimits[0])->toBeInstanceOf(ProviderRateLimit::class); + expect($response->meta->rateLimits[0]->name)->toEqual('tokens'); + expect($response->meta->rateLimits[0]->limit)->toEqual(500000); + expect($response->meta->rateLimits[0]->remaining)->toEqual(499900); + expect($response->meta->rateLimits[0]->resetsAt->equalTo($time->addSeconds(28)))->toBeTrue(); + }); + }); +});