Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import io.agentscope.core.memory.StaticLongTermMemoryHook;
import io.agentscope.core.message.ContentBlock;
import io.agentscope.core.message.GenerateReason;
import io.agentscope.core.message.MessageMetadataKeys;
import io.agentscope.core.message.Msg;
import io.agentscope.core.message.MsgRole;
import io.agentscope.core.message.TextBlock;
Expand Down Expand Up @@ -859,6 +860,11 @@ private Mono<Void> notifyReasoningChunk(Msg chunkMsg, ReasoningContext context)
.role(chunkMsg.getRole())
.content(accumulatedContent)
.build();
if (context.getChatUsage() != null) {
accumulated
.getMetadata()
.put(MessageMetadataKeys.CHAT_USAGE, context.getChatUsage());
}
ReasoningChunkEvent event =
new ReasoningChunkEvent(
this, model.getModelName(), null, chunkMsg, accumulated);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,20 @@ public ToolUseBlock getAccumulatedToolCall(String id) {
public List<ToolUseBlock> getAllAccumulatedToolCalls() {
return toolCallsAcc.getAllAccumulatedToolCalls();
}

/**
* Get the accumulated ChatUsage.
*
* @return ChatUsage with accumulated tokens, or null if no usage data
*/
public ChatUsage getChatUsage() {
if (inputTokens > 0 || outputTokens > 0 || time > 0) {
return ChatUsage.builder()
.inputTokens(inputTokens)
.outputTokens(outputTokens)
.time(time)
.build();
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,111 @@ reactor.core.publisher.Mono<T> onEvent(T event) {
"Second tool should be calculator");
}

@Test
@DisplayName("Should include ChatUsage in accumulated message metadata when available")
void testChatUsageInAccumulatedMessageMetadata() {
// Track received ChatUsage from ReasoningChunkEvent
final java.util.List<ChatUsage> capturedChatUsages =
new java.util.concurrent.CopyOnWriteArrayList<>();
final java.util.List<Msg> capturedAccumulatedMsgs =
new java.util.concurrent.CopyOnWriteArrayList<>();

// Create a hook to capture ReasoningChunkEvent and check metadata
io.agentscope.core.hook.Hook captureHook =
new io.agentscope.core.hook.Hook() {
@Override
public <T extends io.agentscope.core.hook.HookEvent>
reactor.core.publisher.Mono<T> onEvent(T event) {
if (event
instanceof io.agentscope.core.hook.ReasoningChunkEvent chunkEvent) {
// Capture accumulated message and check its metadata
Msg accumulated = chunkEvent.getAccumulated();
if (accumulated != null) {
capturedAccumulatedMsgs.add(accumulated);

// Capture ChatUsage from metadata
Object usage =
accumulated
.getMetadata()
.get(
io.agentscope.core.message
.MessageMetadataKeys.CHAT_USAGE);
if (usage instanceof ChatUsage) {
capturedChatUsages.add((ChatUsage) usage);
}
}
}
return reactor.core.publisher.Mono.just(event);
}
};

// Setup model to return response with ChatUsage
MockModel modelWithUsage =
new MockModel(
messages -> {
return List.of(
ChatResponse.builder()
.content(
List.of(
TextBlock.builder()
.text("Test response")
.build()))
.usage(new ChatUsage(100, 50, 1.5))
.build());
});

agent =
ReActAgent.builder()
.name(TestConstants.TEST_REACT_AGENT_NAME)
.sysPrompt(TestConstants.DEFAULT_SYS_PROMPT)
.model(modelWithUsage)
.toolkit(mockToolkit)
.memory(memory)
.hook(captureHook)
.build();

// Create user message
Msg userMsg = TestUtils.createUserMessage("User", "Test message");

// Get response
Msg response =
agent.call(userMsg).block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS));

// Verify response
assertNotNull(response, "Response should not be null");

// Verify ChatUsage was captured in events
assertFalse(capturedChatUsages.isEmpty(), "Should capture ChatUsage from events");

// Verify ChatUsage values
ChatUsage capturedUsage = capturedChatUsages.get(0);
assertEquals(100, capturedUsage.getInputTokens(), "Input tokens should match");
assertEquals(50, capturedUsage.getOutputTokens(), "Output tokens should match");
assertEquals(1.5, capturedUsage.getTime(), "Time should match");

// Verify accumulated messages were captured
assertFalse(
capturedAccumulatedMsgs.isEmpty(),
"Should capture accumulated messages from events");

// Verify metadata contains CHAT_USAGE
Msg accumulatedMsg = capturedAccumulatedMsgs.get(0);
Object metadataUsage =
accumulatedMsg
.getMetadata()
.get(io.agentscope.core.message.MessageMetadataKeys.CHAT_USAGE);
assertNotNull(metadataUsage, "Accumulated message metadata should contain CHAT_USAGE");
assertTrue(
metadataUsage instanceof ChatUsage,
"Metadata CHAT_USAGE should be ChatUsage instance");

ChatUsage metadataChatUsage = (ChatUsage) metadataUsage;
assertEquals(100, metadataChatUsage.getInputTokens(), "Metadata input tokens should match");
assertEquals(
50, metadataChatUsage.getOutputTokens(), "Metadata output tokens should match");
assertEquals(1.5, metadataChatUsage.getTime(), "Metadata time should match");
}

// Helper method to create tool call response
private static ChatResponse createToolCallResponseHelper(
String toolName, String toolCallId, Map<String, Object> arguments) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ void testSingleChunkUsage() {
assertEquals(50, resultUsage.getOutputTokens());
assertEquals(150, resultUsage.getTotalTokens());
assertEquals(1.5, resultUsage.getTime(), 0.001);

resultUsage = context.getChatUsage();
assertNotNull(resultUsage);
assertEquals(100, resultUsage.getInputTokens());
assertEquals(50, resultUsage.getOutputTokens());
assertEquals(150, resultUsage.getTotalTokens());
assertEquals(1.5, resultUsage.getTime(), 0.001);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import reactor.core.publisher.Flux;

@Timeout(10)
class QuartzFixedDelayIntegrationTest {

private QuartzAgentScheduler scheduler;
Expand Down Expand Up @@ -94,14 +96,14 @@ void testFixedDelayReschedulesAndRunsMultipleTimes() throws InterruptedException
QuartzScheduleAgentTask qt = (QuartzScheduleAgentTask) task;

long start = System.currentTimeMillis();
long timeoutMs = 2000;
long timeoutMs = 5000;
long count = 0;
while (System.currentTimeMillis() - start < timeoutMs) {
count = qt.getExecutionCount();
if (count >= 2) {
break;
}
Thread.sleep(50);
Thread.sleep(100);
}
assertTrue(count >= 2);
}
Expand Down
Loading