Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 65 additions & 12 deletions src/cortex-cli/src/stats_cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,23 +202,23 @@ fn get_cortex_home() -> PathBuf {
fn get_model_pricing(model: &str) -> ModelPricing {
// First check for custom pricing from environment
let custom_pricing = load_custom_pricing();
let model_lower = model.to_lowercase();
let model_key = normalize_model_key(model);

// Check for exact match in custom pricing
if let Some(pricing) = custom_pricing.get(&model_lower) {
if let Some(pricing) = custom_pricing.get(&model_key) {
return pricing.clone();
}

// Check for partial match in custom pricing (e.g., "gpt-4o" matches "gpt-4o-mini")
for (key, pricing) in &custom_pricing {
if model_lower.contains(key) {
if model_key.contains(key) {
return pricing.clone();
}
}

// Fall back to default pricing (may be outdated - users can override via CORTEX_PRICING_*)
// Pricing per 1M tokens (as of late 2024/early 2025 - may change)
match model {
match model_key.as_str() {
// Anthropic
m if m.contains("claude-opus-4") || m.contains("opus-4") => ModelPricing {
input_per_million: 15.0,
Expand Down Expand Up @@ -294,6 +294,10 @@ fn get_model_pricing(model: &str) -> ModelPricing {
}
}

fn normalize_model_key(model: &str) -> String {
model.to_lowercase()
}

/// Calculate cost for token usage.
fn calculate_cost(model: &str, input_tokens: u64, output_tokens: u64) -> f64 {
let pricing = get_model_pricing(model);
Expand Down Expand Up @@ -406,14 +410,13 @@ async fn collect_stats(sessions_dir: &PathBuf, cli: &StatsCli) -> Result<UsageSt
provider_stats.estimated_cost_usd += session_cost;

// Per-model stats
let model_stats =
stats
.by_model
.entry(model.to_string())
.or_insert_with(|| ModelStats {
provider: provider.clone(),
..Default::default()
});
let model_stats = stats
.by_model
.entry(normalize_model_key(model))
.or_insert_with(|| ModelStats {
provider: provider.clone(),
..Default::default()
});
model_stats.sessions += 1;
model_stats.messages += session_data.message_count;
model_stats.input_tokens += session_data.input_tokens;
Expand Down Expand Up @@ -733,6 +736,56 @@ mod tests {
// GPT-4o: $2.50/$10 per 1M
let cost = calculate_cost("gpt-4o", 1_000_000, 1_000_000);
assert!((cost - 12.5).abs() < 0.001);

// Model pricing should be case-insensitive.
let cost = calculate_cost("GPT-4O", 1_000_000, 1_000_000);
assert!((cost - 12.5).abs() < 0.001);
}

#[tokio::test]
async fn test_collect_stats_normalizes_model_case() {
let temp_dir = tempfile::tempdir().unwrap();
let sessions_dir = temp_dir.path().to_path_buf();

std::fs::write(
sessions_dir.join("lower.json"),
r#"{
"created_at": "2026-04-09T00:00:00Z",
"model": "gpt-4o",
"messages": [{"role": "user", "content": "a"}],
"usage": {"input_tokens": 100, "output_tokens": 100}
}"#,
)
.unwrap();

std::fs::write(
sessions_dir.join("upper.json"),
r#"{
"created_at": "2026-04-09T00:01:00Z",
"model": "GPT-4O",
"messages": [{"role": "user", "content": "b"}],
"usage": {"input_tokens": 200, "output_tokens": 200}
}"#,
)
.unwrap();

let cli = StatsCli {
days: 3650,
provider: None,
model: None,
json: false,
verbose: true,
};

let stats = collect_stats(&sessions_dir, &cli).await.unwrap();

assert_eq!(stats.total_sessions, 2);
assert_eq!(stats.by_model.len(), 1);
let model_stats = stats.by_model.get("gpt-4o").unwrap();
assert_eq!(model_stats.sessions, 2);
assert_eq!(model_stats.messages, 2);
assert_eq!(model_stats.input_tokens, 300);
assert_eq!(model_stats.output_tokens, 300);
}

#[test]
Expand Down