From fedefd830e1a6df98972e9e01620f612121bd4f5 Mon Sep 17 00:00:00 2001 From: WenjinXie Date: Wed, 6 May 2026 17:33:38 +0800 Subject: [PATCH 1/4] [api][java] Inject output format prompt after the first system prompt for ReAct agent. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../flink/agents/api/agents/ReActAgent.java | 3 ++- .../agents/api/chat/messages/ChatMessage.java | 10 +++++++++ .../api/chat/messages/ChatMessageTest.java | 21 +++++++++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java index e32910500..0b394baab 100644 --- a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java +++ b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java @@ -160,7 +160,8 @@ public static void startAction(Event event, RunnerContext ctx) { if (schmaPrompt != null) { List instruct = schmaPrompt.formatMessages(MessageRole.SYSTEM, Map.of()); - inputMessages.addAll(0, instruct); + int index = ChatMessage.findFirstSystemMessage(inputMessages); + inputMessages.addAll(index + 1, instruct); } Object outputSchema = ctx.getActionConfigValue("output_schema"); diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java b/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java index ef12dd29c..c3be5ef84 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java @@ -154,4 +154,14 @@ public int hashCode() { public String toString() { return role.getValue() + ": " + content; } + + /** Return the index of the first system message in the list, or -1 if none. */ + public static int findFirstSystemMessage(List messages) { + for (int i = 0; i < messages.size(); i++) { + if (messages.get(i).getRole() == MessageRole.SYSTEM) { + return i; + } + } + return -1; + } } diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/messages/ChatMessageTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/messages/ChatMessageTest.java index f24bae197..d58353ba8 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/messages/ChatMessageTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/messages/ChatMessageTest.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.junit.jupiter.api.Assertions.*; @@ -149,4 +150,24 @@ void testRoleModification() { message.setRole(MessageRole.ASSISTANT); assertEquals(MessageRole.ASSISTANT, message.getRole()); } + + @Test + @DisplayName("findFirstSystemMessage returns the index of the first system message") + void testFindFirstSystemMessage() { + assertEquals(-1, ChatMessage.findFirstSystemMessage(List.of())); + assertEquals( + -1, ChatMessage.findFirstSystemMessage(List.of(userMessage, assistantMessage))); + assertEquals(0, ChatMessage.findFirstSystemMessage(List.of(systemMessage, userMessage))); + assertEquals( + 1, + ChatMessage.findFirstSystemMessage( + List.of(userMessage, systemMessage, assistantMessage))); + assertEquals( + 0, + ChatMessage.findFirstSystemMessage( + List.of( + systemMessage, + new ChatMessage(MessageRole.SYSTEM, "second system"), + userMessage))); + } } From 41f076471d371c17ff2a56a1d4a438c3b8f7f264 Mon Sep 17 00:00:00 2001 From: WenjinXie Date: Wed, 6 May 2026 17:50:11 +0800 Subject: [PATCH 2/4] [api][java] Extract resource context. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../chat/model/BaseChatModelConnection.java | 7 +- .../api/chat/model/BaseChatModelSetup.java | 14 ++-- .../python/PythonChatModelConnection.java | 8 +-- .../model/python/PythonChatModelSetup.java | 8 +-- .../model/BaseEmbeddingModelConnection.java | 6 +- .../model/BaseEmbeddingModelSetup.java | 12 ++-- .../PythonEmbeddingModelConnection.java | 8 +-- .../python/PythonEmbeddingModelSetup.java | 8 +-- .../flink/agents/api/resource/Resource.java | 18 +++-- .../agents/api/resource/ResourceContext.java | 72 +++++++++++++++++++ .../api/vectorstores/BaseVectorStore.java | 11 ++- ...PythonCollectionManageableVectorStore.java | 8 +-- .../python/PythonVectorStore.java | 8 +-- ...seChatModelConnectionTokenMetricsTest.java | 8 +-- .../api/chat/model/BaseChatModelTest.java | 9 +-- .../python/PythonChatModelConnectionTest.java | 6 +- .../python/PythonChatModelSetupTest.java | 6 +- .../PythonEmbeddingModelConnectionTest.java | 6 +- .../python/PythonEmbeddingModelSetupTest.java | 6 +- ...onCollectionManageableVectorStoreTest.java | 6 +- docs/content/docs/development/chat_models.md | 6 +- .../content/docs/development/vector_stores.md | 4 +- .../AnthropicChatModelConnection.java | 8 +-- .../anthropic/AnthropicChatModelSetup.java | 13 ++-- .../azureai/AzureAIChatModelConnection.java | 8 +-- .../azureai/AzureAIChatModelSetup.java | 6 +- .../bedrock/BedrockChatModelConnection.java | 8 +-- .../bedrock/BedrockChatModelSetup.java | 9 +-- .../BedrockChatModelConnectionTest.java | 6 +- .../bedrock/BedrockChatModelSetupTest.java | 6 +- .../ollama/OllamaChatModelConnection.java | 13 ++-- .../ollama/OllamaChatModelSetup.java | 16 ++--- .../openai/OpenAICompletionsConnection.java | 8 +-- .../openai/OpenAICompletionsSetup.java | 13 ++-- .../OpenAIResponsesModelConnection.java | 8 +-- .../openai/OpenAIResponsesModelSetup.java | 12 ++-- .../BedrockEmbeddingModelConnection.java | 8 +-- .../bedrock/BedrockEmbeddingModelSetup.java | 8 +-- .../bedrock/BedrockEmbeddingModelTest.java | 6 +- .../OllamaEmbeddingModelConnection.java | 8 +-- .../ollama/OllamaEmbeddingModelSetup.java | 8 +-- .../OllamaEmbeddingModelConnectionTest.java | 12 ++-- .../agents/integrations/mcp/MCPServer.java | 7 +- .../ElasticsearchVectorStore.java | 8 +-- .../ElasticsearchVectorStoreTest.java | 4 +- .../opensearch/OpenSearchVectorStore.java | 9 +-- .../opensearch/OpenSearchVectorStoreTest.java | 9 ++- .../s3vectors/S3VectorsVectorStore.java | 9 +-- .../s3vectors/S3VectorsVectorStoreTest.java | 9 ++- .../plan/resource/python/PythonMCPServer.java | 6 +- .../JavaResourceProvider.java | 9 ++- .../JavaSerializableResourceProvider.java | 6 +- .../PythonResourceProvider.java | 10 +-- .../PythonSerializableResourceProvider.java | 5 +- .../resourceprovider/ResourceProvider.java | 9 ++- .../plan/AgentPlanDeclareChatModelTest.java | 34 ++++----- .../plan/AgentPlanDeclareMCPServerTest.java | 9 ++- .../plan/AgentPlanDeclareToolFieldTest.java | 9 ++- .../plan/AgentPlanDeclareToolMethodTest.java | 36 ++++++---- .../flink/agents/plan/AgentPlanTest.java | 6 +- .../agents/plan/FunctionToolPlanTest.java | 19 ++--- .../runtime/PythonMCPResourceDiscovery.java | 16 +++-- .../flink/agents/runtime/ResourceCache.java | 16 +++-- .../operator/ActionExecutionOperator.java | 20 +++--- .../python/utils/JavaResourceAdapter.java | 11 ++- .../utils/PythonResourceAdapterImpl.java | 20 ++++-- .../runtime/resource/ResourceContextImpl.java | 65 +++++++++++++++++ .../agents/runtime/ResourceCacheTest.java | 6 +- .../utils/PythonResourceAdapterImplTest.java | 30 ++++---- 69 files changed, 459 insertions(+), 367 deletions(-) create mode 100644 api/src/main/java/org/apache/flink/agents/api/resource/ResourceContext.java create mode 100644 runtime/src/main/java/org/apache/flink/agents/runtime/resource/ResourceContextImpl.java diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java index 7b70af0c3..a6dccc44a 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java @@ -21,13 +21,13 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.Tool; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; /** * Abstraction of chat model connection. @@ -37,9 +37,8 @@ */ public abstract class BaseChatModelConnection extends Resource { - public BaseChatModelConnection( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + public BaseChatModelConnection(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); } @Override diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java index bcfb0e503..b0f73d80f 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java @@ -22,6 +22,7 @@ import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.prompt.Prompt; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.Tool; @@ -35,7 +36,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; public abstract class BaseChatModelSetup extends Resource { protected final String connectionName; @@ -46,9 +46,8 @@ public abstract class BaseChatModelSetup extends Resource { @Nullable protected BaseChatModelConnection connection; protected final List tools = new ArrayList<>(); - public BaseChatModelSetup( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + public BaseChatModelSetup(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); this.connectionName = descriptor.getArgument("connection"); this.model = descriptor.getArgument("model"); this.prompt = descriptor.getArgument("prompt"); @@ -66,14 +65,15 @@ public BaseChatModelSetup( public void open() throws Exception { this.connection = (BaseChatModelConnection) - this.getResource.apply( + this.resourceContext.getResource( this.connectionName, ResourceType.CHAT_MODEL_CONNECTION); if (this.prompt != null && this.prompt instanceof String) { - this.prompt = this.getResource.apply((String) this.prompt, ResourceType.PROMPT); + this.prompt = + this.resourceContext.getResource((String) this.prompt, ResourceType.PROMPT); } if (this.toolNames != null) { for (String name : this.toolNames) { - this.tools.add((Tool) this.getResource.apply(name, ResourceType.TOOL)); + this.tools.add((Tool) this.resourceContext.getResource(name, ResourceType.TOOL)); } } } diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnection.java index 42018f6d8..2b92554d9 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnection.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnection.java @@ -19,9 +19,8 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; import org.apache.flink.agents.api.tools.Tool; @@ -31,7 +30,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; /** * Python-based implementation of ChatModelConnection that wraps a Python chat model object. This @@ -56,8 +54,8 @@ public PythonChatModelConnection( PythonResourceAdapter adapter, PyObject chatModel, ResourceDescriptor descriptor, - BiFunction getResource) { - super(descriptor, getResource); + ResourceContext resourceContext) { + super(descriptor, resourceContext); this.chatModel = chatModel; this.adapter = adapter; } diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java index ab985cd6a..a15cede18 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java @@ -19,9 +19,8 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; import pemja.core.object.PyObject; @@ -31,7 +30,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.apache.flink.util.Preconditions.checkState; @@ -52,8 +50,8 @@ public PythonChatModelSetup( PythonResourceAdapter adapter, PyObject chatModelSetup, ResourceDescriptor descriptor, - BiFunction getResource) { - super(descriptor, getResource); + ResourceContext resourceContext) { + super(descriptor, resourceContext); this.chatModelSetup = chatModelSetup; this.adapter = adapter; } diff --git a/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelConnection.java index 377449036..4d46e3eed 100644 --- a/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelConnection.java +++ b/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelConnection.java @@ -19,12 +19,12 @@ package org.apache.flink.agents.api.embedding.model; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; /** * Abstraction of embedding model connection. @@ -45,8 +45,8 @@ public abstract class BaseEmbeddingModelConnection extends Resource { public BaseEmbeddingModelConnection( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); } @Override diff --git a/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java index a7b4cf954..e7c19893d 100644 --- a/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java @@ -19,6 +19,7 @@ package org.apache.flink.agents.api.embedding.model; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.annotation.VisibleForTesting; @@ -29,7 +30,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; /** * Base class for embedding model setup configurations. @@ -43,9 +43,8 @@ public abstract class BaseEmbeddingModelSetup extends Resource { @Nullable protected BaseEmbeddingModelConnection connection; - public BaseEmbeddingModelSetup( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + public BaseEmbeddingModelSetup(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); this.connectionName = descriptor.getArgument("connection"); this.model = descriptor.getArgument("model"); } @@ -58,10 +57,11 @@ public BaseEmbeddingModelSetup( * resources object out of the method to be async executed and invoking it in the main thread. */ @Override - public void open() { + public void open() throws Exception { this.connection = (BaseEmbeddingModelConnection) - getResource.apply(connectionName, ResourceType.EMBEDDING_MODEL_CONNECTION); + resourceContext.getResource( + connectionName, ResourceType.EMBEDDING_MODEL_CONNECTION); } public abstract Map getParameters(); diff --git a/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnection.java index f91896d16..974e362a1 100644 --- a/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnection.java +++ b/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnection.java @@ -19,9 +19,8 @@ import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection; import org.apache.flink.agents.api.embedding.model.EmbeddingModelUtils; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; import pemja.core.object.PyObject; @@ -30,7 +29,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.apache.flink.util.Preconditions.checkState; @@ -59,8 +57,8 @@ public PythonEmbeddingModelConnection( PythonResourceAdapter adapter, PyObject embeddingModel, ResourceDescriptor descriptor, - BiFunction getResource) { - super(descriptor, getResource); + ResourceContext resourceContext) { + super(descriptor, resourceContext); this.embeddingModel = embeddingModel; this.adapter = adapter; } diff --git a/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java index d0eb4979d..f0b9eca4b 100644 --- a/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java @@ -19,9 +19,8 @@ import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; import org.apache.flink.agents.api.embedding.model.EmbeddingModelUtils; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; import pemja.core.object.PyObject; @@ -31,7 +30,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.apache.flink.util.Preconditions.checkState; @@ -59,8 +57,8 @@ public PythonEmbeddingModelSetup( PythonResourceAdapter adapter, PyObject embeddingModelSetup, ResourceDescriptor descriptor, - BiFunction getResource) { - super(descriptor, getResource); + ResourceContext resourceContext) { + super(descriptor, resourceContext); this.embeddingModelSetup = embeddingModelSetup; this.adapter = adapter; } diff --git a/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java b/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java index de2e588be..52ba40ffd 100644 --- a/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java +++ b/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java @@ -20,22 +20,19 @@ import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; -import java.util.function.BiFunction; - /** * Base interface for all kinds of resources, including chat models, tools, prompts and so on. * *

Resources are components that can be used by agents during action execution. */ public abstract class Resource { - protected BiFunction getResource; + protected ResourceContext resourceContext; /** The metric group bound to this resource, injected by RunnerContext.getResource(). */ private transient FlinkAgentsMetricGroup metricGroup; - protected Resource( - ResourceDescriptor descriptor, BiFunction getResource) { - this.getResource = getResource; + protected Resource(ResourceDescriptor descriptor, ResourceContext resourceContext) { + this.resourceContext = resourceContext; } protected Resource() {} @@ -47,6 +44,15 @@ protected Resource() {} */ public abstract ResourceType getResourceType(); + /** + * Get the {@link ResourceContext} bound to this resource at construction time. + * + * @return the bound resource context, or {@code null} if not bound + */ + public ResourceContext getResourceContext() { + return resourceContext; + } + /** * Set the metric group for this resource. * diff --git a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceContext.java b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceContext.java new file mode 100644 index 000000000..b401bf999 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceContext.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.resource; + +import java.util.Collections; +import java.util.List; +import java.util.function.BiFunction; + +/** + * Capabilities available to a {@link Resource} during execution. + * + *

Mirrors the Python {@code flink_agents.api.resource_context.ResourceContext}. + */ +public interface ResourceContext { + + /** Get another resource declared in the same Agent. */ + Resource getResource(String name, ResourceType type) throws Exception; + + /** + * Generate the available skills prompt for the given skill names. + * + *

Returns an empty string if no skills are configured. + */ + String generateAvailableSkillsPrompt(List skillNames) throws Exception; + + /** + * Return absolute directory paths for the given skill names. + * + *

Returns an empty list if no skills are configured or none of the requested skills are + * filesystem-backed. + */ + List getSkillDirs(List skillNames) throws Exception; + + /** + * Create a {@link ResourceContext} backed by the given resource lookup function. The skill + * methods return empty defaults — convenient for tests or for runtimes without skills support. + */ + static ResourceContext fromGetResource(BiFunction getResource) { + return new ResourceContext() { + @Override + public Resource getResource(String name, ResourceType type) { + return getResource.apply(name, type); + } + + @Override + public String generateAvailableSkillsPrompt(List skillNames) { + return ""; + } + + @Override + public List getSkillDirs(List skillNames) { + return Collections.emptyList(); + } + }; + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/vectorstores/BaseVectorStore.java b/api/src/main/java/org/apache/flink/agents/api/vectorstores/BaseVectorStore.java index 09a622b8f..442e705ed 100644 --- a/api/src/main/java/org/apache/flink/agents/api/vectorstores/BaseVectorStore.java +++ b/api/src/main/java/org/apache/flink/agents/api/vectorstores/BaseVectorStore.java @@ -20,6 +20,7 @@ import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; @@ -28,7 +29,6 @@ import java.io.IOException; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; /** * Base abstract class for vector store. Provides vector store functionality that integrates @@ -78,9 +78,8 @@ public abstract class BaseVectorStore extends Resource { */ protected final @Nullable String collection; - public BaseVectorStore( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + public BaseVectorStore(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); this.embeddingModelName = descriptor.getArgument("embedding_model"); this.collection = descriptor.getArgument("collection"); } @@ -93,11 +92,11 @@ public BaseVectorStore( * resources object out of the method to be async executed and invoking it in the main thread. */ @Override - public void open() { + public void open() throws Exception { if (this.embeddingModelName != null) { this.embeddingModel = (BaseEmbeddingModelSetup) - this.getResource.apply( + this.resourceContext.getResource( this.embeddingModelName, ResourceType.EMBEDDING_MODEL); } } diff --git a/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonCollectionManageableVectorStore.java b/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonCollectionManageableVectorStore.java index 2c75167ec..26cce4236 100644 --- a/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonCollectionManageableVectorStore.java +++ b/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonCollectionManageableVectorStore.java @@ -18,16 +18,14 @@ package org.apache.flink.agents.api.vectorstores.python; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; import pemja.core.object.PyObject; import java.util.HashMap; import java.util.Map; -import java.util.function.BiFunction; /** * Python-based implementation of VectorStore with collection management capabilities that bridges @@ -54,8 +52,8 @@ public PythonCollectionManageableVectorStore( PythonResourceAdapter adapter, PyObject vectorStore, ResourceDescriptor descriptor, - BiFunction getResource) { - super(adapter, vectorStore, descriptor, getResource); + ResourceContext resourceContext) { + super(adapter, vectorStore, descriptor, resourceContext); } @Override diff --git a/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java b/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java index 65418f2ec..6b2d5d71a 100644 --- a/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java +++ b/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java @@ -18,9 +18,8 @@ package org.apache.flink.agents.api.vectorstores.python; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; import org.apache.flink.agents.api.vectorstores.BaseVectorStore; @@ -36,7 +35,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; /** * Python-based implementation of VectorStore that bridges Java and Python vector store @@ -69,8 +67,8 @@ public PythonVectorStore( PythonResourceAdapter adapter, PyObject vectorStore, ResourceDescriptor descriptor, - BiFunction getResource) { - super(descriptor, getResource); + ResourceContext resourceContext) { + super(descriptor, resourceContext); this.vectorStore = vectorStore; this.adapter = adapter; } diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java index 53c9bc6cc..43654944a 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java @@ -21,7 +21,7 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.Tool; @@ -33,7 +33,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; @@ -51,9 +50,8 @@ class BaseChatModelConnectionTokenMetricsTest { private static class TestChatModelConnection extends BaseChatModelConnection { public TestChatModelConnection( - ResourceDescriptor descriptor, - BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); } @Override diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java index 65c83152a..61c9f823a 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java @@ -21,7 +21,7 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.prompt.Prompt; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.junit.jupiter.api.BeforeEach; @@ -33,7 +33,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.junit.jupiter.api.Assertions.*; @@ -51,10 +50,8 @@ class BaseChatModelTest { private static class TestChatModel extends BaseChatModelSetup { private String responsePrefix = "Test Response: "; - public TestChatModel( - ResourceDescriptor descriptor, - BiFunction getResource) { - super(descriptor, getResource); + public TestChatModel(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); } @Override diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnectionTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnectionTest.java index 79069d7ad..3b939db5b 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnectionTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelConnectionTest.java @@ -19,9 +19,8 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; import org.apache.flink.agents.api.tools.Tool; @@ -36,7 +35,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.*; @@ -48,7 +46,7 @@ public class PythonChatModelConnectionTest { @Mock private ResourceDescriptor mockDescriptor; - @Mock private BiFunction mockGetResource; + @Mock private ResourceContext mockGetResource; private PythonChatModelConnection pythonChatModelConnection; private AutoCloseable mocks; diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java index 274ddfe5f..3e8b1ebb5 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java @@ -18,9 +18,8 @@ package org.apache.flink.agents.api.chat.model.python; import org.apache.flink.agents.api.chat.messages.ChatMessage; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -33,7 +32,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -46,7 +44,7 @@ public class PythonChatModelSetupTest { @Mock private ResourceDescriptor mockDescriptor; - @Mock private BiFunction mockGetResource; + @Mock private ResourceContext mockGetResource; private PythonChatModelSetup pythonChatModelSetup; private AutoCloseable mocks; diff --git a/api/src/test/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnectionTest.java b/api/src/test/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnectionTest.java index 13ae25bf5..470d51270 100644 --- a/api/src/test/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnectionTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelConnectionTest.java @@ -17,9 +17,8 @@ */ package org.apache.flink.agents.api.embedding.model.python; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -33,7 +32,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -50,7 +48,7 @@ public class PythonEmbeddingModelConnectionTest { @Mock private ResourceDescriptor mockDescriptor; - @Mock private BiFunction mockGetResource; + @Mock private ResourceContext mockGetResource; private PythonEmbeddingModelConnection pythonEmbeddingModelConnection; private AutoCloseable mocks; diff --git a/api/src/test/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetupTest.java b/api/src/test/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetupTest.java index cedd2a882..d8071d7cd 100644 --- a/api/src/test/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetupTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetupTest.java @@ -17,9 +17,8 @@ */ package org.apache.flink.agents.api.embedding.model.python; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -33,7 +32,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -50,7 +48,7 @@ public class PythonEmbeddingModelSetupTest { @Mock private ResourceDescriptor mockDescriptor; - @Mock private BiFunction mockGetResource; + @Mock private ResourceContext mockGetResource; private PythonEmbeddingModelSetup pythonEmbeddingModelSetup; private AutoCloseable mocks; diff --git a/api/src/test/java/org/apache/flink/agents/api/vectorstores/python/PythonCollectionManageableVectorStoreTest.java b/api/src/test/java/org/apache/flink/agents/api/vectorstores/python/PythonCollectionManageableVectorStoreTest.java index e580db58f..0a5e04cb1 100644 --- a/api/src/test/java/org/apache/flink/agents/api/vectorstores/python/PythonCollectionManageableVectorStoreTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/vectorstores/python/PythonCollectionManageableVectorStoreTest.java @@ -18,9 +18,8 @@ package org.apache.flink.agents.api.vectorstores.python; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; import org.apache.flink.agents.api.vectorstores.Document; @@ -35,7 +34,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; @@ -51,7 +49,7 @@ public class PythonCollectionManageableVectorStoreTest { @Mock private ResourceDescriptor mockDescriptor; - @Mock private BiFunction mockGetResource; + @Mock private ResourceContext mockGetResource; @Mock private PyObject mockPythonDocument; diff --git a/docs/content/docs/development/chat_models.md b/docs/content/docs/development/chat_models.md index 4ccce5e76..99ac9d7e5 100644 --- a/docs/content/docs/development/chat_models.md +++ b/docs/content/docs/development/chat_models.md @@ -1150,11 +1150,11 @@ public class MyChatModelConnection extends BaseChatModelConnection { * Creates a new chat model connection. * * @param descriptor a resource descriptor contains the initial parameters - * @param getResource a function to resolve resources (e.g., tools) by name and type + * @param resourceContext context for resolving resources (e.g., tools) by name and type */ public MyChatModelConnection( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); // get custom arguments from descriptor String endpoint = descriptor.getArgument("endpoint"); ... diff --git a/docs/content/docs/development/vector_stores.md b/docs/content/docs/development/vector_stores.md index 4b730a889..f5903a0d3 100644 --- a/docs/content/docs/development/vector_stores.md +++ b/docs/content/docs/development/vector_stores.md @@ -857,8 +857,8 @@ public class MyVectorStore extends BaseVectorStore { public MyVectorStore( ResourceDescriptor descriptor, - BiFunction getResource) { - super(descriptor, getResource); + ResourceContext resourceContext) { + super(descriptor, resourceContext); } @Override diff --git a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java index 6dded957a..248b464ea 100644 --- a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java +++ b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java @@ -37,9 +37,8 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.ToolMetadata; import java.time.Duration; @@ -49,7 +48,6 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.function.BiFunction; import java.util.stream.Collectors; /** @@ -87,8 +85,8 @@ public class AnthropicChatModelConnection extends BaseChatModelConnection { private final String defaultModel; public AnthropicChatModelConnection( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); String apiKey = descriptor.getArgument("api_key"); if (apiKey == null || apiKey.isBlank()) { diff --git a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelSetup.java b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelSetup.java index 145213cdd..cf2a79358 100644 --- a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelSetup.java +++ b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelSetup.java @@ -18,15 +18,13 @@ package org.apache.flink.agents.integrations.chatmodels.anthropic; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.function.BiFunction; /** * Chat model setup for the Anthropic Messages API. @@ -85,9 +83,8 @@ public class AnthropicChatModelSetup extends BaseChatModelSetup { private final Boolean strictTools; private final Map additionalArguments; - public AnthropicChatModelSetup( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + public AnthropicChatModelSetup(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); this.temperature = Optional.ofNullable(descriptor.getArgument("temperature")) .map(Number::doubleValue) @@ -129,10 +126,10 @@ public AnthropicChatModelSetup( long maxTokens, Map additionalArguments, List tools, - BiFunction getResource) { + ResourceContext resourceContext) { this( createDescriptor(model, temperature, maxTokens, additionalArguments, tools), - getResource); + resourceContext); } @Override diff --git a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java index f223ee553..3051ecf47 100644 --- a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java +++ b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java @@ -28,13 +28,11 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.Tool; import java.util.*; -import java.util.function.BiFunction; import java.util.stream.Collectors; /** @@ -74,8 +72,8 @@ public class AzureAIChatModelConnection extends BaseChatModelConnection { * @throws IllegalArgumentException if endpoint is null or empty */ public AzureAIChatModelConnection( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); String endpoint = descriptor.getArgument("endpoint"); String apiKey = descriptor.getArgument("apiKey"); diff --git a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelSetup.java b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelSetup.java index 465cbc829..84b0fb615 100644 --- a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelSetup.java +++ b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelSetup.java @@ -18,9 +18,7 @@ package org.apache.flink.agents.integrations.chatmodels.azureai; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; -import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; /** * A chat model integration for Azure AI Chat Completions service. @@ -52,8 +50,8 @@ public class AzureAIChatModelSetup extends BaseChatModelSetup { public AzureAIChatModelSetup( ResourceDescriptor descriptor, - java.util.function.BiFunction getResource) { - super(descriptor, getResource); + org.apache.flink.agents.api.resource.ResourceContext resourceContext) { + super(descriptor, resourceContext); } // For any other specific parameters, please refer to ChatCompletionsOptions diff --git a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java index 96b7683c4..8327795a3 100644 --- a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java +++ b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java @@ -25,9 +25,8 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.Tool; import org.apache.flink.agents.api.tools.ToolMetadata; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; @@ -55,7 +54,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import java.util.stream.Collectors; /** @@ -95,8 +93,8 @@ public class BedrockChatModelConnection extends BaseChatModelConnection { private final RetryExecutor retryExecutor; public BedrockChatModelConnection( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); String region = descriptor.getArgument("region"); if (region == null || region.isBlank()) { diff --git a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java index cbcd380b6..0ddc3aaaa 100644 --- a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java +++ b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java @@ -19,14 +19,12 @@ package org.apache.flink.agents.integrations.chatmodels.bedrock; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import java.util.HashMap; import java.util.Map; import java.util.Optional; -import java.util.function.BiFunction; /** * Chat model setup for AWS Bedrock Converse API. @@ -61,9 +59,8 @@ public class BedrockChatModelSetup extends BaseChatModelSetup { private final Double temperature; private final Integer maxTokens; - public BedrockChatModelSetup( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + public BedrockChatModelSetup(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); this.temperature = Optional.ofNullable(descriptor.getArgument("temperature")) .map(Number::doubleValue) diff --git a/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java index 84c294819..5e570f59d 100644 --- a/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java +++ b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java @@ -21,14 +21,12 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import java.util.*; -import java.util.function.BiFunction; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -37,7 +35,7 @@ /** Tests for {@link BedrockChatModelConnection}. */ class BedrockChatModelConnectionTest { - private static final BiFunction NOOP = (a, b) -> null; + private static final ResourceContext NOOP = ResourceContext.fromGetResource((a, b) -> null); private static ResourceDescriptor descriptor(String region, String model) { ResourceDescriptor.Builder b = diff --git a/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java index 05094f024..00ee2d7a0 100644 --- a/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java +++ b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java @@ -19,21 +19,19 @@ package org.apache.flink.agents.integrations.chatmodels.bedrock; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import java.util.Map; -import java.util.function.BiFunction; import static org.assertj.core.api.Assertions.assertThat; /** Tests for {@link BedrockChatModelSetup}. */ class BedrockChatModelSetupTest { - private static final BiFunction NOOP = (a, b) -> null; + private static final ResourceContext NOOP = ResourceContext.fromGetResource((a, b) -> null); @Test @DisplayName("getParameters includes model and default temperature") diff --git a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java index 773069f4c..2cda1ea40 100644 --- a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java +++ b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java @@ -28,13 +28,11 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.Tool; import java.util.*; -import java.util.function.BiFunction; import java.util.stream.Collectors; /** @@ -71,8 +69,8 @@ public class OllamaChatModelConnection extends BaseChatModelConnection { * @throws IllegalArgumentException if endpoint is null or empty */ public OllamaChatModelConnection( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); String endpoint = descriptor.getArgument("endpoint"); if (endpoint == null || endpoint.isEmpty()) { throw new IllegalArgumentException("endpoint should not be null or empty."); @@ -90,12 +88,11 @@ public OllamaChatModelConnection( * @param getResource a function to resolve resources (e.g., tools) by name and type * @throws IllegalArgumentException if endpoint is null or empty */ - public OllamaChatModelConnection( - String endpoint, BiFunction getResource) { + public OllamaChatModelConnection(String endpoint, ResourceContext resourceContext) { this( new ResourceDescriptor( OllamaChatModelConnection.class.getName(), Map.of("endpoint", endpoint)), - getResource); + resourceContext); } /** diff --git a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java index 63805c81b..81a3f92ff 100644 --- a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java +++ b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java @@ -19,14 +19,12 @@ package org.apache.flink.agents.integrations.chatmodels.ollama; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; /** * A chat model integration for Ollama powered by the ollama4j client. @@ -59,9 +57,8 @@ public class OllamaChatModelSetup extends BaseChatModelSetup { private final Object think; private final boolean extractReasoning; - public OllamaChatModelSetup( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + public OllamaChatModelSetup(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); this.model = descriptor.getArgument("model"); this.think = descriptor.getArgument("think", true); this.extractReasoning = descriptor.getArgument("extract_reasoning", true); @@ -77,15 +74,12 @@ public OllamaChatModelSetup( * @throws IllegalArgumentException if endpoint is null or empty */ public OllamaChatModelSetup( - String model, - String prompt, - List tools, - BiFunction getResource) { + String model, String prompt, List tools, ResourceContext resourceContext) { this( new ResourceDescriptor( OllamaChatModelSetup.class.getName(), Map.of("model", model, "prompt", prompt, "tools", tools)), - getResource); + resourceContext); } @Override diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java index 6ed1b1703..153073075 100644 --- a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java @@ -43,9 +43,8 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.Tool; import org.apache.flink.agents.api.tools.ToolMetadata; @@ -56,7 +55,6 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.function.BiFunction; import java.util.stream.Collectors; /** @@ -100,8 +98,8 @@ public class OpenAICompletionsConnection extends BaseChatModelConnection { private final String defaultModel; public OpenAICompletionsConnection( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); String apiKey = descriptor.getArgument("api_key"); if (apiKey == null || apiKey.isBlank()) { diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsSetup.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsSetup.java index 13a32ac29..29ee3753b 100644 --- a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsSetup.java +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsSetup.java @@ -18,16 +18,14 @@ package org.apache.flink.agents.integrations.chatmodels.openai; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.function.BiFunction; /** * Chat model setup for the OpenAI Chat Completions API. @@ -74,9 +72,8 @@ public class OpenAICompletionsSetup extends BaseChatModelSetup { private final String reasoningEffort; private final Map additionalArguments; - public OpenAICompletionsSetup( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + public OpenAICompletionsSetup(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); this.temperature = Optional.ofNullable(descriptor.getArgument("temperature")) .map(Number::doubleValue) @@ -136,7 +133,7 @@ public OpenAICompletionsSetup( String reasoningEffort, Map additionalArguments, List tools, - BiFunction getResource) { + ResourceContext resourceContext) { this( createDescriptor( model, @@ -148,7 +145,7 @@ public OpenAICompletionsSetup( reasoningEffort, additionalArguments, tools), - getResource); + resourceContext); } @Override diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java index f185d65f0..9b0d143eb 100644 --- a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java @@ -31,9 +31,8 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.ToolMetadata; import java.time.Duration; @@ -43,7 +42,6 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.function.BiFunction; /** * A dedicated OpenAI chat model integration using the Responses API. @@ -90,8 +88,8 @@ public class OpenAIResponsesModelConnection extends BaseChatModelConnection { private final String defaultModel; public OpenAIResponsesModelConnection( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); String apiKey = descriptor.getArgument("api_key"); if (apiKey == null || apiKey.isBlank()) { diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelSetup.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelSetup.java index 54e7d49b8..8cdeacb5d 100644 --- a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelSetup.java +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelSetup.java @@ -18,16 +18,14 @@ package org.apache.flink.agents.integrations.chatmodels.openai; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.function.BiFunction; /** * Chat model setup for the OpenAI Responses API. @@ -73,8 +71,8 @@ public class OpenAIResponsesModelSetup extends BaseChatModelSetup { private final Map additionalArguments; public OpenAIResponsesModelSetup( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); this.temperature = Optional.ofNullable(descriptor.getArgument("temperature")) @@ -129,7 +127,7 @@ public OpenAIResponsesModelSetup( String instructions, Map additionalArguments, List tools, - BiFunction getResource) { + ResourceContext resourceContext) { this( createDescriptor( model, @@ -141,7 +139,7 @@ public OpenAIResponsesModelSetup( instructions, additionalArguments, tools), - getResource); + resourceContext); } @Override diff --git a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java index 0d9a87838..a7dc2926c 100644 --- a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java +++ b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java @@ -23,9 +23,8 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import org.apache.flink.agents.api.RetryExecutor; import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; @@ -39,7 +38,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.function.BiFunction; /** * Bedrock embedding model connection using Amazon Titan Text Embeddings V2. @@ -80,8 +78,8 @@ public class BedrockEmbeddingModelConnection extends BaseEmbeddingModelConnectio private final RetryExecutor retryExecutor; public BedrockEmbeddingModelConnection( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); String region = descriptor.getArgument("region"); if (region == null || region.isBlank()) { diff --git a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java index e3e6d6dd4..d5bd7855b 100644 --- a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java +++ b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java @@ -19,13 +19,11 @@ package org.apache.flink.agents.integrations.embeddingmodels.bedrock; import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import java.util.HashMap; import java.util.Map; -import java.util.function.BiFunction; /** * Embedding model setup for Bedrock Titan Text Embeddings. @@ -56,8 +54,8 @@ public class BedrockEmbeddingModelSetup extends BaseEmbeddingModelSetup { private final Integer dimensions; public BedrockEmbeddingModelSetup( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); this.dimensions = descriptor.getArgument("dimensions"); } diff --git a/integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java b/integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java index 3d2d3d07b..0891d7067 100644 --- a/integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java +++ b/integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java @@ -20,14 +20,12 @@ import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection; import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import java.util.Map; -import java.util.function.BiFunction; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -35,7 +33,7 @@ /** Tests for {@link BedrockEmbeddingModelConnection} and {@link BedrockEmbeddingModelSetup}. */ class BedrockEmbeddingModelTest { - private static final BiFunction NOOP = (a, b) -> null; + private static final ResourceContext NOOP = ResourceContext.fromGetResource((a, b) -> null); private static ResourceDescriptor connDescriptor(String region) { ResourceDescriptor.Builder b = diff --git a/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnection.java b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnection.java index 6ada023de..240dc14b5 100644 --- a/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnection.java +++ b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnection.java @@ -24,12 +24,10 @@ import io.github.ollama4j.models.embed.OllamaEmbedResult; import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection; import org.apache.flink.agents.api.embedding.model.EmbeddingModelUtils; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import java.util.*; -import java.util.function.BiFunction; /** An embedding model integration for Ollama powered by the ollama4j client. */ public class OllamaEmbeddingModelConnection extends BaseEmbeddingModelConnection { @@ -38,8 +36,8 @@ public class OllamaEmbeddingModelConnection extends BaseEmbeddingModelConnection private final String defaultModel; public OllamaEmbeddingModelConnection( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); String host = descriptor.getArgument("host") != null ? descriptor.getArgument("host") diff --git a/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelSetup.java b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelSetup.java index 3880dfb20..a718da7e2 100644 --- a/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelSetup.java +++ b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelSetup.java @@ -19,20 +19,18 @@ package org.apache.flink.agents.integrations.embeddingmodels.ollama; import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import java.util.HashMap; import java.util.Map; -import java.util.function.BiFunction; /** An embedding model setup for Ollama powered by the ollama4j client. */ public class OllamaEmbeddingModelSetup extends BaseEmbeddingModelSetup { public OllamaEmbeddingModelSetup( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); } @Override diff --git a/integrations/embedding-models/ollama/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnectionTest.java b/integrations/embedding-models/ollama/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnectionTest.java index 2b908e206..167ffecb0 100644 --- a/integrations/embedding-models/ollama/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnectionTest.java +++ b/integrations/embedding-models/ollama/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnectionTest.java @@ -20,15 +20,13 @@ import org.apache.flink.agents.api.annotation.EmbeddingModelConnection; import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; -import java.util.function.BiFunction; import static org.junit.jupiter.api.Assertions.*; @@ -41,7 +39,7 @@ private static ResourceDescriptor buildDescriptor() { .build(); } - private static BiFunction dummyResource = (a, b) -> null; + private static ResourceContext dummyResource = ResourceContext.fromGetResource((a, b) -> null); @Test @DisplayName("Create OllamaEmbeddingModelConnection and check embed method") @@ -63,10 +61,8 @@ void testAnnotationPresence() { @DisplayName("Test EmbeddingModelSetup annotation presence on setup class") void testSetupAnnotationPresence() { class DummySetup extends BaseEmbeddingModelSetup { - public DummySetup( - ResourceDescriptor descriptor, - BiFunction getResource) { - super(descriptor, getResource); + public DummySetup(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); } public Map getParameters() { diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java index ea3eb01ed..7d53cfab2 100644 --- a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java +++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java @@ -30,6 +30,7 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.ToolMetadata; @@ -47,7 +48,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.function.BiFunction; /** * Resource representing an MCP server and exposing its tools/prompts. @@ -173,9 +173,8 @@ public MCPServer build() { } } - public MCPServer( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + public MCPServer(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); this.endpoint = Objects.requireNonNull( descriptor.getArgument(FIELD_ENDPOINT), "endpoint cannot be null"); diff --git a/integrations/vector-stores/elasticsearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/elasticsearch/ElasticsearchVectorStore.java b/integrations/vector-stores/elasticsearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/elasticsearch/ElasticsearchVectorStore.java index 478e268b8..424c9d398 100644 --- a/integrations/vector-stores/elasticsearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/elasticsearch/ElasticsearchVectorStore.java +++ b/integrations/vector-stores/elasticsearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/elasticsearch/ElasticsearchVectorStore.java @@ -42,9 +42,8 @@ import co.elastic.clients.transport.rest_client.RestClientTransport; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.vectorstores.BaseVectorStore; import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; import org.apache.flink.agents.api.vectorstores.Document; @@ -63,7 +62,6 @@ import java.io.StringReader; import java.nio.charset.StandardCharsets; import java.util.*; -import java.util.function.BiFunction; /** * Elasticsearch-backed implementation of a vector store. @@ -158,8 +156,8 @@ public class ElasticsearchVectorStore extends BaseVectorStore * @throws IllegalArgumentException if required arguments are missing or invalid */ public ElasticsearchVectorStore( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); this.storeInContentField = Objects.requireNonNullElse(descriptor.getArgument("store_in_content_field"), true); diff --git a/integrations/vector-stores/elasticsearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/elasticsearch/ElasticsearchVectorStoreTest.java b/integrations/vector-stores/elasticsearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/elasticsearch/ElasticsearchVectorStoreTest.java index b1c5ac860..e89365e1f 100644 --- a/integrations/vector-stores/elasticsearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/elasticsearch/ElasticsearchVectorStoreTest.java +++ b/integrations/vector-stores/elasticsearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/elasticsearch/ElasticsearchVectorStoreTest.java @@ -20,6 +20,7 @@ import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.vectorstores.BaseVectorStore; @@ -73,7 +74,8 @@ public static void initialize() { .addInitialArgument("password", System.getenv("ES_PASSWORD")); store = new ElasticsearchVectorStore( - builder.build(), ElasticsearchVectorStoreTest::getResource); + builder.build(), + ResourceContext.fromGetResource(ElasticsearchVectorStoreTest::getResource)); } @Test diff --git a/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java b/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java index 3a241b72a..0f7cbf1a5 100644 --- a/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java +++ b/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java @@ -23,9 +23,8 @@ import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; import org.apache.flink.agents.api.RetryExecutor; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.vectorstores.BaseVectorStore; import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; import org.apache.flink.agents.api.vectorstores.Document; @@ -55,7 +54,6 @@ import java.util.Map; import java.util.Objects; import java.util.UUID; -import java.util.function.BiFunction; /** * OpenSearch vector store supporting both OpenSearch Serverless (AOSS) and OpenSearch Service @@ -123,9 +121,8 @@ public class OpenSearchVectorStore extends BaseVectorStore private final DefaultCredentialsProvider credentialsProvider; private final RetryExecutor retryExecutor; - public OpenSearchVectorStore( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + public OpenSearchVectorStore(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); this.endpoint = descriptor.getArgument("endpoint"); if (this.endpoint == null || this.endpoint.isBlank()) { diff --git a/integrations/vector-stores/opensearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStoreTest.java b/integrations/vector-stores/opensearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStoreTest.java index 5d35349eb..43ee586ad 100644 --- a/integrations/vector-stores/opensearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStoreTest.java +++ b/integrations/vector-stores/opensearch/src/test/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStoreTest.java @@ -20,6 +20,7 @@ import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.vectorstores.BaseVectorStore; @@ -37,7 +38,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.assertj.core.api.Assertions.assertThat; @@ -49,7 +49,7 @@ */ public class OpenSearchVectorStoreTest { - private static final BiFunction NOOP = (a, b) -> null; + private static final ResourceContext NOOP = ResourceContext.fromGetResource((a, b) -> null); @Test @DisplayName("Constructor creates store with IAM auth") @@ -164,7 +164,10 @@ static void initialize() { builder.addInitialArgument("username", System.getenv("OPENSEARCH_USERNAME")); builder.addInitialArgument("password", System.getenv("OPENSEARCH_PASSWORD")); } - store = new OpenSearchVectorStore(builder.build(), OpenSearchVectorStoreTest::getResource); + store = + new OpenSearchVectorStore( + builder.build(), + ResourceContext.fromGetResource(OpenSearchVectorStoreTest::getResource)); } @Test diff --git a/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java b/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java index 306c9a7d7..5c6e5e98f 100644 --- a/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java +++ b/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java @@ -19,9 +19,8 @@ package org.apache.flink.agents.integrations.vectorstores.s3vectors; import org.apache.flink.agents.api.RetryExecutor; -import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.vectorstores.BaseVectorStore; import org.apache.flink.agents.api.vectorstores.Document; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; @@ -47,7 +46,6 @@ import java.util.List; import java.util.Map; import java.util.UUID; -import java.util.function.BiFunction; /** * Amazon S3 Vectors vector store for flink-agents. @@ -87,9 +85,8 @@ public class S3VectorsVectorStore extends BaseVectorStore { private final String vectorIndex; private final RetryExecutor retryExecutor; - public S3VectorsVectorStore( - ResourceDescriptor descriptor, BiFunction getResource) { - super(descriptor, getResource); + public S3VectorsVectorStore(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); this.vectorBucket = descriptor.getArgument("vector_bucket"); if (this.vectorBucket == null || this.vectorBucket.isBlank()) { diff --git a/integrations/vector-stores/s3vectors/src/test/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStoreTest.java b/integrations/vector-stores/s3vectors/src/test/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStoreTest.java index 337fbdb2c..311cb3bec 100644 --- a/integrations/vector-stores/s3vectors/src/test/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStoreTest.java +++ b/integrations/vector-stores/s3vectors/src/test/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStoreTest.java @@ -20,6 +20,7 @@ import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.vectorstores.BaseVectorStore; @@ -35,7 +36,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.assertj.core.api.Assertions.assertThat; @@ -46,7 +46,7 @@ */ public class S3VectorsVectorStoreTest { - private static final BiFunction NOOP = (a, b) -> null; + private static final ResourceContext NOOP = ResourceContext.fromGetResource((a, b) -> null); @Test @DisplayName("Constructor creates store") @@ -114,7 +114,10 @@ static void initialize() { .addInitialArgument( "region", System.getenv().getOrDefault("AWS_REGION", "us-east-1")) .build(); - store = new S3VectorsVectorStore(desc, S3VectorsVectorStoreTest::getResource); + store = + new S3VectorsVectorStore( + desc, + ResourceContext.fromGetResource(S3VectorsVectorStoreTest::getResource)); } @Test diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPServer.java b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPServer.java index 6ce0f4da4..a6268347d 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPServer.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPServer.java @@ -18,6 +18,7 @@ package org.apache.flink.agents.plan.resource.python; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; @@ -27,7 +28,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.function.BiFunction; public class PythonMCPServer extends Resource implements PythonResourceWrapper { private final PyObject server; @@ -46,8 +46,8 @@ public PythonMCPServer( PythonResourceAdapter adapter, PyObject server, ResourceDescriptor descriptor, - BiFunction getResource) { - super(descriptor, getResource); + ResourceContext resourceContext) { + super(descriptor, resourceContext); this.server = server; this.adapter = adapter; } diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/JavaResourceProvider.java b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/JavaResourceProvider.java index ff7ab7c71..ca8d65f39 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/JavaResourceProvider.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/JavaResourceProvider.java @@ -19,11 +19,11 @@ package org.apache.flink.agents.plan.resourceprovider; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import java.lang.reflect.Constructor; -import java.util.function.BiFunction; /** Java Resource provider that carries resource instance to be used at runtime. */ public class JavaResourceProvider extends ResourceProvider { @@ -35,8 +35,7 @@ public JavaResourceProvider(String name, ResourceType type, ResourceDescriptor d } @Override - public Resource provide(BiFunction getResource) - throws Exception { + public Resource provide(ResourceContext resourceContext) throws Exception { String clazzName; if (descriptor.getModule() == null || descriptor.getModule().isEmpty()) { clazzName = descriptor.getClazz(); @@ -46,8 +45,8 @@ public Resource provide(BiFunction getResource) Class clazz = Class.forName(clazzName, true, Thread.currentThread().getContextClassLoader()); Constructor constructor = - clazz.getConstructor(ResourceDescriptor.class, BiFunction.class); - return (Resource) constructor.newInstance(descriptor, getResource); + clazz.getConstructor(ResourceDescriptor.class, ResourceContext.class); + return (Resource) constructor.newInstance(descriptor, resourceContext); } public ResourceDescriptor getDescriptor() { diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/JavaSerializableResourceProvider.java b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/JavaSerializableResourceProvider.java index bea54f3de..bb173d044 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/JavaSerializableResourceProvider.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/JavaSerializableResourceProvider.java @@ -22,11 +22,10 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.SerializableResource; -import java.util.function.BiFunction; - /** * Serializable Resource Provider for Java-based resources. * @@ -78,8 +77,7 @@ public String getSerializedResource() { } @Override - public Resource provide(BiFunction getResource) - throws Exception { + public Resource provide(ResourceContext resourceContext) throws Exception { if (resource == null) { resource = (SerializableResource) diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java index 1f2c60e86..54fe1d301 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java @@ -23,6 +23,7 @@ import org.apache.flink.agents.api.embedding.model.python.PythonEmbeddingModelConnection; import org.apache.flink.agents.api.embedding.model.python.PythonEmbeddingModelSetup; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; @@ -34,7 +35,6 @@ import java.util.HashMap; import java.util.Map; import java.util.Objects; -import java.util.function.BiFunction; import static org.apache.flink.util.Preconditions.checkState; @@ -74,8 +74,7 @@ public ResourceDescriptor getDescriptor() { } @Override - public Resource provide(BiFunction getResource) - throws Exception { + public Resource provide(ResourceContext resourceContext) throws Exception { checkState(pythonResourceAdapter != null, "PythonResourceAdapter is not set"); Class clazz = RESOURCE_TYPE_TO_CLASS.get(getType()); @@ -120,9 +119,10 @@ public Resource provide(BiFunction getResource) PythonResourceAdapter.class, PyObject.class, ResourceDescriptor.class, - BiFunction.class); + ResourceContext.class); return (Resource) - constructor.newInstance(pythonResourceAdapter, pyResource, descriptor, getResource); + constructor.newInstance( + pythonResourceAdapter, pyResource, descriptor, resourceContext); } @Override diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonSerializableResourceProvider.java b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonSerializableResourceProvider.java index 283c59ceb..793ec04d7 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonSerializableResourceProvider.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonSerializableResourceProvider.java @@ -19,6 +19,7 @@ package org.apache.flink.agents.plan.resourceprovider; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.SerializableResource; import org.apache.flink.agents.plan.resource.python.PythonPrompt; @@ -26,7 +27,6 @@ import java.util.Map; import java.util.Objects; -import java.util.function.BiFunction; /** * Resource Provider that carries Resource object or serialized object. @@ -68,8 +68,7 @@ public SerializableResource getResource() { } @Override - public Resource provide(BiFunction getResource) - throws Exception { + public Resource provide(ResourceContext resourceContext) throws Exception { if (resource == null) { if (this.getType() == ResourceType.PROMPT) { resource = PythonPrompt.fromSerializedMap(serialized); diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/ResourceProvider.java b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/ResourceProvider.java index e574444b2..a90bba58b 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/ResourceProvider.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/ResourceProvider.java @@ -21,12 +21,11 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.plan.serializer.ResourceProviderJsonDeserializer; import org.apache.flink.agents.plan.serializer.ResourceProviderJsonSerializer; -import java.util.function.BiFunction; - /** * Resource provider that carries resource metadata to create Resource objects at runtime. * @@ -66,10 +65,10 @@ public ResourceType getType() { /** * Create resource at runtime. * - * @param getResource helper function to get other resources declared in the same Agent + * @param resourceContext context exposing helper for fetching other resources declared in the + * same Agent and skill-related capabilities * @return the created resource * @throws Exception if the resource cannot be created */ - public abstract Resource provide(BiFunction getResource) - throws Exception; + public abstract Resource provide(ResourceContext resourceContext) throws Exception; } diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareChatModelTest.java b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareChatModelTest.java index 6729a5471..0ac51609e 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareChatModelTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareChatModelTest.java @@ -32,6 +32,7 @@ import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.prompt.Prompt; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.plan.resourceprovider.ResourceProvider; @@ -43,7 +44,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.junit.jupiter.api.Assertions.*; @@ -52,10 +52,8 @@ class AgentPlanDeclareChatModelTest { private AgentPlan agentPlan; public static class MockChatModel extends BaseChatModelSetup { - public MockChatModel( - ResourceDescriptor descriptor, - BiFunction getResource) { - super(descriptor, getResource); + public MockChatModel(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); } @Override @@ -105,9 +103,11 @@ private Resource resolveResource(String name, ResourceType type) throws Exceptio .get(type) .get(name) .provide( - (n, t) -> { - throw new UnsupportedOperationException("No dependencies expected"); - }); + ResourceContext.fromGetResource( + (n, t) -> { + throw new UnsupportedOperationException( + "No dependencies expected"); + })); } @Test @@ -147,10 +147,11 @@ void jsonRoundTrip() throws Exception { .get(ResourceType.CHAT_MODEL) .get("testChatModel") .provide( - (n, t) -> { - throw new UnsupportedOperationException( - "No dependencies expected"); - }); + ResourceContext.fromGetResource( + (n, t) -> { + throw new UnsupportedOperationException( + "No dependencies expected"); + })); ChatMessage reply = model.chat(Prompt.fromText("Hi").formatMessages(MessageRole.USER, new HashMap<>())); assertEquals("ok:Hi", reply.getContent()); @@ -177,10 +178,11 @@ void testAddChatModel() throws Exception { .get(ResourceType.CHAT_MODEL) .get("testChatModel") .provide( - (n, t) -> { - throw new UnsupportedOperationException( - "No dependencies expected"); - }); + ResourceContext.fromGetResource( + (n, t) -> { + throw new UnsupportedOperationException( + "No dependencies expected"); + })); BaseChatModelSetup expectedChatModel = (BaseChatModelSetup) resolveResource("testChatModel", ResourceType.CHAT_MODEL); Assertions.assertEquals(expectedChatModel.getClass(), actualChatModel.getClass()); diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareMCPServerTest.java b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareMCPServerTest.java index 552246014..719f195f5 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareMCPServerTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareMCPServerTest.java @@ -26,6 +26,7 @@ import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.prompt.Prompt; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceName; import org.apache.flink.agents.api.resource.ResourceType; @@ -192,9 +193,11 @@ private Resource resolveResource(String name, ResourceType type) throws Exceptio .get(type) .get(name) .provide( - (n, t) -> { - throw new UnsupportedOperationException("No dependencies expected"); - }); + ResourceContext.fromGetResource( + (n, t) -> { + throw new UnsupportedOperationException( + "No dependencies expected"); + })); } @AfterAll diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareToolFieldTest.java b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareToolFieldTest.java index bfbdbf8fe..c87cc53f3 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareToolFieldTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareToolFieldTest.java @@ -27,6 +27,7 @@ import org.apache.flink.agents.api.annotation.ToolParam; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.Tool; import org.apache.flink.agents.api.tools.ToolMetadata; @@ -125,9 +126,11 @@ private Resource resolveResource(String name, ResourceType type) throws Exceptio .get(type) .get(name) .provide( - (n, t) -> { - throw new UnsupportedOperationException("No dependencies expected"); - }); + ResourceContext.fromGetResource( + (n, t) -> { + throw new UnsupportedOperationException( + "No dependencies expected"); + })); } @Test diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareToolMethodTest.java b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareToolMethodTest.java index 88e7081b4..a3784be59 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareToolMethodTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareToolMethodTest.java @@ -28,6 +28,7 @@ import org.apache.flink.agents.api.annotation.ToolParam; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.Tool; import org.apache.flink.agents.api.tools.ToolMetadata; @@ -101,9 +102,11 @@ private Resource resolveResource(String name, ResourceType type) throws Exceptio .get(type) .get(name) .provide( - (n, t) -> { - throw new UnsupportedOperationException("No dependencies expected"); - }); + ResourceContext.fromGetResource( + (n, t) -> { + throw new UnsupportedOperationException( + "No dependencies expected"); + })); } @Test @@ -124,10 +127,11 @@ void checkToolCall(AgentPlan plan) throws Exception { .get(ResourceType.TOOL) .get("calculate") .provide( - (n, t) -> { - throw new UnsupportedOperationException( - "No dependencies expected"); - }); + ResourceContext.fromGetResource( + (n, t) -> { + throw new UnsupportedOperationException( + "No dependencies expected"); + })); ToolParameters tp = new ToolParameters( new HashMap<>( @@ -145,10 +149,11 @@ void checkToolCall(AgentPlan plan) throws Exception { .get(ResourceType.TOOL) .get("getWeather") .provide( - (n, t) -> { - throw new UnsupportedOperationException( - "No dependencies expected"); - }); + ResourceContext.fromGetResource( + (n, t) -> { + throw new UnsupportedOperationException( + "No dependencies expected"); + })); ToolResponse wr = weather.call( new ToolParameters( @@ -268,10 +273,11 @@ void testAgentPlanJsonSerializable() throws Exception { .get(ResourceType.TOOL) .get("calculate") .provide( - (n, t) -> { - throw new UnsupportedOperationException( - "No dependencies expected"); - }); + ResourceContext.fromGetResource( + (n, t) -> { + throw new UnsupportedOperationException( + "No dependencies expected"); + })); ToolResponse r = calculator.call( new ToolParameters( diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java index 9aa0f72ed..29196620f 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java @@ -26,6 +26,7 @@ import org.apache.flink.agents.api.annotation.Tool; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.SerializableResource; @@ -41,7 +42,6 @@ import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.assertj.core.api.Assertions.assertThat; @@ -106,8 +106,8 @@ public TestPythonResource( PythonResourceAdapter adapter, PyObject chatModel, ResourceDescriptor descriptor, - BiFunction getResource) { - super(descriptor, getResource); + ResourceContext resourceContext) { + super(descriptor, resourceContext); } @Override diff --git a/plan/src/test/java/org/apache/flink/agents/plan/FunctionToolPlanTest.java b/plan/src/test/java/org/apache/flink/agents/plan/FunctionToolPlanTest.java index a7cef61d1..5977fbcb3 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/FunctionToolPlanTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/FunctionToolPlanTest.java @@ -25,6 +25,7 @@ import org.apache.flink.agents.api.agents.Agent; import org.apache.flink.agents.api.annotation.Tool; import org.apache.flink.agents.api.annotation.ToolParam; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.ToolMetadata; import org.apache.flink.agents.api.tools.ToolParameters; @@ -96,10 +97,11 @@ void functionToolAgentPlan() throws Exception { .get(ResourceType.TOOL) .get("javaTool") .provide( - (n, t) -> { - throw new UnsupportedOperationException( - "No dependencies expected"); - }); + ResourceContext.fromGetResource( + (n, t) -> { + throw new UnsupportedOperationException( + "No dependencies expected"); + })); ToolResponse ok = javaTool.call( new ToolParameters( @@ -117,10 +119,11 @@ void functionToolAgentPlan() throws Exception { .get(ResourceType.TOOL) .get("pyTool") .provide( - (n, t) -> { - throw new UnsupportedOperationException( - "No dependencies expected"); - }); + ResourceContext.fromGetResource( + (n, t) -> { + throw new UnsupportedOperationException( + "No dependencies expected"); + })); ToolResponse err = pyTool.call(new ToolParameters(new HashMap<>(Map.of("x", 1)))); assertFalse(err.isSuccess()); } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/PythonMCPResourceDiscovery.java b/runtime/src/main/java/org/apache/flink/agents/runtime/PythonMCPResourceDiscovery.java index 10e21968a..32af34280 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/PythonMCPResourceDiscovery.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/PythonMCPResourceDiscovery.java @@ -25,6 +25,7 @@ import org.apache.flink.agents.plan.resource.python.PythonMCPTool; import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; import org.apache.flink.agents.plan.resourceprovider.ResourceProvider; +import org.apache.flink.agents.runtime.resource.ResourceContextImpl; import java.util.Map; @@ -73,13 +74,14 @@ public static void discoverPythonMCPResources( PythonMCPServer server = (PythonMCPServer) provider.provide( - (name, type) -> { - try { - return cache.getResource(name, type); - } catch (Exception e) { - throw new RuntimeException(e); - } - }); + new ResourceContextImpl( + (name, type) -> { + try { + return cache.getResource(name, type); + } catch (Exception e) { + throw new RuntimeException(e); + } + })); for (PythonMCPTool tool : server.listTools()) { cache.put(tool.getName(), TOOL, tool); diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/ResourceCache.java b/runtime/src/main/java/org/apache/flink/agents/runtime/ResourceCache.java index 604348b63..db8e5dc30 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/ResourceCache.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/ResourceCache.java @@ -23,6 +23,7 @@ import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; import org.apache.flink.agents.plan.resourceprovider.ResourceProvider; +import org.apache.flink.agents.runtime.resource.ResourceContextImpl; import java.util.HashMap; import java.util.Map; @@ -87,13 +88,14 @@ public synchronized Resource getResource(String name, ResourceType type) throws Resource resource = provider.provide( - (anotherName, anotherType) -> { - try { - return this.getResource(anotherName, anotherType); - } catch (Exception e) { - throw new RuntimeException(e); - } - }); + new ResourceContextImpl( + (anotherName, anotherType) -> { + try { + return this.getResource(anotherName, anotherType); + } catch (Exception e) { + throw new RuntimeException(e); + } + })); resource.open(); cache.computeIfAbsent(type, k -> new ConcurrentHashMap<>()).put(name, resource); return resource; diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index 926a5283b..2eb932916 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -60,6 +60,7 @@ import org.apache.flink.agents.runtime.python.utils.JavaResourceAdapter; import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; import org.apache.flink.agents.runtime.python.utils.PythonResourceAdapterImpl; +import org.apache.flink.agents.runtime.resource.ResourceContextImpl; import org.apache.flink.agents.runtime.utils.EventUtil; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.operators.MailboxExecutor; @@ -677,7 +678,9 @@ private void initPythonEnvironment() throws Exception { this.resourceCache, this.jobIdentifier); - javaResourceAdapter = new JavaResourceAdapter(this::getResource, pythonInterpreter); + javaResourceAdapter = + new JavaResourceAdapter( + new ResourceContextImpl(this::getResource), pythonInterpreter); if (containPythonResource || mem0Configured) { initPythonResourceAdapter(); } @@ -753,13 +756,14 @@ private void initPythonActionExecutor() throws Exception { private void initPythonResourceAdapter() throws Exception { pythonResourceAdapter = new PythonResourceAdapterImpl( - (String anotherName, ResourceType anotherType) -> { - try { - return resourceCache.getResource(anotherName, anotherType); - } catch (Exception e) { - throw new RuntimeException(e); - } - }, + new ResourceContextImpl( + (String anotherName, ResourceType anotherType) -> { + try { + return resourceCache.getResource(anotherName, anotherType); + } catch (Exception e) { + throw new RuntimeException(e); + } + }), pythonInterpreter, javaResourceAdapter); pythonResourceAdapter.open(); diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java index 2d84a59a7..81336ecb2 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java @@ -20,6 +20,7 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.vectorstores.Document; import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; @@ -28,17 +29,15 @@ import pemja.core.object.PyObject; import java.util.Map; -import java.util.function.BiFunction; /** Adapter for managing Java resources and facilitating Python-Java interoperability. */ public class JavaResourceAdapter { - private final BiFunction getResource; + private final ResourceContext resourceContext; private final transient PythonInterpreter interpreter; - public JavaResourceAdapter( - BiFunction getResource, PythonInterpreter interpreter) { - this.getResource = getResource; + public JavaResourceAdapter(ResourceContext resourceContext, PythonInterpreter interpreter) { + this.resourceContext = resourceContext; this.interpreter = interpreter; } @@ -52,7 +51,7 @@ public JavaResourceAdapter( * @throws Exception if the resource cannot be retrieved */ public Resource getResource(String name, String typeValue) throws Exception { - return getResource.apply(name, ResourceType.fromValue(typeValue)); + return resourceContext.getResource(name, ResourceType.fromValue(typeValue)); } /** diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java index 99d189b7c..238e0e8c9 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java @@ -21,6 +21,7 @@ import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.prompt.Prompt; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; @@ -35,7 +36,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; public class PythonResourceAdapterImpl implements PythonResourceAdapter { @@ -72,16 +72,16 @@ public class PythonResourceAdapterImpl implements PythonResourceAdapter { static final String FROM_JAVA_VECTOR_STORE_QUERY = PYTHON_MODULE_PREFIX + "from_java_vector_store_query"; - private final BiFunction getResource; + private final ResourceContext resourceContext; private final PythonInterpreter interpreter; private final JavaResourceAdapter javaResourceAdapter; private PyObject pythonResourceContext; public PythonResourceAdapterImpl( - BiFunction getResource, + ResourceContext resourceContext, PythonInterpreter interpreter, JavaResourceAdapter javaResourceAdapter) { - this.getResource = getResource; + this.resourceContext = resourceContext; this.interpreter = interpreter; this.javaResourceAdapter = javaResourceAdapter; } @@ -92,8 +92,16 @@ public void open() { } public Object getResource(String resourceName, String resourceType) { - Resource resource = - this.getResource.apply(resourceName, ResourceType.fromValue(resourceType)); + Resource resource; + try { + resource = + this.resourceContext.getResource( + resourceName, ResourceType.fromValue(resourceType)); + } catch (RuntimeException e) { + throw e; + } catch (Exception e) { + throw new RuntimeException(e); + } if (resource instanceof PythonResourceWrapper) { PythonResourceWrapper pythonResource = (PythonResourceWrapper) resource; return pythonResource.getPythonResource(); diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/resource/ResourceContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/resource/ResourceContextImpl.java new file mode 100644 index 000000000..11c1cb63d --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/resource/ResourceContextImpl.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.resource; + +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceType; + +import java.util.Collections; +import java.util.List; +import java.util.function.BiFunction; + +/** + * Default {@link ResourceContext} implementation that delegates resource lookup to a {@link + * BiFunction} (typically the underlying {@code ResourceCache::getResource}). + * + *

Mirrors the Python {@code flink_agents.runtime.resource_context.ResourceContextImpl}. Skill + * methods return safe defaults; callers without skills configured see empty values. + */ +public class ResourceContextImpl implements ResourceContext { + + private final BiFunction getResource; + + public ResourceContextImpl(BiFunction getResource) { + this.getResource = getResource; + } + + @Override + public Resource getResource(String name, ResourceType type) throws Exception { + try { + return getResource.apply(name, type); + } catch (RuntimeException e) { + if (e.getCause() instanceof Exception) { + throw (Exception) e.getCause(); + } + throw e; + } + } + + @Override + public String generateAvailableSkillsPrompt(List skillNames) throws Exception { + return ""; + } + + @Override + public List getSkillDirs(List skillNames) throws Exception { + return Collections.emptyList(); + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/ResourceCacheTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/ResourceCacheTest.java index 1c21494c2..ecf8ef182 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/ResourceCacheTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/ResourceCacheTest.java @@ -27,6 +27,7 @@ import org.apache.flink.agents.api.chat.model.python.PythonChatModelSetup; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.SerializableResource; @@ -41,7 +42,6 @@ import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -91,8 +91,8 @@ public TestPythonResource( PythonResourceAdapter adapter, PyObject chatModel, ResourceDescriptor descriptor, - BiFunction getResource) { - super(descriptor, getResource); + ResourceContext resourceContext) { + super(descriptor, resourceContext); } @Override diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImplTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImplTest.java index 1aba02f1e..f8821bfb2 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImplTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImplTest.java @@ -20,6 +20,7 @@ import org.apache.flink.agents.api.chat.model.python.PythonChatModelSetup; import org.apache.flink.agents.api.prompt.Prompt; import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; import org.apache.flink.agents.api.tools.Tool; @@ -33,7 +34,6 @@ import java.util.HashMap; import java.util.Map; -import java.util.function.BiFunction; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.*; @@ -41,7 +41,7 @@ public class PythonResourceAdapterImplTest { @Mock private PythonInterpreter mockInterpreter; - @Mock private BiFunction getResource; + @Mock private ResourceContext resourceContext; private PythonResourceAdapterImpl pythonResourceAdapter; private AutoCloseable mocks; @@ -49,7 +49,8 @@ public class PythonResourceAdapterImplTest { @BeforeEach void setUp() throws Exception { mocks = MockitoAnnotations.openMocks(this); - pythonResourceAdapter = new PythonResourceAdapterImpl(getResource, mockInterpreter, null); + pythonResourceAdapter = + new PythonResourceAdapterImpl(resourceContext, mockInterpreter, null); } @AfterEach @@ -91,66 +92,67 @@ void testOpen() { } @Test - void testGetResourceWithPythonResourceWrapper() { + void testGetResourceWithPythonResourceWrapper() throws Exception { String resourceName = "test_resource"; String resourceType = "chat_model"; PythonResourceWrapper mockPythonChatModelSetup = mock(PythonChatModelSetup.class); Object expectedPythonResource = new Object(); - when(getResource.apply(resourceName, ResourceType.CHAT_MODEL)) + when(resourceContext.getResource(resourceName, ResourceType.CHAT_MODEL)) .thenReturn((Resource) mockPythonChatModelSetup); when(mockPythonChatModelSetup.getPythonResource()).thenReturn(expectedPythonResource); Object result = pythonResourceAdapter.getResource(resourceName, resourceType); assertThat(result).isEqualTo(expectedPythonResource); - verify(getResource).apply(resourceName, ResourceType.CHAT_MODEL); + verify(resourceContext).getResource(resourceName, ResourceType.CHAT_MODEL); verify(mockPythonChatModelSetup).getPythonResource(); } @Test - void testGetResourceWithTool() { + void testGetResourceWithTool() throws Exception { String resourceName = "test_tool"; String resourceType = "tool"; Tool mockTool = mock(Tool.class); Object expectedPythonTool = new Object(); - when(getResource.apply(resourceName, ResourceType.TOOL)).thenReturn(mockTool); + when(resourceContext.getResource(resourceName, ResourceType.TOOL)).thenReturn(mockTool); when(mockInterpreter.invoke(PythonResourceAdapterImpl.FROM_JAVA_TOOL, mockTool)) .thenReturn(expectedPythonTool); Object result = pythonResourceAdapter.getResource(resourceName, resourceType); assertThat(result).isEqualTo(expectedPythonTool); - verify(getResource).apply(resourceName, ResourceType.TOOL); + verify(resourceContext).getResource(resourceName, ResourceType.TOOL); verify(mockInterpreter).invoke(PythonResourceAdapterImpl.FROM_JAVA_TOOL, mockTool); } @Test - void testGetResourceWithPrompt() { + void testGetResourceWithPrompt() throws Exception { String resourceName = "test_prompt"; String resourceType = "prompt"; Prompt mockPrompt = mock(Prompt.class); Object expectedPythonPrompt = new Object(); - when(getResource.apply(resourceName, ResourceType.PROMPT)).thenReturn(mockPrompt); + when(resourceContext.getResource(resourceName, ResourceType.PROMPT)).thenReturn(mockPrompt); when(mockInterpreter.invoke(PythonResourceAdapterImpl.FROM_JAVA_PROMPT, mockPrompt)) .thenReturn(expectedPythonPrompt); Object result = pythonResourceAdapter.getResource(resourceName, resourceType); assertThat(result).isEqualTo(expectedPythonPrompt); - verify(getResource).apply(resourceName, ResourceType.PROMPT); + verify(resourceContext).getResource(resourceName, ResourceType.PROMPT); verify(mockInterpreter).invoke(PythonResourceAdapterImpl.FROM_JAVA_PROMPT, mockPrompt); } @Test - void testGetResourceWithRegularResource() { + void testGetResourceWithRegularResource() throws Exception { String resourceName = "test_resource"; String resourceType = "chat_model"; Resource mockResource = mock(Resource.class); - when(getResource.apply(resourceName, ResourceType.CHAT_MODEL)).thenReturn(mockResource); + when(resourceContext.getResource(resourceName, ResourceType.CHAT_MODEL)) + .thenReturn(mockResource); pythonResourceAdapter.getResource(resourceName, resourceType); From 5a768ab542b39c50c7fa8751cf85066b8a48185e Mon Sep 17 00:00:00 2001 From: WenjinXie Date: Wed, 6 May 2026 17:51:51 +0800 Subject: [PATCH 3/4] [plan][java] Add bash tool in java. Co-Authored-By: Claude Opus 4.7 (1M context) --- plan/pom.xml | 11 + .../agents/plan/tools/bash/BashTool.java | 163 ++++++++++++ .../agents/plan/tools/bash/BashValidator.java | 244 ++++++++++++++++++ .../agents/plan/tools/bash/BashToolTest.java | 90 +++++++ .../plan/tools/bash/BashValidatorTest.java | 116 +++++++++ 5 files changed, 624 insertions(+) create mode 100644 plan/src/main/java/org/apache/flink/agents/plan/tools/bash/BashTool.java create mode 100644 plan/src/main/java/org/apache/flink/agents/plan/tools/bash/BashValidator.java create mode 100644 plan/src/test/java/org/apache/flink/agents/plan/tools/bash/BashToolTest.java create mode 100644 plan/src/test/java/org/apache/flink/agents/plan/tools/bash/BashValidatorTest.java diff --git a/plan/pom.xml b/plan/pom.xml index ede26e085..02df3c2c3 100644 --- a/plan/pom.xml +++ b/plan/pom.xml @@ -50,6 +50,17 @@ under the License. com.fasterxml.jackson.core jackson-databind + + + io.github.bonede + tree-sitter + 0.25.3 + + + io.github.bonede + tree-sitter-bash + 0.23.3 + com.alibaba pemja diff --git a/plan/src/main/java/org/apache/flink/agents/plan/tools/bash/BashTool.java b/plan/src/main/java/org/apache/flink/agents/plan/tools/bash/BashTool.java new file mode 100644 index 000000000..57f1dc08f --- /dev/null +++ b/plan/src/main/java/org/apache/flink/agents/plan/tools/bash/BashTool.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.plan.tools.bash; + +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.tools.Tool; +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.apache.flink.agents.api.tools.ToolType; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +/** + * Standalone bash execution tool. + * + *

Mirrors the Python {@code flink_agents.plan.tools.bash.bash_tool.BashTool}. The framework + * (e.g. {@code ChatModelAction}) injects {@code allowed_commands} and {@code allowed_script_dirs} + * at call time; the model only sees {@code command}, {@code timeout} and {@code cwd}. + */ +public class BashTool extends Tool { + + private static final String DESCRIPTION = + "Execute a shell command. Only commands on the allowed list or scripts under the allowed directories may run."; + + private static final String INPUT_SCHEMA = + "{\"type\":\"object\"," + + "\"properties\":{" + + "\"command\":{\"type\":\"string\",\"description\":\"The shell command to execute.\"}," + + "\"timeout\":{\"type\":\"integer\",\"description\":\"Timeout in seconds. Defaults to 60.\",\"default\":60}," + + "\"cwd\":{\"type\":\"string\",\"description\":\"The working directory to run the command in. Defaults to the current directory. Use this instead of `cd` commands.\"}" + + "}," + + "\"required\":[\"command\"]}"; + + public BashTool(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(new ToolMetadata("bash", DESCRIPTION, INPUT_SCHEMA)); + this.resourceContext = resourceContext; + } + + @Override + public ToolType getToolType() { + return ToolType.FUNCTION; + } + + @Override + public ToolResponse call(ToolParameters parameters) { + @SuppressWarnings("unchecked") + List allowedCommands = + parameters.hasParameter("allowed_commands") + ? (List) parameters.getParameter("allowed_commands") + : Collections.emptyList(); + @SuppressWarnings("unchecked") + List allowedScriptDirs = + parameters.hasParameter("allowed_script_dirs") + ? (List) parameters.getParameter("allowed_script_dirs") + : Collections.emptyList(); + + String command = parameters.getParameter("command", String.class); + int timeout = + parameters.hasParameter("timeout") + ? parameters.getParameter("timeout", Integer.class) + : 60; + String cwd = + parameters.hasParameter("cwd") + ? parameters.getParameter("cwd", String.class) + : null; + + if (cwd != null && !BashValidator.isUnderAllowedDirs(cwd, allowedScriptDirs, null)) { + List sorted = new ArrayList<>(allowedScriptDirs); + Collections.sort(sorted); + return ToolResponse.success( + "Command rejected: cwd '" + + cwd + + "' is not under any allowed script dir. Allowed script dirs: " + + sorted + + "."); + } + + Optional error = + BashValidator.validate(command, allowedCommands, allowedScriptDirs, cwd); + if (error.isPresent()) { + return ToolResponse.success("Command rejected: " + error.get()); + } + + try { + ProcessBuilder pb = new ProcessBuilder("bash", "-c", command); + if (cwd != null) { + pb.directory(new File(cwd)); + } + Process process = pb.start(); + ByteArrayOutputStream stdout = new ByteArrayOutputStream(); + ByteArrayOutputStream stderr = new ByteArrayOutputStream(); + // Drain output streams to avoid blocking on pipe buffer fill. + Thread tOut = drainAsync(process.getInputStream(), stdout); + Thread tErr = drainAsync(process.getErrorStream(), stderr); + boolean finished = process.waitFor(timeout, TimeUnit.SECONDS); + if (!finished) { + process.destroyForcibly(); + return ToolResponse.success( + "Error: Command timed out after " + timeout + " seconds"); + } + tOut.join(); + tErr.join(); + int exit = process.exitValue(); + String stdoutStr = new String(stdout.toByteArray(), StandardCharsets.UTF_8).strip(); + String stderrStr = new String(stderr.toByteArray(), StandardCharsets.UTF_8).strip(); + if (exit == 0) { + return ToolResponse.success(stdoutStr.isEmpty() ? "Success" : stdoutStr); + } + return ToolResponse.success("Error (exit code " + exit + "): " + stderrStr); + } catch (IOException | InterruptedException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + return ToolResponse.success("Error: " + e.getMessage()); + } + } + + private static Thread drainAsync(InputStream stream, ByteArrayOutputStream sink) { + Thread t = + new Thread( + () -> { + try (InputStream in = stream) { + byte[] buf = new byte[4096]; + int n; + while ((n = in.read(buf)) > 0) { + sink.write(buf, 0, n); + } + } catch (IOException ignored) { + // process exit closes stream + } + }); + t.setDaemon(true); + t.start(); + return t; + } +} diff --git a/plan/src/main/java/org/apache/flink/agents/plan/tools/bash/BashValidator.java b/plan/src/main/java/org/apache/flink/agents/plan/tools/bash/BashValidator.java new file mode 100644 index 000000000..2d938ecfb --- /dev/null +++ b/plan/src/main/java/org/apache/flink/agents/plan/tools/bash/BashValidator.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.plan.tools.bash; + +import org.treesitter.TSNode; +import org.treesitter.TSParser; +import org.treesitter.TSTree; +import org.treesitter.TreeSitterBash; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +/** + * AST-based bash command validator backed by tree-sitter-bash. + * + *

Mirrors the Python {@code flink_agents.plan.tools.bash.bash_validator}: walks the parsed AST + * and rejects any named node whose type is not on the {@link #ALLOWED_NAMED} allowlist (e.g. {@code + * command_substitution}, {@code subshell}, {@code for_statement}, etc.); for every {@code command} + * node it requires the executable to be either in {@code allowedCommands} or to resolve to a path + * under one of {@code allowedScriptDirs}. + */ +public final class BashValidator { + + /** + * Named AST node types we accept. Anything named but missing from this set is treated as a + * potentially dangerous shell construct and rejected. Unnamed nodes (literal punctuation like + * {@code |}, {@code &&}, {@code (}) are always allowed — they're just syntax tokens. + * + *

Kept in sync with the Python {@code _ALLOWED_NAMED} set. + */ + public static final Set ALLOWED_NAMED = + Set.of( + "program", + "command", + "command_name", + // `export VAR=...`, `readonly`, `declare`, `local`, `typeset` + "declaration_command", + "pipeline", + "list", + "redirected_statement", + "file_redirect", + "file_descriptor", + "variable_assignment", + "variable_name", + "special_variable_name", // $@ $? $* $# + "word", + "string", + "string_content", + "raw_string", + "ansi_c_string", + "translated_string", + "concatenation", + "number", + "simple_expansion", // $VAR + "expansion", // ${VAR} + "arithmetic_expansion", // $((...)) + "binary_expression", + "unary_expression", + "parenthesized_expression", + "array"); + + private static final Object PARSER_LOCK = new Object(); + private static volatile TSParser parser; + + private BashValidator() {} + + private static TSParser parser() { + TSParser p = parser; + if (p == null) { + synchronized (PARSER_LOCK) { + p = parser; + if (p == null) { + TSParser created = new TSParser(); + created.setLanguage(new TreeSitterBash()); + parser = created; + p = created; + } + } + } + return p; + } + + /** + * Validate a bash command. Returns {@link Optional#empty()} when allowed, or a non-empty + * descriptive error otherwise. + */ + public static Optional validate( + String command, + List allowedCommands, + List allowedScriptDirs, + @Nullable String cwd) { + if (command == null || command.trim().isEmpty()) { + return Optional.of("Empty command."); + } + TSTree tree; + synchronized (PARSER_LOCK) { + tree = parser().parseString(null, command); + } + TSNode root = tree.getRootNode(); + if (root.hasError()) { + return Optional.of("Command has syntax errors."); + } + if (root.getChildCount() == 0) { + return Optional.of("Empty command."); + } + return walk(root, command, allowedCommands, allowedScriptDirs, cwd); + } + + private static Optional walk( + TSNode node, + String command, + List allowedCommands, + List allowedScriptDirs, + @Nullable String cwd) { + if (node.isNamed() && !ALLOWED_NAMED.contains(node.getType())) { + String snippet = nodeText(node, command); + if (snippet.length() > 80) { + snippet = snippet.substring(0, 80); + } + return Optional.of( + "Disallowed shell construct '" + node.getType() + "' in: '" + snippet + "'"); + } + if ("command".equals(node.getType())) { + Optional err = + validateCommand(node, command, allowedCommands, allowedScriptDirs, cwd); + if (err.isPresent()) { + return err; + } + } + for (int i = 0; i < node.getChildCount(); i++) { + Optional err = + walk(node.getChild(i), command, allowedCommands, allowedScriptDirs, cwd); + if (err.isPresent()) { + return err; + } + } + return Optional.empty(); + } + + private static Optional validateCommand( + TSNode commandNode, + String command, + List allowedCommands, + List allowedScriptDirs, + @Nullable String cwd) { + TSNode nameNode = commandNode.getChildByFieldName("name"); + if (nameNode == null || nameNode.isNull()) { + // Bare variable-assignment parsed as command — nothing to validate. + return Optional.empty(); + } + String executable = nodeText(nameNode, command); + if (allowedCommands.contains(executable)) { + return Optional.empty(); + } + if (isUnderAllowedDirs(executable, allowedScriptDirs, cwd)) { + return Optional.empty(); + } + Set sortedCommands = new HashSet<>(allowedCommands); + Set sortedDirs = new HashSet<>(allowedScriptDirs); + return Optional.of( + "Command '" + + executable + + "' is not allowed. Allowed commands: " + + sortedCommands + + ". Allowed script dirs: " + + sortedDirs + + "."); + } + + /** Return true when {@code pathStr} resolves to a path under any of the allowed dirs. */ + public static boolean isUnderAllowedDirs( + String pathStr, List allowedDirs, @Nullable String cwd) { + Path base; + try { + base = Path.of(pathStr); + } catch (Exception e) { + return false; + } + if (!base.isAbsolute() && cwd != null) { + base = Path.of(cwd).resolve(base); + } + Path resolved; + try { + resolved = base.toAbsolutePath().toRealPath(); + } catch (IOException e) { + try { + resolved = base.toAbsolutePath().normalize(); + } catch (Exception ee) { + return false; + } + } catch (Exception e) { + return false; + } + for (String allowed : allowedDirs) { + try { + Path allowedRoot; + try { + allowedRoot = Path.of(allowed).toAbsolutePath().toRealPath(); + } catch (IOException io) { + allowedRoot = Path.of(allowed).toAbsolutePath().normalize(); + } + if (resolved.startsWith(allowedRoot)) { + return true; + } + } catch (Exception ignored) { + // skip invalid allowed root + } + } + return false; + } + + private static String nodeText(TSNode node, String command) { + byte[] bytes = command.getBytes(StandardCharsets.UTF_8); + int start = node.getStartByte(); + int end = node.getEndByte(); + if (start < 0 || end > bytes.length || start > end) { + return ""; + } + return new String(bytes, start, end - start, StandardCharsets.UTF_8); + } +} diff --git a/plan/src/test/java/org/apache/flink/agents/plan/tools/bash/BashToolTest.java b/plan/src/test/java/org/apache/flink/agents/plan/tools/bash/BashToolTest.java new file mode 100644 index 000000000..3a134923e --- /dev/null +++ b/plan/src/test/java/org/apache/flink/agents/plan/tools/bash/BashToolTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.plan.tools.bash; + +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class BashToolTest { + + private static BashTool tool() { + return new BashTool( + new ResourceDescriptor(BashTool.class.getName(), Map.of()), + ResourceContext.fromGetResource((n, t) -> null)); + } + + private static ToolParameters args( + String command, List allowedCommands, List allowedScriptDirs) { + Map m = new HashMap<>(); + m.put("command", command); + m.put("allowed_commands", allowedCommands); + m.put("allowed_script_dirs", allowedScriptDirs); + return new ToolParameters(m); + } + + @Test + void allowedSimpleCommandRuns() { + ToolResponse r = tool().call(args("echo hello", List.of("echo"), List.of())); + assertEquals("hello", r.getResult()); + } + + @Test + void disallowedCommandRejected() { + ToolResponse r = tool().call(args("rm -rf /", List.of("echo"), List.of())); + String out = (String) r.getResult(); + assertTrue(out.startsWith("Command rejected:")); + assertTrue(out.contains("'rm' is not allowed")); + } + + @Test + void controlFlowRejected() { + ToolResponse r = + tool().call(args("for i in 1 2 3; do echo $i; done", List.of("echo"), List.of())); + String out = (String) r.getResult(); + assertTrue(out.startsWith("Command rejected:")); + } + + @Test + void successfulCommandWithEmptyOutput() { + ToolResponse r = tool().call(args("true", List.of("true"), List.of())); + assertEquals("Success", r.getResult()); + } + + @Test + void timeoutEnforced() { + Map m = new HashMap<>(); + m.put("command", "sleep 5"); + m.put("allowed_commands", List.of("sleep")); + m.put("allowed_script_dirs", List.of()); + m.put("timeout", 1); + ToolResponse r = tool().call(new ToolParameters(m)); + String out = (String) r.getResult(); + assertTrue(out.startsWith("Error: Command timed out")); + } +} diff --git a/plan/src/test/java/org/apache/flink/agents/plan/tools/bash/BashValidatorTest.java b/plan/src/test/java/org/apache/flink/agents/plan/tools/bash/BashValidatorTest.java new file mode 100644 index 000000000..7ff28f071 --- /dev/null +++ b/plan/src/test/java/org/apache/flink/agents/plan/tools/bash/BashValidatorTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.plan.tools.bash; + +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class BashValidatorTest { + + @Test + void emptyCommandRejected() { + assertEquals( + Optional.of("Empty command."), + BashValidator.validate("", List.of("echo"), List.of(), null)); + assertEquals( + Optional.of("Empty command."), + BashValidator.validate(" ", List.of("echo"), List.of(), null)); + } + + @Test + void simpleAllowedCommandPasses() { + assertEquals( + Optional.empty(), + BashValidator.validate("echo hello", List.of("echo"), List.of(), null)); + } + + @Test + void unknownCommandRejected() { + Optional r = BashValidator.validate("rm -rf /", List.of("echo"), List.of(), null); + assertTrue(r.isPresent()); + assertTrue(r.get().contains("'rm' is not allowed")); + } + + @Test + void pipelineAllowedWhenAllPartsAllowed() { + assertEquals( + Optional.empty(), + BashValidator.validate( + "echo hi | tr a-z A-Z", List.of("echo", "tr"), List.of(), null)); + } + + @Test + void pipelineRejectedWhenAnyPartUnknown() { + Optional r = + BashValidator.validate("echo hi | grep h", List.of("echo"), List.of(), null); + assertTrue(r.isPresent()); + assertTrue(r.get().contains("'grep'")); + } + + @Test + void variableExpansionAllowed() { + assertEquals( + Optional.empty(), + BashValidator.validate("echo $HOME", List.of("echo"), List.of(), null)); + } + + @Test + void arithmeticExpansionAllowed() { + assertEquals( + Optional.empty(), + BashValidator.validate("echo $((1+2))", List.of("echo"), List.of(), null)); + } + + @Test + void commandSubstitutionRejected() { + Optional r = + BashValidator.validate("echo $(rm /etc/passwd)", List.of("echo"), List.of(), null); + assertTrue(r.isPresent()); + assertTrue(r.get().contains("command_substitution")); + } + + @Test + void backticksRejected() { + Optional r = + BashValidator.validate("echo `whoami`", List.of("echo"), List.of(), null); + assertTrue(r.isPresent()); + } + + @Test + void controlFlowRejected() { + Optional r = + BashValidator.validate( + "for i in 1 2 3; do echo $i; done", List.of("echo"), List.of(), null); + assertTrue(r.isPresent()); + assertTrue(r.get().contains("for_statement")); + } + + @Test + void redirectAllowed() { + // basic redirect of allowed command should pass + Optional r = + BashValidator.validate("echo hi > /tmp/x", List.of("echo"), List.of(), null); + assertEquals(Optional.empty(), r); + } +} From 9a28c67a816bcfb4e0527c303aa3e451b5a1697d Mon Sep 17 00:00:00 2001 From: WenjinXie Date: Wed, 6 May 2026 17:54:29 +0800 Subject: [PATCH 4/4] [api][runtime][java] Support agent skills in Java. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../flink/agents/api/annotation/Skills.java | 39 +++ .../api/chat/model/BaseChatModelSetup.java | 52 ++++ .../flink/agents/api/skills/Skills.java | 86 +++++++ .../model/BaseChatModelSetupSkillsTest.java | 179 ++++++++++++++ .../agents/api/skills/SkillsResourceTest.java | 53 ++++ .../test/SkillsIntegrationAgent.java | 119 +++++++++ .../test/SkillsIntegrationTest.java | 226 ++++++++++++++++++ .../resources/skills/joke-generator/SKILL.md | 18 ++ .../skills/joke-generator/scripts/gen_joke.py | 20 ++ .../resources/skills/math-calculator/SKILL.md | 57 +++++ .../apache/flink/agents/plan/AgentPlan.java | 62 +++++ .../agents/plan/actions/ChatModelAction.java | 47 +++- .../plan/AgentPlanDeclareSkillsTest.java | 131 ++++++++++ .../runtime/java/java_resource_wrapper.py | 14 +- runtime/pom.xml | 6 + .../python/utils/JavaResourceAdapter.java | 15 ++ .../runtime/resource/ResourceContextImpl.java | 46 +++- .../agents/runtime/skill/AgentSkill.java | 147 ++++++++++++ .../agents/runtime/skill/LoadSkillTool.java | 146 +++++++++++ .../agents/runtime/skill/SkillManager.java | 148 ++++++++++++ .../agents/runtime/skill/SkillParser.java | 121 ++++++++++ .../runtime/skill/SkillPromptProvider.java | 50 ++++ .../agents/runtime/skill/SkillRepository.java | 44 ++++ .../repository/FileSystemSkillRepository.java | 184 ++++++++++++++ .../skill/FileSystemSkillRepositoryTest.java | 105 ++++++++ .../runtime/skill/LoadSkillToolTest.java | 117 +++++++++ .../runtime/skill/SkillManagerTest.java | 92 +++++++ .../agents/runtime/skill/SkillParserTest.java | 122 ++++++++++ .../test/resources/skill_discovery_prompt.txt | 24 ++ .../src/test/resources/skills/github/SKILL.md | 47 ++++ .../resources/skills/nano-banana-pro/SKILL.md | 130 ++++++++++ .../skills/nano-banana-pro/_meta.json | 6 + .../nano-banana-pro/scripts/generate_image.py | 165 +++++++++++++ 33 files changed, 2809 insertions(+), 9 deletions(-) create mode 100644 api/src/main/java/org/apache/flink/agents/api/annotation/Skills.java create mode 100644 api/src/main/java/org/apache/flink/agents/api/skills/Skills.java create mode 100644 api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupSkillsTest.java create mode 100644 api/src/test/java/org/apache/flink/agents/api/skills/SkillsResourceTest.java create mode 100644 e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/SkillsIntegrationAgent.java create mode 100644 e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/SkillsIntegrationTest.java create mode 100644 e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/skills/joke-generator/SKILL.md create mode 100755 e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/skills/joke-generator/scripts/gen_joke.py create mode 100644 e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/skills/math-calculator/SKILL.md create mode 100644 plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareSkillsTest.java create mode 100644 runtime/src/main/java/org/apache/flink/agents/runtime/skill/AgentSkill.java create mode 100644 runtime/src/main/java/org/apache/flink/agents/runtime/skill/LoadSkillTool.java create mode 100644 runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillManager.java create mode 100644 runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillParser.java create mode 100644 runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillPromptProvider.java create mode 100644 runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillRepository.java create mode 100644 runtime/src/main/java/org/apache/flink/agents/runtime/skill/repository/FileSystemSkillRepository.java create mode 100644 runtime/src/test/java/org/apache/flink/agents/runtime/skill/FileSystemSkillRepositoryTest.java create mode 100644 runtime/src/test/java/org/apache/flink/agents/runtime/skill/LoadSkillToolTest.java create mode 100644 runtime/src/test/java/org/apache/flink/agents/runtime/skill/SkillManagerTest.java create mode 100644 runtime/src/test/java/org/apache/flink/agents/runtime/skill/SkillParserTest.java create mode 100644 runtime/src/test/resources/skill_discovery_prompt.txt create mode 100644 runtime/src/test/resources/skills/github/SKILL.md create mode 100644 runtime/src/test/resources/skills/nano-banana-pro/SKILL.md create mode 100644 runtime/src/test/resources/skills/nano-banana-pro/_meta.json create mode 100644 runtime/src/test/resources/skills/nano-banana-pro/scripts/generate_image.py diff --git a/api/src/main/java/org/apache/flink/agents/api/annotation/Skills.java b/api/src/main/java/org/apache/flink/agents/api/annotation/Skills.java new file mode 100644 index 000000000..072d360c3 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/annotation/Skills.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Marks a static method that returns an {@link org.apache.flink.agents.api.skills.Skills} resource + * describing where to load agent skills from. + * + *

Mirrors the Python {@code @skills} decorator. Multiple {@code @Skills} methods on the same + * agent are merged at plan-build time. + * + *

Note: this annotation shares its simple name with {@link + * org.apache.flink.agents.api.skills.Skills} (different package). When importing both, one of them + * must be referenced by its fully-qualified name. + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface Skills {} diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java index b0f73d80f..af7ed10b2 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java @@ -25,6 +25,7 @@ import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.skills.Skills; import org.apache.flink.agents.api.tools.Tool; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.util.Preconditions; @@ -42,6 +43,10 @@ public abstract class BaseChatModelSetup extends Resource { protected String model; protected Object prompt; protected List toolNames; + @Nullable protected List skills; + @Nullable protected String skillDiscoveryPrompt; + protected List allowedCommands; + protected List allowedScriptDirs; @Nullable protected BaseChatModelConnection connection; protected final List tools = new ArrayList<>(); @@ -52,6 +57,15 @@ public BaseChatModelSetup(ResourceDescriptor descriptor, ResourceContext resourc this.model = descriptor.getArgument("model"); this.prompt = descriptor.getArgument("prompt"); this.toolNames = descriptor.getArgument("tools"); + this.skills = descriptor.getArgument("skills"); + List declaredCommands = descriptor.getArgument("allowed_commands"); + this.allowedCommands = + declaredCommands == null ? new ArrayList<>() : new ArrayList<>(declaredCommands); + List declaredScriptDirs = descriptor.getArgument("allowed_script_dirs"); + this.allowedScriptDirs = + declaredScriptDirs == null + ? new ArrayList<>() + : new ArrayList<>(declaredScriptDirs); } /** @@ -71,6 +85,19 @@ public void open() throws Exception { this.prompt = this.resourceContext.getResource((String) this.prompt, ResourceType.PROMPT); } + if (this.skills != null) { + this.skillDiscoveryPrompt = + this.resourceContext.generateAvailableSkillsPrompt(this.skills); + List mutable = + this.toolNames == null ? new ArrayList<>() : new ArrayList<>(this.toolNames); + if (!mutable.contains(Skills.LOAD_SKILL_TOOL)) { + mutable.add(Skills.LOAD_SKILL_TOOL); + } + if (!mutable.contains(Skills.BASH_TOOL)) { + mutable.add(Skills.BASH_TOOL); + } + this.toolNames = mutable; + } if (this.toolNames != null) { for (String name : this.toolNames) { this.tools.add((Tool) this.resourceContext.getResource(name, ResourceType.TOOL)); @@ -115,6 +142,13 @@ public ChatMessage chat(List messages, Map paramete messages = promptMessages; } + if (this.skillDiscoveryPrompt != null && !this.skillDiscoveryPrompt.isEmpty()) { + int idx = ChatMessage.findFirstSystemMessage(messages); + List mutated = new ArrayList<>(messages); + mutated.add(idx + 1, new ChatMessage(MessageRole.SYSTEM, this.skillDiscoveryPrompt)); + messages = mutated; + } + Map params = this.getParameters(); params.putAll(parameters); return connection.chat(messages, tools, params); @@ -144,4 +178,22 @@ public Object getPrompt() { public List getToolNames() { return toolNames; } + + @Nullable + public List getSkills() { + return skills; + } + + @Nullable + public String getSkillDiscoveryPrompt() { + return skillDiscoveryPrompt; + } + + public List getAllowedCommands() { + return allowedCommands; + } + + public List getAllowedScriptDirs() { + return allowedScriptDirs; + } } diff --git a/api/src/main/java/org/apache/flink/agents/api/skills/Skills.java b/api/src/main/java/org/apache/flink/agents/api/skills/Skills.java new file mode 100644 index 000000000..c4a8de7be --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/skills/Skills.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.skills; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.resource.SerializableResource; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Configuration resource describing where to load agent skills from. + * + *

Mirrors the Python {@code flink_agents.api.skills.Skills}. Use {@link + * #fromLocalDir(String...)} to construct. + * + *

Multiple {@code @Skills} declarations on the same agent are merged at plan-build time. + */ +@JsonIgnoreProperties( + ignoreUnknown = true, + value = {"metricGroup", "resourceType"}) +public class Skills extends SerializableResource { + + /** Reserved resource name under which AgentPlan registers the merged Skills config. */ + public static final String SKILLS_CONFIG = "_skills_config"; + + /** Reserved name of the built-in skill loader tool. */ + public static final String LOAD_SKILL_TOOL = "load_skill"; + + /** Reserved name of the built-in bash tool used to execute skill scripts. */ + public static final String BASH_TOOL = "bash"; + + private List paths; + + /** Required by Jackson. */ + public Skills() { + this.paths = Collections.emptyList(); + } + + @JsonCreator + public Skills(@JsonProperty("paths") List paths) { + this.paths = paths == null ? Collections.emptyList() : List.copyOf(paths); + } + + /** + * Create a {@link Skills} resource from one or more local filesystem directories. + * + *

Each path points to a directory whose immediate subdirectories each contain a {@code + * SKILL.md} file. + */ + public static Skills fromLocalDir(String... paths) { + return new Skills(Arrays.asList(paths)); + } + + @JsonProperty("paths") + public List getPaths() { + return paths; + } + + @JsonIgnore + @Override + public ResourceType getResourceType() { + return ResourceType.SKILLS; + } +} diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupSkillsTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupSkillsTest.java new file mode 100644 index 000000000..a31d99710 --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupSkillsTest.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.skills.Skills; +import org.apache.flink.agents.api.tools.Tool; +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.apache.flink.agents.api.tools.ToolType; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class BaseChatModelSetupSkillsTest { + + /** Stub chat model setup that exposes a configurable parameters map. */ + private static class StubChatSetup extends BaseChatModelSetup { + public StubChatSetup(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); + } + + @Override + public Map getParameters() { + return new HashMap<>(); + } + } + + private static class StubConnection extends BaseChatModelConnection { + List capturedMessages; + List capturedTools; + + StubConnection(ResourceDescriptor d, ResourceContext c) { + super(d, c); + } + + @Override + public ChatMessage chat( + List messages, List tools, Map arguments) { + this.capturedMessages = new ArrayList<>(messages); + this.capturedTools = new ArrayList<>(tools); + return new ChatMessage(MessageRole.ASSISTANT, "ok"); + } + } + + private static class StubTool extends Tool { + public StubTool(String name) { + super(new ToolMetadata(name, "stub", "{}")); + } + + @Override + public ToolType getToolType() { + return ToolType.FUNCTION; + } + + @Override + public ToolResponse call(ToolParameters parameters) { + return ToolResponse.success(""); + } + } + + @Test + void openInjectsSkillToolsAndDiscoveryPrompt() throws Exception { + Map store = new HashMap<>(); + StubConnection connection = new StubConnection(new ResourceDescriptor("X", Map.of()), null); + Tool loadSkillTool = new StubTool(Skills.LOAD_SKILL_TOOL); + Tool bashTool = new StubTool(Skills.BASH_TOOL); + store.put("conn", connection); + store.put(Skills.LOAD_SKILL_TOOL, loadSkillTool); + store.put(Skills.BASH_TOOL, bashTool); + ResourceContext ctx = + new ResourceContext() { + @Override + public Resource getResource(String name, ResourceType type) { + return store.get(name); + } + + @Override + public String generateAvailableSkillsPrompt(List skillNames) { + return "\n\n" + + skillNames.get(0) + + "\n\n"; + } + + @Override + public List getSkillDirs(List skillNames) { + return List.of(); + } + }; + + Map args = new HashMap<>(); + args.put("connection", "conn"); + args.put("skills", Arrays.asList("github")); + ResourceDescriptor descriptor = new ResourceDescriptor("X", args); + + StubChatSetup setup = new StubChatSetup(descriptor, ctx); + setup.open(); + + assertNotNull(setup.getSkillDiscoveryPrompt()); + assertTrue(setup.getSkillDiscoveryPrompt().contains("")); + assertTrue(setup.getToolNames().contains(Skills.LOAD_SKILL_TOOL)); + assertTrue(setup.getToolNames().contains(Skills.BASH_TOOL)); + } + + @Test + void chatInjectsSkillPromptAfterFirstSystemMessage() throws Exception { + Map store = new HashMap<>(); + StubConnection connection = new StubConnection(new ResourceDescriptor("X", Map.of()), null); + store.put("conn", connection); + store.put(Skills.LOAD_SKILL_TOOL, new StubTool(Skills.LOAD_SKILL_TOOL)); + store.put(Skills.BASH_TOOL, new StubTool(Skills.BASH_TOOL)); + ResourceContext ctx = + new ResourceContext() { + @Override + public Resource getResource(String name, ResourceType type) { + return store.get(name); + } + + @Override + public String generateAvailableSkillsPrompt(List skillNames) { + return "marker"; + } + + @Override + public List getSkillDirs(List skillNames) { + return List.of(); + } + }; + Map args = new HashMap<>(); + args.put("connection", "conn"); + args.put("skills", Arrays.asList("github")); + StubChatSetup setup = new StubChatSetup(new ResourceDescriptor("X", args), ctx); + setup.open(); + + List input = + Arrays.asList( + new ChatMessage(MessageRole.SYSTEM, "you are an agent"), + new ChatMessage(MessageRole.USER, "hi")); + setup.chat(input); + + // Expected: SYSTEM, SYSTEM(skill_prompt), USER + assertEquals(3, connection.capturedMessages.size()); + assertEquals(MessageRole.SYSTEM, connection.capturedMessages.get(0).getRole()); + assertEquals("you are an agent", connection.capturedMessages.get(0).getContent()); + assertEquals(MessageRole.SYSTEM, connection.capturedMessages.get(1).getRole()); + assertTrue(connection.capturedMessages.get(1).getContent().contains("")); + assertEquals(MessageRole.USER, connection.capturedMessages.get(2).getRole()); + } +} diff --git a/api/src/test/java/org/apache/flink/agents/api/skills/SkillsResourceTest.java b/api/src/test/java/org/apache/flink/agents/api/skills/SkillsResourceTest.java new file mode 100644 index 000000000..864eae42c --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/skills/SkillsResourceTest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.skills; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.resource.ResourceType; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class SkillsResourceTest { + + @Test + void fromLocalDirCarriesPaths() { + Skills skills = Skills.fromLocalDir("/tmp/a", "/tmp/b"); + assertEquals(List.of("/tmp/a", "/tmp/b"), skills.getPaths()); + assertEquals(ResourceType.SKILLS, skills.getResourceType()); + } + + @Test + void roundTripsThroughJackson() throws Exception { + Skills original = Skills.fromLocalDir("/tmp/skill1", "/tmp/skill2"); + ObjectMapper mapper = new ObjectMapper(); + String json = mapper.writeValueAsString(original); + Skills restored = mapper.readValue(json, Skills.class); + assertEquals(original.getPaths(), restored.getPaths()); + } + + @Test + void reservedNamesMatchPython() { + assertEquals("_skills_config", Skills.SKILLS_CONFIG); + assertEquals("load_skill", Skills.LOAD_SKILL_TOOL); + assertEquals("bash", Skills.BASH_TOOL); + } +} diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/SkillsIntegrationAgent.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/SkillsIntegrationAgent.java new file mode 100644 index 000000000..b35338242 --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/SkillsIntegrationAgent.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integration.test; + +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.agents.Agent; +import org.apache.flink.agents.api.annotation.Action; +import org.apache.flink.agents.api.annotation.ChatModelConnection; +import org.apache.flink.agents.api.annotation.ChatModelSetup; +import org.apache.flink.agents.api.annotation.Prompt; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.event.ChatRequestEvent; +import org.apache.flink.agents.api.event.ChatResponseEvent; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceName; +import org.apache.flink.agents.api.skills.Skills; + +import java.net.URL; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * Agent that exercises the agent-skills feature end-to-end. Mirrors the Python {@code + * agent_skills_test.SkillTestAgent}: declares two skills (math-calculator, joke-generator) and a + * system prompt that instructs the model to load the skill before answering. + */ +public class SkillsIntegrationAgent extends Agent { + + /** Same model name as the Python {@code agent_skills_test} (a dashscope-hosted Qwen). */ + public static final String MODEL = "qwen3.6-plus"; + + /** + * Same default endpoint as the Python test — overridden by the {@code ACTION_BASE_URL} CI env + * var. + */ + public static final String DEFAULT_BASE_URL = "https://coding.dashscope.aliyuncs.com/v1"; + + @ChatModelConnection + public static ResourceDescriptor openaiConnection() { + String apiKey = System.getenv("ACTION_API_KEY"); + String baseUrl = System.getenv().getOrDefault("ACTION_BASE_URL", DEFAULT_BASE_URL); + return ResourceDescriptor.Builder.newBuilder( + ResourceName.ChatModel.OPENAI_COMPLETIONS_CONNECTION) + .addInitialArgument("api_key", apiKey) + .addInitialArgument("api_base_url", baseUrl) + .build(); + } + + @ChatModelSetup + public static ResourceDescriptor openaiSetup() { + return ResourceDescriptor.Builder.newBuilder( + ResourceName.ChatModel.OPENAI_COMPLETIONS_SETUP) + .addInitialArgument("connection", "openaiConnection") + .addInitialArgument("model", MODEL) + .addInitialArgument("skills", List.of("math-calculator", "joke-generator")) + .addInitialArgument("allowed_commands", List.of("echo", "bc", "python", "python3")) + .addInitialArgument("prompt", "systemPrompt") + .build(); + } + + /** Resolve the {@code skills/} test resource directory to an absolute filesystem path. */ + @org.apache.flink.agents.api.annotation.Skills + public static Skills mySkills() { + URL url = + Objects.requireNonNull( + SkillsIntegrationAgent.class.getClassLoader().getResource("skills"), + "skills/ test resources are missing"); + Path path = Paths.get(url.getPath()); + return Skills.fromLocalDir(path.toString()); + } + + @Prompt + public static org.apache.flink.agents.api.prompt.Prompt systemPrompt() { + return org.apache.flink.agents.api.prompt.Prompt.fromMessages( + Collections.singletonList( + new ChatMessage( + MessageRole.SYSTEM, + "You are a helpful assistant. Use the math-calculator skill when " + + "asked to evaluate an expression, and the joke-generator " + + "skill when asked for a joke. You must load the skill " + + "first and strictly follow the instructions of the skill."))); + } + + @Action(listenEventTypes = {InputEvent.EVENT_TYPE}) + public static void process(InputEvent event, RunnerContext ctx) throws Exception { + ctx.sendEvent( + new ChatRequestEvent( + "openaiSetup", + Collections.singletonList( + new ChatMessage(MessageRole.USER, (String) event.getInput())))); + } + + @Action(listenEventTypes = {ChatResponseEvent.EVENT_TYPE}) + public static void processChatResponse(ChatResponseEvent event, RunnerContext ctx) { + ctx.sendEvent(new OutputEvent(event.getResponse().getContent())); + } +} diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/SkillsIntegrationTest.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/SkillsIntegrationTest.java new file mode 100644 index 000000000..65360910e --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/SkillsIntegrationTest.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integration.test; + +import org.apache.flink.agents.api.AgentsExecutionEnvironment; +import org.apache.flink.agents.api.agents.Agent; +import org.apache.flink.agents.api.agents.ReActAgent; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.prompt.Prompt; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceName; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.skills.Skills; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; + +import java.net.URL; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import static org.apache.flink.agents.api.agents.AgentExecutionOptions.ERROR_HANDLING_STRATEGY; +import static org.apache.flink.agents.api.agents.AgentExecutionOptions.MAX_RETRIES; + +/** + * End-to-end tests for agent skills. Mirrors the Python {@code + * python/flink_agents/e2e_tests/e2e_tests_integration/agent_skills_test.py}, including both the + * workflow-style agent and the {@link ReActAgent} variant. + * + *

    + *
  • {@link #testWorkflowWithSkills()} — feeds prompts through {@link SkillsIntegrationAgent} + * (workflow agent) and asserts on the math/joke responses. + *
  • {@link #testReActAgentWithSkills()} — uses {@link ReActAgent} with a structured output + * schema; asserts the parsed {@code result} field equals 8 ({@code 2 ^ 3}). + *
+ * + *

Skipped unless {@code ACTION_API_KEY} (the GitHub Actions-injected env var, mirroring the + * Python test) is set; small local models do not reliably handle the multi-turn skill-loading flow. + * {@code ACTION_BASE_URL} optionally overrides the default dashscope endpoint. + */ +public class SkillsIntegrationTest { + + @Test + public void testWorkflowWithSkills() throws Exception { + Assumptions.assumeTrue( + System.getenv().get("ACTION_API_KEY") != null, + "ACTION_API_KEY is required for the skills end-to-end test."); + + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + + DataStream inputStream = + env.fromData( + "Please evaluate the expression: (2 ^ 3)", "Tell me a joke about cat."); + + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + + DataStream outputStream = + agentsEnv + .fromDataStream(inputStream, (KeySelector) value -> value) + .apply(new SkillsIntegrationAgent()) + .toDataStream(); + + CloseableIterator results = outputStream.collectAsync(); + agentsEnv.execute(); + + List responses = new ArrayList<>(); + while (results.hasNext()) { + responses.add(String.valueOf(results.next())); + } + + Assertions.assertEquals( + 2, responses.size(), String.format("Expected 2 responses, got: %s", responses)); + + String text = String.join("\n", responses); + Assertions.assertTrue( + text.contains("8"), + String.format("Math response should contain '8'. Full responses: %s", text)); + Assertions.assertTrue( + text.contains("Too many cheetahs"), + String.format( + "Joke response should contain script punchline 'Too many cheetahs'. " + + "Full responses: %s", + text)); + } + + @Test + public void testReActAgentWithSkills() throws Exception { + Assumptions.assumeTrue( + System.getenv().get("ACTION_API_KEY") != null, + "ACTION_API_KEY is required for the skills end-to-end test."); + + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env); + + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env, tableEnv); + + String apiKey = System.getenv("ACTION_API_KEY"); + String baseUrl = + System.getenv() + .getOrDefault("ACTION_BASE_URL", SkillsIntegrationAgent.DEFAULT_BASE_URL); + + // Resolve the bundled skills/ test resource directory (same fixtures as the workflow test). + URL url = + Objects.requireNonNull( + SkillsIntegrationTest.class.getClassLoader().getResource("skills"), + "skills/ test resources are missing"); + String skillsPath = Paths.get(url.getPath()).toString(); + + agentsEnv + .addResource( + "openai", + ResourceType.CHAT_MODEL_CONNECTION, + ResourceDescriptor.Builder.newBuilder( + ResourceName.ChatModel.OPENAI_COMPLETIONS_CONNECTION) + .addInitialArgument("api_key", apiKey) + .addInitialArgument("api_base_url", baseUrl) + .build()) + .addResource("my_skill", ResourceType.SKILLS, Skills.fromLocalDir(skillsPath)); + + agentsEnv.getConfig().set(ERROR_HANDLING_STRATEGY, ReActAgent.ErrorHandlingStrategy.RETRY); + agentsEnv.getConfig().set(MAX_RETRIES, 3); + + ResourceDescriptor chatModelDescriptor = + ResourceDescriptor.Builder.newBuilder( + ResourceName.ChatModel.OPENAI_COMPLETIONS_SETUP) + .addInitialArgument("connection", "openai") + .addInitialArgument("model", SkillsIntegrationAgent.MODEL) + .addInitialArgument("skills", List.of("math-calculator")) + .addInitialArgument("allowed_commands", List.of("echo", "bc")) + .build(); + + Prompt prompt = + Prompt.fromMessages( + List.of( + new ChatMessage( + MessageRole.SYSTEM, + "You are a math calculate assistant. Use the math-calculator " + + "skill when asked to evaluate an expression. You " + + "must load the skill first and strictly follow the " + + "instructions of the skill."), + new ChatMessage( + MessageRole.USER, + "Please evaluate the expression: {a} ^ {b}"))); + + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + new TypeInformation[] {BasicTypeInfo.INT_TYPE_INFO}, + new String[] {"result"}); + + Agent agent = new ReActAgent(chatModelDescriptor, prompt, outputTypeInfo); + + Table inputTable = + tableEnv.fromValues( + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.INT()), + DataTypes.FIELD("b", DataTypes.INT())), + Row.of(2, 3)); + + Schema outputSchema = + Schema.newBuilder() + .column("f0", DataTypes.ROW(DataTypes.FIELD("result", DataTypes.INT()))) + .build(); + + Table outputTable = + agentsEnv + .fromTable( + inputTable, + (KeySelector) + value -> (Integer) ((Row) value).getField("a")) + .apply(agent) + .toTable(outputSchema); + + CloseableIterator results = + tableEnv.toDataStream(outputTable) + .map((MapFunction) x -> (Row) x.getField("f0")) + .collectAsync(); + + env.execute(); + + Assertions.assertTrue( + results.hasNext(), + "ReAct agent did not produce any output — the LLM response may not have matched the " + + "output schema; rerun if so."); + Row row = (Row) results.next(); + Object result = row.getField("result"); + Assertions.assertNotNull(result, String.format("Missing result field in row %s", row)); + Assertions.assertEquals( + 8, ((Integer) result).intValue(), String.format("Expected 2 ^ 3 = 8, got %s", row)); + } +} diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/skills/joke-generator/SKILL.md b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/skills/joke-generator/SKILL.md new file mode 100644 index 000000000..eadf07abd --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/skills/joke-generator/SKILL.md @@ -0,0 +1,18 @@ +--- +name: joke-generator +description: Tell a joke about cat. +--- + +# Math Calculator Skill + +This skill provides the ability to tell a joke about cat. + +## When to Use + +Use this skill when user want to get a joke about cat. + +## Methods + +```bash +python scripts/gen_joke.py +``` diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/skills/joke-generator/scripts/gen_joke.py b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/skills/joke-generator/scripts/gen_joke.py new file mode 100755 index 000000000..3c681eba2 --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/skills/joke-generator/scripts/gen_joke.py @@ -0,0 +1,20 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# + +if __name__ == "__main__": + print("Why don't cats play poker in the jungle? Too many cheetahs. 🐱") diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/skills/math-calculator/SKILL.md b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/skills/math-calculator/SKILL.md new file mode 100644 index 000000000..b00a1592b --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/skills/math-calculator/SKILL.md @@ -0,0 +1,57 @@ +--- +name: math-calculator +description: Calculate simple mathematical expressions using shell commands. Use when the user asks to perform arithmetic calculations like addition, subtraction, multiplication, division, or more complex math expressions. +license: Apache-2.0 +compatibility: Requires bash with bc (basic calculator) +--- + +# Math Calculator Skill + +This skill provides the ability to calculate mathematical expressions using shell commands. + +## When to Use + +Use this skill when: +- Performing arithmetic calculations (add, subtract, multiply, divide) +- Evaluating mathematical expressions with parentheses +- Computing percentages or powers +- Any numeric computation requested by the user + +## Methods + +### Using `bc` (Basic Calculator) + +The `bc` command is a powerful calculator that supports: +- Basic arithmetic: `+`, `-`, `*`, `/` +- Power: `^` +- Parentheses for grouping +- Scale for decimal precision + +**Example:** +```bash +echo "2 + 3 * 4" | bc +# Output: 14 + +echo "scale=2; 10 / 3" | bc +# Output: 3.33 + +echo "(2 + 3) * 4" | bc +# Output: 20 +``` + +## Supported Operations + +| Operation | Symbol | Example | +|-----------|--------|---------| +| Addition | `+` | `5 + 3 = 8` | +| Subtraction | `-` | `10 - 4 = 6` | +| Multiplication | `*` | `6 * 7 = 42` | +| Division | `/` | `15 / 3 = 5` | +| Power | `^` (bc) or `**` (Python) | `2 ^ 3 = 8` | +| Modulo | `%` | `17 % 5 = 2` | +| Square Root | `sqrt()` (bc) | `sqrt(16) = 4` | + +## Notes + +- Use `scale=N` in `bc` to set decimal precision (default is 0, integer only) +- For floating-point division, always set `scale` in `bc` diff --git a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java index 0dde929e7..3fb151096 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java @@ -28,6 +28,7 @@ import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.SerializableResource; import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; +import org.apache.flink.agents.api.skills.Skills; import org.apache.flink.agents.api.tools.ToolMetadata; import org.apache.flink.agents.plan.actions.Action; import org.apache.flink.agents.plan.actions.ChatModelAction; @@ -42,6 +43,7 @@ import org.apache.flink.agents.plan.serializer.AgentPlanJsonSerializer; import org.apache.flink.agents.plan.tools.FunctionTool; import org.apache.flink.agents.plan.tools.ToolMetadataFactory; +import org.apache.flink.agents.plan.tools.bash.BashTool; import org.apache.flink.api.java.tuple.Tuple3; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -56,6 +58,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -337,6 +341,9 @@ private void extractJavaMCPServer(Method method) throws Exception { private void extractResourceProvidersFromAgent(Agent agent) throws Exception { Class agentClass = agent.getClass(); + // Collect Skills declarations from both @Skills methods and Agent.addResource(SKILLS, ...) + Map skillsObjects = new LinkedHashMap<>(); + // Scan all fields in the agent class for @Tool and @ChatModel annotations for (Field field : agentClass.getDeclaredFields()) { field.setAccessible(true); // Allow access to private fields @@ -411,6 +418,17 @@ private void extractResourceProvidersFromAgent(Agent agent) throws Exception { extractResource(ResourceType.EMBEDDING_MODEL_CONNECTION, method); } else if (method.isAnnotationPresent(VectorStore.class)) { extractResource(ResourceType.VECTOR_STORE, method); + } else if (method.isAnnotationPresent( + org.apache.flink.agents.api.annotation.Skills.class) + && Modifier.isStatic(method.getModifiers())) { + Object value = method.invoke(null); + if (!(value instanceof Skills)) { + throw new IllegalStateException( + "@Skills method " + + method.getName() + + " must return org.apache.flink.agents.api.skills.Skills"); + } + skillsObjects.put(method.getName(), (Skills) value); } else if (method.isAnnotationPresent(MCPServer.class)) { // Check the MCPServer annotation version to determine which version to use. MCPServer MCPServerAnnotation = method.getAnnotation(MCPServer.class); @@ -476,8 +494,52 @@ private void extractResourceProvidersFromAgent(Agent agent) throws Exception { ((org.apache.flink.agents.api.tools.FunctionTool) kv.getValue()) .getMethod()); } + } else if (type == ResourceType.SKILLS) { + for (Map.Entry kv : entry.getValue().entrySet()) { + if (kv.getValue() instanceof Skills) { + skillsObjects.put(kv.getKey(), (Skills) kv.getValue()); + } + } } } + + addSkills(skillsObjects); + } + + /** + * Mirror of Python {@code _add_skills}: register the merged Skills config under {@link + * Skills#SKILLS_CONFIG} plus the built-in {@code load_skill} and {@code bash} tools. + * + *

{@link BashTool} lives in this module so we can reference its class directly; {@code + * LoadSkillTool} lives in the runtime module and is referenced by FQN string to avoid a reverse + * dependency. + */ + private void addSkills(Map skillsObjects) throws Exception { + if (skillsObjects.isEmpty()) { + return; + } + + addResourceProvider( + new JavaResourceProvider( + Skills.LOAD_SKILL_TOOL, + ResourceType.TOOL, + new ResourceDescriptor( + "org.apache.flink.agents.runtime.skill.LoadSkillTool", + new HashMap<>()))); + addResourceProvider( + new JavaResourceProvider( + Skills.BASH_TOOL, + ResourceType.TOOL, + new ResourceDescriptor(BashTool.class.getName(), new HashMap<>()))); + + LinkedHashSet paths = new LinkedHashSet<>(); + for (Skills s : skillsObjects.values()) { + paths.addAll(s.getPaths()); + } + Skills merged = Skills.fromLocalDir(paths.toArray(new String[0])); + addResourceProvider( + JavaSerializableResourceProvider.createResourceProvider( + Skills.SKILLS_CONFIG, ResourceType.SKILLS, merged)); } /** diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java index becec4714..997fb28b9 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java @@ -36,6 +36,7 @@ import org.apache.flink.agents.api.event.ToolResponseEvent; import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.skills.Skills; import org.apache.flink.agents.api.tools.ToolResponse; import org.apache.flink.agents.plan.JavaFunction; import org.apache.flink.api.java.typeutils.RowTypeInfo; @@ -182,6 +183,7 @@ private static void handleToolCalls( ChatMessage response, UUID initialRequestId, String model, + BaseChatModelSetup chatModel, List messages, Object outputSchema, RunnerContext ctx) @@ -192,6 +194,8 @@ private static void handleToolCalls( messages, Collections.singletonList(response)); + injectBashToolArgs(response.getToolCalls(), chatModel); + ToolRequestEvent toolRequestEvent = new ToolRequestEvent(model, response.getToolCalls()); saveToolRequestEventContext( @@ -204,6 +208,46 @@ private static void handleToolCalls( ctx.sendEvent(toolRequestEvent); } + /** + * Inject framework-controlled args ({@code allowed_commands}, {@code allowed_script_dirs}) into + * bash tool calls so they remain hidden from the LLM. Mirrors Python {@code + * _inject_bash_tool_args}. + */ + @SuppressWarnings("unchecked") + private static void injectBashToolArgs( + List> toolCalls, BaseChatModelSetup chatModel) throws Exception { + if (toolCalls == null || toolCalls.isEmpty()) { + return; + } + List scriptDirs = new ArrayList<>(chatModel.getAllowedScriptDirs()); + List declaredSkills = chatModel.getSkills(); + if (declaredSkills != null + && !declaredSkills.isEmpty() + && chatModel.getResourceContext() != null) { + scriptDirs.addAll(chatModel.getResourceContext().getSkillDirs(declaredSkills)); + } + for (Map call : toolCalls) { + Object function = call.get("function"); + if (!(function instanceof Map)) { + continue; + } + Map functionMap = (Map) function; + if (!Skills.BASH_TOOL.equals(functionMap.get("name"))) { + continue; + } + Object argsObj = functionMap.get("arguments"); + Map args; + if (argsObj instanceof Map) { + args = (Map) argsObj; + } else { + args = new HashMap<>(); + functionMap.put("arguments", args); + } + args.put("allowed_commands", new ArrayList<>(chatModel.getAllowedCommands())); + args.put("allowed_script_dirs", scriptDirs); + } + } + static String cleanLlmResponse(String rawResponse) { String trimmed = rawResponse.trim(); if (trimmed.startsWith("```")) { @@ -348,7 +392,8 @@ public ChatMessage call() throws Exception { } if (!Objects.requireNonNull(response).getToolCalls().isEmpty()) { - handleToolCalls(response, initialRequestId, model, messages, outputSchema, ctx); + handleToolCalls( + response, initialRequestId, model, chatModel, messages, outputSchema, ctx); } else { Map retryStats = getRetryStats(ctx.getSensoryMemory(), initialRequestId); int totalRetryCount = retryStats.get(TOTAL_RETRY_COUNT).intValue(); diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareSkillsTest.java b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareSkillsTest.java new file mode 100644 index 000000000..e7605368a --- /dev/null +++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareSkillsTest.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.plan; + +import org.apache.flink.agents.api.agents.Agent; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.skills.Skills; +import org.apache.flink.agents.plan.resourceprovider.JavaResourceProvider; +import org.apache.flink.agents.plan.resourceprovider.JavaSerializableResourceProvider; +import org.apache.flink.agents.plan.resourceprovider.ResourceProvider; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class AgentPlanDeclareSkillsTest { + + public static class SingleSkillsAgent extends Agent { + @org.apache.flink.agents.api.annotation.Skills + public static Skills mySkills() { + return Skills.fromLocalDir("/tmp/skill-a", "/tmp/skill-b"); + } + } + + public static class MultiSkillsAgent extends Agent { + @org.apache.flink.agents.api.annotation.Skills + public static Skills first() { + return Skills.fromLocalDir("/tmp/skill-a", "/tmp/skill-b"); + } + + @org.apache.flink.agents.api.annotation.Skills + public static Skills second() { + return Skills.fromLocalDir("/tmp/skill-b", "/tmp/skill-c"); + } + } + + public static class NoSkillsAgent extends Agent {} + + @Test + void singleSkillsRegistersConfigAndBuiltInTools() throws Exception { + AgentPlan plan = new AgentPlan(new SingleSkillsAgent()); + Map> providers = plan.getResourceProviders(); + + // Skills config under reserved name + assertNotNull(providers.get(ResourceType.SKILLS)); + ResourceProvider configProvider = + providers.get(ResourceType.SKILLS).get(Skills.SKILLS_CONFIG); + assertNotNull(configProvider); + assertTrue(configProvider instanceof JavaSerializableResourceProvider); + + // load_skill + bash tools as JavaResourceProviders pointing at runtime / plan classes + Map tools = providers.get(ResourceType.TOOL); + assertNotNull(tools); + assertTrue(tools.get(Skills.LOAD_SKILL_TOOL) instanceof JavaResourceProvider); + assertEquals( + "org.apache.flink.agents.runtime.skill.LoadSkillTool", + ((JavaResourceProvider) tools.get(Skills.LOAD_SKILL_TOOL)) + .getDescriptor() + .getClazz()); + assertEquals( + "org.apache.flink.agents.plan.tools.bash.BashTool", + ((JavaResourceProvider) tools.get(Skills.BASH_TOOL)).getDescriptor().getClazz()); + } + + @Test + void multipleSkillsMethodsMergePathsWithDeduplication() throws Exception { + AgentPlan plan = new AgentPlan(new MultiSkillsAgent()); + ResourceProvider configProvider = + plan.getResourceProviders().get(ResourceType.SKILLS).get(Skills.SKILLS_CONFIG); + Skills merged = + (Skills) + ((JavaSerializableResourceProvider) configProvider) + .provide( + org.apache.flink.agents.api.resource.ResourceContext + .fromGetResource((n, t) -> null)); + // Order is preserved; "/tmp/skill-b" appears once. + assertEquals(3, merged.getPaths().size()); + assertTrue(merged.getPaths().contains("/tmp/skill-a")); + assertTrue(merged.getPaths().contains("/tmp/skill-b")); + assertTrue(merged.getPaths().contains("/tmp/skill-c")); + } + + @Test + void noSkillsLeavesNoConfigProvider() throws Exception { + AgentPlan plan = new AgentPlan(new NoSkillsAgent()); + Map skillsMap = + plan.getResourceProviders().getOrDefault(ResourceType.SKILLS, Map.of()); + assertNull(skillsMap.get(Skills.SKILLS_CONFIG)); + Map tools = + plan.getResourceProviders().getOrDefault(ResourceType.TOOL, Map.of()); + assertNull(tools.get(Skills.LOAD_SKILL_TOOL)); + assertNull(tools.get(Skills.BASH_TOOL)); + } + + @Test + void programmaticSkillsAddResourceParticipates() throws Exception { + Agent agent = new NoSkillsAgent(); + agent.addResource("more", ResourceType.SKILLS, Skills.fromLocalDir("/tmp/skill-d")); + AgentPlan plan = new AgentPlan(agent); + ResourceProvider configProvider = + plan.getResourceProviders().get(ResourceType.SKILLS).get(Skills.SKILLS_CONFIG); + assertNotNull(configProvider); + Skills merged = + (Skills) + ((JavaSerializableResourceProvider) configProvider) + .provide( + org.apache.flink.agents.api.resource.ResourceContext + .fromGetResource((n, t) -> null)); + assertEquals(java.util.List.of("/tmp/skill-d"), merged.getPaths()); + } +} diff --git a/python/flink_agents/runtime/java/java_resource_wrapper.py b/python/flink_agents/runtime/java/java_resource_wrapper.py index cc929b6ff..833057310 100644 --- a/python/flink_agents/runtime/java/java_resource_wrapper.py +++ b/python/flink_agents/runtime/java/java_resource_wrapper.py @@ -93,11 +93,17 @@ def get_resource(self, name: str, type: ResourceType) -> Resource: @override def generate_available_skills_prompt(self, *skill_names: str) -> str: - """Generate the skill discovery prompt for the given skill names.""" - # TODO: Implement after java supports agent skills. + """Generate the skill discovery prompt for the given skill names. + + Forwards to the Java ``JavaResourceAdapter#generateAvailableSkillsPrompt`` + so that a Python chat model running inside a Java agent can use skills + declared on the Java side. + """ + result = self._j_resource_adapter.generateAvailableSkillsPrompt(list(skill_names)) + return result if result is not None else "" @override def get_skill_dirs(self, *skill_names: str) -> List[str]: """Return absolute directory paths for the given skill names.""" - # TODO: Implement after java supports agent skills. - return [] + result = self._j_resource_adapter.getSkillDirs(list(skill_names)) + return list(result) if result is not None else [] diff --git a/runtime/pom.xml b/runtime/pom.xml index 39679389d..168fb2bed 100644 --- a/runtime/pom.xml +++ b/runtime/pom.xml @@ -41,6 +41,12 @@ under the License. ${project.version} + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + + org.apache.flink diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java index 81336ecb2..fbb0365e8 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java @@ -28,6 +28,7 @@ import pemja.core.PythonInterpreter; import pemja.core.object.PyObject; +import java.util.List; import java.util.Map; /** Adapter for managing Java resources and facilitating Python-Java interoperability. */ @@ -54,6 +55,20 @@ public Resource getResource(String name, String typeValue) throws Exception { return resourceContext.getResource(name, ResourceType.fromValue(typeValue)); } + /** + * Generate the available skills prompt for the given skill names. Used by the Python {@code + * JavaResourceContextWrapper} when a Python chat model running in a Java agent needs the skill + * discovery prompt. + */ + public String generateAvailableSkillsPrompt(List skillNames) throws Exception { + return resourceContext.generateAvailableSkillsPrompt(skillNames); + } + + /** Return absolute directory paths for the given skill names. */ + public List getSkillDirs(List skillNames) throws Exception { + return resourceContext.getSkillDirs(skillNames); + } + /** * Convert a Python chat message to a Java chat message. This method is intended for use by the * Python interpreter. diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/resource/ResourceContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/resource/ResourceContextImpl.java index 11c1cb63d..b7dc0eeda 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/resource/ResourceContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/resource/ResourceContextImpl.java @@ -21,6 +21,10 @@ import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.skills.Skills; +import org.apache.flink.agents.runtime.skill.SkillManager; + +import javax.annotation.Nullable; import java.util.Collections; import java.util.List; @@ -30,13 +34,17 @@ * Default {@link ResourceContext} implementation that delegates resource lookup to a {@link * BiFunction} (typically the underlying {@code ResourceCache::getResource}). * - *

Mirrors the Python {@code flink_agents.runtime.resource_context.ResourceContextImpl}. Skill - * methods return safe defaults; callers without skills configured see empty values. + *

Mirrors the Python {@code flink_agents.runtime.resource_context.ResourceContextImpl}. The + * skill methods lazily build a {@link SkillManager} from the {@code _skills_config} resource — if + * no such resource is registered they return safe defaults (empty string / empty list). */ public class ResourceContextImpl implements ResourceContext { private final BiFunction getResource; + @Nullable private volatile SkillManager skillManager; + @Nullable private volatile Skills cachedSkillsConfig; + public ResourceContextImpl(BiFunction getResource) { this.getResource = getResource; } @@ -55,11 +63,41 @@ public Resource getResource(String name, ResourceType type) throws Exception { @Override public String generateAvailableSkillsPrompt(List skillNames) throws Exception { - return ""; + SkillManager manager = ensureSkillManager(); + return manager == null ? "" : manager.generateDiscoveryPrompt(skillNames); } @Override public List getSkillDirs(List skillNames) throws Exception { - return Collections.emptyList(); + SkillManager manager = ensureSkillManager(); + return manager == null ? Collections.emptyList() : manager.getSkillDirs(skillNames); + } + + /** + * Returns the cached {@link SkillManager} for this context, or {@code null} if not configured. + */ + @Nullable + public synchronized SkillManager getSkillManager() throws Exception { + return ensureSkillManager(); + } + + @Nullable + private synchronized SkillManager ensureSkillManager() throws Exception { + Skills config; + try { + Resource r = getResource(Skills.SKILLS_CONFIG, ResourceType.SKILLS); + if (!(r instanceof Skills)) { + return null; + } + config = (Skills) r; + } catch (Exception e) { + // No skills config registered — that's fine, return null. + return null; + } + if (config != cachedSkillsConfig) { + cachedSkillsConfig = config; + skillManager = new SkillManager(config); + } + return skillManager; } } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/skill/AgentSkill.java b/runtime/src/main/java/org/apache/flink/agents/runtime/skill/AgentSkill.java new file mode 100644 index 000000000..7b32ade79 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/skill/AgentSkill.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.skill; + +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +/** + * Runtime representation of one parsed {@code SKILL.md}. + * + *

Mirrors the Python {@code flink_agents.runtime.skill.agent_skill.AgentSkill}. Resources are + * lazily loaded on first access. + */ +public final class AgentSkill { + + private final String name; + private final String description; + private final String content; + @Nullable private final String license; + @Nullable private final String compatibility; + @Nullable private final Map metadata; + @Nullable private volatile Map resources; + @Nullable private Supplier> resourceLoader; + private volatile boolean activated; + + public AgentSkill( + String name, + String description, + String content, + @Nullable String license, + @Nullable String compatibility, + @Nullable Map metadata) { + this(name, description, content, license, compatibility, metadata, null); + } + + public AgentSkill( + String name, + String description, + String content, + @Nullable String license, + @Nullable String compatibility, + @Nullable Map metadata, + @Nullable Map resources) { + Preconditions.checkArgument( + name != null && !name.isEmpty() && name.length() <= 64, + "Skill name must be 1..64 characters: %s", + name); + Preconditions.checkArgument( + description != null && !description.isEmpty() && description.length() <= 1024, + "Skill description must be 1..1024 characters"); + Preconditions.checkArgument( + content != null && !content.isEmpty(), "Skill content must not be empty"); + Preconditions.checkArgument( + compatibility == null || compatibility.length() <= 500, + "Skill compatibility must be at most 500 characters"); + this.name = name; + this.description = description; + this.content = content; + this.license = license; + this.compatibility = compatibility; + this.metadata = metadata; + this.resources = resources; + this.activated = resources != null; + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } + + public String getContent() { + return content; + } + + @Nullable + public String getLicense() { + return license; + } + + @Nullable + public String getCompatibility() { + return compatibility; + } + + @Nullable + public Map getMetadata() { + return metadata; + } + + /** Set a lazy resource loader. Must be called before the first {@link #getResource(String)}. */ + public void setResourceLoader(Supplier> loader) { + this.resourceLoader = loader; + } + + /** + * Return the content of the named resource (relative path from the skill root) or {@code null} + * if no such resource is registered. + */ + @Nullable + public String getResource(String relativePath) { + activate(); + return resources == null ? null : resources.get(relativePath); + } + + /** Return all registered resource relative paths (sorted, may be empty). */ + public List getResourcePaths() { + activate(); + if (resources == null) { + return List.of(); + } + List keys = new ArrayList<>(resources.keySet()); + keys.sort(String::compareTo); + return keys; + } + + private synchronized void activate() { + if (!activated && resourceLoader != null) { + resources = resourceLoader.get(); + activated = true; + } + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/skill/LoadSkillTool.java b/runtime/src/main/java/org/apache/flink/agents/runtime/skill/LoadSkillTool.java new file mode 100644 index 000000000..6e08bd769 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/skill/LoadSkillTool.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.skill; + +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.tools.Tool; +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.apache.flink.agents.api.tools.ToolType; +import org.apache.flink.agents.runtime.resource.ResourceContextImpl; + +import java.nio.file.Path; +import java.util.List; + +/** + * Built-in tool that returns the body of a SKILL.md (default) or a specific bundled resource. + * + *

Mirrors the Python {@code flink_agents.runtime.skill.skill_tools.LoadSkillTool}. Auto-loaded + * by {@code AgentPlan.addSkills} when an agent declares any {@code @Skills} method. + */ +public class LoadSkillTool extends Tool { + + private static final String DESCRIPTION = + "Load a skill's content or a specific resource. Use this to access skill instructions and resources."; + + private static final String INPUT_SCHEMA = + "{\"type\":\"object\"," + + "\"properties\":{" + + "\"name\":{\"type\":\"string\"," + + "\"description\":\"The name of the skill to load (e.g., 'pdf-processing').\"}," + + "\"path\":{\"type\":\"string\"," + + "\"description\":\"Optional path to a specific resource within the skill. If not provided, returns the full SKILL.md content.\"," + + "\"default\":\"SKILL.md\"}}," + + "\"required\":[\"name\"]}"; + + public LoadSkillTool(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(new ToolMetadata("load_skill", DESCRIPTION, INPUT_SCHEMA)); + this.resourceContext = resourceContext; + } + + @Override + public ToolType getToolType() { + return ToolType.FUNCTION; + } + + @Override + public ToolResponse call(ToolParameters parameters) { + String name = parameters.getParameter("name", String.class); + String path = + parameters.hasParameter("path") + ? parameters.getParameter("path", String.class) + : "SKILL.md"; + + SkillManager manager; + try { + manager = resolveSkillManager(); + } catch (Exception e) { + return ToolResponse.success( + "Skill manager not available. No skills have been registered."); + } + if (manager == null) { + return ToolResponse.success( + "Skill manager not available. No skills have been registered."); + } + + AgentSkill skill; + try { + skill = manager.getSkill(name); + } catch (IllegalArgumentException e) { + List available = manager.getAllSkillNames(); + String availableStr = + available.isEmpty() ? "No skills available." : String.join(", ", available); + return ToolResponse.success( + "Skill '" + name + "' not found. Available skills: " + availableStr); + } + + if (path == null || "SKILL.md".equals(path)) { + Path skillDir = manager.getSkillDir(name); + if (skillDir != null) { + StringBuilder files = new StringBuilder(); + for (String rel : skill.getResourcePaths()) { + files.append("") + .append(skillDir.resolve(rel)) + .append("") + .append('\n'); + } + String filesSection = files.length() == 0 ? "" : files.toString().stripTrailing(); + return ToolResponse.success( + "\n" + + "# Skill: " + + name + + "\n\n" + + skill.getContent().strip() + + "\n\n" + + "Base directory for this skill: " + + skillDir + + "\n" + + "Relative paths in this skill are relative to this base directory.\n" + + "\n" + + filesSection + + "\n\n" + + ""); + } + return ToolResponse.success(skill.getContent()); + } + + String content = skill.getResource(path); + if (content == null) { + return ToolResponse.success( + "Resource '" + + path + + "' not found in skill '" + + name + + "', Available resources: " + + skill.getResourcePaths()); + } + return ToolResponse.success(content); + } + + private SkillManager resolveSkillManager() throws Exception { + if (resourceContext instanceof ResourceContextImpl) { + return ((ResourceContextImpl) resourceContext).getSkillManager(); + } + return null; + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillManager.java new file mode 100644 index 000000000..8778bd0e7 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillManager.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.skill; + +import org.apache.flink.agents.api.skills.Skills; +import org.apache.flink.agents.runtime.skill.repository.FileSystemSkillRepository; + +import javax.annotation.Nullable; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * Loads and indexes all skills referenced by a {@link Skills} configuration. + * + *

Mirrors the Python {@code flink_agents.runtime.skill.skill_manager.SkillManager}. + */ +public class SkillManager { + + private final Skills config; + private final Map skills = new LinkedHashMap<>(); + private final Map repos = new HashMap<>(); + + public SkillManager(Skills config) { + this.config = config; + loadFromPaths(); + } + + public int size() { + return skills.size(); + } + + public AgentSkill getSkill(String name) { + AgentSkill skill = skills.get(name); + if (skill == null) { + throw new IllegalArgumentException( + "Skill " + + name + + " not found, available skill names are: " + + getAllSkillNames()); + } + return skill; + } + + public List getAllSkillNames() { + return new ArrayList<>(skills.keySet()); + } + + @Nullable + public String loadSkillResource(String skillName, String resourcePath) { + return getSkill(skillName).getResource(resourcePath); + } + + /** Build the {@code } system prompt for the given skill names. */ + public String generateDiscoveryPrompt(List names) { + if (size() == 0) { + return ""; + } + StringBuilder sb = new StringBuilder(SkillPromptProvider.SKILL_DISCOVERY_PROMPT); + for (String name : names) { + AgentSkill skill = getSkill(name); + sb.append( + String.format( + SkillPromptProvider.AVAILABLE_SKILL_TEMPLATE, + skill.getName(), + skill.getDescription())); + } + sb.append(SkillPromptProvider.AVAILABLE_SKILLS_TAG_END); + return sb.toString(); + } + + /** + * Absolute directory paths for the listed skill names (filesystem-backed only). When called + * with an empty or {@code null} list, returns directories for all filesystem-backed skills. + */ + public List getSkillDirs(List names) { + Iterable selected = (names == null || names.isEmpty()) ? repos.keySet() : names; + List dirs = new ArrayList<>(); + for (String skillName : selected) { + SkillRepository repo = repos.get(skillName); + if (repo instanceof FileSystemSkillRepository) { + Path dir = ((FileSystemSkillRepository) repo).getBaseDir().resolve(skillName); + dirs.add(dir.toString()); + } + } + return dirs; + } + + /** Return absolute directory path for a single skill, if filesystem-backed. */ + @Nullable + public Path getSkillDir(String skillName) { + SkillRepository repo = repos.get(skillName); + if (repo instanceof FileSystemSkillRepository) { + return ((FileSystemSkillRepository) repo).getBaseDir().resolve(skillName); + } + return null; + } + + /** Resolve a skill resource's relative path to an absolute path, or {@code null} if missing. */ + @Nullable + public Path resolveResourcePath(String skillName, String resourcePath) { + SkillRepository repo = repos.get(skillName); + if (repo instanceof FileSystemSkillRepository) { + Path resolved = + ((FileSystemSkillRepository) repo) + .getBaseDir() + .resolve(skillName) + .resolve(resourcePath); + if (Files.isRegularFile(resolved)) { + return resolved; + } + } + return null; + } + + private void loadFromPaths() { + for (String path : config.getPaths()) { + FileSystemSkillRepository repo = new FileSystemSkillRepository(path); + for (AgentSkill skill : repo.getSkills()) { + final String skillName = skill.getName(); + skill.setResourceLoader(() -> repo.getResources(skillName)); + skills.put(skillName, skill); + repos.put(skillName, repo); + } + } + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillParser.java b/runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillParser.java new file mode 100644 index 000000000..02dcdfdaa --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillParser.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.skill; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Parser that splits a {@code SKILL.md} file into YAML frontmatter and markdown body, then + * constructs an {@link AgentSkill}. + * + *

Mirrors the Python {@code flink_agents.runtime.skill.skill_parser}. + */ +public final class SkillParser { + + private static final Pattern FRONTMATTER = + Pattern.compile( + "^---\\s*[\\r\\n]+(.*?)[\\r\\n]*---(?:\\s*[\\r\\n]+)?(.*)", Pattern.DOTALL); + + private static final ObjectMapper YAML = new ObjectMapper(new YAMLFactory()); + + private SkillParser() {} + + /** Result of splitting a markdown file. */ + public static final class ParsedMarkdown { + public final Map metadata; + public final String content; + + public ParsedMarkdown(Map metadata, String content) { + this.metadata = metadata == null ? Collections.emptyMap() : metadata; + this.content = content == null ? "" : content; + } + } + + /** Split the markdown into YAML frontmatter and the remaining body. */ + public static ParsedMarkdown parseMarkdown(String markdown) { + if (markdown == null || markdown.isEmpty()) { + return new ParsedMarkdown(Collections.emptyMap(), ""); + } + Matcher m = FRONTMATTER.matcher(markdown); + if (!m.matches()) { + return new ParsedMarkdown(Collections.emptyMap(), markdown); + } + String yaml = m.group(1).trim(); + String body = m.group(2); + if (yaml.isEmpty()) { + return new ParsedMarkdown(Collections.emptyMap(), body); + } + try { + Map metadata = + YAML.readValue(yaml, new TypeReference>() {}); + return new ParsedMarkdown(metadata, body); + } catch (Exception e) { + throw new IllegalArgumentException( + "Invalid YAML frontmatter syntax: " + e.getMessage(), e); + } + } + + /** Parse a SKILL.md content into an {@link AgentSkill}. */ + public static AgentSkill parseSkill(String skillMdContent) { + ParsedMarkdown parsed = parseMarkdown(skillMdContent); + Map metadata = parsed.metadata; + + Object name = metadata.get("name"); + if (!(name instanceof String) || ((String) name).trim().isEmpty()) { + throw new IllegalArgumentException( + "The SKILL.md must have a YAML frontmatter including 'name' field."); + } + Object description = metadata.get("description"); + if (!(description instanceof String) || ((String) description).trim().isEmpty()) { + throw new IllegalArgumentException( + "The SKILL.md must have a YAML frontmatter including 'description' field."); + } + if (parsed.content == null || parsed.content.isEmpty()) { + throw new IllegalArgumentException( + "The SKILL.md must have a markdown content after YAML frontmatter."); + } + + Object license = metadata.get("license"); + Object compatibility = metadata.get("compatibility"); + Object inner = metadata.get("metadata"); + Map innerMetadata = null; + if (inner instanceof Map) { + innerMetadata = new HashMap<>(); + for (Map.Entry e : ((Map) inner).entrySet()) { + innerMetadata.put(String.valueOf(e.getKey()), String.valueOf(e.getValue())); + } + } + + return new AgentSkill( + ((String) name).trim(), + ((String) description).trim(), + parsed.content, + license == null ? null : license.toString(), + compatibility == null ? null : compatibility.toString(), + innerMetadata); + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillPromptProvider.java b/runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillPromptProvider.java new file mode 100644 index 000000000..2cf843866 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillPromptProvider.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.skill; + +/** + * System-prompt templates for skill discovery and activation. + * + *

Mirrors the Python {@code flink_agents.runtime.skill.skill_prompt_provider} byte-for-byte. + * Descriptions injected into {@link #AVAILABLE_SKILL_TEMPLATE} must not contain {@code %} (no skill + * fixture currently does — the SKILL.md schema bounds descriptions to plain text). + */ +public final class SkillPromptProvider { + + public static final String SKILL_DISCOVERY_PROMPT = + "## Available Skills\n" + + "\n" + + "\n" + + "Skills provide specialized capabilities and domain knowledge. Use them when they match your current task.\n" + + "\n" + + "Load a skill with `load_skill(name=\"\")` to get its full instructions.\n" + + "Individual resources (scripts, references, assets) can be loaded with a `path` argument.\n" + + "\n" + + "The loaded content includes the skill's base directory and the absolute paths of its resources.\n" + + "\n" + + "\n" + + "\n"; + + public static final String AVAILABLE_SKILL_TEMPLATE = + "\n\n%s\n%s\n\n"; + + public static final String AVAILABLE_SKILLS_TAG_END = "\n\n"; + + private SkillPromptProvider() {} +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillRepository.java b/runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillRepository.java new file mode 100644 index 000000000..9d0e48b80 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/skill/SkillRepository.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.skill; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Map; + +/** + * Source of skills. Mirrors the Python {@code + * flink_agents.runtime.skill.skill_repository.SkillRepository}. + */ +public interface SkillRepository { + + /** Return the named skill, or {@code null} if not found. */ + @Nullable + AgentSkill getSkill(String name); + + /** Return all skills in the repository (order is implementation-specific). */ + List getSkills(); + + /** + * Return the resource map for the named skill — keys are relative paths from the skill root, + * values are the file contents. + */ + Map getResources(String name); +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/skill/repository/FileSystemSkillRepository.java b/runtime/src/main/java/org/apache/flink/agents/runtime/skill/repository/FileSystemSkillRepository.java new file mode 100644 index 000000000..91e16df4d --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/skill/repository/FileSystemSkillRepository.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.skill.repository; + +import org.apache.flink.agents.runtime.skill.AgentSkill; +import org.apache.flink.agents.runtime.skill.SkillParser; +import org.apache.flink.agents.runtime.skill.SkillRepository; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.nio.charset.MalformedInputException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +/** + * Filesystem-backed {@link SkillRepository}. Each immediate subdirectory of the configured base + * directory that contains a {@code SKILL.md} file is treated as a skill. + * + *

Mirrors the Python {@code + * flink_agents.runtime.skill.repository.filesystem_repository.FileSystemSkillRepository}. + */ +public class FileSystemSkillRepository implements SkillRepository { + + private static final Logger LOG = LoggerFactory.getLogger(FileSystemSkillRepository.class); + + public static final String SKILL_MD_FILE = "SKILL.md"; + + private final Path baseDir; + + public FileSystemSkillRepository(Path baseDir) { + if (baseDir == null) { + throw new IllegalArgumentException("Base directory cannot be null"); + } + Path resolved = baseDir.toAbsolutePath().normalize(); + if (!Files.exists(resolved)) { + throw new IllegalArgumentException("Base directory does not exist: " + resolved); + } + if (!Files.isDirectory(resolved)) { + throw new IllegalArgumentException("Base directory is not a directory: " + resolved); + } + this.baseDir = resolved; + } + + public FileSystemSkillRepository(String baseDir) { + this(Path.of(baseDir)); + } + + public Path getBaseDir() { + return baseDir; + } + + @Override + @Nullable + public AgentSkill getSkill(String name) { + Path skillDir = baseDir.resolve(name); + Path skillMd = skillDir.resolve(SKILL_MD_FILE); + if (!Files.exists(skillMd)) { + return null; + } + return loadSkill(skillDir); + } + + @Override + public List getSkills() { + List skills = new ArrayList<>(); + for (String skillName : listSkillNames()) { + AgentSkill skill = getSkill(skillName); + if (skill != null) { + skills.add(skill); + } + } + return skills; + } + + @Override + public Map getResources(String name) { + Path skillDir = baseDir.resolve(name); + if (!Files.isDirectory(skillDir)) { + return Collections.emptyMap(); + } + return loadResources(skillDir); + } + + private List listSkillNames() { + List names = new ArrayList<>(); + try (Stream entries = Files.list(baseDir)) { + entries.forEach( + entry -> { + if (Files.isDirectory(entry) + && Files.exists(entry.resolve(SKILL_MD_FILE))) { + names.add(entry.getFileName().toString()); + } + }); + } catch (IOException e) { + throw new IllegalStateException("Failed to list skills under " + baseDir, e); + } + names.sort(String::compareTo); + return names; + } + + private AgentSkill loadSkill(Path skillDir) { + Path skillMd = skillDir.resolve(SKILL_MD_FILE); + if (!Files.exists(skillMd)) { + return null; + } + try { + String content = Files.readString(skillMd, StandardCharsets.UTF_8); + AgentSkill skill = SkillParser.parseSkill(content); + if (!skill.getName().equals(skillDir.getFileName().toString())) { + LOG.warn( + "The skill name {} is different from the base directory {}.", + skill.getName(), + skillDir.getFileName()); + } + return skill; + } catch (Exception e) { + throw new IllegalArgumentException("Failed to load skill from " + skillDir, e); + } + } + + private Map loadResources(Path skillDir) { + Map resources = new HashMap<>(); + try (Stream walk = Files.walk(skillDir)) { + walk.filter(Files::isRegularFile) + .forEach( + file -> { + if (file.getFileName().toString().equals(SKILL_MD_FILE)) { + return; + } + String rel = skillDir.relativize(file).toString(); + try { + resources.put( + rel, Files.readString(file, StandardCharsets.UTF_8)); + } catch (MalformedInputException mie) { + try { + byte[] bytes = Files.readAllBytes(file); + resources.put( + rel, + "base64: " + + Base64.getEncoder() + .encodeToString(bytes)); + } catch (IOException e) { + LOG.warn( + "Failed to read resource file {} as binary.", + file, + e); + } + } catch (IOException e) { + LOG.warn("Failed to read resource file {}.", file, e); + } + }); + } catch (IOException e) { + throw new IllegalStateException("Failed to walk skill dir " + skillDir, e); + } + return resources; + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/skill/FileSystemSkillRepositoryTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/skill/FileSystemSkillRepositoryTest.java new file mode 100644 index 000000000..b4c3a7781 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/skill/FileSystemSkillRepositoryTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.skill; + +import org.apache.flink.agents.runtime.skill.repository.FileSystemSkillRepository; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class FileSystemSkillRepositoryTest { + + private static Path resourcesRoot() { + return Path.of("src/test/resources/skills").toAbsolutePath(); + } + + @Test + void getSkillsReturnsSortedSkillNames() { + FileSystemSkillRepository repo = new FileSystemSkillRepository(resourcesRoot()); + List skills = repo.getSkills(); + assertEquals(2, skills.size()); + assertEquals("github", skills.get(0).getName()); + assertEquals("nano-banana-pro", skills.get(1).getName()); + } + + @Test + void getSkillReturnsNullForUnknown() { + FileSystemSkillRepository repo = new FileSystemSkillRepository(resourcesRoot()); + assertNull(repo.getSkill("does-not-exist")); + } + + @Test + void getResourcesReadsBundledFiles() { + FileSystemSkillRepository repo = new FileSystemSkillRepository(resourcesRoot()); + Map resources = repo.getResources("nano-banana-pro"); + assertNotNull(resources); + assertTrue( + resources.containsKey("scripts/generate_image.py"), + "expected scripts/generate_image.py to be loaded as a resource"); + assertTrue(resources.containsKey("_meta.json")); + } + + @Test + void resourceLoaderIsLazy() { + FileSystemSkillRepository repo = new FileSystemSkillRepository(resourcesRoot()); + AgentSkill skill = repo.getSkill("nano-banana-pro"); + assertNotNull(skill); + // resources are not loaded until requested through the loader hook. + skill.setResourceLoader(() -> repo.getResources("nano-banana-pro")); + assertEquals(2, skill.getResourcePaths().size()); + } + + @Test + void missingBaseDirRaises(@TempDir Path tempDir) { + Path missing = tempDir.resolve("missing"); + IllegalArgumentException ex = + assertThrows( + IllegalArgumentException.class, + () -> new FileSystemSkillRepository(missing)); + assertTrue(ex.getMessage().contains("does not exist")); + } + + @Test + void binaryResourceFallsBackToBase64(@TempDir Path tempDir) throws IOException { + Path skillDir = Files.createDirectory(tempDir.resolve("binary-skill")); + Files.writeString( + skillDir.resolve("SKILL.md"), + "---\nname: binary-skill\ndescription: holds a binary resource\n---\n# Body\n", + StandardCharsets.UTF_8); + // Bytes that are not valid UTF-8 (start of a 4-byte sequence with no continuation bytes). + byte[] bad = new byte[] {(byte) 0xF8, (byte) 0x88, (byte) 0x80, (byte) 0x80}; + Files.write(skillDir.resolve("blob.bin"), bad); + + FileSystemSkillRepository repo = new FileSystemSkillRepository(tempDir); + Map resources = repo.getResources("binary-skill"); + assertTrue(resources.get("blob.bin").startsWith("base64: ")); + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/skill/LoadSkillToolTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/skill/LoadSkillToolTest.java new file mode 100644 index 000000000..f048920df --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/skill/LoadSkillToolTest.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.skill; + +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.skills.Skills; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.apache.flink.agents.runtime.resource.ResourceContextImpl; +import org.junit.jupiter.api.Test; + +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class LoadSkillToolTest { + + private static ResourceContextImpl contextWithSkills() { + Skills skills = + Skills.fromLocalDir( + Path.of("src/test/resources/skills").toAbsolutePath().toString()); + Map store = new HashMap<>(); + store.put(Skills.SKILLS_CONFIG, skills); + return new ResourceContextImpl( + (name, type) -> { + if (type == ResourceType.SKILLS) { + return store.get(name); + } + return null; + }); + } + + private static LoadSkillTool tool(ResourceContextImpl ctx) { + return new LoadSkillTool( + new ResourceDescriptor(LoadSkillTool.class.getName(), Map.of()), ctx); + } + + private static ToolParameters args(String name, String path) { + Map m = new HashMap<>(); + m.put("name", name); + if (path != null) { + m.put("path", path); + } + return new ToolParameters(m); + } + + @Test + void unknownSkillReturnsAvailableList() { + LoadSkillTool t = tool(contextWithSkills()); + ToolResponse resp = t.call(args("does-not-exist", null)); + String out = (String) resp.getResult(); + assertTrue(out.contains("not found")); + assertTrue(out.contains("github")); + assertTrue(out.contains("nano-banana-pro")); + } + + @Test + void defaultPathReturnsSkillContentEnvelope() { + LoadSkillTool t = tool(contextWithSkills()); + ToolResponse resp = t.call(args("github", null)); + String out = (String) resp.getResult(); + assertTrue(out.startsWith("")); + assertTrue(out.contains("# Skill: github")); + assertTrue(out.contains("Base directory for this skill: ")); + assertTrue(out.contains("")); + } + + @Test + void resourcePathReturnsRawContent() { + LoadSkillTool t = tool(contextWithSkills()); + ToolResponse resp = t.call(args("nano-banana-pro", "scripts/generate_image.py")); + String out = (String) resp.getResult(); + // The script file should be returned verbatim (not wrapped in ). + assertTrue(!out.startsWith(" 0); + } + + @Test + void missingResourceReportsAvailable() { + LoadSkillTool t = tool(contextWithSkills()); + ToolResponse resp = t.call(args("nano-banana-pro", "no-such.txt")); + String out = (String) resp.getResult(); + assertTrue(out.contains("Resource 'no-such.txt' not found")); + assertTrue(out.contains("Available resources")); + } + + @Test + void noSkillsRegisteredReturnsFriendlyMessage() { + // Empty resource context — no _skills_config registered. + ResourceContextImpl ctx = new ResourceContextImpl((name, type) -> null); + LoadSkillTool t = tool(ctx); + ToolResponse resp = t.call(args("anything", null)); + assertEquals( + "Skill manager not available. No skills have been registered.", resp.getResult()); + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/skill/SkillManagerTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/skill/SkillManagerTest.java new file mode 100644 index 000000000..54e8178e5 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/skill/SkillManagerTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.skill; + +import org.apache.flink.agents.api.skills.Skills; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class SkillManagerTest { + + private static Skills configFromResources() { + return Skills.fromLocalDir( + Path.of("src/test/resources/skills").toAbsolutePath().toString()); + } + + @Test + void sizeAndAllSkillNames() { + SkillManager manager = new SkillManager(configFromResources()); + assertEquals(2, manager.size()); + assertEquals(List.of("github", "nano-banana-pro"), manager.getAllSkillNames()); + } + + @Test + void getSkillThrowsWithAvailableNames() { + SkillManager manager = new SkillManager(configFromResources()); + IllegalArgumentException ex = + assertThrows(IllegalArgumentException.class, () -> manager.getSkill("missing")); + assertTrue(ex.getMessage().contains("github")); + assertTrue(ex.getMessage().contains("nano-banana-pro")); + } + + @Test + void generateDiscoveryPromptMatchesGoldenFile() throws IOException { + SkillManager manager = new SkillManager(configFromResources()); + String prompt = manager.generateDiscoveryPrompt(List.of("github", "nano-banana-pro")); + String expected = + Files.readString( + Path.of("src/test/resources/skill_discovery_prompt.txt"), + StandardCharsets.UTF_8); + assertEquals(expected, prompt); + } + + @Test + void getSkillDirsEmptyArgumentReturnsAllFsBacked() { + SkillManager manager = new SkillManager(configFromResources()); + List dirs = manager.getSkillDirs(List.of()); + assertEquals(2, dirs.size()); + assertTrue(dirs.get(0).endsWith("github") || dirs.get(0).endsWith("nano-banana-pro")); + } + + @Test + void getSkillDirsReturnsNamedSkillsInOrder() { + SkillManager manager = new SkillManager(configFromResources()); + List dirs = manager.getSkillDirs(List.of("github")); + assertEquals(1, dirs.size()); + assertTrue(dirs.get(0).endsWith("github")); + } + + @Test + void resolveResourcePathLocatesBundledFile() { + SkillManager manager = new SkillManager(configFromResources()); + Path resolved = manager.resolveResourcePath("nano-banana-pro", "scripts/generate_image.py"); + assertNotNull(resolved); + assertTrue(Files.isRegularFile(resolved)); + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/skill/SkillParserTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/skill/SkillParserTest.java new file mode 100644 index 000000000..270e1374d --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/skill/SkillParserTest.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.skill; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class SkillParserTest { + + @Test + void parseMarkdownSplitsFrontmatterFromBody() { + String md = + "---\n" + + "name: my-skill\n" + + "description: Does X\n" + + "---\n" + + "# Body\n" + + "Some markdown.\n"; + SkillParser.ParsedMarkdown parsed = SkillParser.parseMarkdown(md); + assertEquals("my-skill", parsed.metadata.get("name")); + assertEquals("Does X", parsed.metadata.get("description")); + assertTrue(parsed.content.startsWith("# Body")); + } + + @Test + void parseMarkdownNoFrontmatterReturnsRawContent() { + String md = "# Just a body\nNo frontmatter."; + SkillParser.ParsedMarkdown parsed = SkillParser.parseMarkdown(md); + assertTrue(parsed.metadata.isEmpty()); + assertEquals(md, parsed.content); + } + + @Test + void parseMarkdownHandlesCrlfLineEndings() { + String md = "---\r\nname: x\r\ndescription: y\r\n---\r\nBody\r\n"; + SkillParser.ParsedMarkdown parsed = SkillParser.parseMarkdown(md); + assertEquals("x", parsed.metadata.get("name")); + assertTrue(parsed.content.contains("Body")); + } + + @Test + void parseSkillRequiresName() { + String md = "---\ndescription: y\n---\nBody\n"; + IllegalArgumentException ex = + assertThrows(IllegalArgumentException.class, () -> SkillParser.parseSkill(md)); + assertTrue(ex.getMessage().contains("name")); + } + + @Test + void parseSkillRequiresDescription() { + String md = "---\nname: x\n---\nBody\n"; + IllegalArgumentException ex = + assertThrows(IllegalArgumentException.class, () -> SkillParser.parseSkill(md)); + assertTrue(ex.getMessage().contains("description")); + } + + @Test + void parseSkillRequiresBody() { + String md = "---\nname: x\ndescription: y\n---\n"; + IllegalArgumentException ex = + assertThrows(IllegalArgumentException.class, () -> SkillParser.parseSkill(md)); + assertTrue(ex.getMessage().contains("markdown content")); + } + + @Test + void parseSkillSurfacesYamlSyntaxError() { + String md = "---\nname: x\n bad-indent: : :\n---\nBody\n"; + IllegalArgumentException ex = + assertThrows(IllegalArgumentException.class, () -> SkillParser.parseSkill(md)); + assertTrue(ex.getMessage().startsWith("Invalid YAML frontmatter syntax")); + } + + @Test + void parseSkillTrimsNameAndDescription() { + String md = + "---\nname: \" trimmed-name \"\ndescription: \" trimmed desc \"\n---\nBody\n"; + AgentSkill skill = SkillParser.parseSkill(md); + assertEquals("trimmed-name", skill.getName()); + assertEquals("trimmed desc", skill.getDescription()); + } + + @Test + void parseSkillCarriesOptionalFields() { + String md = + "---\n" + + "name: test\n" + + "description: Does X\n" + + "license: Apache-2.0\n" + + "compatibility: macOS, Linux\n" + + "metadata:\n" + + " author: alice\n" + + " version: \"1.2\"\n" + + "---\n" + + "# Body\n"; + AgentSkill skill = SkillParser.parseSkill(md); + assertEquals("Apache-2.0", skill.getLicense()); + assertEquals("macOS, Linux", skill.getCompatibility()); + assertNotNull(skill.getMetadata()); + assertEquals("alice", skill.getMetadata().get("author")); + assertEquals("1.2", skill.getMetadata().get("version")); + } +} diff --git a/runtime/src/test/resources/skill_discovery_prompt.txt b/runtime/src/test/resources/skill_discovery_prompt.txt new file mode 100644 index 000000000..11d2e6aca --- /dev/null +++ b/runtime/src/test/resources/skill_discovery_prompt.txt @@ -0,0 +1,24 @@ +## Available Skills + + +Skills provide specialized capabilities and domain knowledge. Use them when they match your current task. + +Load a skill with `load_skill(name="")` to get its full instructions. +Individual resources (scripts, references, assets) can be loaded with a `path` argument. + +The loaded content includes the skill's base directory and the absolute paths of its resources. + + + + + +github +Interact with GitHub using the `gh` CLI. Use `gh issue`, `gh pr`, `gh run`, and `gh api` for issues, PRs, CI runs, and advanced queries. + + + +nano-banana-pro +Generate/edit images with Nano Banana Pro (Gemini 3 Pro Image). Use for image create/modify requests incl. edits. Supports text-to-image + image-to-image; 1K/2K/4K; use --input-image. + + + diff --git a/runtime/src/test/resources/skills/github/SKILL.md b/runtime/src/test/resources/skills/github/SKILL.md new file mode 100644 index 000000000..2c9356cef --- /dev/null +++ b/runtime/src/test/resources/skills/github/SKILL.md @@ -0,0 +1,47 @@ +--- +name: github +description: "Interact with GitHub using the `gh` CLI. Use `gh issue`, `gh pr`, `gh run`, and `gh api` for issues, PRs, CI runs, and advanced queries." +--- + +# GitHub Skill + +Use the `gh` CLI to interact with GitHub. Always specify `--repo owner/repo` when not in a git directory, or use URLs directly. + +## Pull Requests + +Check CI status on a PR: +```bash +gh pr checks 55 --repo owner/repo +``` + +List recent workflow runs: +```bash +gh run list --repo owner/repo --limit 10 +``` + +View a run and see which steps failed: +```bash +gh run view --repo owner/repo +``` + +View logs for failed steps only: +```bash +gh run view --repo owner/repo --log-failed +``` + +## API for Advanced Queries + +The `gh api` command is useful for accessing data not available through other subcommands. + +Get PR with specific fields: +```bash +gh api repos/owner/repo/pulls/55 --jq '.title, .state, .user.login' +``` + +## JSON Output + +Most commands support `--json` for structured output. You can use `--jq` to filter: + +```bash +gh issue list --repo owner/repo --json number,title --jq '.[] | "\(.number): \(.title)"' +``` \ No newline at end of file diff --git a/runtime/src/test/resources/skills/nano-banana-pro/SKILL.md b/runtime/src/test/resources/skills/nano-banana-pro/SKILL.md new file mode 100644 index 000000000..711ee3ff2 --- /dev/null +++ b/runtime/src/test/resources/skills/nano-banana-pro/SKILL.md @@ -0,0 +1,130 @@ +--- +name: nano-banana-pro +description: Generate/edit images with Nano Banana Pro (Gemini 3 Pro Image). Use for image create/modify requests incl. edits. Supports text-to-image + image-to-image; 1K/2K/4K; use --input-image. +--- + +# Nano Banana Pro Image Generation & Editing + +Generate new images or edit existing ones using Google's Nano Banana Pro API (Gemini 3 Pro Image). + +## Usage + +Run the script using absolute path (do NOT cd to skill directory first): + +**Generate new image:** +```bash +uv run ~/.codex/skills/nano-banana-pro/scripts/generate_image.py --prompt "your image description" --filename "output-name.png" [--resolution 1K|2K|4K] [--api-key KEY] +``` + +**Edit existing image:** +```bash +uv run ~/.codex/skills/nano-banana-pro/scripts/generate_image.py --prompt "editing instructions" --filename "output-name.png" --input-image "path/to/input.png" [--resolution 1K|2K|4K] [--api-key KEY] +``` + +**Important:** Always run from the user's current working directory so images are saved where the user is working, not in the skill directory. + +## Default Workflow (draft → iterate → final) + +Goal: fast iteration without burning time on 4K until the prompt is correct. + +- Draft (1K): quick feedback loop + - `uv run ~/.codex/skills/nano-banana-pro/scripts/generate_image.py --prompt "" --filename "yyyy-mm-dd-hh-mm-ss-draft.png" --resolution 1K` +- Iterate: adjust prompt in small diffs; keep filename new per run + - If editing: keep the same `--input-image` for every iteration until you’re happy. +- Final (4K): only when prompt is locked + - `uv run ~/.codex/skills/nano-banana-pro/scripts/generate_image.py --prompt "" --filename "yyyy-mm-dd-hh-mm-ss-final.png" --resolution 4K` + +## Resolution Options + +The Gemini 3 Pro Image API supports three resolutions (uppercase K required): + +- **1K** (default) - ~1024px resolution +- **2K** - ~2048px resolution +- **4K** - ~4096px resolution + +Map user requests to API parameters: +- No mention of resolution → `1K` +- "low resolution", "1080", "1080p", "1K" → `1K` +- "2K", "2048", "normal", "medium resolution" → `2K` +- "high resolution", "high-res", "hi-res", "4K", "ultra" → `4K` + +## API Key + +The script checks for API key in this order: +1. `--api-key` argument (use if user provided key in chat) +2. `GEMINI_API_KEY` environment variable + +If neither is available, the script exits with an error message. + +## Preflight + Common Failures (fast fixes) + +- Preflight: + - `command -v uv` (must exist) + - `test -n \"$GEMINI_API_KEY\"` (or pass `--api-key`) + - If editing: `test -f \"path/to/input.png\"` + +- Common failures: + - `Error: No API key provided.` → set `GEMINI_API_KEY` or pass `--api-key` + - `Error loading input image:` → wrong path / unreadable file; verify `--input-image` points to a real image + - “quota/permission/403” style API errors → wrong key, no access, or quota exceeded; try a different key/account + +## Filename Generation + +Generate filenames with the pattern: `yyyy-mm-dd-hh-mm-ss-name.png` + +**Format:** `{timestamp}-{descriptive-name}.png` +- Timestamp: Current date/time in format `yyyy-mm-dd-hh-mm-ss` (24-hour format) +- Name: Descriptive lowercase text with hyphens +- Keep the descriptive part concise (1-5 words typically) +- Use context from user's prompt or conversation +- If unclear, use random identifier (e.g., `x9k2`, `a7b3`) + +Examples: +- Prompt "A serene Japanese garden" → `2025-11-23-14-23-05-japanese-garden.png` +- Prompt "sunset over mountains" → `2025-11-23-15-30-12-sunset-mountains.png` +- Prompt "create an image of a robot" → `2025-11-23-16-45-33-robot.png` +- Unclear context → `2025-11-23-17-12-48-x9k2.png` + +## Image Editing + +When the user wants to modify an existing image: +1. Check if they provide an image path or reference an image in the current directory +2. Use `--input-image` parameter with the path to the image +3. The prompt should contain editing instructions (e.g., "make the sky more dramatic", "remove the person", "change to cartoon style") +4. Common editing tasks: add/remove elements, change style, adjust colors, blur background, etc. + +## Prompt Handling + +**For generation:** Pass user's image description as-is to `--prompt`. Only rework if clearly insufficient. + +**For editing:** Pass editing instructions in `--prompt` (e.g., "add a rainbow in the sky", "make it look like a watercolor painting") + +Preserve user's creative intent in both cases. + +## Prompt Templates (high hit-rate) + +Use templates when the user is vague or when edits must be precise. + +- Generation template: + - “Create an image of: . Style: