Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/cortex-cli/src/agent_cmd/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
#[cfg(test)]
mod tests {
use crate::agent_cmd::cli::{CopyArgs, ExportArgs};
use crate::agent_cmd::loader::{
load_builtin_agents, parse_frontmatter, read_file_with_encoding,
};
use crate::agent_cmd::loader::{load_builtin_agents, parse_frontmatter};
use crate::agent_cmd::types::AgentMode;
use crate::utils::file::read_file_with_encoding;

#[test]
fn test_read_file_with_utf8() {
Expand Down
2 changes: 1 addition & 1 deletion src/cortex-cli/src/run_cmd/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ pub struct RunCli {
/// Maximum tokens for response (used for token validation).
/// If specified, cortex will validate that prompt + max_tokens
/// does not exceed the model's context limit before making the API call.
#[arg(long = "max-tokens")]
#[arg(long = "max-tokens", value_parser = clap::value_parser!(u32).range(1..))]
pub max_tokens: Option<u32>,

/// Custom system prompt to use instead of the default.
Expand Down
157 changes: 117 additions & 40 deletions src/cortex-cli/src/run_cmd/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ use super::output::{copy_to_clipboard, send_notification};
use super::session::{SessionMode, resolve_session_id};
use super::system::check_file_descriptor_limits;

#[derive(Debug, PartialEq, Eq)]
struct DryRunTokenEstimate {
user_prompt_tokens: u32,
attachment_tokens: u32,
system_prompt_tokens: u32,
tool_tokens: u32,
tool_count: u32,
total_input_tokens: u32,
max_response_tokens: Option<u32>,
total_with_max_response: Option<u32>,
}

impl RunCli {
/// Run the command.
pub async fn run(self) -> Result<()> {
Expand Down Expand Up @@ -784,55 +796,28 @@ impl RunCli {

/// Run in dry-run mode - show token estimates without executing.
async fn run_dry_run(&self, message: &str, attachments: &[FileAttachment]) -> Result<()> {
use cortex_engine::tokenizer::TokenCounter;

let config = cortex_engine::Config::default();
let model = self
.model
.as_ref()
.map(|m| resolve_model_alias(m).to_string())
.unwrap_or_else(|| config.model.clone());

let mut counter = TokenCounter::for_model(&model);

// Count user prompt tokens
let user_prompt_tokens = counter.count(message);

// Count attachment tokens
let mut attachment_tokens = 0u32;
for attachment in attachments {
let content =
std::fs::read_to_string(&attachment.path).unwrap_or_else(|_| String::new());
attachment_tokens += counter.count(&content);
// Add overhead for file markers
attachment_tokens += 20; // Approximate overhead for "--- File: ... ---" markers
}

// Estimate system prompt tokens (typical system prompt is ~500-2000 tokens)
// This is an approximation as the actual system prompt varies
let system_prompt_tokens = 1500u32;

// Estimate tool definition tokens
// Each tool definition is approximately 100-200 tokens on average
// Common tools: Execute, Read, Write, Edit, LS, Grep, Glob, etc.
let tool_count = 15; // Approximate number of default tools
let tool_tokens = tool_count * 150; // ~150 tokens per tool definition

// Calculate totals
let total_input_tokens =
user_prompt_tokens + attachment_tokens + system_prompt_tokens + tool_tokens;
let estimate = self.estimate_dry_run_tokens(&model, message, attachments);

// Output based on format
if matches!(self.format, OutputFormat::Json | OutputFormat::Jsonl) {
let output = serde_json::json!({
"dry_run": true,
"model": model,
"token_estimates": {
"user_prompt": user_prompt_tokens,
"attachments": attachment_tokens,
"system_prompt": system_prompt_tokens,
"tool_definitions": tool_tokens,
"total_input": total_input_tokens,
"user_prompt": estimate.user_prompt_tokens,
"attachments": estimate.attachment_tokens,
"system_prompt": estimate.system_prompt_tokens,
"tool_definitions": estimate.tool_tokens,
"total_input": estimate.total_input_tokens,
"max_response": estimate.max_response_tokens,
"total_with_max_response": estimate.total_with_max_response,
},
"message_preview": if message.len() > 100 {
format!("{}...", &message[..100])
Expand All @@ -849,24 +834,36 @@ impl RunCli {
println!("Model: {}", model);
println!();
println!("Token Breakdown:");
println!(" User prompt: {:>8} tokens", user_prompt_tokens);
println!(
" User prompt: {:>8} tokens",
estimate.user_prompt_tokens
);
if !attachments.is_empty() {
println!(
" Attachments: {:>8} tokens ({} files)",
attachment_tokens,
estimate.attachment_tokens,
attachments.len()
);
}
println!(
" System prompt: {:>8} tokens (estimated)",
system_prompt_tokens
estimate.system_prompt_tokens
);
println!(
" Tool definitions: {:>8} tokens (estimated, {} tools)",
tool_tokens, tool_count
estimate.tool_tokens, estimate.tool_count
);
println!(" {}", "-".repeat(30));
println!(" Total input: {:>8} tokens", total_input_tokens);
println!(
" Total input: {:>8} tokens",
estimate.total_input_tokens
);
if let Some(max_tokens) = estimate.max_response_tokens {
println!(" Max response: {:>8} tokens", max_tokens);
if let Some(total_with_max_response) = estimate.total_with_max_response {
println!(" Input + response: {:>8} tokens", total_with_max_response);
}
}
println!();
println!("Note: System prompt and tool definition token counts are estimates.");
println!("Actual counts may vary based on agent configuration.");
Expand All @@ -884,4 +881,84 @@ impl RunCli {

Ok(())
}

fn estimate_dry_run_tokens(
&self,
model: &str,
message: &str,
attachments: &[FileAttachment],
) -> DryRunTokenEstimate {
use cortex_engine::tokenizer::TokenCounter;

let mut counter = TokenCounter::for_model(&model);

// 统计用户提示词 token。
let user_prompt_tokens = counter.count(message);

// 统计附件 token。
let mut attachment_tokens = 0u32;
for attachment in attachments {
let content =
std::fs::read_to_string(&attachment.path).unwrap_or_else(|_| String::new());
attachment_tokens += counter.count(&content);
// 加上文件标记的近似开销。
attachment_tokens += 20;
}

// 系统提示词会随配置变化,这里使用常见范围内的近似值。
let system_prompt_tokens = 1500u32;

// 工具定义按默认工具数量和单个工具的平均 token 估算。
let tool_count = 15;
let tool_tokens = tool_count * 150;

// 计算输入总量,并在传入 max_tokens 时纳入响应上限。
let total_input_tokens =
user_prompt_tokens + attachment_tokens + system_prompt_tokens + tool_tokens;
let total_with_max_response = self
.max_tokens
.map(|max| total_input_tokens.saturating_add(max));

DryRunTokenEstimate {
user_prompt_tokens,
attachment_tokens,
system_prompt_tokens,
tool_tokens,
tool_count,
total_input_tokens,
max_response_tokens: self.max_tokens,
total_with_max_response,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use clap::Parser;

#[test]
fn dry_run_estimate_includes_max_response_tokens() {
let cli = RunCli::try_parse_from(["run", "--dry-run", "--max-tokens", "4096", "Long task"])
.expect("max tokens above zero should parse");

let estimate = cli.estimate_dry_run_tokens("gpt-4o", "Long task", &[]);

assert_eq!(estimate.max_response_tokens, Some(4096));
assert_eq!(
estimate.total_with_max_response,
Some(estimate.total_input_tokens + 4096)
);
}

#[test]
fn run_rejects_zero_max_tokens() {
let error = RunCli::try_parse_from(["run", "--max-tokens", "0", "Long task"])
.expect_err("zero max tokens should be rejected");

assert!(
error.to_string().contains("invalid value"),
"unexpected error: {error}"
);
}
}