Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions claude/src/main/scala/sttp/ai/claude/agent/ClaudeAgent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import sttp.monad.IdentityMonad
private[claude] class ClaudeAgentBackend[F[_]](
client: ClaudeClient,
modelName: String,
val tools: Seq[AgentTool[_]],
val tools: Seq[AgentTool[F, _]],
val systemPrompt: Option[String],
responseSchema: Option[ResponseSchema[_]]
)(implicit monad: sttp.monad.MonadError[F])
Expand All @@ -27,7 +27,7 @@ private[claude] class ClaudeAgentBackend[F[_]](
private val outputConfig: Option[OutputConfig] =
responseSchema.map(rs => OutputConfig(format = Some(OutputFormat.JsonSchema(rs.schema))))

private def convertTool(tool: AgentTool[_]): Tool = {
private def convertTool(tool: AgentTool[F, _]): Tool = {
val schemaCursor = tool.jsonSchema.asJson.hcursor

val properties = schemaCursor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ClaudeAgentIntegrationSpec extends AgentIntegrationSpecBase {
override def providerName: String = "Claude"
override def apiKeyEnvVar: String = "ANTHROPIC_API_KEY"

override def createAgent(maxIterations: Int, tools: Seq[AgentTool[_]]): Agent[Identity] = {
override def createAgent(maxIterations: Int, tools: Seq[AgentTool[Identity, _]]): Agent[Identity] = {
val config = ClaudeConfig.fromEnv
val client = ClaudeClient(config)
val agentConfig = AgentConfig[Identity](maxIterations = maxIterations, userTools = tools)
Expand All @@ -29,7 +29,7 @@ class ClaudeAgentIntegrationSpec extends AgentIntegrationSpecBase {

override def createTypedAgent[T](
maxIterations: Int,
tools: Seq[AgentTool[_]],
tools: Seq[AgentTool[Identity, _]],
responseSchema: ResponseSchema[T]
): Agent[Identity] =
ClaudeAgent
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/sttp/ai/core/agent/Agent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class Agent[F[_]](
}
}

private def executeTool[T](tool: AgentTool[T], toolCall: ToolCall): F[String] =
private def executeTool[T](tool: AgentTool[F, T], toolCall: ToolCall): F[String] =
monad
.eval(decode[T](toolCall.input)(tool.codec).fold(throw _, identity))
.map[Either[String, T]](Right(_))
Expand All @@ -134,7 +134,7 @@ class Agent[F[_]](
.flatMap {
case Left(errorMessage) => monad.unit(errorMessage)
case Right(typedInput) =>
monad.eval(tool.execute(typedInput)).handleError { case e: Exception =>
tool.execute(typedInput).handleError { case e: Exception =>
config.exceptionHandler.handleToolException(toolCall.toolName, e) match {
case Left(errorMessage) => monad.unit(errorMessage)
case Right(ex) => monad.error(ex)
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/sttp/ai/core/agent/AgentBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import sttp.client4.Backend
trait AgentBackend[F[_]] {

/** The tools available to the agent */
def tools: Seq[AgentTool[_]]
def tools: Seq[AgentTool[F, _]]

/** Optional system prompt to guide the agent's behavior */
def systemPrompt: Option[String]
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/sttp/ai/core/agent/AgentBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ final class AgentBuilder[F[_]] private (

def systemPrompt(prompt: String): AgentBuilder[F] = systemPrompt(_ => prompt)

def tools(values: Seq[AgentTool[_]]): AgentBuilder[F] = withConfig(config.copy(userTools = values))
def tools(values: Seq[AgentTool[F, _]]): AgentBuilder[F] = withConfig(config.copy(userTools = values))

def tools(first: AgentTool[_], rest: AgentTool[_]*): AgentBuilder[F] = tools(first +: rest)
def tools(first: AgentTool[F, _], rest: AgentTool[F, _]*): AgentBuilder[F] = tools(first +: rest)

def addTool(tool: AgentTool[_]): AgentBuilder[F] = withConfig(config.copy(userTools = config.userTools :+ tool))
def addTool(tool: AgentTool[F, _]): AgentBuilder[F] = withConfig(config.copy(userTools = config.userTools :+ tool))

def exceptionHandler(handler: ExceptionHandler): AgentBuilder[F] = withConfig(config.copy(exceptionHandler = handler))

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/sttp/ai/core/agent/AgentConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import sttp.ai.core.agent.AgentConfig.SystemPromptParameters
case class AgentConfig[F[_]](
maxIterations: Int = 10,
systemPromptBuilder: Option[SystemPromptParameters => String] = Some(AgentConfig.buildSystemPrompt),
userTools: Seq[AgentTool[_]] = Seq.empty,
userTools: Seq[AgentTool[F, _]] = Seq.empty[AgentTool[F, _]],
exceptionHandler: ExceptionHandler = ExceptionHandler.default,
responseSchema: Option[ResponseSchema[_]] = None,
beforeToolCall: Option[ToolCall => F[Unit]] = None,
Expand Down
33 changes: 24 additions & 9 deletions core/src/main/scala/sttp/ai/core/agent/AgentTool.scala
Original file line number Diff line number Diff line change
@@ -1,42 +1,57 @@
package sttp.ai.core.agent

import io.circe.{Codec, Decoder, Encoder, Json}
import io.circe.{Codec, Json}
import sttp.apispec.Schema
import sttp.shared.Identity
import sttp.tapir.{Schema => TapirSchema}
import sttp.tapir.docs.apispec.schema.TapirSchemaToJsonSchema

trait AgentTool[T] {
trait AgentTool[F[_], T] {
def name: String
def description: String
def jsonSchema: Schema
def codec: Codec[T]
def execute(input: T): String
def execute(input: T): F[String]
}

object AgentTool {

def fromFunction[T](
toolName: String,
toolDescription: String
)(f: T => String)(implicit tapirSchema: TapirSchema[T], toolCodec: Codec[T]): AgentTool[T] =
new AgentTool[T] {
)(f: T => String)(implicit tapirSchema: TapirSchema[T], toolCodec: Codec[T]): AgentTool[Identity, T] =
fromFunctionF[Identity, T](toolName, toolDescription)(f)

def fromFunctionF[F[_], T](
toolName: String,
toolDescription: String
)(f: T => F[String])(implicit tapirSchema: TapirSchema[T], toolCodec: Codec[T]): AgentTool[F, T] =
new AgentTool[F, T] {
override def name: String = toolName
override def description: String = toolDescription
override def jsonSchema: Schema =
TapirSchemaToJsonSchema(tapirSchema, markOptionsAsNullable = true)
override def codec: Codec[T] = toolCodec
override def execute(input: T): String = f(input)
override def execute(input: T): F[String] = f(input)
}

def dynamic(
toolName: String,
toolDescription: String,
toolSchema: Schema
)(f: Map[String, Json] => String): AgentTool[Map[String, Json]] =
new AgentTool[Map[String, Json]] {
)(f: Map[String, Json] => String): AgentTool[Identity, Map[String, Json]] =
dynamicF[Identity](toolName, toolDescription, toolSchema)(f)

def dynamicF[F[_]](
toolName: String,
toolDescription: String,
toolSchema: Schema
)(f: Map[String, Json] => F[String]): AgentTool[F, Map[String, Json]] =
new AgentTool[F, Map[String, Json]] {
override def name: String = toolName
override def description: String = toolDescription
override def jsonSchema: Schema = toolSchema
override def codec: Codec[Map[String, Json]] = Codec.implied
override def execute(input: Map[String, Json]): String = f(input)
override def execute(input: Map[String, Json]): F[String] = f(input)
}
}
2 changes: 1 addition & 1 deletion core/src/test/scala/sttp/ai/core/agent/AgentSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class AgentSpec extends AnyFlatSpec with Matchers with OptionValues {
private var callCount = 0
var receivedHistories: Seq[ConversationHistory] = Seq.empty

override def tools: Seq[AgentTool[_]] = Seq.empty
override def tools: Seq[AgentTool[Identity, _]] = Seq.empty
override def systemPrompt: Option[String] = None

override def sendRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ abstract class AgentIntegrationSpecBase extends AnyFlatSpec with Matchers {

def providerName: String
def apiKeyEnvVar: String
def createAgent(maxIterations: Int, tools: Seq[AgentTool[_]]): Agent[Identity]
def createAgent(maxIterations: Int, tools: Seq[AgentTool[Identity, _]]): Agent[Identity]

def createTypedAgent[T](
maxIterations: Int,
tools: Seq[AgentTool[_]],
tools: Seq[AgentTool[Identity, _]],
responseSchema: ResponseSchema[T]
): Agent[Identity] =
cancel(s"$providerName typed agent factory not implemented for this spec")
Expand All @@ -29,7 +29,7 @@ abstract class AgentIntegrationSpecBase extends AnyFlatSpec with Matchers {
implicit val calculatorInputCodec: Codec[CalculatorInput] = deriveCodec
implicit val calculatorInputSchema: Schema[CalculatorInput] = Schema.derived

protected val calculatorTool: AgentTool[CalculatorInput] = AgentTool.fromFunction(
protected val calculatorTool: AgentTool[Identity, CalculatorInput] = AgentTool.fromFunction(
"calculator",
"Perform basic arithmetic operations (one of: `add`, `subtract`, `multiply`, `divide`)"
) { (input: CalculatorInput) =>
Expand All @@ -47,7 +47,7 @@ abstract class AgentIntegrationSpecBase extends AnyFlatSpec with Matchers {
implicit val weatherInputCodec: Codec[WeatherInput] = deriveCodec
implicit val weatherInputSchema: Schema[WeatherInput] = Schema.derived

protected val weatherTool: AgentTool[WeatherInput] = AgentTool.fromFunction(
protected val weatherTool: AgentTool[Identity, WeatherInput] = AgentTool.fromFunction(
"get_weather",
"Get current weather for a city"
) { (input: WeatherInput) =>
Expand Down Expand Up @@ -84,7 +84,7 @@ abstract class AgentIntegrationSpecBase extends AnyFlatSpec with Matchers {
s"Should have at least $min iterations, but had ${result.iterations}"
)

def withAgent[T](maxIter: Int, tools: Seq[AgentTool[_]])(test: (Agent[Identity], Backend[Identity]) => T): T = {
def withAgent[T](maxIter: Int, tools: Seq[AgentTool[Identity, _]])(test: (Agent[Identity], Backend[Identity]) => T): T = {
if (maybeApiKey.isEmpty) {
cancel(s"$apiKeyEnvVar not defined - skipping integration test")
}
Expand Down
4 changes: 2 additions & 2 deletions openai/src/main/scala/sttp/ai/openai/agent/OpenAIAgent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import sttp.monad.IdentityMonad
private[openai] class OpenAIAgentBackend[F[_]](
openAI: OpenAI,
modelName: String,
val tools: Seq[AgentTool[_]],
val tools: Seq[AgentTool[F, _]],
val systemPrompt: Option[String],
responseSchema: Option[ResponseSchema[_]]
)(implicit monad: sttp.monad.MonadError[F])
Expand All @@ -29,7 +29,7 @@ private[openai] class OpenAIAgentBackend[F[_]](
)
}

private def convertTool(tool: AgentTool[_]): Tool.Function = {
private def convertTool(tool: AgentTool[F, _]): Tool.Function = {
val schema = tool.jsonSchema
val schemaJson = SchemaSupport.schemaCodec(schema)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class OpenAIAgentIntegrationSpec extends AgentIntegrationSpecBase {
override def providerName: String = "OpenAI"
override def apiKeyEnvVar: String = "OPENAI_API_KEY"

override def createAgent(maxIterations: Int, tools: Seq[AgentTool[_]]): Agent[Identity] = {
override def createAgent(maxIterations: Int, tools: Seq[AgentTool[Identity, _]]): Agent[Identity] = {
val openai = OpenAI.fromEnv
val agentConfig = AgentConfig[Identity](maxIterations = maxIterations, userTools = tools)
val agentBackend = new OpenAIAgentBackend[Identity](
Expand All @@ -27,7 +27,7 @@ class OpenAIAgentIntegrationSpec extends AgentIntegrationSpecBase {

override def createTypedAgent[T](
maxIterations: Int,
tools: Seq[AgentTool[_]],
tools: Seq[AgentTool[Identity, _]],
responseSchema: ResponseSchema[T]
): Agent[Identity] = {
val openai = OpenAI.fromEnv
Expand Down