diff --git a/claude/src/main/scala/sttp/ai/claude/agent/ClaudeAgent.scala b/claude/src/main/scala/sttp/ai/claude/agent/ClaudeAgent.scala index c7631894..dae48ff8 100644 --- a/claude/src/main/scala/sttp/ai/claude/agent/ClaudeAgent.scala +++ b/claude/src/main/scala/sttp/ai/claude/agent/ClaudeAgent.scala @@ -16,7 +16,7 @@ import sttp.monad.IdentityMonad private[claude] class ClaudeAgentBackend[F[_]]( client: ClaudeClient, modelName: String, - val tools: Seq[AgentTool[_]], + val tools: Seq[AgentTool[F, _]], val systemPrompt: Option[String], responseSchema: Option[ResponseSchema[_]] )(implicit monad: sttp.monad.MonadError[F]) @@ -27,7 +27,7 @@ private[claude] class ClaudeAgentBackend[F[_]]( private val outputConfig: Option[OutputConfig] = responseSchema.map(rs => OutputConfig(format = Some(OutputFormat.JsonSchema(rs.schema)))) - private def convertTool(tool: AgentTool[_]): Tool = { + private def convertTool(tool: AgentTool[F, _]): Tool = { val schemaCursor = tool.jsonSchema.asJson.hcursor val properties = schemaCursor diff --git a/claude/src/test/scala/sttp/ai/claude/integration/ClaudeAgentIntegrationSpec.scala b/claude/src/test/scala/sttp/ai/claude/integration/ClaudeAgentIntegrationSpec.scala index ff04f9b2..753f4534 100644 --- a/claude/src/test/scala/sttp/ai/claude/integration/ClaudeAgentIntegrationSpec.scala +++ b/claude/src/test/scala/sttp/ai/claude/integration/ClaudeAgentIntegrationSpec.scala @@ -13,7 +13,7 @@ class ClaudeAgentIntegrationSpec extends AgentIntegrationSpecBase { override def providerName: String = "Claude" override def apiKeyEnvVar: String = "ANTHROPIC_API_KEY" - override def createAgent(maxIterations: Int, tools: Seq[AgentTool[_]]): Agent[Identity] = { + override def createAgent(maxIterations: Int, tools: Seq[AgentTool[Identity, _]]): Agent[Identity] = { val config = ClaudeConfig.fromEnv val client = ClaudeClient(config) val agentConfig = AgentConfig[Identity](maxIterations = maxIterations, userTools = tools) @@ -29,7 +29,7 @@ class ClaudeAgentIntegrationSpec extends AgentIntegrationSpecBase { override def createTypedAgent[T]( maxIterations: Int, - tools: Seq[AgentTool[_]], + tools: Seq[AgentTool[Identity, _]], responseSchema: ResponseSchema[T] ): Agent[Identity] = ClaudeAgent diff --git a/core/src/main/scala/sttp/ai/core/agent/Agent.scala b/core/src/main/scala/sttp/ai/core/agent/Agent.scala index 8da1874c..ea3eccc8 100644 --- a/core/src/main/scala/sttp/ai/core/agent/Agent.scala +++ b/core/src/main/scala/sttp/ai/core/agent/Agent.scala @@ -121,7 +121,7 @@ class Agent[F[_]]( } } - private def executeTool[T](tool: AgentTool[T], toolCall: ToolCall): F[String] = + private def executeTool[T](tool: AgentTool[F, T], toolCall: ToolCall): F[String] = monad .eval(decode[T](toolCall.input)(tool.codec).fold(throw _, identity)) .map[Either[String, T]](Right(_)) @@ -134,7 +134,7 @@ class Agent[F[_]]( .flatMap { case Left(errorMessage) => monad.unit(errorMessage) case Right(typedInput) => - monad.eval(tool.execute(typedInput)).handleError { case e: Exception => + tool.execute(typedInput).handleError { case e: Exception => config.exceptionHandler.handleToolException(toolCall.toolName, e) match { case Left(errorMessage) => monad.unit(errorMessage) case Right(ex) => monad.error(ex) diff --git a/core/src/main/scala/sttp/ai/core/agent/AgentBackend.scala b/core/src/main/scala/sttp/ai/core/agent/AgentBackend.scala index 519826d2..40e0977c 100644 --- a/core/src/main/scala/sttp/ai/core/agent/AgentBackend.scala +++ b/core/src/main/scala/sttp/ai/core/agent/AgentBackend.scala @@ -17,7 +17,7 @@ import sttp.client4.Backend trait AgentBackend[F[_]] { /** The tools available to the agent */ - def tools: Seq[AgentTool[_]] + def tools: Seq[AgentTool[F, _]] /** Optional system prompt to guide the agent's behavior */ def systemPrompt: Option[String] diff --git a/core/src/main/scala/sttp/ai/core/agent/AgentBuilder.scala b/core/src/main/scala/sttp/ai/core/agent/AgentBuilder.scala index 646e015f..df80cfe9 100644 --- a/core/src/main/scala/sttp/ai/core/agent/AgentBuilder.scala +++ b/core/src/main/scala/sttp/ai/core/agent/AgentBuilder.scala @@ -20,11 +20,11 @@ final class AgentBuilder[F[_]] private ( def systemPrompt(prompt: String): AgentBuilder[F] = systemPrompt(_ => prompt) - def tools(values: Seq[AgentTool[_]]): AgentBuilder[F] = withConfig(config.copy(userTools = values)) + def tools(values: Seq[AgentTool[F, _]]): AgentBuilder[F] = withConfig(config.copy(userTools = values)) - def tools(first: AgentTool[_], rest: AgentTool[_]*): AgentBuilder[F] = tools(first +: rest) + def tools(first: AgentTool[F, _], rest: AgentTool[F, _]*): AgentBuilder[F] = tools(first +: rest) - def addTool(tool: AgentTool[_]): AgentBuilder[F] = withConfig(config.copy(userTools = config.userTools :+ tool)) + def addTool(tool: AgentTool[F, _]): AgentBuilder[F] = withConfig(config.copy(userTools = config.userTools :+ tool)) def exceptionHandler(handler: ExceptionHandler): AgentBuilder[F] = withConfig(config.copy(exceptionHandler = handler)) diff --git a/core/src/main/scala/sttp/ai/core/agent/AgentConfig.scala b/core/src/main/scala/sttp/ai/core/agent/AgentConfig.scala index 05c45ee2..cab48eb7 100644 --- a/core/src/main/scala/sttp/ai/core/agent/AgentConfig.scala +++ b/core/src/main/scala/sttp/ai/core/agent/AgentConfig.scala @@ -5,7 +5,7 @@ import sttp.ai.core.agent.AgentConfig.SystemPromptParameters case class AgentConfig[F[_]]( maxIterations: Int = 10, systemPromptBuilder: Option[SystemPromptParameters => String] = Some(AgentConfig.buildSystemPrompt), - userTools: Seq[AgentTool[_]] = Seq.empty, + userTools: Seq[AgentTool[F, _]] = Seq.empty[AgentTool[F, _]], exceptionHandler: ExceptionHandler = ExceptionHandler.default, responseSchema: Option[ResponseSchema[_]] = None, beforeToolCall: Option[ToolCall => F[Unit]] = None, diff --git a/core/src/main/scala/sttp/ai/core/agent/AgentTool.scala b/core/src/main/scala/sttp/ai/core/agent/AgentTool.scala index e7be099d..41298681 100644 --- a/core/src/main/scala/sttp/ai/core/agent/AgentTool.scala +++ b/core/src/main/scala/sttp/ai/core/agent/AgentTool.scala @@ -1,42 +1,57 @@ package sttp.ai.core.agent -import io.circe.{Codec, Decoder, Encoder, Json} +import io.circe.{Codec, Json} import sttp.apispec.Schema +import sttp.shared.Identity import sttp.tapir.{Schema => TapirSchema} import sttp.tapir.docs.apispec.schema.TapirSchemaToJsonSchema -trait AgentTool[T] { +trait AgentTool[F[_], T] { def name: String def description: String def jsonSchema: Schema def codec: Codec[T] - def execute(input: T): String + def execute(input: T): F[String] } object AgentTool { + def fromFunction[T]( toolName: String, toolDescription: String - )(f: T => String)(implicit tapirSchema: TapirSchema[T], toolCodec: Codec[T]): AgentTool[T] = - new AgentTool[T] { + )(f: T => String)(implicit tapirSchema: TapirSchema[T], toolCodec: Codec[T]): AgentTool[Identity, T] = + fromFunctionF[Identity, T](toolName, toolDescription)(f) + + def fromFunctionF[F[_], T]( + toolName: String, + toolDescription: String + )(f: T => F[String])(implicit tapirSchema: TapirSchema[T], toolCodec: Codec[T]): AgentTool[F, T] = + new AgentTool[F, T] { override def name: String = toolName override def description: String = toolDescription override def jsonSchema: Schema = TapirSchemaToJsonSchema(tapirSchema, markOptionsAsNullable = true) override def codec: Codec[T] = toolCodec - override def execute(input: T): String = f(input) + override def execute(input: T): F[String] = f(input) } def dynamic( toolName: String, toolDescription: String, toolSchema: Schema - )(f: Map[String, Json] => String): AgentTool[Map[String, Json]] = - new AgentTool[Map[String, Json]] { + )(f: Map[String, Json] => String): AgentTool[Identity, Map[String, Json]] = + dynamicF[Identity](toolName, toolDescription, toolSchema)(f) + + def dynamicF[F[_]]( + toolName: String, + toolDescription: String, + toolSchema: Schema + )(f: Map[String, Json] => F[String]): AgentTool[F, Map[String, Json]] = + new AgentTool[F, Map[String, Json]] { override def name: String = toolName override def description: String = toolDescription override def jsonSchema: Schema = toolSchema override def codec: Codec[Map[String, Json]] = Codec.implied - override def execute(input: Map[String, Json]): String = f(input) + override def execute(input: Map[String, Json]): F[String] = f(input) } } diff --git a/core/src/test/scala/sttp/ai/core/agent/AgentSpec.scala b/core/src/test/scala/sttp/ai/core/agent/AgentSpec.scala index bea32e3b..ae94c83e 100644 --- a/core/src/test/scala/sttp/ai/core/agent/AgentSpec.scala +++ b/core/src/test/scala/sttp/ai/core/agent/AgentSpec.scala @@ -17,7 +17,7 @@ class AgentSpec extends AnyFlatSpec with Matchers with OptionValues { private var callCount = 0 var receivedHistories: Seq[ConversationHistory] = Seq.empty - override def tools: Seq[AgentTool[_]] = Seq.empty + override def tools: Seq[AgentTool[Identity, _]] = Seq.empty override def systemPrompt: Option[String] = None override def sendRequest( diff --git a/core/src/test/scala/sttp/ai/core/agent/integration/AgentIntegrationSpecBase.scala b/core/src/test/scala/sttp/ai/core/agent/integration/AgentIntegrationSpecBase.scala index ce367f1f..446d5db6 100644 --- a/core/src/test/scala/sttp/ai/core/agent/integration/AgentIntegrationSpecBase.scala +++ b/core/src/test/scala/sttp/ai/core/agent/integration/AgentIntegrationSpecBase.scala @@ -14,11 +14,11 @@ abstract class AgentIntegrationSpecBase extends AnyFlatSpec with Matchers { def providerName: String def apiKeyEnvVar: String - def createAgent(maxIterations: Int, tools: Seq[AgentTool[_]]): Agent[Identity] + def createAgent(maxIterations: Int, tools: Seq[AgentTool[Identity, _]]): Agent[Identity] def createTypedAgent[T]( maxIterations: Int, - tools: Seq[AgentTool[_]], + tools: Seq[AgentTool[Identity, _]], responseSchema: ResponseSchema[T] ): Agent[Identity] = cancel(s"$providerName typed agent factory not implemented for this spec") @@ -29,7 +29,7 @@ abstract class AgentIntegrationSpecBase extends AnyFlatSpec with Matchers { implicit val calculatorInputCodec: Codec[CalculatorInput] = deriveCodec implicit val calculatorInputSchema: Schema[CalculatorInput] = Schema.derived - protected val calculatorTool: AgentTool[CalculatorInput] = AgentTool.fromFunction( + protected val calculatorTool: AgentTool[Identity, CalculatorInput] = AgentTool.fromFunction( "calculator", "Perform basic arithmetic operations (one of: `add`, `subtract`, `multiply`, `divide`)" ) { (input: CalculatorInput) => @@ -47,7 +47,7 @@ abstract class AgentIntegrationSpecBase extends AnyFlatSpec with Matchers { implicit val weatherInputCodec: Codec[WeatherInput] = deriveCodec implicit val weatherInputSchema: Schema[WeatherInput] = Schema.derived - protected val weatherTool: AgentTool[WeatherInput] = AgentTool.fromFunction( + protected val weatherTool: AgentTool[Identity, WeatherInput] = AgentTool.fromFunction( "get_weather", "Get current weather for a city" ) { (input: WeatherInput) => @@ -84,7 +84,7 @@ abstract class AgentIntegrationSpecBase extends AnyFlatSpec with Matchers { s"Should have at least $min iterations, but had ${result.iterations}" ) - def withAgent[T](maxIter: Int, tools: Seq[AgentTool[_]])(test: (Agent[Identity], Backend[Identity]) => T): T = { + def withAgent[T](maxIter: Int, tools: Seq[AgentTool[Identity, _]])(test: (Agent[Identity], Backend[Identity]) => T): T = { if (maybeApiKey.isEmpty) { cancel(s"$apiKeyEnvVar not defined - skipping integration test") } diff --git a/openai/src/main/scala/sttp/ai/openai/agent/OpenAIAgent.scala b/openai/src/main/scala/sttp/ai/openai/agent/OpenAIAgent.scala index 2d8c36d8..c2d6e54f 100644 --- a/openai/src/main/scala/sttp/ai/openai/agent/OpenAIAgent.scala +++ b/openai/src/main/scala/sttp/ai/openai/agent/OpenAIAgent.scala @@ -12,7 +12,7 @@ import sttp.monad.IdentityMonad private[openai] class OpenAIAgentBackend[F[_]]( openAI: OpenAI, modelName: String, - val tools: Seq[AgentTool[_]], + val tools: Seq[AgentTool[F, _]], val systemPrompt: Option[String], responseSchema: Option[ResponseSchema[_]] )(implicit monad: sttp.monad.MonadError[F]) @@ -29,7 +29,7 @@ private[openai] class OpenAIAgentBackend[F[_]]( ) } - private def convertTool(tool: AgentTool[_]): Tool.Function = { + private def convertTool(tool: AgentTool[F, _]): Tool.Function = { val schema = tool.jsonSchema val schemaJson = SchemaSupport.schemaCodec(schema) diff --git a/openai/src/test/scala/sttp/ai/openai/integration/OpenAIAgentIntegrationSpec.scala b/openai/src/test/scala/sttp/ai/openai/integration/OpenAIAgentIntegrationSpec.scala index 4c1ed5f6..e98220ed 100644 --- a/openai/src/test/scala/sttp/ai/openai/integration/OpenAIAgentIntegrationSpec.scala +++ b/openai/src/test/scala/sttp/ai/openai/integration/OpenAIAgentIntegrationSpec.scala @@ -12,7 +12,7 @@ class OpenAIAgentIntegrationSpec extends AgentIntegrationSpecBase { override def providerName: String = "OpenAI" override def apiKeyEnvVar: String = "OPENAI_API_KEY" - override def createAgent(maxIterations: Int, tools: Seq[AgentTool[_]]): Agent[Identity] = { + override def createAgent(maxIterations: Int, tools: Seq[AgentTool[Identity, _]]): Agent[Identity] = { val openai = OpenAI.fromEnv val agentConfig = AgentConfig[Identity](maxIterations = maxIterations, userTools = tools) val agentBackend = new OpenAIAgentBackend[Identity]( @@ -27,7 +27,7 @@ class OpenAIAgentIntegrationSpec extends AgentIntegrationSpecBase { override def createTypedAgent[T]( maxIterations: Int, - tools: Seq[AgentTool[_]], + tools: Seq[AgentTool[Identity, _]], responseSchema: ResponseSchema[T] ): Agent[Identity] = { val openai = OpenAI.fromEnv