diff --git a/.claude/INSTALL.md b/.claude/INSTALL.md new file mode 100644 index 0000000..042cdb3 --- /dev/null +++ b/.claude/INSTALL.md @@ -0,0 +1,132 @@ +# LMeterX Claude Code Skills 安装指南 + +## 概述 + +本项目提供三个 Claude Code Skills,用于通过自然语言驱动 LMeterX 平台执行压力测试: + +| Skill | 斜杠命令 | 用途 | +|-------|---------|------| +| `lmeterx-llm-loadtest` | `/llm-loadtest` | LLM API 压测(OpenAI/Claude 兼容) | +| `lmeterx-http-loadtest` | `/http-loadtest` | 业务 HTTP API 压测(REST/GraphQL) | +| `lmeterx-web-loadtest` | `/web-loadtest` | 网站/网页压测(自动分析页面 API) | + +## 前置条件 + +- [Claude Code CLI](https://docs.anthropic.com/en/docs/claude-code) 已安装 +- Python 3.8+(脚本会自动安装 `httpx` 依赖,无需手动安装) + +## 安装步骤 + +### Step 1: 复制 Skills 到全局目录 + +```bash +# 从项目根目录执行 +cp -r .claude/skills/lmeterx-llm-loadtest ~/.claude/skills/ +cp -r .claude/skills/lmeterx-http-loadtest ~/.claude/skills/ +cp -r .claude/skills/lmeterx-web-loadtest ~/.claude/skills/ +``` + +安装后的目录结构: + +``` +~/.claude/skills/ +├── lmeterx-llm-loadtest/ +│ ├── SKILL.md +│ └── scripts/run.py +├── lmeterx-http-loadtest/ +│ ├── SKILL.md +│ └── scripts/run.py +└── lmeterx-web-loadtest/ + ├── SKILL.md + └── scripts/run.py +``` + +### Step 2: 配置权限 + +在 `~/.claude/settings.local.json` 的 `permissions.allow` 数组中添加以下条目: + +```json +{ + "permissions": { + "allow": [ + "Bash(python ~/.claude/skills/lmeterx-llm-loadtest/scripts/run.py*)", + "Bash(python ~/.claude/skills/lmeterx-http-loadtest/scripts/run.py*)", + "Bash(python ~/.claude/skills/lmeterx-web-loadtest/scripts/run.py*)", + "Bash(export LMETERX_AUTH_TOKEN*)" + ] + } +} +``` + +> 如果文件已有其他 permissions,将上述条目追加到 `allow` 数组末尾即可。 + +### Step 3: 配置环境变量(可选) + +Skills 内置了默认值,可直接使用。如需自定义,在 shell 配置文件(`~/.bashrc` 或 `~/.zshrc`)中添加: + +```bash +export LMETERX_BASE_URL="" # LMeterX 后端地址 +export LMETERX_AUTH_TOKEN="" # Service Token +``` + +| 变量 | 默认值 | 说明 | +|------|--------|------| +| `LMETERX_BASE_URL` | `` | LMeterX 后端地址 | +| `LMETERX_AUTH_TOKEN` | `` | Service Token,通过 `X-Authorization` 头传递 | + +## 验证安装 + +重启 Claude Code 后,输入 `/llm-loadtest`、`/http-loadtest` 或 `/web-loadtest` 即可验证 Skills 是否正常加载。 + +也可以运行以下命令验证脚本(首次运行会自动安装 `httpx` 依赖到脚本本地 `.deps/` 目录): + +```bash +python ~/.claude/skills/lmeterx-llm-loadtest/scripts/run.py --help +python ~/.claude/skills/lmeterx-http-loadtest/scripts/run.py --help +python ~/.claude/skills/lmeterx-web-loadtest/scripts/run.py --help +``` + +## 使用示例 + +### LLM API 压测 + +``` +/llm-loadtest 压测这个接口,并发10,持续5分钟: +curl https://api.openai.com/v1/chat/completions \ + -H "Authorization: Bearer sk-xxx" \ + -d '{"model":"gpt-4","messages":[{"role":"user","content":"Hi"}]}' +``` + +### 业务 HTTP API 压测 + +``` +/http-loadtest 并发50压测10分钟: +curl -X GET https://api.example.com/users \ + -H "Authorization: Bearer token123" +``` + +### 网站压测 + +``` +/web-loadtest 帮我压测这个网站,并发20:https://example.com +``` + +## 路由规则 + +当用户请求压测时,根据以下规则选择 Skill: + +``` +URL 以 /v1/chat/completions 或 /v1/messages 结尾 → /llm-loadtest +URL 是浏览器可访问的普通网页 → /web-loadtest +其他 API 端点(REST/GraphQL/curl 命令) → /http-loadtest +``` + +## 卸载 + +```bash +rm -rf ~/.claude/skills/lmeterx-llm-loadtest +rm -rf ~/.claude/skills/lmeterx-http-loadtest +rm -rf ~/.claude/skills/lmeterx-web-loadtest +``` + +并从 `~/.claude/settings.local.json` 中移除对应的权限条目。 diff --git a/.openclaw/skills/lmeterx-http-loadtest/SKILL.md b/.openclaw/skills/lmeterx-http-loadtest/SKILL.md new file mode 100644 index 0000000..b77a475 --- /dev/null +++ b/.openclaw/skills/lmeterx-http-loadtest/SKILL.md @@ -0,0 +1,127 @@ +--- +name: lmeterx-http-loadtest +emoji: "\U0001F310" +description: | + LMeterX HTTP API Load Test tool. When a user provides a **business/regular API endpoint URL** + or a curl command targeting a non-LLM HTTP API, this skill executes a script to pre-check + connectivity and create a load testing task. For REST APIs, GraphQL, and any HTTP endpoints + that are NOT LLM model APIs. +triggers: + - 压测这个API + - 压测这个接口 + - 压测这个端点 + - 帮我压测这个API接口 + - 压测这个 curl + - 帮我压测这个HTTP接口 + - 压测这个REST API + - 帮我压测这个业务接口 + - load test this API + - load test this endpoint + - stress test this curl + - load test this HTTP API +requires: + env: + - LMETERX_BASE_URL +--- + +# Skill: lmeterx-http-loadtest + +## Intent Routing Rules (Highest Priority) + +### When to USE this Skill + +- User provides a regular HTTP API URL (e.g. `https://api.example.com/users`, `https://app.com/api/orders`) +- User provides a curl command targeting a business API +- URL contains paths like `/api/`, `/graphql`, `/v2/`, `/rest/` etc. (but NOT `/v1/chat/completions` or `/v1/messages`) +- User says "压测这个API/接口/端点" and the URL is NOT an LLM endpoint + +### When NOT to use this Skill + +| Condition | Use Instead | +|-----------|------------| +| URL ends with `/v1/chat/completions` or `/v1/messages` | `lmeterx-llm-loadtest` | +| User mentions "LLM", "大模型", "OpenAI", "Claude" | `lmeterx-llm-loadtest` | +| URL is a webpage (e.g. `https://www.baidu.com`) | `lmeterx-web-loadtest` | +| User says "网站/网页/页面" | `lmeterx-web-loadtest` | + +### Quick Decision Rule + +``` +URL ends with /v1/chat/completions or /v1/messages → lmeterx-llm-loadtest +URL is a normal webpage (HTML page for browsers) → lmeterx-web-loadtest +Everything else (REST/GraphQL/business API) → THIS SKILL +``` + +## Execution Rules + +1. **Mandatory:** You **must and may only** execute the provided script via Bash. +2. **Prohibition:** Do NOT manually construct HTTP requests using `curl` or `requests` to call LMeterX APIs. +3. **Prohibition:** Do NOT fabricate results. Execute the script and respond based on actual stdout output. + +## The Only Correct Way to Execute + +### With URL: + +```bash +export LMETERX_AUTH_TOKEN="${LMETERX_AUTH_TOKEN:-}" +python "${SKILL_DIR}/scripts/run.py" \ + --url "" \ + --method GET +``` + +### With curl command: + +```bash +export LMETERX_AUTH_TOKEN="${LMETERX_AUTH_TOKEN:-}" +python "${SKILL_DIR}/scripts/run.py" \ + --curl '' +``` + +### With custom load parameters: + +```bash +export LMETERX_AUTH_TOKEN="${LMETERX_AUTH_TOKEN:-}" +python "${SKILL_DIR}/scripts/run.py" \ + --url "" \ + --method POST \ + --header "Authorization: Bearer " \ + --header "Content-Type: application/json" \ + --body '{"key": "value"}' \ + --concurrent-users 100 \ + --duration 600 \ + --spawn-rate 50 +``` + +## Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `--url` | (required, or use --curl) | API endpoint URL | +| `--curl` | (required, or use --url) | Full curl command string | +| `--method` | POST (auto: POST if body, else GET) | HTTP method | +| `--header` | [] | Request header (repeatable, format: `Key: Value`) | +| `--body` | "" | Request body string | +| `--cookie` | [] | Cookie (repeatable, format: `Key=Value`) | +| `--concurrent-users` | 50 | Concurrent users (1-5000) | +| `--duration` | 300 | Duration in seconds (1-172800) | +| `--spawn-rate` | 30 | User spawn rate | +| `--name` | (auto-generated) | Task name | + +## Presenting Results to the User + +After execution, present: + +1. **Target Info:** Method + URL +2. **Pre-check Result:** Pass/Fail with categorized failure reason +3. **Task ID and Report URL:** `{LMETERX_BASE_URL}/http-results/{task_id}` + +## Exception Handling + +| Error Scenario | Output Message | +|---------|---------| +| HTTP 401/403 | LMeterX token is invalid or expired; check `LMETERX_AUTH_TOKEN` | +| HTTP 5xx | LMeterX platform service error; try again later | +| Connection Failure | Cannot connect to LMeterX service; check network | +| Target API 401 | Target API requires auth; check Authorization header | +| Target API 404 | Target API path not found; check URL | +| Target API timeout | Target API timed out; check target service status | diff --git a/.openclaw/skills/lmeterx-http-loadtest/scripts/run.py b/.openclaw/skills/lmeterx-http-loadtest/scripts/run.py new file mode 100644 index 0000000..d7d8f3e --- /dev/null +++ b/.openclaw/skills/lmeterx-http-loadtest/scripts/run.py @@ -0,0 +1,516 @@ +#!/usr/bin/env python3 +""" +lmeterx-http-loadtest — LMeterX HTTP API Load Test Skill Script. + +Workflow: + 1. Parse input (curl command or --url/--method/--body/--header params) + 2. Validate that URL is NOT an LLM API endpoint + 3. POST /api/http-tasks/test → Pre-check connectivity + 4. POST /api/http-tasks → Create load test task + +Security constraints: + - Only calls whitelisted LMeterX paths: /health, /api/auth/profile, /api/http-tasks/* + - All requests automatically inject X-Authorization: + - Concurrent number limit [1, 5000], duration limit [1, 172800] +""" + +import argparse +import json +import os +import re +import shlex +import subprocess +import sys +import uuid +from typing import Any, Dict, List, Optional, Tuple +from urllib.parse import urlparse + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +_DEPS_DIR = os.path.join(_SCRIPT_DIR, ".deps") + + +def _ensure_httpx(): + try: + import httpx + + return httpx + except ImportError: + pass + if os.path.isdir(_DEPS_DIR): + sys.path.insert(0, _DEPS_DIR) + try: + import httpx + + return httpx + except ImportError: + pass + print("📦 首次运行,自动安装依赖 httpx ...") + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "httpx", "-t", _DEPS_DIR, "-q"], + stdout=subprocess.DEVNULL, + ) + sys.path.insert(0, _DEPS_DIR) + import httpx + + return httpx + + +httpx = _ensure_httpx() + +# ── Global configuration ────────────────────────────────────────────────────── + +LMETERX_BASE_URL: str = os.getenv("LMETERX_BASE_URL", "").rstrip( + "/" +) + +LMETERX_AUTH_TOKEN: str = os.getenv("LMETERX_AUTH_TOKEN") or "" + +TIMEOUT = 60.0 + +# ── LLM patterns (for rejection) ───────────────────────────────────────────── + +LLM_PATH_SUFFIXES = ("/v1/chat/completions", "/v1/messages") + +# ── Pre-check failure classification ───────────────────────────────────────── + +_FAILURE_CATEGORIES: Dict[str, Tuple[str, str]] = { + "401": ("🔐 认证失败 (401)", "目标 API 需要认证,请检查 Authorization 或 API Key"), + "403": ("🚫 权限不足 (403)", "已认证但无访问权限,请确认账号权限"), + "404": ("🔗 地址无效 (404)", "API 路径不存在,请检查 URL"), + "405": ("⛔ 方法不允许 (405)", "HTTP 方法不匹配,请检查 GET/POST 等"), + "429": ("⏳ 请求限流 (429)", "目标 API 限流中,稍后重试"), + "4xx": ("⚠️ 客户端错误 (4xx)", "目标 API 返回客户端错误"), + "5xx": ("💥 服务端错误 (5xx)", "目标服务内部异常"), + "connection": ("🌐 连接失败", "无法连接目标主机,请检查 URL 和网络"), + "timeout": ("⏱ 请求超时", "目标 API 响应超时"), + "ssl": ("🔒 SSL/TLS 错误", "证书验证或 TLS 握手失败"), + "unknown": ("❓ 未知错误", "发生意外错误"), +} + + +def _classify_failure( + *, http_status: Optional[int] = None, error_msg: str = "" +) -> Tuple[str, str, str]: + if http_status is not None: + key = str(http_status) + if key in _FAILURE_CATEGORIES: + label, hint = _FAILURE_CATEGORIES[key] + return key, label, hint + if 400 <= http_status < 500: + label, hint = _FAILURE_CATEGORIES["4xx"] + return "4xx", f"{label} ({http_status})", hint + if http_status >= 500: + label, hint = _FAILURE_CATEGORIES["5xx"] + return "5xx", f"{label} ({http_status})", hint + return "unknown", f"❓ 异常状态码 ({http_status})", "" + + err = error_msg.lower() + if "timeout" in err: + cat = "timeout" + elif "connection" in err: + cat = "connection" + elif "ssl" in err: + cat = "ssl" + else: + cat = "unknown" + label, hint = _FAILURE_CATEGORIES[cat] + return cat, label, hint + + +# ── Utility functions ───────────────────────────────────────────────────────── + + +def _bounded_int(value: Any, default: int, lo: int, hi: int) -> int: + try: + v = int(value) + except (TypeError, ValueError): + v = default + return max(lo, min(v, hi)) + + +def _headers_to_kv_list(hdr_dict: Dict[str, str]) -> List[Dict[str, str]]: + return [{"key": k, "value": v} for k, v in hdr_dict.items()] + + +def _cookies_to_kv_list(cookie_dict: Dict[str, str]) -> List[Dict[str, str]]: + return [{"key": k, "value": v} for k, v in cookie_dict.items()] + + +def _make_headers() -> Dict[str, str]: + return { + "Content-Type": "application/json", + "X-Authorization": LMETERX_AUTH_TOKEN, + } + + +# ── curl parser ─────────────────────────────────────────────────────────────── + + +def _parse_curl(curl_cmd: str) -> Dict[str, Any]: + cmd = curl_cmd.replace("\\\n", " ").replace("\\\r\n", " ").strip() + if re.match(r"^curl\s", cmd, re.IGNORECASE): + cmd = re.sub(r"^curl\s+", "", cmd, count=1, flags=re.IGNORECASE) + + try: + tokens = shlex.split(cmd) + except ValueError: + tokens = cmd.split() + + url = "" + method = "" + req_headers: Dict[str, str] = {} + body = "" + cookies: Dict[str, str] = {} + + SKIP_FLAGS_WITH_ARG = { + "--connect-timeout", + "--max-time", + "-m", + "--retry", + "-o", + "--output", + "-u", + "--user", + "-e", + "--referer", + "-A", + "--user-agent", + "--proxy", + "-x", + "--cert", + "--key", + "--cacert", + } + SKIP_FLAGS_NO_ARG = { + "--compressed", + "--insecure", + "-k", + "-v", + "--verbose", + "-s", + "--silent", + "-S", + "--show-error", + "-L", + "--location", + "-i", + "--include", + "-f", + "--fail", + "-N", + "--no-buffer", + } + + i = 0 + while i < len(tokens): + tok = tokens[i] + if tok in ("-X", "--request"): + i += 1 + if i < len(tokens): + method = tokens[i].upper() + elif tok in ("-H", "--header"): + i += 1 + if i < len(tokens) and ":" in tokens[i]: + key, val = tokens[i].split(":", 1) + req_headers[key.strip()] = val.strip() + elif tok in ("-d", "--data", "--data-raw", "--data-binary", "--data-ascii"): + i += 1 + if i < len(tokens): + body = tokens[i] + elif tok in ("-b", "--cookie"): + i += 1 + if i < len(tokens): + for part in tokens[i].split(";"): + part = part.strip() + if "=" in part: + k, v = part.split("=", 1) + cookies[k.strip()] = v.strip() + elif tok in SKIP_FLAGS_NO_ARG: + pass + elif tok in SKIP_FLAGS_WITH_ARG: + i += 1 + elif tok.startswith("http://") or tok.startswith("https://"): + url = tok + elif not tok.startswith("-") and not url and "://" in tok: + url = tok + i += 1 + + if not method: + method = "POST" if body else "GET" + + return { + "url": url, + "method": method, + "headers": req_headers, + "body": body, + "cookies": cookies, + } + + +# ── Preflight check ─────────────────────────────────────────────────────────── + + +def _preflight_check() -> None: + try: + resp = httpx.get(f"{LMETERX_BASE_URL}/health", timeout=10.0, verify=False) + if resp.status_code != 200: + print(f"❌ LMeterX 后端健康检查异常: HTTP {resp.status_code}") + print(f" 请确认 LMETERX_BASE_URL={LMETERX_BASE_URL} 是否正确") + sys.exit(1) + except httpx.ConnectError: + print(f"❌ 无法连接 LMeterX 后端: {LMETERX_BASE_URL}") + print(" 请确认后端服务已启动且网络畅通") + sys.exit(1) + except httpx.TimeoutException: + print(f"❌ 连接 LMeterX 后端超时: {LMETERX_BASE_URL}") + sys.exit(1) + + try: + profile_resp = httpx.get( + f"{LMETERX_BASE_URL}/api/auth/profile", + headers=_make_headers(), + timeout=10.0, + verify=False, + ) + if profile_resp.status_code == 200: + profile = profile_resp.json() + user = profile.get("username", "") + if user and user not in ("anonymous", "-"): + print(f" 👤 已认证用户: {user}") + elif profile_resp.status_code == 401: + if not os.getenv("LMETERX_AUTH_TOKEN"): + print("❌ LMeterX 后端已启用认证,但未配置 LMETERX_AUTH_TOKEN") + else: + print("❌ LMETERX_AUTH_TOKEN 无效或已过期") + sys.exit(1) + except Exception as e: + print(f"⚠️ 认证检查异常 ({e}),继续执行...") + + +# ── Main flow ───────────────────────────────────────────────────────────────── + + +def main() -> None: + parser = argparse.ArgumentParser( + description="LMeterX HTTP API Load Test (REST / GraphQL / Business APIs)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog="""\ +Examples: + # GET request + python run.py --url https://api.example.com/users --method GET + + # POST with body + python run.py --url https://api.example.com/orders \\ + --method POST \\ + --header "Authorization: Bearer token123" \\ + --body '{"item": "book", "qty": 1}' + + # curl mode + python run.py --curl 'curl -X GET https://api.example.com/users \\ + -H "Authorization: Bearer token123"' +""", + ) + + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument("--curl", help="Complete curl command string") + input_group.add_argument("--url", help="HTTP API endpoint URL") + + parser.add_argument( + "--method", default="", help="HTTP method (default: POST if body, else GET)" + ) + parser.add_argument("--body", default="", help="Request body string") + parser.add_argument( + "--header", + action="append", + default=[], + help="Request header (repeatable, format: 'Key: Value')", + ) + parser.add_argument( + "--cookie", + action="append", + default=[], + help="Cookie (repeatable, format: 'Key=Value')", + ) + parser.add_argument( + "--concurrent-users", type=int, default=50, help="Concurrent users (default 50)" + ) + parser.add_argument( + "--duration", type=int, default=300, help="Duration in seconds (default 300)" + ) + parser.add_argument( + "--spawn-rate", type=int, default=30, help="Spawn rate (default 30)" + ) + parser.add_argument( + "--name", default="", help="Task name (auto-generated if empty)" + ) + + args = parser.parse_args() + + # ── Step 0: Preflight ───────────────────────────────────────────────── + print("\n🔑 Step 0: 检查 LMeterX 后端连通性与认证 ...") + _preflight_check() + print(" ✅ 后端连通,认证正常") + + # ── Parse input ─────────────────────────────────────────────────────── + if args.curl: + parsed = _parse_curl(args.curl) + url = parsed["url"] + method = parsed["method"] + req_headers = parsed["headers"] + body = parsed["body"] + cookies = parsed["cookies"] + if not url: + print("❌ 无法从 curl 命令中解析出 URL") + sys.exit(1) + print(f"\n📋 已解析 curl 命令:") + print(f" URL: {url}") + print(f" Method: {method}") + print(f" Headers: {len(req_headers)} 个") + else: + url = args.url + body = args.body + method = args.method.upper() if args.method else ("POST" if body else "GET") + req_headers: Dict[str, str] = {} + for h in args.header: + if ":" in h: + k, v = h.split(":", 1) + req_headers[k.strip()] = v.strip() + cookies: Dict[str, str] = {} + for c in args.cookie: + if "=" in c: + k, v = c.split("=", 1) + cookies[k.strip()] = v.strip() + + # ── Validate NOT an LLM URL ─────────────────────────────────────────── + parsed_path = urlparse(url).path.rstrip("/") + for suffix in LLM_PATH_SUFFIXES: + if parsed_path.endswith(suffix): + print(f"\n❌ 该 URL 是 LLM API 端点: {url}") + print(" 请使用 lmeterx-llm-loadtest 进行 LLM API 压测") + sys.exit(1) + + # ── Validate URL format ─────────────────────────────────────────────── + if not re.match(r"^https?://", url): + print(f"❌ 无效的 URL 格式: {url}(必须以 http:// 或 https:// 开头)") + sys.exit(1) + + print(f"\n🔍 API 类型: 🌐 普通 HTTP 业务 API") + print(f" 请求方法: {method}") + print(f" 目标 URL: {url}") + + # ── Prepare parameters ──────────────────────────────────────────────── + concurrent_users = _bounded_int(args.concurrent_users, 50, 1, 5000) + duration = _bounded_int(args.duration, 300, 1, 172800) + spawn_rate = _bounded_int(args.spawn_rate, 30, 1, 10000) + + parsed_url = urlparse(url) + auto_name = f"{parsed_url.netloc}{parsed_url.path}" + if len(auto_name) > 80: + auto_name = auto_name[:80] + task_name = args.name or auto_name + + # ── Step 1: Pre-check ───────────────────────────────────────────────── + print(f"\n🔗 Step 1/2: 预检 API 连通性 ...") + + test_payload = { + "method": method, + "target_url": url, + "headers": _headers_to_kv_list(req_headers), + "cookies": _cookies_to_kv_list(cookies), + "request_body": body or "", + } + + with httpx.Client(timeout=TIMEOUT, verify=False) as client: + try: + test_resp = client.post( + f"{LMETERX_BASE_URL}/api/http-tasks/test", + headers=_make_headers(), + json=test_payload, + ) + + if test_resp.status_code != 200: + print(f" ❌ 连通性测试失败: HTTP {test_resp.status_code}") + try: + err_data = test_resp.json() + print(f" 详情: {json.dumps(err_data, ensure_ascii=False)}") + except Exception: + print(f" 响应: {test_resp.text[:500]}") + sys.exit(1) + + test_data = test_resp.json() + if test_data.get("status") == "success": + http_code = test_data.get("http_status") + if isinstance(http_code, int) and http_code >= 400: + _, label, hint = _classify_failure(http_status=http_code) + print(f" ❌ 连通性测试未通过: {label}") + if hint: + print(f" 💡 {hint}") + sys.exit(1) + print(f" ✅ 连通性正常 → HTTP {http_code or '?'}") + else: + error = test_data.get("error", "N/A") + _, label, hint = _classify_failure(error_msg=error) + print(f" ❌ 连通性测试未通过: {label}") + print(f" 错误: {error}") + if hint: + print(f" 💡 {hint}") + sys.exit(1) + + except SystemExit: + raise + except Exception as e: + _, label, hint = _classify_failure(error_msg=str(e)) + print(f" ❌ 连通性测试异常: {label}") + print(f" 详情: {e}") + if hint: + print(f" 💡 {hint}") + sys.exit(1) + + # ── Step 2: Create task ─────────────────────────────────────────── + print(f"\n🚀 Step 2/2: 创建 HTTP 压测任务 ...") + + temp_task_id = f"http_{uuid.uuid4().hex[:8]}" + create_payload = { + "temp_task_id": temp_task_id, + "name": task_name, + "method": method, + "target_url": url, + "headers": _headers_to_kv_list(req_headers), + "cookies": _cookies_to_kv_list(cookies), + "request_body": body or "", + "concurrent_users": concurrent_users, + "duration": duration, + "spawn_rate": spawn_rate, + "load_mode": "fixed", + } + + try: + create_resp = client.post( + f"{LMETERX_BASE_URL}/api/http-tasks", + headers=_make_headers(), + json=create_payload, + ) + create_resp.raise_for_status() + result = create_resp.json() + task_id = result.get("task_id", "") + except Exception as e: + print(f" ❌ 任务创建失败: {e}") + sys.exit(1) + + # ── Summary ─────────────────────────────────────────────────────────── + print(f"\n{'=' * 60}") + print(" 📊 执行摘要") + print(f"{'=' * 60}") + print(f" API 类型: 🌐 普通 HTTP 业务 API") + print(f" 请求方法: {method}") + print(f" 目标 URL: {url}") + print(f" 并发用户: {concurrent_users}") + print(f" 持续时间: {duration}s") + print(f" Task ID: {task_id}") + print(f"\n 📈 查看报告:") + print(f" → {LMETERX_BASE_URL}/http-results/{task_id}") + print() + + +if __name__ == "__main__": + main() diff --git a/.openclaw/skills/lmeterx-llm-loadtest/SKILL.md b/.openclaw/skills/lmeterx-llm-loadtest/SKILL.md new file mode 100644 index 0000000..71d7f3a --- /dev/null +++ b/.openclaw/skills/lmeterx-llm-loadtest/SKILL.md @@ -0,0 +1,123 @@ +--- +name: lmeterx-llm-loadtest +emoji: "\U0001F916" +description: | + LMeterX LLM API Load Test tool. When a user provides an **LLM API endpoint URL** + (ending with `/v1/chat/completions` or `/v1/messages`) or a curl command targeting + an LLM API, this skill executes a script to pre-check connectivity and create + a load testing task. Supports OpenAI-compatible and Claude-compatible APIs. +triggers: + - 压测这个LLM API + - 压测这个大模型接口 + - 压测这个模型API + - 帮我压测 OpenAI + - 帮我压测 Claude API + - 压测 chat completions + - 压测 messages 接口 + - load test this LLM API + - load test OpenAI endpoint + - stress test this model API + - 压测这个AI接口 +requires: + env: + - LMETERX_BASE_URL +--- + +# Skill: lmeterx-llm-loadtest + +## Intent Routing Rules (Highest Priority) + +### When to USE this Skill + +- URL ends with `/v1/chat/completions` (OpenAI-compatible) +- URL ends with `/v1/messages` (Claude/Anthropic-compatible) +- User mentions "LLM", "大模型", "模型接口", "OpenAI", "Claude", "chat completions" +- curl command body contains `"model"` and `"messages"` fields + +### When NOT to use this Skill + +| Condition | Use Instead | +|-----------|------------| +| URL is a webpage (e.g. `https://www.baidu.com`) | `lmeterx-web-loadtest` | +| URL is a regular API without LLM path patterns (e.g. `/api/users`, `/graphql`) | `lmeterx-http-loadtest` | +| User says "网站/网页/页面" | `lmeterx-web-loadtest` | + +### Quick Decision Rule + +``` +URL ends with /v1/chat/completions or /v1/messages → THIS SKILL +URL is a normal webpage → lmeterx-web-loadtest +Everything else (REST API, GraphQL, etc.) → lmeterx-http-loadtest +``` + +## Execution Rules + +1. **Mandatory:** You **must and may only** execute the provided script via Bash. +2. **Prohibition:** Do NOT manually construct HTTP requests using `curl` or `requests` to call LMeterX APIs. +3. **Prohibition:** Do NOT fabricate results. Execute the script and respond based on actual stdout output. + +## The Only Correct Way to Execute + +### With URL: + +```bash +export LMETERX_AUTH_TOKEN="${LMETERX_AUTH_TOKEN:-}" +python "${SKILL_DIR}/scripts/run.py" \ + --url "" \ + --header "Authorization: Bearer " +``` + +### With curl command: + +```bash +export LMETERX_AUTH_TOKEN="${LMETERX_AUTH_TOKEN:-}" +python "${SKILL_DIR}/scripts/run.py" \ + --curl '' +``` + +### With custom load parameters: + +```bash +export LMETERX_AUTH_TOKEN="${LMETERX_AUTH_TOKEN:-}" +python "${SKILL_DIR}/scripts/run.py" \ + --url "" \ + --header "Authorization: Bearer " \ + --body '{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}' \ + --concurrent-users 50 \ + --duration 300 \ + --spawn-rate 30 +``` + +## Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `--url` | (required, or use --curl) | LLM API endpoint URL | +| `--curl` | (required, or use --url) | Full curl command string | +| `--header` | [] | Request header (repeatable, format: `Key: Value`) | +| `--body` | "" | Request body JSON string | +| `--model` | (auto-extracted from body) | Model name | +| `--stream` / `--no-stream` | true | Enable/disable streaming | +| `--concurrent-users` | 50 | Concurrent users (1-5000) | +| `--duration` | 300 | Duration in seconds (1-172800) | +| `--spawn-rate` | 30 | User spawn rate | +| `--name` | (auto-generated) | Task name | +| `--test-data` | "" | Dataset: `""` for none, `"default"` for built-in dataset | + +## Presenting Results to the User + +After execution, present: + +1. **API Type:** OpenAI Chat or Claude Chat +2. **Pre-check Result:** Pass/Fail with categorized failure reason +3. **Task ID and Report URL:** `{LMETERX_BASE_URL}/results/{task_id}` + +## Exception Handling + +| Error Scenario | Output Message | +|---------|---------| +| HTTP 401/403 | LMeterX token is invalid or expired; check `LMETERX_AUTH_TOKEN` | +| HTTP 5xx | LMeterX platform service error; try again later | +| Connection Failure | Cannot connect to LMeterX service; check network | +| Target API 401 | Target LLM API requires auth; check Authorization header | +| Target API timeout | Target LLM API timed out; may be overloaded | diff --git a/.openclaw/skills/lmeterx-llm-loadtest/scripts/run.py b/.openclaw/skills/lmeterx-llm-loadtest/scripts/run.py new file mode 100644 index 0000000..4530fd0 --- /dev/null +++ b/.openclaw/skills/lmeterx-llm-loadtest/scripts/run.py @@ -0,0 +1,541 @@ +#!/usr/bin/env python3 +""" +lmeterx-llm-loadtest — LMeterX LLM API Load Test Skill Script. + +Workflow: + 1. Parse input (curl command or --url/--header/--body params) + 2. Detect LLM API type: openai-chat (/v1/chat/completions) or claude-chat (/v1/messages) + 3. Split URL into target_host + api_path + 4. POST /api/llm-tasks/test → Pre-check connectivity + 5. POST /api/llm-tasks → Create load test task + +Security constraints: + - Only calls 3 whitelisted LMeterX paths: /health, /api/auth/profile, /api/llm-tasks/* + - All requests automatically inject X-Authorization: + - Concurrent number limit [1, 5000], duration limit [1, 172800] +""" + +import argparse +import json +import os +import re +import shlex +import subprocess +import sys +import uuid +from typing import Any, Dict, List, Optional, Tuple +from urllib.parse import urlparse + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +_DEPS_DIR = os.path.join(_SCRIPT_DIR, ".deps") + +def _ensure_httpx(): + try: + import httpx + return httpx + except ImportError: + pass + if os.path.isdir(_DEPS_DIR): + sys.path.insert(0, _DEPS_DIR) + try: + import httpx + return httpx + except ImportError: + pass + print("📦 首次运行,自动安装依赖 httpx ...") + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "httpx", "-t", _DEPS_DIR, "-q"], + stdout=subprocess.DEVNULL, + ) + sys.path.insert(0, _DEPS_DIR) + import httpx + return httpx + +httpx = _ensure_httpx() + +# ── Global configuration ────────────────────────────────────────────────────── + +LMETERX_BASE_URL: str = os.getenv( + "LMETERX_BASE_URL", "" +).rstrip("/") + +LMETERX_AUTH_TOKEN: str = os.getenv("LMETERX_AUTH_TOKEN") or "" + +TIMEOUT = 60.0 + +# ── LLM path patterns ──────────────────────────────────────────────────────── + +LLM_PATTERNS: List[Tuple[str, str]] = [ + ("/v1/chat/completions", "openai-chat"), + ("/v1/messages", "claude-chat"), +] + +# ── Pre-check failure classification ───────────────────────────────────────── + +_FAILURE_CATEGORIES: Dict[str, Tuple[str, str]] = { + "401": ("🔐 认证失败 (401)", "目标 API 需要认证,请检查 Authorization 或 API Key"), + "403": ("🚫 权限不足 (403)", "已认证但无访问权限,请确认账号权限"), + "404": ("🔗 地址无效 (404)", "API 路径不存在,请检查 URL"), + "405": ("⛔ 方法不允许 (405)", "HTTP 方法不匹配"), + "429": ("⏳ 请求限流 (429)", "目标 API 限流中,稍后重试"), + "4xx": ("⚠️ 客户端错误 (4xx)", "目标 API 返回客户端错误"), + "5xx": ("💥 服务端错误 (5xx)", "目标服务内部异常"), + "connection": ("🌐 连接失败", "无法连接目标主机,请检查 URL 和网络"), + "timeout": ("⏱ 请求超时", "目标 API 响应超时"), + "ssl": ("🔒 SSL/TLS 错误", "证书验证或 TLS 握手失败"), + "unknown": ("❓ 未知错误", "发生意外错误"), +} + + +def _classify_failure( + *, http_status: Optional[int] = None, error_msg: str = "" +) -> Tuple[str, str, str]: + if http_status is not None: + key = str(http_status) + if key in _FAILURE_CATEGORIES: + label, hint = _FAILURE_CATEGORIES[key] + return key, label, hint + if 400 <= http_status < 500: + label, hint = _FAILURE_CATEGORIES["4xx"] + return "4xx", f"{label} ({http_status})", hint + if http_status >= 500: + label, hint = _FAILURE_CATEGORIES["5xx"] + return "5xx", f"{label} ({http_status})", hint + return "unknown", f"❓ 异常状态码 ({http_status})", "" + + err = error_msg.lower() + if "timeout" in err: + cat = "timeout" + elif "connection" in err: + cat = "connection" + elif "ssl" in err: + cat = "ssl" + else: + cat = "unknown" + label, hint = _FAILURE_CATEGORIES[cat] + return cat, label, hint + + +# ── Utility functions ───────────────────────────────────────────────────────── + + +def _bounded_int(value: Any, default: int, lo: int, hi: int) -> int: + try: + v = int(value) + except (TypeError, ValueError): + v = default + return max(lo, min(v, hi)) + + +def _headers_to_kv_list(hdr_dict: Dict[str, str]) -> List[Dict[str, str]]: + return [{"key": k, "value": v} for k, v in hdr_dict.items()] + + +def _cookies_to_kv_list(cookie_dict: Dict[str, str]) -> List[Dict[str, str]]: + return [{"key": k, "value": v} for k, v in cookie_dict.items()] + + +def _make_headers() -> Dict[str, str]: + return { + "Content-Type": "application/json", + "X-Authorization": LMETERX_AUTH_TOKEN, + } + + +# ── curl parser ─────────────────────────────────────────────────────────────── + + +def _parse_curl(curl_cmd: str) -> Dict[str, Any]: + cmd = curl_cmd.replace("\\\n", " ").replace("\\\r\n", " ").strip() + if re.match(r"^curl\s", cmd, re.IGNORECASE): + cmd = re.sub(r"^curl\s+", "", cmd, count=1, flags=re.IGNORECASE) + + try: + tokens = shlex.split(cmd) + except ValueError: + tokens = cmd.split() + + url = "" + method = "" + req_headers: Dict[str, str] = {} + body = "" + cookies: Dict[str, str] = {} + + SKIP_FLAGS_WITH_ARG = { + "--connect-timeout", "--max-time", "-m", "--retry", "-o", "--output", + "-u", "--user", "-e", "--referer", "-A", "--user-agent", "--proxy", + "-x", "--cert", "--key", "--cacert", + } + SKIP_FLAGS_NO_ARG = { + "--compressed", "--insecure", "-k", "-v", "--verbose", "-s", "--silent", + "-S", "--show-error", "-L", "--location", "-i", "--include", "-f", + "--fail", "-N", "--no-buffer", + } + + i = 0 + while i < len(tokens): + tok = tokens[i] + if tok in ("-X", "--request"): + i += 1 + if i < len(tokens): + method = tokens[i].upper() + elif tok in ("-H", "--header"): + i += 1 + if i < len(tokens) and ":" in tokens[i]: + key, val = tokens[i].split(":", 1) + req_headers[key.strip()] = val.strip() + elif tok in ("-d", "--data", "--data-raw", "--data-binary", "--data-ascii"): + i += 1 + if i < len(tokens): + body = tokens[i] + elif tok in ("-b", "--cookie"): + i += 1 + if i < len(tokens): + for part in tokens[i].split(";"): + part = part.strip() + if "=" in part: + k, v = part.split("=", 1) + cookies[k.strip()] = v.strip() + elif tok in SKIP_FLAGS_NO_ARG: + pass + elif tok in SKIP_FLAGS_WITH_ARG: + i += 1 + elif tok.startswith("http://") or tok.startswith("https://"): + url = tok + elif not tok.startswith("-") and not url and "://" in tok: + url = tok + i += 1 + + if not method: + method = "POST" if body else "GET" + + return {"url": url, "method": method, "headers": req_headers, "body": body, "cookies": cookies} + + +# ── LLM API detection and URL splitting ─────────────────────────────────────── + + +def _detect_llm_type(url: str) -> Tuple[bool, str]: + parsed = urlparse(url) + path = parsed.path.rstrip("/") + for pattern, api_type in LLM_PATTERNS: + if path.endswith(pattern): + return True, api_type + return False, "" + + +def _split_llm_url(url: str, api_type: str) -> Tuple[str, str]: + parsed = urlparse(url) + path = parsed.path.rstrip("/") + + if api_type == "openai-chat": + suffix = "/chat/completions" + elif api_type == "claude-chat": + suffix = "/messages" + else: + base = f"{parsed.scheme}://{parsed.netloc}" + return base, path or "/" + + idx = path.rfind(suffix) + if idx >= 0: + prefix_path = path[:idx] + target_host = f"{parsed.scheme}://{parsed.netloc}{prefix_path}" + return target_host, suffix + else: + base = f"{parsed.scheme}://{parsed.netloc}" + return base, path or "/" + + +def _extract_model_from_body(body: str) -> str: + if not body: + return "" + try: + return json.loads(body).get("model", "") + except (json.JSONDecodeError, AttributeError): + return "" + + +def _extract_stream_from_body(body: str) -> Optional[bool]: + if not body: + return None + try: + val = json.loads(body).get("stream") + return val if isinstance(val, bool) else None + except (json.JSONDecodeError, AttributeError): + return None + + +# ── Preflight check ─────────────────────────────────────────────────────────── + + +def _preflight_check() -> None: + try: + resp = httpx.get(f"{LMETERX_BASE_URL}/health", timeout=10.0, verify=False) + if resp.status_code != 200: + print(f"❌ LMeterX 后端健康检查异常: HTTP {resp.status_code}") + print(f" 请确认 LMETERX_BASE_URL={LMETERX_BASE_URL} 是否正确") + sys.exit(1) + except httpx.ConnectError: + print(f"❌ 无法连接 LMeterX 后端: {LMETERX_BASE_URL}") + print(" 请确认后端服务已启动且网络畅通") + sys.exit(1) + except httpx.TimeoutException: + print(f"❌ 连接 LMeterX 后端超时: {LMETERX_BASE_URL}") + sys.exit(1) + + try: + profile_resp = httpx.get( + f"{LMETERX_BASE_URL}/api/auth/profile", + headers=_make_headers(), + timeout=10.0, + verify=False, + ) + if profile_resp.status_code == 200: + profile = profile_resp.json() + user = profile.get("username", "") + if user and user not in ("anonymous", "-"): + print(f" 👤 已认证用户: {user}") + elif profile_resp.status_code == 401: + if not os.getenv("LMETERX_AUTH_TOKEN"): + print("❌ LMeterX 后端已启用认证,但未配置 LMETERX_AUTH_TOKEN") + else: + print("❌ LMETERX_AUTH_TOKEN 无效或已过期") + sys.exit(1) + except Exception as e: + print(f"⚠️ 认证检查异常 ({e}),继续执行...") + + +# ── Main flow ───────────────────────────────────────────────────────────────── + + +def main() -> None: + parser = argparse.ArgumentParser( + description="LMeterX LLM API Load Test (OpenAI / Claude compatible)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog="""\ +Examples: + # URL mode + python run.py --url https://api.openai.com/v1/chat/completions \\ + --header "Authorization: Bearer sk-xxx" \\ + --body '{"model":"gpt-4","messages":[{"role":"user","content":"Hi"}]}' + + # curl mode + python run.py --curl 'curl https://api.openai.com/v1/chat/completions \\ + -H "Authorization: Bearer sk-xxx" \\ + -d \\'{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":\"Hi\"}]}\\'' +""", + ) + + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument("--curl", help="Complete curl command string") + input_group.add_argument("--url", help="LLM API endpoint URL") + + parser.add_argument("--body", default="", help="Request body (JSON string)") + parser.add_argument( + "--header", action="append", default=[], help="Request header (repeatable, format: 'Key: Value')" + ) + parser.add_argument( + "--cookie", action="append", default=[], help="Cookie (repeatable, format: 'Key=Value')" + ) + parser.add_argument("--model", default="", help="Model name (auto-extracted from body if not set)") + parser.add_argument("--stream", dest="stream_mode", action="store_true", default=None, help="Enable streaming") + parser.add_argument("--no-stream", dest="stream_mode", action="store_false", help="Disable streaming") + parser.add_argument("--concurrent-users", type=int, default=50, help="Concurrent users (default 50)") + parser.add_argument("--duration", type=int, default=300, help="Duration in seconds (default 300)") + parser.add_argument("--spawn-rate", type=int, default=30, help="Spawn rate (default 30)") + parser.add_argument("--name", default="", help="Task name (auto-generated if empty)") + parser.add_argument( + "--test-data", default="", + help="Test dataset: '' for none (default), 'default' for built-in dataset" + ) + + args = parser.parse_args() + + # ── Step 0: Preflight ───────────────────────────────────────────────── + print("\n🔑 Step 0: 检查 LMeterX 后端连通性与认证 ...") + _preflight_check() + print(" ✅ 后端连通,认证正常") + + # ── Parse input ─────────────────────────────────────────────────────── + if args.curl: + parsed = _parse_curl(args.curl) + url = parsed["url"] + req_headers = parsed["headers"] + body = parsed["body"] + cookies = parsed["cookies"] + if not url: + print("❌ 无法从 curl 命令中解析出 URL") + sys.exit(1) + print(f"\n📋 已解析 curl 命令:") + print(f" URL: {url}") + print(f" Headers: {len(req_headers)} 个") + else: + url = args.url + body = args.body + req_headers: Dict[str, str] = {} + for h in args.header: + if ":" in h: + k, v = h.split(":", 1) + req_headers[k.strip()] = v.strip() + cookies: Dict[str, str] = {} + for c in args.cookie: + if "=" in c: + k, v = c.split("=", 1) + cookies[k.strip()] = v.strip() + + # ── Validate LLM URL ────────────────────────────────────────────────── + is_llm, api_type = _detect_llm_type(url) + if not is_llm: + print(f"\n❌ URL 不是 LLM API 端点: {url}") + print(" LLM API 必须以 /v1/chat/completions 或 /v1/messages 结尾") + print(" 如需压测普通 HTTP API,请使用 lmeterx-http-loadtest") + sys.exit(1) + + target_host, api_path = _split_llm_url(url, api_type) + model = args.model or _extract_model_from_body(body) + + if args.stream_mode is not None: + stream_mode = args.stream_mode + else: + extracted = _extract_stream_from_body(body) + stream_mode = extracted if extracted is not None else True + + # Filter Content-Type from user headers + filtered_headers = {k: v for k, v in req_headers.items() if k.lower() != "content-type"} + + api_type_label = "OpenAI Chat" if api_type == "openai-chat" else "Claude Chat" + print(f"\n🔍 API 类型: 🤖 {api_type_label}") + print(f" 目标主机: {target_host}") + print(f" API 路径: {api_path}") + print(f" 模型: {model or '(auto)'}") + print(f" 流式: {stream_mode}") + + # ── Prepare parameters ──────────────────────────────────────────────── + concurrent_users = _bounded_int(args.concurrent_users, 50, 1, 5000) + duration = _bounded_int(args.duration, 300, 1, 172800) + spawn_rate = _bounded_int(args.spawn_rate, 30, 1, 10000) + + parsed_url = urlparse(url) + auto_name = f"{parsed_url.netloc}{parsed_url.path}" + if len(auto_name) > 80: + auto_name = auto_name[:80] + task_name = args.name or auto_name + + # ── Step 1: Pre-check ───────────────────────────────────────────────── + print(f"\n🔗 Step 1/2: 预检 LLM API 连通性 ...") + + test_payload = { + "target_host": target_host, + "api_path": api_path, + "model": model, + "stream_mode": stream_mode, + "headers": _headers_to_kv_list(filtered_headers), + "cookies": _cookies_to_kv_list(cookies), + "request_payload": body or "", + "api_type": api_type, + } + + with httpx.Client(timeout=TIMEOUT, verify=False) as client: + try: + test_resp = client.post( + f"{LMETERX_BASE_URL}/api/llm-tasks/test", + headers=_make_headers(), + json=test_payload, + ) + + if test_resp.status_code != 200: + print(f" ❌ 连通性测试失败: HTTP {test_resp.status_code}") + try: + err_data = test_resp.json() + print(f" 详情: {json.dumps(err_data, ensure_ascii=False)}") + except Exception: + print(f" 响应: {test_resp.text[:500]}") + sys.exit(1) + + test_data = test_resp.json() + if test_data.get("status") == "success": + http_code = test_data.get("http_status") + if isinstance(http_code, int) and http_code >= 400: + _, label, hint = _classify_failure(http_status=http_code) + print(f" ❌ 连通性测试未通过: {label}") + if hint: + print(f" 💡 {hint}") + sys.exit(1) + print(f" ✅ 连通性正常 → HTTP {http_code or '?'}") + else: + error = test_data.get("error", "N/A") + _, label, hint = _classify_failure(error_msg=error) + print(f" ❌ 连通性测试未通过: {label}") + print(f" 错误: {error}") + if hint: + print(f" 💡 {hint}") + sys.exit(1) + + except SystemExit: + raise + except Exception as e: + _, label, hint = _classify_failure(error_msg=str(e)) + print(f" ❌ 连通性测试异常: {label}") + print(f" 详情: {e}") + if hint: + print(f" 💡 {hint}") + sys.exit(1) + + # ── Step 2: Create task ─────────────────────────────────────────── + print(f"\n🚀 Step 2/2: 创建 LLM 压测任务 ...") + + temp_task_id = f"llm_{uuid.uuid4().hex[:8]}" + test_data = args.test_data + chat_type = 2 if test_data else 0 + create_payload = { + "temp_task_id": temp_task_id, + "name": task_name, + "target_host": target_host, + "api_path": api_path, + "model": model, + "duration": duration, + "concurrent_users": concurrent_users, + "spawn_rate": spawn_rate, + "stream_mode": stream_mode, + "headers": _headers_to_kv_list(filtered_headers), + "cookies": _cookies_to_kv_list(cookies), + "request_payload": body or "", + "api_type": api_type, + "test_data": test_data, + "chat_type": chat_type, + "warmup_enabled": True, + "warmup_duration": 120, + "load_mode": "fixed", + } + + try: + create_resp = client.post( + f"{LMETERX_BASE_URL}/api/llm-tasks", + headers=_make_headers(), + json=create_payload, + ) + create_resp.raise_for_status() + result = create_resp.json() + task_id = result.get("task_id", "") + except Exception as e: + print(f" ❌ 任务创建失败: {e}") + sys.exit(1) + + # ── Summary ─────────────────────────────────────────────────────────── + print(f"\n{'=' * 60}") + print(" 📊 执行摘要") + print(f"{'=' * 60}") + print(f" API 类型: 🤖 {api_type_label}") + print(f" 目标主机: {target_host}") + print(f" 模型: {model or '(auto)'}") + print(f" 数据集: {test_data or '(无,使用 request_payload)'}") + print(f" 并发用户: {concurrent_users}") + print(f" 持续时间: {duration}s") + print(f" Task ID: {task_id}") + print(f"\n 📈 查看报告:") + print(f" → {LMETERX_BASE_URL}/results/{task_id}") + print() + + +if __name__ == "__main__": + main() diff --git a/.openclaw/skills/lmeterx-web-loadtest/README.md b/.openclaw/skills/lmeterx-web-loadtest/README.md index 57ecad1..01e62c1 100644 --- a/.openclaw/skills/lmeterx-web-loadtest/README.md +++ b/.openclaw/skills/lmeterx-web-loadtest/README.md @@ -10,8 +10,8 @@ The script comes with the following built-in default values ​​and can be run | Variable | Default Value | Description | |------|--------|------| -| `LMETERX_BASE_URL` | `http://localhost:8080` | LMeterX Backend Address | -| `LMETERX_AUTH_TOKEN` | `localhost_lmeterx` | Service Token: Binds to an Agent User | +| `LMETERX_BASE_URL` | `` | LMeterX Backend Address | +| `LMETERX_AUTH_TOKEN` | `` | Service Token: Binds to an Agent User | ### 2. Run diff --git a/.openclaw/skills/lmeterx-web-loadtest/scripts/run.py b/.openclaw/skills/lmeterx-web-loadtest/scripts/run.py index f0f7a47..0624395 100644 --- a/.openclaw/skills/lmeterx-web-loadtest/scripts/run.py +++ b/.openclaw/skills/lmeterx-web-loadtest/scripts/run.py @@ -30,7 +30,7 @@ ) # Prioritize getting Service Token from environment variables; if not configured, use the built-in default value "localhost_lmeterx". -LMETERX_AUTH_TOKEN: str = os.getenv("LMETERX_AUTH_TOKEN") or "localhost_lmeterx" +LMETERX_AUTH_TOKEN: str = os.getenv("LMETERX_AUTH_TOKEN") or "" # Only allow calling the following 3 whitelisted interface paths _ALLOWED_PATHS = frozenset( diff --git a/backend/.env.example b/backend/.env.example new file mode 100644 index 0000000..a026efc --- /dev/null +++ b/backend/.env.example @@ -0,0 +1,28 @@ +DB_HOST="localhost" +DB_PORT="3306" +DB_USER="lmeterx" +DB_PASSWORD="lmeterx" +DB_NAME="lmeterx" +LOG_LEVEL="DEBUG" +DETAIL_LOG_LEVEL="DEBUG" + +VICTORIA_METRICS_URL=http://localhost:8428 + +JWT_SECRET_KEY="" +JWT_ALGORITHM=HS256 +JWT_EXPIRE_MINUTES=10080 +JWT_ISSUER=lmeterx + +LDAP_ENABLED=on +LDAP_SERVER=ldaps:// +LDAP_PORT= +LDAP_USE_SSL=true +LDAP_TIMEOUT=5 +LDAP_SEARCH_BASE= +LDAP_SEARCH_FILTER= +LDAP_USER_DN_TEMPLATE= +LDAP_BIND_DN= +LDAP_BIND_PASSWORD= + +ADMIN_USERNAMES= + diff --git a/backend/middleware/auth_middleware.py b/backend/middleware/auth_middleware.py index cadeade..cdfe26d 100644 --- a/backend/middleware/auth_middleware.py +++ b/backend/middleware/auth_middleware.py @@ -27,6 +27,8 @@ "/api/skills/analyze-url", "/api/http-tasks/test", "/api/http-tasks", + "/api/llm-tasks/test", + "/api/llm-tasks", } ) diff --git a/backend/tests/test_service_token_auth.py b/backend/tests/test_service_token_auth.py index 089b292..2d798f0 100644 --- a/backend/tests/test_service_token_auth.py +++ b/backend/tests/test_service_token_auth.py @@ -166,8 +166,8 @@ def test_non_whitelist_profile_no_auth(self): assert resp.status_code == 200 assert resp.json()["username"] == "anonymous" - def test_non_whitelist_llm_tasks_no_auth(self): - """非白名单接口 /api/llm-tasks 无需 Token。""" + def test_whitelist_llm_tasks_no_auth(self): + """白名单接口 /api/llm-tasks 无需 Token。""" resp = self.client.get("/api/llm-tasks") assert resp.status_code == 200 @@ -227,10 +227,10 @@ def test_non_whitelist_profile_returns_403(self): resp = self.client.get("/api/auth/profile") assert resp.status_code == 403 - def test_non_whitelist_llm_tasks_returns_403(self): - """非白名单接口无 Token → 403。""" + def test_whitelist_llm_tasks_returns_401(self): + """白名单接口无 Token → 401。""" resp = self.client.get("/api/llm-tasks") - assert resp.status_code == 403 + assert resp.status_code == 401 # ═════════════════════════════════════════════════════════════════════════════ @@ -288,10 +288,11 @@ def test_non_whitelist_profile_returns_403(self): resp = self.client.get("/api/auth/profile", headers=self.auth_header) assert resp.status_code == 403 - def test_non_whitelist_llm_tasks_returns_403(self): - """Service Token + 非白名单接口 /api/llm-tasks → 403 Forbidden。""" + def test_whitelist_llm_tasks_success(self): + """白名单接口 /api/llm-tasks + 正确 Token → 200。""" resp = self.client.get("/api/llm-tasks", headers=self.auth_header) - assert resp.status_code == 403 + assert resp.status_code == 200 + assert resp.json()["user"]["sub"] == "agent" def test_non_whitelist_response_body(self): """403 响应体应包含明确的错误信息。""" @@ -347,10 +348,10 @@ def test_non_whitelist_wrong_token_returns_403(self): resp = self.client.get("/api/system", headers=self.wrong_header) assert resp.status_code == 403 - def test_non_whitelist_llm_tasks_wrong_token_returns_403(self): - """非白名单接口 + 错误 Token → 403。""" + def test_whitelist_llm_tasks_wrong_token_401(self): + """白名单接口 /api/llm-tasks + 错误 Token → 401。""" resp = self.client.get("/api/llm-tasks", headers=self.wrong_header) - assert resp.status_code == 403 + assert resp.status_code == 401 # ── 白名单路径 + 无 Token: 仍然 401 ── @@ -392,6 +393,8 @@ def test_whitelist_paths_accepted(self): "/api/skills/analyze-url", "/api/http-tasks/test", "/api/http-tasks", + "/api/llm-tasks/test", + "/api/llm-tasks", } ) for path in allowed: @@ -404,11 +407,12 @@ def test_non_whitelist_paths_blocked(self): "/api/skills/analyze-url", "/api/http-tasks/test", "/api/http-tasks", + "/api/llm-tasks/test", + "/api/llm-tasks", } ) blocked_paths = [ "/api/system", - "/api/llm-tasks", "/api/auth/profile", "/api/auth/login", "/api/analyze", @@ -427,6 +431,8 @@ def test_safe_request_blocks_non_whitelist(self): "/api/skills/analyze-url", "/api/http-tasks/test", "/api/http-tasks", + "/api/llm-tasks/test", + "/api/llm-tasks", } ) @@ -444,6 +450,8 @@ def test_safe_request_allows_whitelist(self): "/api/skills/analyze-url", "/api/http-tasks/test", "/api/http-tasks", + "/api/llm-tasks/test", + "/api/llm-tasks", } ) diff --git a/frontend/public/locales/en/translation.json b/frontend/public/locales/en/translation.json index 4604f45..d7774ae 100644 --- a/frontend/public/locales/en/translation.json +++ b/frontend/public/locales/en/translation.json @@ -170,6 +170,16 @@ "p95ResponseTime": "95% Response Time (s)", "medianResponseTime": "Median Response Time (s)", "failureCount": "Failure Requests", + "successRequestCount": "Success Requests", + "totalQpm": "Total QPM", + "successQpm": "Success QPM", + "failureQpm": "Failure QPM", + "successAvgResponseTime": "Success Avg Time (s)", + "successP95ResponseTime": "Success P95 (s)", + "successRps": "Success RPS", + "failureAvgResponseTime": "Failure Avg Time (s)", + "failureP95ResponseTime": "Failure P95 (s)", + "failureRps": "Failure RPS", "avgContentLength": "Avg Content Length", "rps": "RPS (req/s)", "ttft": "TTFT (s)", @@ -594,9 +604,11 @@ "jsonlFormatDescription": "Required format: .jsonl file with each line containing {\"id\": \"...\", \"prompt\": \"...\"}", "datasetFileFormatDescription": "Supports JSON (ShareGPT format) and JSONL formats:\n• JSON: [{\"id\": \"...\", \"conversations\": [...]}, ...]\n• JSONL: one JSON object per line {\"id\": \"...\", \"prompt\": \"...\"}", "datasetFileFormatDescriptionChat": "Supports JSONL format:\n• JSONL: one JSON object per line {\"id\": \"...\", \"messages\": [...]} or {\"id\": \"...\", \"prompt\": \"...\"}", + "datasetFileFormatDescriptionCustom": "Supports JSONL format:\n• JSONL: one valid JSON object per line representing the full request payload sent to the API.", "datasetImageMountWarning": "⚠️ If the dataset contains images paths, ensure image files are mounted to the container before starting the service. See DATASET_GUIDE for details.", "jsonlData": "JSONL Data", "jsonlDataTooltip": "Each line must be a valid JSON object with \"id\" and \"prompt\" fields.", + "jsonlDataTooltipPayload": "Each line must be a complete request payload and a valid JSON object.", "testDuration": "Test Duration (seconds)", "testDurationTooltip": "How long the load test should run", "concurrentUsersTooltip": "Maximum number of virtual users sending requests simultaneously", @@ -671,6 +683,7 @@ "pleaseUploadDatasetFile": "Please upload dataset file", "pleaseEnterJsonlData": "Please enter JSONL data", "invalidJsonlFormat": "Invalid JSONL format. Each line must be valid JSON with required fields.", + "invalidJsonlFormatBasic": "Invalid JSONL format. Each line must be valid JSON.", "eachLineMustContainFields": "Each line must contain \"id\" and \"prompt\" fields", "eachLineMustContainIdAndMessages": "Each line must contain \"id\" and \"prompt\" or \"messages\" array", "pleaseEnterTestDuration": "Please enter test duration", diff --git a/frontend/public/locales/zh/translation.json b/frontend/public/locales/zh/translation.json index b65de83..58edade 100644 --- a/frontend/public/locales/zh/translation.json +++ b/frontend/public/locales/zh/translation.json @@ -170,6 +170,16 @@ "medianResponseTime": "中位数响应时间(秒)", "avgContentLength": "平均响应长度", "failureCount": "失败请求数", + "successRequestCount": "成功请求数", + "totalQpm": "总 QPM", + "successQpm": "成功 QPM", + "failureQpm": "失败 QPM", + "successAvgResponseTime": "成功平均响应时间(秒)", + "successP95ResponseTime": "成功 P95 响应时间(秒)", + "successRps": "成功 RPS", + "failureAvgResponseTime": "失败平均响应时间(秒)", + "failureP95ResponseTime": "失败 P95 响应时间(秒)", + "failureRps": "失败 RPS", "rps": "RPS (请求/秒)", "ttft": "首Token时延 (秒)", "totalTps": "Total Tokens 吞吐量 (Tokens/秒)", @@ -559,7 +569,7 @@ "responseMode": "响应模式", "responseModeTooltip": "选择流式和非流式响应模式", "requestPayload": "请求参数", - "requestPayloadTooltip": "可选择补充除模型名称、流式模式之外的API的请求参数,若没有可填写{}。请在此处使用简单的测试prompt快速调试,后续可通过选择数据集进行压测。", + "requestPayloadTooltip": "请填写完整的请求参数,请在此处使用简单的测试prompt快速调试,后续可通过选择数据集进行压测。", "systemPrompt": "系统提示词", "systemPromptTooltip": "将与每个请求一起发送的系统级指令", "advancedSettings": "高级设置", @@ -595,9 +605,11 @@ "jsonlFormatDescription": "必需格式:.jsonl文件,每行包含{\"id\": \"...\", \"prompt\": \"...\"}", "datasetFileFormatDescription": "支持 JSON (ShareGPT 格式) 和 JSONL 格式:\n• JSON: [{\"id\": \"...\", \"conversations\": [...]}, ...]\n• JSONL: 每行一个 JSON 对象 {\"id\": \"...\", \"prompt\": \"...\"}", "datasetFileFormatDescriptionChat": "支持 JSONL 格式:\n• JSONL: 每行一个 JSON 对象 {\"id\": \"...\", \"messages\": [...]} 或 {\"id\": \"...\", \"prompt\": \"...\"}", + "datasetFileFormatDescriptionCustom": "支持 JSONL 格式:\n• JSONL: 每行一个有效 JSON 对象,且为完整请求体。", "datasetImageMountWarning": "⚠️ 若上传数据集包含图片路径,请确保启动服务前图片已挂载到容器对应目录下,详见 DATASET_GUIDE。", "jsonlData": "JSONL数据", "jsonlDataTooltip": "每行必须是具有id和messages字段的有效JSON对象。", + "jsonlDataTooltipPayload": "每行必须是完整的请求体且是有效JSON对象。", "testDuration": "测试持续时间(秒)", "testDurationTooltip": "负载测试应运行多长时间", "concurrentUsersTooltip": "同时发送请求的最大虚拟用户数", @@ -672,6 +684,7 @@ "pleaseUploadDatasetFile": "请上传数据集文件", "pleaseEnterJsonlData": "请输入JSONL数据", "invalidJsonlFormat": "无效的JSONL格式。每行必须是具有必需字段的有效JSON。", + "invalidJsonlFormatBasic": "无效的JSONL格式。每行必须是有效的JSON。", "eachLineMustContainFields": "每行必须包含id和prompt字段", "eachLineMustContainIdAndMessages": "每行必须包含id和prompt或messages数组", "pleaseEnterTestDuration": "请输入测试持续时间", diff --git a/frontend/src/components/AddToCollectionModal.tsx b/frontend/src/components/AddToCollectionModal.tsx index afa7698..07b7874 100644 --- a/frontend/src/components/AddToCollectionModal.tsx +++ b/frontend/src/components/AddToCollectionModal.tsx @@ -226,12 +226,12 @@ const AddToCollectionModal: React.FC = ({ setCollections(prev => [newCollection, ...prev]); form.setFieldsValue({ collection_id: newCollection.id }); setSearchValue(''); - message.success(t('collections.createSuccess')); + message.success(t('pages.collections.createSuccess')); // Automatically submit to add tasks to the newly created collection await handleSubmit(); } catch (error) { - message.error(t('collections.createFailed')); + message.error(t('pages.collections.createFailed')); setSubmitting(false); } }; diff --git a/frontend/src/components/CreateHttpTaskForm.tsx b/frontend/src/components/CreateHttpTaskForm.tsx index 9677d79..63667a3 100644 --- a/frontend/src/components/CreateHttpTaskForm.tsx +++ b/frontend/src/components/CreateHttpTaskForm.tsx @@ -258,13 +258,6 @@ const CreateHttpTaskForm: React.FC = ({ }, [methodValue]); const buildPayload = (values: any, includeTempId: boolean = false) => { - const curlCommand = (values.curl_command || '').trim(); - const maxCurlLength = 8000; - const isCurlTooLong = curlCommand.length > maxCurlLength; - const safeCurlCommand = isCurlTooLong - ? curlCommand.slice(0, maxCurlLength) - : curlCommand; - const hasBody = METHODS_WITH_BODY.has((values.method || '').toUpperCase()); const datasetFile = hasBody && values.dataset_source === 'upload' @@ -311,7 +304,6 @@ const CreateHttpTaskForm: React.FC = ({ request_body: hasBody ? values.request_body || '' : '', dataset_file: datasetFile, dataset_source: hasBody ? values.dataset_source || 'none' : 'none', - curl_command: safeCurlCommand, success_assert: successAssert || null, headers: (values.headers || '').trim() ? values.headers @@ -331,6 +323,8 @@ const CreateHttpTaskForm: React.FC = ({ delete payload.success_assert_field; delete payload.success_assert_operator; delete payload.success_assert_value; + // Raw curl text can trip upstream WAF rules; parsed fields above are sufficient. + delete payload.curl_command; if (mode === 'fixed') { // Clear stepped fields for fixed mode diff --git a/frontend/src/components/CreateLlmTaskForm.tsx b/frontend/src/components/CreateLlmTaskForm.tsx index 7e6d58a..47fab4d 100644 --- a/frontend/src/components/CreateLlmTaskForm.tsx +++ b/frontend/src/components/CreateLlmTaskForm.tsx @@ -394,7 +394,8 @@ const CreateLlmTaskFormContent: React.FC = ({ // Add chat_type validation when using default dataset and chat API const currentTestDataInputType = - form.getFieldValue('test_data_input_type') || 'default'; + form.getFieldValue('test_data_input_type') || + (isStandardChatApi ? 'default' : 'none'); if ( currentTestDataInputType === 'default' && (currentApiType === 'openai-chat' || currentApiType === 'claude-chat') @@ -432,7 +433,6 @@ const CreateLlmTaskFormContent: React.FC = ({ // Form values states to replace Form.useWatch const [concurrentUsers, setConcurrentUsers] = useState(); - const [streamMode, setStreamMode] = useState(true); const [isFormReady, setIsFormReady] = useState(false); // Initialize form ready state @@ -451,7 +451,6 @@ const CreateLlmTaskFormContent: React.FC = ({ const currentStreamMode = form.getFieldValue('stream_mode'); if (currentApiType === 'embeddings' && currentStreamMode !== false) { form.setFieldsValue({ stream_mode: false }); - setStreamMode(false); // Update request_payload for embeddings API const currentModel = form.getFieldValue('model') || ''; @@ -631,6 +630,28 @@ const CreateLlmTaskFormContent: React.FC = ({ dataToFill.test_data_file = extractFilename(dataToFill.test_data); } + // Dataset source defaults must match API type (no built-in option for embeddings/custom-chat) + const fillApiType = dataToFill.api_type || 'openai-chat'; + const isFillChatApi = + fillApiType === 'openai-chat' || fillApiType === 'claude-chat'; + if (isFillChatApi) { + if ( + dataToFill.test_data_input_type === undefined || + dataToFill.test_data_input_type === null || + dataToFill.test_data_input_type === '' + ) { + dataToFill.test_data_input_type = 'default'; + } + } else if ( + dataToFill.test_data_input_type === undefined || + dataToFill.test_data_input_type === null || + dataToFill.test_data_input_type === '' || + dataToFill.test_data_input_type === 'default' + ) { + dataToFill.test_data_input_type = 'none'; + dataToFill.chat_type = undefined; + } + // clean fields that should not be copied directly or provided by the user delete dataToFill.id; delete dataToFill.status; @@ -640,11 +661,6 @@ const CreateLlmTaskFormContent: React.FC = ({ form.setFieldsValue(dataToFill); - // Update stream mode state for proper field mapping - if (dataToFill.stream_mode !== undefined) { - setStreamMode(dataToFill.stream_mode); - } - if ( dataToFill.concurrent_users && dataToFill.spawn_rate && @@ -1140,7 +1156,11 @@ const CreateLlmTaskFormContent: React.FC = ({ } // Handle test data input type - const inputType = values.test_data_input_type || 'default'; + const submitApiType = values.api_type || 'openai-chat'; + const isSubmitChatApi = + submitApiType === 'openai-chat' || submitApiType === 'claude-chat'; + const inputType = + values.test_data_input_type || (isSubmitChatApi ? 'default' : 'none'); if (inputType === 'upload') { // validateFields() may omit unmounted fields; read from form store as fallback. const storedTestData = form.getFieldValue('test_data'); @@ -1714,28 +1734,41 @@ const CreateLlmTaskFormContent: React.FC = ({ - - - {t('components.createJobForm.modelName')} - - - - - } - rules={[ - { - max: 255, - message: t('components.createJobForm.modelNameLengthLimit'), - }, - ]} - normalize={value => value?.trim() || ''} - > - - - + + {({ getFieldValue }) => { + const apiType = getFieldValue('api_type'); + const isEmbeddingsOrCustom = + apiType === 'embeddings' || apiType === 'custom-chat'; + return !isEmbeddingsOrCustom ? ( + + + {t('components.createJobForm.modelName')} + + + + + } + rules={[ + { + max: 255, + message: t( + 'components.createJobForm.modelNameLengthLimit' + ), + }, + ]} + normalize={value => value?.trim() || ''} + > + + + + ) : null; + }} + @@ -2186,12 +2219,15 @@ const CreateLlmTaskFormContent: React.FC = ({ ) || 'Supports JSONL format:\n• JSONL: one JSON object per line {"id": "...", "messages": [...]}' : t( - 'components.createJobForm.datasetFileFormatDescription' - )} - - - {t('components.createJobForm.datasetImageMountWarning')} + 'components.createJobForm.datasetFileFormatDescriptionCustom' + ) || + 'Supports JSONL format:\n• JSONL: one JSON object per line, representing the full request payload'} + {isChatApi && ( + + {t('components.createJobForm.datasetImageMountWarning')} + + )} @@ -2202,7 +2238,9 @@ const CreateLlmTaskFormContent: React.FC = ({ {t('components.createJobForm.jsonlData')} - {t('components.createJobForm.jsonlDataTooltip')} + {isChatApi + ? t('components.createJobForm.jsonlDataTooltip') + : t('components.createJobForm.jsonlDataTooltipPayload')} = ({ 'Each line must contain "id" and "prompt" or "messages" array'; throw new Error(customError); } - } else if (!jsonObj.id || !jsonObj.prompt) { - customError = t( - 'components.createJobForm.eachLineMustContainFields' - ); - throw new Error(customError); } }); return Promise.resolve(); @@ -2263,9 +2296,14 @@ const CreateLlmTaskFormContent: React.FC = ({ ) || 'Each line must contain "id" and "prompt" or "messages" array') ? e.message - : t( - 'components.createJobForm.invalidJsonlFormat' - ) + : isChatApi + ? t( + 'components.createJobForm.invalidJsonlFormat' + ) + : t( + 'components.createJobForm.invalidJsonlFormatBasic' + ) || + 'Invalid JSONL format. Each line must be valid JSON.' ) ); } @@ -2278,7 +2316,9 @@ const CreateLlmTaskFormContent: React.FC = ({ placeholder={ isChatApi ? `{"id": "1", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"}]}\n{"id": "2", "messages": [{"role": "user", "content": "What is AI?"}]}` - : `{"id": "1", "prompt": "Hello, how are you?"}\n{"id": "2", "prompt": "What is artificial intelligence?"}\n{"id": "3", "prompt": "Explain machine learning in simple terms"}` + : currentApiType === 'custom-chat' + ? `{"model":"custom-chat-model","stream":true,"messages":[{"role":"user","content":"Hello, how are you?"}]}\n{"model":"custom-chat-model","stream":true,"messages":[{"role":"user","content":"What is artificial intelligence?"}]}` + : `{"id": "1", "input": "Hello, how are you?", "model": "text-embedding-3-small"}\n{"id": "2", "input": "What is artificial intelligence?", "model": "text-embedding-3-small"}` } maxLength={50000} showCount @@ -2334,18 +2374,34 @@ const CreateLlmTaskFormContent: React.FC = ({ size='large' placeholder={t('components.createJobForm.datasetSource')} > - - {t('components.createJobForm.builtInDataset')} - - - {t('components.createJobForm.customJsonlData')} - - - {t('components.createJobForm.uploadJsonlFile')} - - - {t('components.createJobForm.noDataset')} - + {isChatApi ? ( + <> + + {t('components.createJobForm.builtInDataset')} + + + {t('components.createJobForm.noDataset')} + + + {t('components.createJobForm.uploadJsonlFile')} + + + {t('components.createJobForm.customJsonlData')} + + + ) : ( + <> + + {t('components.createJobForm.noDataset')} + + + {t('components.createJobForm.uploadJsonlFile')} + + + {t('components.createJobForm.customJsonlData')} + + + )} {additionalContent} @@ -2815,22 +2871,25 @@ const CreateLlmTaskFormContent: React.FC = ({ {({ getFieldValue }) => { const currentApiType = getFieldValue('api_type') || 'openai-chat'; const isEmbedType = currentApiType === 'embeddings'; + const currentStreamMode = getFieldValue('stream_mode'); // Get placeholders based on API type const getContentPlaceholder = () => { if (currentApiType === 'claude-chat') { - return streamMode ? 'content.-1.text' : 'content.-1.text'; + return currentStreamMode ? 'content.-1.text' : 'content.-1.text'; } - return streamMode + return currentStreamMode ? 'choices.0.delta.content' : 'choices.0.message.content'; }; const getReasoningPlaceholder = () => { if (currentApiType === 'claude-chat') { - return streamMode ? 'content.0.thinking' : 'content.0.thinking'; + return currentStreamMode + ? 'content.0.thinking' + : 'content.0.thinking'; } - return streamMode + return currentStreamMode ? 'choices.0.delta.reasoning_content' : 'choices.0.message.reasoning_content'; }; @@ -2862,7 +2921,7 @@ const CreateLlmTaskFormContent: React.FC = ({ return null; } - return streamMode ? ( + return currentStreamMode ? ( // Streaming mode configuration <> {/* Stream Data Configuration */} @@ -3509,9 +3568,6 @@ const CreateLlmTaskFormContent: React.FC = ({ if ('concurrent_users' in changedValues) { setConcurrentUsers(changedValues.concurrent_users); } - if ('stream_mode' in changedValues) { - setStreamMode(changedValues.stream_mode); - } } else { // Handle API type changes — regenerate entire payload since structure differs if ('api_type' in changedValues && !isCopyMode) { @@ -3525,7 +3581,6 @@ const CreateLlmTaskFormContent: React.FC = ({ const isEmbedType = newApiType === 'embeddings'; if (isEmbedType) { form.setFieldsValue({ stream_mode: false }); - setStreamMode(false); } // Update dataset settings based on API type @@ -3571,7 +3626,6 @@ const CreateLlmTaskFormContent: React.FC = ({ // Handle stream_mode changes — update only `stream` field in payload JSON if ('stream_mode' in changedValues) { - setStreamMode(changedValues.stream_mode); // Update field_mapping default values when stream mode changes (but not in copy mode) if (!isCopyMode) { const currentApiType = @@ -3621,7 +3675,6 @@ const CreateLlmTaskFormContent: React.FC = ({ extracted.stream !== form.getFieldValue('stream_mode') ) { updates.stream_mode = extracted.stream; - setStreamMode(extracted.stream); } if (Object.keys(updates).length > 0) { isSyncingRef.current = true; diff --git a/frontend/src/pages/HttpResults.tsx b/frontend/src/pages/HttpResults.tsx index 19323a3..2719477 100644 --- a/frontend/src/pages/HttpResults.tsx +++ b/frontend/src/pages/HttpResults.tsx @@ -265,8 +265,20 @@ const HttpResults: React.FC = () => { () => results.find(r => r.metric_type === 'total') || results[0], [results] ); + const totalSuccessRow = useMemo( + () => results.find(r => r.metric_type === 'total::success'), + [results] + ); + const totalFailureRow = useMemo( + () => results.find(r => r.metric_type === 'total::failure'), + [results] + ); const totalRequests = totalRow?.request_count ?? 0; const failureCount = totalRow?.failure_count ?? 0; + const hasOutcomeSplit = Boolean(totalSuccessRow || totalFailureRow); + const failureRequests = hasOutcomeSplit + ? (totalFailureRow?.request_count ?? 0) + : failureCount; // Smart format success rate: if close to 100% but not 100%, show more decimal places const calculateSuccessRate = (total: number, failures: number): number => { @@ -287,7 +299,12 @@ const HttpResults: React.FC = () => { totalRow?.rps != null && totalRow.rps !== undefined ? Number(totalRow.rps) : 0; - const qpm = Number((rawRps * 60).toFixed(2)); + const successRps = hasOutcomeSplit + ? Number(totalSuccessRow?.rps ?? 0) + : failureCount === 0 + ? rawRps + : 0; + const successQpm = Number((successRps * 60).toFixed(2)); const avgTimeSec = totalRow?.avg_response_time != null ? Number((totalRow.avg_response_time / 1000).toFixed(3)) @@ -296,6 +313,50 @@ const HttpResults: React.FC = () => { totalRow?.percentile_95_response_time != null ? Number((totalRow.percentile_95_response_time / 1000).toFixed(3)) : 0; + const successAvgTimeSec = hasOutcomeSplit + ? Number(((totalSuccessRow?.avg_response_time ?? 0) / 1000).toFixed(3)) + : failureCount === 0 + ? avgTimeSec + : 0; + const successP95TimeSec = hasOutcomeSplit + ? Number( + ((totalSuccessRow?.percentile_95_response_time ?? 0) / 1000).toFixed(3) + ) + : failureCount === 0 + ? p95TimeSec + : 0; + + const metricsDetailRows = useMemo(() => { + const baseNames = new Set(); + results.forEach(r => { + const mt = r.metric_type as string | undefined; + if ( + !mt || + mt === 'total' || + mt === 'total::success' || + mt === 'total::failure' + ) { + return; + } + if (mt.endsWith('::success')) { + baseNames.add(mt.replace('::success', '')); + } else if (mt.endsWith('::failure')) { + baseNames.add(mt.replace('::failure', '')); + } else { + baseNames.add(mt); + } + }); + + const orderedMetricTypes: string[] = []; + Array.from(baseNames).forEach(name => { + orderedMetricTypes.push(`${name}::success`, `${name}::failure`); + }); + orderedMetricTypes.push('total'); + + return orderedMetricTypes + .map(metricType => results.find(r => r.metric_type === metricType)) + .filter(Boolean); + }, [results]); // Whether the task is currently in a stoppable state const isTaskRunning = @@ -650,7 +711,21 @@ const HttpResults: React.FC = () => { key: 'metric_type', width: 140, ellipsis: true, - render: (text: string) => text, + render: (text: string) => { + if (text.endsWith('::success')) { + return `${text.replace('::success', '')} (${t( + 'common.success', + 'Success' + )})`; + } + if (text.endsWith('::failure')) { + return `${text.replace('::failure', '')} (${t( + 'common.error', + 'Failure' + )})`; + } + return text; + }, }, { title: t('pages.results.totalRequests', 'Requests'), @@ -1059,8 +1134,8 @@ const HttpResults: React.FC = () => {
- - + + { valueStyle={statisticValueStyle} /> - + { valueStyle={statisticValueStyle} /> - + + 0 ? 'var(--color-error)' : undefined, + }} + /> + + - + + + + - {p95TimeSec > 0 && ( - - - - )}
@@ -1124,7 +1209,7 @@ const HttpResults: React.FC = () => {
self._logger_cache_maxsize: + self._logger_cache.popitem(last=False) + return bound_logger + + def evict_task_logger(self, task_id: str) -> None: + """Remove a specific task_id from the logger cache.""" + if self._gevent_lock is not None: + with self._gevent_lock: + self._logger_cache.pop(task_id, None) + else: + self._logger_cache.pop(task_id, None) class SimpleLock: @@ -566,9 +582,10 @@ def validate_config(config: GlobalConfig, task_logger) -> bool: task_logger.error("Task ID is required but not provided") return False - if not config.model_name: - task_logger.error("Model name is required") - return False + if getattr(config, "api_type", "") in ("openai-chat", "claude-chat"): + if not config.model_name: + task_logger.error("Model name is required for this API type") + return False if not config.request_payload: task_logger.error("Request payload is required for all API endpoints") diff --git a/st_engine/engine/http_locustfile.py b/st_engine/engine/http_locustfile.py index 1750609..f37d608 100644 --- a/st_engine/engine/http_locustfile.py +++ b/st_engine/engine/http_locustfile.py @@ -8,6 +8,7 @@ import queue import tempfile import uuid +from collections import Counter from typing import Any, Dict, Optional import gevent @@ -15,6 +16,12 @@ from locust import HttpUser, events, task from urllib3.exceptions import InsecureRequestWarning +from config.base import ( + HTTP_OUTCOME_EXACT_LATENCY_MS, + HTTP_OUTCOME_LATENCY_BUCKET_MS, + MAX_QUEUE_SIZE, +) +from utils.error_handler import _safe_repr_truncate from utils.logger import logger, setup_clean_log_format from utils.realtime_metrics import realtime_metrics_greenlet @@ -23,6 +30,181 @@ urllib3.disable_warnings(InsecureRequestWarning) +class _OutcomeBucket: + """Aggregate response-time stats without keeping one item per request.""" + + def __init__(self) -> None: + self.count = 0 + self.total_latency = 0.0 + self.total_content_length = 0 + self.min_latency: Optional[float] = None + self.max_latency: Optional[float] = None + self.response_times: Counter[int] = Counter() + + @staticmethod + def _bucket_latency(latency: float) -> int: + rounded_latency = int(round(max(0.0, latency))) + if rounded_latency <= HTTP_OUTCOME_EXACT_LATENCY_MS: + return rounded_latency + bucket_size = max(1, HTTP_OUTCOME_LATENCY_BUCKET_MS) + return int(round(rounded_latency / bucket_size) * bucket_size) + + def record( + self, + response_time: Optional[float], + response_length: Optional[int] = None, + ) -> None: + latency = float(response_time or 0.0) + self.count += 1 + self.total_latency += latency + self.total_content_length += int(response_length or 0) + self.min_latency = ( + latency if self.min_latency is None else min(self.min_latency, latency) + ) + self.max_latency = ( + latency if self.max_latency is None else max(self.max_latency, latency) + ) + self.response_times[self._bucket_latency(latency)] += 1 + + def merge(self, payload: Dict[str, Any]) -> None: + count = int(payload.get("count", 0) or 0) + self.count += count + self.total_latency += float(payload.get("total_latency", 0.0) or 0.0) + self.total_content_length += int(payload.get("total_content_length", 0) or 0) + + min_latency = payload.get("min_latency") + if min_latency is not None: + min_latency = float(min_latency) + self.min_latency = ( + min_latency + if self.min_latency is None + else min(self.min_latency, min_latency) + ) + + max_latency = payload.get("max_latency") + if max_latency is not None: + max_latency = float(max_latency) + self.max_latency = ( + max_latency + if self.max_latency is None + else max(self.max_latency, max_latency) + ) + + for latency, latency_count in (payload.get("response_times") or {}).items(): + self.response_times[self._bucket_latency(float(latency))] += int( + latency_count + ) + + def avg(self) -> float: + return self.total_latency / self.count if self.count else 0.0 + + def avg_content_length(self) -> float: + return self.total_content_length / self.count if self.count else 0.0 + + def percentile(self, percentile: float) -> float: + if not self.count: + return 0.0 + threshold = self.count * percentile + seen = 0 + for latency in sorted(self.response_times): + seen += self.response_times[latency] + if seen >= threshold: + return float(latency) + return float(max(self.response_times.keys(), default=0)) + + def to_dict(self) -> Dict[str, Any]: + return { + "count": self.count, + "total_latency": self.total_latency, + "total_content_length": self.total_content_length, + "min_latency": self.min_latency, + "max_latency": self.max_latency, + "response_times": dict(self.response_times), + } + + +class _OutcomeStats: + """Track success/failure stats per request name and for the total row.""" + + def __init__(self) -> None: + self.rows: Dict[str, Dict[str, _OutcomeBucket]] = {} + + def record( + self, + name: Optional[str], + response_time: Optional[float], + response_length: Optional[int], + failed: bool, + ) -> None: + metric_name = str(name or "unknown") + for row_name in (metric_name, "total"): + buckets = self.rows.setdefault( + row_name, + {"success": _OutcomeBucket(), "failure": _OutcomeBucket()}, + ) + buckets["failure" if failed else "success"].record( + response_time, response_length + ) + + def merge(self, payload: Dict[str, Any]) -> None: + for row_name, row_payload in (payload or {}).items(): + buckets = self.rows.setdefault( + str(row_name), + {"success": _OutcomeBucket(), "failure": _OutcomeBucket()}, + ) + for outcome in ("success", "failure"): + if outcome in row_payload: + buckets[outcome].merge(row_payload[outcome]) + + def build_rows(self, task_id: str, name: str, total_rps: float) -> list[Dict]: + buckets = self.rows.get(name) + if not buckets: + return [] + + success = buckets["success"] + failure = buckets["failure"] + total_requests = success.count + failure.count + + def proportional_rps(count: int) -> float: + if total_requests <= 0: + return 0.0 + return float(total_rps or 0.0) * count / total_requests + + rows = [] + for suffix, bucket in (("success", success), ("failure", failure)): + if bucket.count == 0: + continue + rows.append( + { + "task_id": task_id, + "metric_type": f"{name}::{suffix}", + "num_requests": bucket.count, + "num_failures": bucket.count if suffix == "failure" else 0, + "avg_latency": bucket.avg(), + "min_latency": bucket.min_latency or 0.0, + "max_latency": bucket.max_latency or 0.0, + "median_latency": bucket.percentile(0.5), + "p95_latency": bucket.percentile(0.95), + "rps": proportional_rps(bucket.count), + "avg_content_length": bucket.avg_content_length(), + } + ) + return rows + + def to_dict(self) -> Dict[str, Any]: + return { + row_name: { + "success": buckets["success"].to_dict(), + "failure": buckets["failure"].to_dict(), + } + for row_name, buckets in self.rows.items() + } + + +_OUTCOME_STATS = _OutcomeStats() +_WORKER_OUTCOME_STATS: Dict[str, Dict[str, Any]] = {} + + def _parse_kv(json_str: str) -> Dict[str, str]: """Safely parse headers/cookies JSON string.""" if not json_str: @@ -250,11 +432,46 @@ def on_locust_init(environment, **kwargs): setup_clean_log_format() +@events.request.add_listener +def on_request( + request_type=None, + name=None, + response_time=None, + response_length=None, + response=None, + context=None, + exception=None, + start_time=None, + url=None, + **kwargs, +): + """Collect success/failure split metrics for final HTTP results.""" + _OUTCOME_STATS.record( + name, + response_time, + response_length, + exception is not None, + ) + + +@events.report_to_master.add_listener +def on_report_to_master(client_id, data, **kwargs): + """Send worker-local success/failure split metrics to the master.""" + data["http_outcome_stats"] = _OUTCOME_STATS.to_dict() + + +@events.worker_report.add_listener +def on_worker_report(client_id, data, **kwargs): + """Keep the latest cumulative split metrics received from each worker.""" + _WORKER_OUTCOME_STATS[str(client_id)] = data.get("http_outcome_stats") or {} + + def _preload_dataset(environment) -> None: """Pre-load dataset file into a shared queue on the environment. Called once during ``test_start`` so that all users share the same queue without racing during ``on_start``. + In multiprocess mode, uses mmap-backed SharedDatasetReader for memory efficiency. """ options = environment.parsed_options dataset_file = getattr(options, "dataset_file", "") or "" @@ -264,10 +481,55 @@ def _preload_dataset(environment) -> None: task_id = options.task_id or os.environ.get("TASK_ID", "unknown") task_logger = logger.bind(task_id=task_id) + is_multiprocess = os.environ.get("LOCUST_PROCESSES", "1") != "1" + try: + if is_multiprocess: + from utils.shared_dataset import DatasetQueueAdapter, SharedDatasetReader + + items = [] + with open(dataset_file, "r", encoding="utf-8") as f: + for line in f: + if len(items) >= MAX_QUEUE_SIZE: + task_logger.warning( + f"Dataset file {dataset_file} exceeded " + f"MAX_QUEUE_SIZE={MAX_QUEUE_SIZE}; remaining records skipped." + ) + break + line = line.strip() + if not line: + continue + try: + payload = json.loads(line) + items.append({"json": payload}) + except Exception: + items.append({"text": line}) + + if items: + reader = SharedDatasetReader.from_items(items, task_logger) + environment.dataset_queue = DatasetQueueAdapter(reader) + task_logger.info( + f"Using SharedDatasetReader ({len(reader)} items) " + f"for multiprocess HTTP mode" + ) + return + else: + environment.dataset_queue = None + task_logger.warning( + f"Dataset file {dataset_file} contained no valid records." + ) + return + + # Single-process fallback: standard queue dq: queue.Queue = queue.Queue() with open(dataset_file, "r", encoding="utf-8") as f: for line in f: + if dq.qsize() >= MAX_QUEUE_SIZE: + task_logger.warning( + f"Dataset file {dataset_file} exceeded " + f"MAX_QUEUE_SIZE={MAX_QUEUE_SIZE}; remaining records skipped." + ) + break line = line.strip() if not line: continue @@ -306,6 +568,10 @@ def on_test_start(environment, **kwargs): Dataset pre-loading runs on ALL processes (master, worker, local) because each process has its own address space and its own User instances. """ + global _OUTCOME_STATS, _WORKER_OUTCOME_STATS + _OUTCOME_STATS = _OutcomeStats() + _WORKER_OUTCOME_STATS = {} + task_id = environment.parsed_options.task_id or os.environ.get("TASK_ID", "unknown") task_logger = logger.bind(task_id=task_id) load_mode = os.environ.get("LOAD_MODE", "fixed") @@ -370,6 +636,11 @@ def on_test_stop(environment, **kwargs): locust_stats = [] try: + outcome_stats = _OutcomeStats() + outcome_stats.merge(_OUTCOME_STATS.to_dict()) + for worker_payload in _WORKER_OUTCOME_STATS.values(): + outcome_stats.merge(worker_payload) + # Locust `stats.entries` keys are (name, method) tuples. # Use `stat.name` for a clean string metric_type. for entry_key, stat in environment.stats.entries.items(): @@ -380,6 +651,9 @@ def on_test_stop(environment, **kwargs): row = _build_stat_row(task_id, stat_name, stat) if row: locust_stats.append(row) + locust_stats.extend( + outcome_stats.build_rows(task_id, stat_name, stat.total_rps) + ) except Exception as e: # pragma: no cover - defensive task_logger.warning(f"Failed to build stat row for '{stat_name}': {e}") @@ -389,6 +663,9 @@ def on_test_stop(environment, **kwargs): total_row = _build_stat_row(task_id, "total", total_stat) if total_row: locust_stats.append(total_row) + locust_stats.extend( + outcome_stats.build_rows(task_id, "total", total_stat.total_rps) + ) except Exception as e: # pragma: no cover - defensive task_logger.warning(f"Failed to build total stat row: {e}") except Exception as e: # pragma: no cover - defensive @@ -402,6 +679,11 @@ def on_test_stop(environment, **kwargs): except Exception as e: # pragma: no cover - defensive task_logger.error(f"Failed to write result file: {e}", exc_info=True) + # Release shared dataset mmap resources if applicable + dataset_queue = getattr(environment, "dataset_queue", None) + if dataset_queue and hasattr(dataset_queue, "close"): + dataset_queue.close() + # --------------------------------------------------------------------------- # Stepped load shape (conditionally activated via LOAD_MODE env var) @@ -531,13 +813,7 @@ def invoke_api(self): self.task_logger.opt(lazy=True).debug( "[{req_id}] Request Payload: {payload}", req_id=lambda: req_id, - payload=lambda: ( - lambda s: ( - s[:500] + "... (truncated)" - if len(s) > 500 - else s - ) - )(repr(payload_data)), + payload=lambda: _safe_repr_truncate(payload_data, 500), ) else: resp.failure(reason) @@ -557,11 +833,7 @@ def invoke_api(self): self.task_logger.opt(lazy=True).debug( "[{req_id}] Request Payload: {payload}", req_id=lambda: req_id, - payload=lambda: ( - lambda s: ( - s[:500] + "... (truncated)" if len(s) > 500 else s - ) - )(repr(payload_data)), + payload=lambda: _safe_repr_truncate(payload_data, 500), ) except Exception as e: # pragma: no cover - network dependent # Log failure with request context diff --git a/st_engine/engine/llm_locustfile.py b/st_engine/engine/llm_locustfile.py index 44ec294..8fb5055 100644 --- a/st_engine/engine/llm_locustfile.py +++ b/st_engine/engine/llm_locustfile.py @@ -117,6 +117,7 @@ def _ensure_prompt_queue(environment, options, task_logger): Ensure the Locust environment carries a prompt_queue. Returns the queue instance for chaining. In warmup mode, skip dataset loading - use default prompt only. + In multiprocess mode, uses SharedDatasetReader for memory efficiency. """ if hasattr(environment, "prompt_queue"): return environment.prompt_queue @@ -127,11 +128,34 @@ def _ensure_prompt_queue(environment, options, task_logger): return environment.prompt_queue try: + # In multiprocess mode, use mmap-backed shared dataset to avoid + # duplicating the entire dataset in each worker process. + is_multiprocess = os.environ.get("LOCUST_PROCESSES", "1") != "1" + + if is_multiprocess: + from utils.dataset_loader import init_shared_dataset + from utils.shared_dataset import DatasetQueueAdapter + + reader = init_shared_dataset( + chat_type=int(getattr(options, "chat_type", 0)), + test_data=getattr(options, "test_data", "") or "", + api_type=getattr(options, "api_type", ""), + task_logger=task_logger, + ) + if reader is not None: + environment.prompt_queue = DatasetQueueAdapter(reader) + task_logger.info( + f"Using SharedDatasetReader ({len(reader)} items) for multiprocess mode" + ) + return environment.prompt_queue + + # Fallback: standard queue (single-process or shared dataset failed) from utils.dataset_loader import init_prompt_queue environment.prompt_queue = init_prompt_queue( chat_type=int(getattr(options, "chat_type", 0)), test_data=getattr(options, "test_data", "") or "", + api_type=getattr(options, "api_type", ""), task_logger=task_logger, ) except Exception as exc: @@ -480,6 +504,11 @@ def on_test_stop(environment, **kwargs): except Exception as e: task_logger.error(f"Error in on_test_stop: {e}", exc_info=True) + finally: + # Release shared dataset mmap resources if applicable + prompt_queue = getattr(environment, "prompt_queue", None) + if prompt_queue and hasattr(prompt_queue, "close"): + prompt_queue.close() # --------------------------------------------------------------------------- diff --git a/st_engine/engine/llm_runner.py b/st_engine/engine/llm_runner.py index 514ca9f..004ba69 100644 --- a/st_engine/engine/llm_runner.py +++ b/st_engine/engine/llm_runner.py @@ -11,12 +11,16 @@ import tempfile import threading import time -from queue import Queue +from collections import deque from typing import Dict, List, Tuple import psutil -from config.base import LOCUST_STOP_TIMEOUT, LOCUST_WAIT_TIMEOUT_BUFFER +from config.base import ( + LOCUST_STOP_TIMEOUT, + LOCUST_WAIT_TIMEOUT_BUFFER, + MAX_CAPTURED_OUTPUT_BYTES, +) from config.multiprocess import ( get_cpu_count, get_process_count, @@ -33,6 +37,45 @@ from utils.logger import logger +class _OutputTailBuffer: + """Bound retained subprocess output while preserving recent diagnostics.""" + + def __init__(self, max_bytes: int) -> None: + self.max_bytes = max(0, int(max_bytes or 0)) + self._lines: deque[str] = deque() + self._bytes = 0 + self.truncated = False + + @staticmethod + def _line_size(line: str) -> int: + return len(line.encode("utf-8", errors="replace")) + + def append(self, line: str) -> None: + if self.max_bytes <= 0: + self.truncated = True + return + + size = self._line_size(line) + if size > self.max_bytes: + encoded = line.encode("utf-8", errors="replace")[-self.max_bytes :] + line = encoded.decode("utf-8", errors="ignore") + size = self._line_size(line) + self._lines.clear() + self._bytes = 0 + self.truncated = True + + self._lines.append(line) + self._bytes += size + + while self._bytes > self.max_bytes and self._lines: + dropped = self._lines.popleft() + self._bytes -= self._line_size(dropped) + self.truncated = True + + def text(self) -> str: + return "".join(self._lines) + + class LlmLocustRunner: """ Enhanced Locust runner with robust multiprocess management. @@ -42,6 +85,7 @@ class LlmLocustRunner: _stopped_task_ids: set[str] = ( set() ) # Track task IDs that have been requested to stop + _STOPPED_IDS_HARD_CAP = 500 _WARMUP_DURATION_SECONDS = 120 _WARMUP_COOLDOWN_SECONDS = 3 _WARMUP_STOP_TIMEOUT_SECONDS = 10 @@ -53,6 +97,23 @@ def __init__(self, base_dir: str): self.base_dir, "engine", "llm_locustfile.py" ) + def _cleanup_stale_stopped_ids(self) -> int: + """Remove stopped task IDs that have no corresponding active process. + + Returns the number of entries removed. + """ + if len(self._stopped_task_ids) > self._STOPPED_IDS_HARD_CAP: + logger.warning( + f"_stopped_task_ids exceeded hard cap ({len(self._stopped_task_ids)}). " + f"Force clearing to prevent memory leak." + ) + self._stopped_task_ids.clear() + return 0 + stale_ids = self._stopped_task_ids - set(self._process_dict.keys()) + for task_id in stale_ids: + self._stopped_task_ids.discard(task_id) + return len(stale_ids) + # --- Shared stepped load helpers --- @staticmethod @@ -115,6 +176,11 @@ def run_locust_process(self, task: Task) -> dict: """ task_logger = logger.bind(task_id=task.id) + # Opportunistic cleanup of stale stopped IDs to prevent memory leak + cleaned = self._cleanup_stale_stopped_ids() + if cleaned: + task_logger.debug(f"Cleaned {cleaned} stale stopped task IDs") + try: # Step 1: Prepare environment self._prepare_task(task, task_logger) @@ -153,22 +219,15 @@ def run_locust_process(self, task: Task) -> dict: "locust_result": {}, } finally: - # Ensure cleanup is attempted even if an exception occurs - # This is a safety net. `_finalize_task` should handle normal cleanup. + # Always discard from stopped set to prevent memory leak + self._stopped_task_ids.discard(task.id) + + # Emergency cleanup if process still tracked (abnormal exit) if task.id in self._process_dict: task_logger.warning( f"Task {task.id} exited abnormally. Triggering emergency cleanup." ) - # We create a dummy process object just to satisfy the _cleanup_task signature. - # The actual PID might be invalid, but _cleanup_task will handle it gracefully. - dummy_true = shutil.which("true") or "/bin/true" - dummy_process = subprocess.Popen( - [dummy_true], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) # nosec B607,B603 - absolute path; no untrusted input - dummy_process.pid = -1 # Mark it as invalid - self._cleanup_task(task, dummy_process, task_logger) + self._cleanup_task_resources(task, task_logger) def _prepare_task(self, task: Task, task_logger) -> None: """Prepare task environment: validate config and files.""" @@ -235,10 +294,10 @@ def _run_warmup_phase(self, task: Task, task_logger) -> None: env = os.environ.copy() env["TASK_ID"] = warmup_task_id env["LOCUST_CONCURRENT_USERS"] = str(warmup_users) - # Force the child process to output DEBUG logs so we can capture payloads - # for the detailed task log. - if "LOG_LEVEL" not in env or env["LOG_LEVEL"] == "INFO": - env["LOG_LEVEL"] = "DEBUG" + # Warmup uses INFO level to avoid OOM from large payloads (e.g. base64 + # images) being repr'd into DEBUG log lines in each worker process. + if "LOG_LEVEL" not in env: + env["LOG_LEVEL"] = "INFO" existing_pythonpath = env.get("PYTHONPATH", "") env["PYTHONPATH"] = ( @@ -585,6 +644,15 @@ def _start_process( env = os.environ.copy() env["TASK_ID"] = str(task.id) env["LOCUST_CONCURRENT_USERS"] = str(task.concurrent_users) + + # Expose process count so locustfiles can detect multiprocess mode + # and use shared memory for datasets instead of per-process copies. + try: + proc_idx = cmd.index("--processes") + env["LOCUST_PROCESSES"] = cmd[proc_idx + 1] + except (ValueError, IndexError): + env["LOCUST_PROCESSES"] = "1" + # Ensure Locust subprocess can import project modules # Force the child process to output DEBUG logs so we can capture payloads # for the detailed task log. @@ -642,14 +710,16 @@ def _monitor_and_capture( self, process: subprocess.Popen, task: Task, task_logger ) -> Tuple[str, str]: """Monitor process execution and capture real-time output.""" - stdout_queue: Queue[str] = Queue() - stderr_queue: Queue[str] = Queue() + stdout_buffer = _OutputTailBuffer(MAX_CAPTURED_OUTPUT_BYTES) + stderr_buffer = _OutputTailBuffer(MAX_CAPTURED_OUTPUT_BYTES) - def read_stream(pipe, q, name): + def read_stream(pipe, output_buffer, name): + if pipe is None: + return try: for line in iter(pipe.readline, ""): if line.strip(): - q.put(line) + output_buffer.append(line) if " | DEBUG | " in line or " | DEBUG | " in line: task_logger.opt(raw=True).debug(line) elif " | WARNING | " in line or " | WARNING | " in line: @@ -663,10 +733,10 @@ def read_stream(pipe, q, name): task_logger.error(f"Error reading {name}: {e}") stdout_thread = threading.Thread( - target=read_stream, args=(process.stdout, stdout_queue, "stdout") + target=read_stream, args=(process.stdout, stdout_buffer, "stdout") ) stderr_thread = threading.Thread( - target=read_stream, args=(process.stderr, stderr_queue, "stderr") + target=read_stream, args=(process.stderr, stderr_buffer, "stderr") ) stdout_thread.daemon = True @@ -708,9 +778,13 @@ def read_stream(pipe, q, name): stdout_thread.join(timeout=10) stderr_thread.join(timeout=10) - # Drain queues - stdout = "".join(list(stdout_queue.queue)) - stderr = "".join(list(stderr_queue.queue)) + stdout = stdout_buffer.text() + stderr = stderr_buffer.text() + if stdout_buffer.truncated or stderr_buffer.truncated: + task_logger.warning( + "Subprocess output exceeded in-memory capture limit; " + f"retained last {MAX_CAPTURED_OUTPUT_BYTES} bytes per stream." + ) return stdout, stderr @@ -774,6 +848,10 @@ def _finalize_task( def _cleanup_task(self, task: Task, process: subprocess.Popen, task_logger) -> None: """Perform comprehensive cleanup after task completion.""" + self._cleanup_task_resources(task, task_logger) + + def _cleanup_task_resources(self, task, task_logger) -> None: + """Core cleanup logic that does not require a process object.""" task_id = task.id task_logger.info(f"Starting cleanup for task {task_id}") diff --git a/st_engine/engine/process_manager.py b/st_engine/engine/process_manager.py index 91d7ae6..164d514 100644 --- a/st_engine/engine/process_manager.py +++ b/st_engine/engine/process_manager.py @@ -239,6 +239,61 @@ def cleanup_task(self, task_id: str) -> None: # Remove from tracking del self._process_groups[task_id] + def find_locust_processes_by_task_id(self, task_id: str) -> List[psutil.Process]: + """Find running Locust processes whose command line references task_id.""" + matches: List[psutil.Process] = [] + for proc in self._iter_external_locust_processes(): + try: + cmdline = proc.info.get("cmdline", []) + if cmdline is None: + cmdline = [] + cmdline_str = ( + " ".join(cmdline) + if isinstance(cmdline, (list, tuple)) + else str(cmdline) + ) + if task_id in cmdline_str: + matches.append(proc) + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + return matches + + def terminate_locust_processes_by_task_id( + self, task_id: str, timeout: float = 5.0 + ) -> int: + """Terminate any Locust processes associated with a task id.""" + processes = self.find_locust_processes_by_task_id(task_id) + terminated_count = 0 + + for process in processes: + try: + if not process.is_running(): + continue + process.terminate() + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + + for process in processes: + try: + if not process.is_running(): + terminated_count += 1 + continue + try: + process.wait(timeout=timeout) + except psutil.TimeoutExpired: + process.kill() + process.wait(timeout=timeout) + terminated_count += 1 + except (psutil.NoSuchProcess, psutil.AccessDenied): + terminated_count += 1 + except psutil.TimeoutExpired: + logger.warning( + f"Timed out while force killing Locust process {process.pid} " + f"for task {task_id}" + ) + + return min(terminated_count, len(processes)) + def get_process_group_status(self, task_id: str) -> Optional[LocustProcessGroup]: """Get status of a process group.""" with self._lock: @@ -419,6 +474,16 @@ def cleanup_task_resources(task_id: str) -> None: _multiprocess_manager.cleanup_task(task_id) +def find_locust_processes_by_task_id(task_id: str) -> List[psutil.Process]: + """Find running Locust processes associated with a task id.""" + return _multiprocess_manager.find_locust_processes_by_task_id(task_id) + + +def terminate_locust_processes_by_task_id(task_id: str, timeout: float = 5.0) -> int: + """Terminate running Locust processes associated with a task id.""" + return _multiprocess_manager.terminate_locust_processes_by_task_id(task_id, timeout) + + def get_task_process_status(task_id: str) -> Optional[LocustProcessGroup]: """Get status of a task's process group.""" return _multiprocess_manager.get_process_group_status(task_id) diff --git a/st_engine/engine/request_processor.py b/st_engine/engine/request_processor.py index ea10223..efc9c51 100644 --- a/st_engine/engine/request_processor.py +++ b/st_engine/engine/request_processor.py @@ -21,7 +21,7 @@ StreamMetrics, ) from utils.common import encode_image, safe_int_convert -from utils.error_handler import ErrorResponse +from utils.error_handler import ErrorResponse, _safe_repr_truncate from utils.event_handler import EventManager global_state = GlobalStateManager() @@ -29,6 +29,7 @@ # Maximum accumulated content size per streaming response (10 MB) # Prevents OOM from malicious or buggy servers sending unbounded data MAX_STREAM_CONTENT_SIZE = 10 * 1024 * 1024 # 10 MB +STOP_REASON_CHECK_API_TYPES = {"openai-chat", "claude-chat", "custom-chat"} # === LAZY IMAGE ENCODING === @@ -182,6 +183,29 @@ def check_end_field_stop( end_value = str(end_value) if end_value else "" return end_value == stop_flag + @staticmethod + def iter_stop_reasons(data: Any): + """Yield every stop_reason value from a nested response object.""" + if isinstance(data, dict): + for key, value in data.items(): + if key == "stop_reason": + yield value + yield from StreamProcessor.iter_stop_reasons(value) + elif isinstance(data, list): + for item in data: + yield from StreamProcessor.iter_stop_reasons(item) + + @staticmethod + def check_stop_reason_error(data: Any, api_type: str) -> Optional[str]: + """Return an error message when stop_reason explicitly reports an error.""" + if api_type not in STOP_REASON_CHECK_API_TYPES: + return None + + for stop_reason in StreamProcessor.iter_stop_reasons(data): + if isinstance(stop_reason, str) and stop_reason.strip().lower() == "error": + return f"Response stop_reason is error: {data}" + return None + @staticmethod def extract_metrics_from_chunk( chunk_data: Dict[str, Any], @@ -309,6 +333,7 @@ def process_stream_chunk( start_time: float, metrics: StreamMetrics, task_logger, + api_type: str = "", ) -> Tuple[bool, Optional[str], StreamMetrics]: """ Process a single stream chunk according to the specified logic. @@ -365,6 +390,12 @@ def process_stream_chunk( if error_msg: return True, error_msg, metrics # Error occurred + stop_reason_error = StreamProcessor.check_stop_reason_error( + chunk_data, api_type + ) + if stop_reason_error: + return True, stop_reason_error, metrics + # Extract and update metrics BEFORE checking end_field, # because the final chunk may carry both the end signal # AND usage/content data that must not be lost. @@ -473,6 +504,17 @@ def prepare_request_kwargs( # Unified payload handling for all APIs self._update_payload_with_prompt_data(payload, prompt_data) + # For embeddings and custom-chat, if the payload was completely replaced by raw_data, + # the original user_prompt (extracted by dataset_loader) might be empty or incorrect. + # We re-extract it from the newly replaced payload using field mapping. + api_type = getattr(self.config, "api_type", "") + if api_type in ("embeddings", "custom-chat") and prompt_data.get( + "raw_data" + ): + extracted_prompt = self._extract_prompt_from_payload(payload) + if extracted_prompt: + user_prompt = extracted_prompt + # Set request name based on API path request_name = self.config.api_path @@ -527,18 +569,23 @@ def _update_payload_with_prompt_data( image_base64 = prompt_data.get("image_base64", "") image_path = prompt_data.get("image_path", "") + # Get API type + api_type = getattr(self.config, "api_type", "") + # Lazy encode: if we have a file path but no pre-encoded base64, # encode on-demand with LRU cache to minimize memory usage. - if image_path and not image_base64: + # Skip this for embeddings and custom-chat as they use raw_data directly. + if ( + image_path + and not image_base64 + and api_type not in ("embeddings", "custom-chat") + ): image_base64 = _encode_image_cached(image_path) if not image_base64: self.task_logger.warning( f"Failed to lazy-encode image: {image_path}" ) - # Get API type - api_type = getattr(self.config, "api_type", "") - # Route to appropriate updater based on API type if api_type == "openai-chat": self._update_openai_chat_payload( @@ -548,23 +595,34 @@ def _update_payload_with_prompt_data( self._update_claude_chat_payload( payload, user_prompt, image_url, image_base64, prompt_data ) - elif api_type == "embeddings": - self._update_embeddings_payload(payload, user_prompt) - elif api_type == "custom-chat": - # For custom-chat, rely on explicit field mapping without auto defaults - field_mapping = ConfigManager.resolve_field_mapping( - self.config, - required_fields=("prompt", "image"), - fallback_to_api_defaults=False, - ) - if field_mapping.prompt or field_mapping.image: - self._update_payload_by_field_mapping( - payload, user_prompt, image_url, image_base64, field_mapping - ) + elif api_type in ("embeddings", "custom-chat"): + raw_data = prompt_data.get("raw_data") + if raw_data: + # Completely replace the payload with the JSON object from the dataset line + payload.clear() + payload.update(raw_data) else: - self.task_logger.warning( - "The field_mapping configuration is empty, the original payload will be used for the request." - ) + # Fallback if raw_data is somehow missing + if api_type == "embeddings": + self._update_embeddings_payload(payload, user_prompt) + else: + field_mapping = ConfigManager.resolve_field_mapping( + self.config, + required_fields=("prompt", "image"), + fallback_to_api_defaults=False, + ) + if field_mapping.prompt or field_mapping.image: + self._update_payload_by_field_mapping( + payload, + user_prompt, + image_url, + image_base64, + field_mapping, + ) + else: + self.task_logger.warning( + "The field_mapping configuration is empty, the original payload will be used for the request." + ) else: # Fallback: No dataset integration for unknown types self.task_logger.debug( @@ -925,9 +983,7 @@ def _log_error_with_payload( self, req_id: str, error_msg: str, payload_data: Any ) -> None: """Log an error message with truncated payload and req_id.""" - payload_str = repr(payload_data) if payload_data else "" - if len(payload_str) > 500: - payload_str = payload_str[:500] + "... (truncated)" + payload_str = _safe_repr_truncate(payload_data, 500) if payload_data else "" self.task_logger.error(f"[{req_id}] {error_msg} | Payload: {payload_str}") def _iter_stream_lines(self, response) -> Any: @@ -991,6 +1047,7 @@ def handle_stream_request( actual_start_time, metrics, self.task_logger, + getattr(self.config, "api_type", ""), ) ) @@ -1044,13 +1101,7 @@ def handle_stream_request( self.task_logger.opt(lazy=True).debug( "[{req_id}] Request Payload: {payload}", req_id=lambda: req_id, - payload=lambda: ( - lambda s: ( - s[:500] + "... (truncated)" - if len(s) > 500 - else s - ) - )(repr(payload_data)), + payload=lambda: _safe_repr_truncate(payload_data, 500), ) self.task_logger.opt(lazy=True).debug( "[{req_id}] Stream Response Content: reasoning_content={r_content}, content={content}", @@ -1271,6 +1322,22 @@ def handle_non_stream_request( ) return "", "", usage + stop_reason_error = StreamProcessor.check_stop_reason_error( + resp_json, getattr(self.config, "api_type", "") + ) + if stop_reason_error: + self.error_handler._handle_general_exception_event( + error_msg=stop_reason_error, + response=response, + response_time=total_time, + additional_context={ + "api_path": self.config.api_path, + }, + req_id=req_id, + payload_data=payload_data, + ) + return "", "", usage + EventManager.fire_metric_event( METRIC_TTT, total_time, @@ -1308,11 +1375,7 @@ def handle_non_stream_request( self.task_logger.opt(lazy=True).debug( "[{req_id}] Request Payload: {payload}", req_id=lambda: req_id, - payload=lambda: ( - lambda s: ( - s[:500] + "... (truncated)" if len(s) > 500 else s - ) - )(repr(payload_data)), + payload=lambda: _safe_repr_truncate(payload_data, 500), ) return reasoning_content, content, usage diff --git a/st_engine/service/http_task_service.py b/st_engine/service/http_task_service.py index 75f444a..0e39381 100644 --- a/st_engine/service/http_task_service.py +++ b/st_engine/service/http_task_service.py @@ -25,8 +25,10 @@ from engine.http_runner import HttpLocustRunner from engine.process_manager import ( cleanup_task_resources, + find_locust_processes_by_task_id, get_task_process_status, terminate_locust_process_group, + terminate_locust_processes_by_task_id, ) from model.http_task import HttpTask from service.http_result_service import HttpResultService @@ -250,41 +252,30 @@ def reconcile_tasks_on_startup(self, session: Session): task_logger.warning( f" Task {task.id} was {task.status} during restart. Checking for orphaned process and failing it." ) - try: - cmd = ["pgrep", "-f", f"locust .*--task-id {task.id}"] - subprocess.check_output( - cmd, stderr=subprocess.DEVNULL - ) # nosec B603 - + task_id = str(task.id) + orphaned_processes = find_locust_processes_by_task_id(task_id) + if orphaned_processes: task_logger.warning( - " Orphaned Locust process detected after engine restart. Terminating and marking task as FAILED." + " Orphaned Locust process detected after engine restart. " + "Terminating and marking task as FAILED." + ) + terminated_count = terminate_locust_processes_by_task_id( + task_id ) - try: - kill_cmd = ["pkill", "-f", f"locust .*--task-id {task.id}"] - subprocess.run(kill_cmd, check=True) # nosec B603 + if terminated_count: task_logger.info( - " Successfully terminated orphaned process." + f" Successfully terminated {terminated_count} orphaned Locust process(es)." ) - except subprocess.CalledProcessError as e: - if e.returncode > 1: - task_logger.error( - f" Failed to kill orphaned process: {e}" - ) - else: - task_logger.warning( - f" Orphaned process cleanup interrupted or already gone (exit code {e.returncode})." - ) - except Exception as kill_e: - task_logger.error( - f" Unexpected error while killing orphaned process: {kill_e}" + else: + task_logger.warning( + " Orphaned process disappeared before cleanup completed." ) error_message = "Task process was orphaned by an engine restart and has been terminated." self.update_task_status( session, task, TASK_STATUS_FAILED, error_message ) - except subprocess.CalledProcessError: - # pgrep did not find a process; mark failed with explanation + else: task_logger.warning( " Task was running during restart, but no active process found. Marking as FAILED." ) @@ -294,13 +285,6 @@ def reconcile_tasks_on_startup(self, session: Session): self.update_task_status( session, task, TASK_STATUS_FAILED, error_message ) - except FileNotFoundError as e: - # Minimal images may not include procps (pgrep/pkill). - # Keep current status to avoid false negatives during scaling. - task_logger.warning( - f" Process inspection command is missing: {e}. " - "Skipping startup reconciliation for this task and keeping current status." - ) finally: if handler_id is not None: remove_task_log_sink(handler_id) diff --git a/st_engine/service/llm_task_service.py b/st_engine/service/llm_task_service.py index c2f96bb..cf34022 100644 --- a/st_engine/service/llm_task_service.py +++ b/st_engine/service/llm_task_service.py @@ -23,8 +23,10 @@ from engine.llm_runner import LlmLocustRunner from engine.process_manager import ( cleanup_task_resources, + find_locust_processes_by_task_id, get_task_process_status, terminate_locust_process_group, + terminate_locust_processes_by_task_id, ) from model.llm_task import Task from service.llm_result_service import LlmResultService @@ -216,22 +218,15 @@ def _kill_orphaned_process(self, task: Task, task_logger): task (Task): The task whose orphaned process should be killed. task_logger: A logger instance bound to the task ID. """ - try: - kill_cmd = ["pkill", "-f", f"locust .*--task-id {task.id}"] - subprocess.run(kill_cmd, check=True) # nosec B603 - task_logger.info(f"Successfully terminated orphaned process.") - except subprocess.CalledProcessError as e: - if e.returncode > 1: - task_logger.error(f"Failed to kill orphaned process: {e}") - else: - task_logger.warning( - f"Orphaned process cleanup for task {task.id} was interrupted " - f"or the process was already gone (exit code {e.returncode}). " - "This is likely safe to ignore." - ) - except Exception as kill_e: - task_logger.error( - f"An unexpected error occurred while trying to kill orphaned process: {kill_e}" + task_id = str(task.id) + terminated_count = terminate_locust_processes_by_task_id(task_id) + if terminated_count: + task_logger.info( + f"Successfully terminated {terminated_count} orphaned Locust process(es)." + ) + else: + task_logger.warning( + f"No orphaned Locust process remained for task {task_id}." ) def _reconcile_running_task(self, session: Session, task: Task, task_logger): @@ -244,34 +239,25 @@ def _reconcile_running_task(self, session: Session, task: Task, task_logger): task (Task): The running task to reconcile. task_logger: A logger instance bound to the task ID. """ - try: - # Use pgrep to check if a locust process with a specific task-id exists. - cmd = ["pgrep", "-f", f"locust .*--task-id {task.id}"] - subprocess.check_output(cmd, stderr=subprocess.DEVNULL) # nosec B603 - - # If pgrep succeeds, the process exists and is now an orphan. + task_id = str(task.id) + orphaned_processes = find_locust_processes_by_task_id(task_id) + if orphaned_processes: task_logger.warning( - f"Something went wrong with engine service." - f"Terminating it and marking task as FAILED." + f"Task {task.id} was still running during engine restart. " + f"Found {len(orphaned_processes)} orphaned Locust process(es); " + "terminating and marking task as FAILED." ) self._kill_orphaned_process(task, task_logger) error_message = "Task process was orphaned by an engine restart and has been terminated." self.update_task_status(session, task, TASK_STATUS_FAILED, error_message) - - except subprocess.CalledProcessError: - # pgrep returns a non-zero exit code, meaning no process was found. + else: task_logger.warning( f"Task {task.id} was in '{task.status}' state, but no active process found. " f"Marking as FAILED. This likely occurred during an engine restart." ) error_message = "Task process was not found after an engine restart." self.update_task_status(session, task, TASK_STATUS_FAILED, error_message) - except FileNotFoundError as e: - task_logger.warning( - f"Process inspection command is missing: {e}. " - "Skipping startup reconciliation for this task and keeping current status." - ) def _reconcile_single_task(self, session: Session, task: Task): """ diff --git a/st_engine/tests/test_http_locustfile.py b/st_engine/tests/test_http_locustfile.py index 44ad6de..1cdf61a 100644 --- a/st_engine/tests/test_http_locustfile.py +++ b/st_engine/tests/test_http_locustfile.py @@ -17,6 +17,7 @@ _build_stat_row, _check_success_assert, _format_context, + _OutcomeStats, _parse_kv, _parse_request_body, _preload_dataset, @@ -310,6 +311,31 @@ def test_normal_stat(self): assert row["p95_latency"] == 180.0 stat.get_response_time_percentile.assert_called_with(0.95) + def test_builds_success_failure_split_rows(self): + outcome_stats = _OutcomeStats() + outcome_stats.record("GET /api/users", 10.0, 100, failed=False) + outcome_stats.record("GET /api/users", 20.0, 200, failed=False) + outcome_stats.record("GET /api/users", 30.0, 300, failed=False) + outcome_stats.record("GET /api/users", 200.0, 40, failed=True) + + rows = outcome_stats.build_rows("task-001", "GET /api/users", 8.0) + rows_by_type = {row["metric_type"]: row for row in rows} + + success_row = rows_by_type["GET /api/users::success"] + failure_row = rows_by_type["GET /api/users::failure"] + assert success_row["num_requests"] == 3 + assert success_row["num_failures"] == 0 + assert success_row["avg_latency"] == 20.0 + assert success_row["p95_latency"] == 30.0 + assert success_row["rps"] == 6.0 + assert success_row["avg_content_length"] == 200.0 + assert failure_row["num_requests"] == 1 + assert failure_row["num_failures"] == 1 + assert failure_row["avg_latency"] == 200.0 + assert failure_row["p95_latency"] == 200.0 + assert failure_row["rps"] == 2.0 + assert failure_row["avg_content_length"] == 40.0 + def test_exception_returns_empty(self): stat = Mock( spec=[] diff --git a/st_engine/tests/test_http_task_service.py b/st_engine/tests/test_http_task_service.py index 3fd4b17..246cad5 100644 --- a/st_engine/tests/test_http_task_service.py +++ b/st_engine/tests/test_http_task_service.py @@ -4,7 +4,7 @@ - get_and_lock_task - stop_task (process not found / already finished) - pipeline: soft-delete check, status resolution, exception handling - - reconciliation on startup (owned, other-engine, locked, pgrep missing) + - reconciliation on startup (owned, other-engine, locked, missing process) """ from unittest.mock import Mock, patch @@ -380,10 +380,10 @@ def test_locked_task_marked_failed(self, task_service): or "restart" in call_args[0][3].lower() ) - def test_keeps_running_when_pgrep_missing(self, task_service): + def test_marks_running_failed_when_process_missing(self, task_service): session = Mock() task = Mock() - task.id = "task-pgrep-missing" + task.id = "task-process-missing" task.status = TASK_STATUS_RUNNING with patch("service.http_task_service.ENGINE_ID", "my-engine"): @@ -398,13 +398,16 @@ def test_keeps_running_when_pgrep_missing(self, task_service): patch("service.http_task_service.remove_task_log_sink"), patch.object(task_service, "update_task_status") as mock_update, patch( - "service.http_task_service.subprocess.check_output", - side_effect=FileNotFoundError("pgrep"), + "service.http_task_service.find_locust_processes_by_task_id", + return_value=[], ), ): task_service.reconcile_tasks_on_startup(session) - mock_update.assert_not_called() + mock_update.assert_called_once() + call_args = mock_update.call_args + assert call_args[0][2] == TASK_STATUS_FAILED + assert "not found" in call_args[0][3].lower() def test_empty_engine_id_skipped(self, task_service): session = Mock() diff --git a/st_engine/tests/test_llm_task_service.py b/st_engine/tests/test_llm_task_service.py index 906d4e7..d5b83a4 100644 --- a/st_engine/tests/test_llm_task_service.py +++ b/st_engine/tests/test_llm_task_service.py @@ -148,7 +148,7 @@ def test_pipeline_runs_non_deleted_task(task_service): mock_start_task.assert_called_once_with(mock_task) -def test_reconcile_keeps_running_when_pgrep_missing(task_service): +def test_reconcile_marks_running_failed_when_process_missing(task_service): mock_session = Mock() mock_task = Mock() mock_task.id = "task-owned-engine" @@ -164,10 +164,13 @@ def test_reconcile_keeps_running_when_pgrep_missing(task_service): patch("service.llm_task_service.remove_task_log_sink"), patch.object(task_service, "update_task_status") as mock_update_status, patch( - "service.llm_task_service.subprocess.check_output", - side_effect=FileNotFoundError("pgrep"), + "service.llm_task_service.find_locust_processes_by_task_id", + return_value=[], ), ): task_service.reconcile_tasks_on_startup(mock_session) - mock_update_status.assert_not_called() + mock_update_status.assert_called_once() + call_args = mock_update_status.call_args + assert call_args[0][2] == TASK_STATUS_FAILED + assert "not found" in call_args[0][3].lower() diff --git a/st_engine/tests/test_stop_reason_validation.py b/st_engine/tests/test_stop_reason_validation.py new file mode 100644 index 0000000..da8ad92 --- /dev/null +++ b/st_engine/tests/test_stop_reason_validation.py @@ -0,0 +1,96 @@ +"""Tests for stop reason validation.""" + +import time +from unittest.mock import Mock + +from engine.core import FieldMapping, GlobalConfig, StreamMetrics +from engine.request_processor import APIClient, StreamProcessor + + +class FakeResponse: + """Fake response for testing.""" + + def __init__(self, payload): + """Initialize FakeResponse.""" + self.status_code = 200 + self._payload = payload + self.success = Mock() + self.failure = Mock() + + def __enter__(self): + """Enter context manager.""" + return self + + def __exit__(self, exc_type, exc, tb): + """Exit context manager.""" + return False + + def json(self): + """Return the JSON payload.""" + return self._payload + + +class FakeClient: + """Fake client for testing.""" + + def __init__(self, response): + """Initialize FakeClient.""" + self.response = response + + def post(self, *args, **kwargs): + """Simulate a POST request.""" + return self.response + + +def test_stream_stop_reason_error_returns_failure_message(): + """Test that stream processing returns failure when stop_reason is error.""" + field_mapping = FieldMapping(stream_prefix="data:", data_format="json") + metrics = StreamMetrics() + + should_break, error_message, _ = StreamProcessor.process_stream_chunk( + b'data: {"delta": {"stop_reason": "error"}}', + field_mapping, + time.perf_counter(), + metrics, + Mock(), + api_type="openai-chat", + ) + + assert should_break is True + assert "stop_reason is error" in error_message + + +def test_non_stream_stop_reason_error_marks_response_failure(monkeypatch): + """Test that non-stream processing marks response failure on error stop_reason.""" + monkeypatch.setattr( + "engine.request_processor.EventManager.fire_failure_event", + lambda *args, **kwargs: None, + ) + monkeypatch.setattr( + "engine.request_processor.EventManager.fire_metric_event", + lambda *args, **kwargs: None, + ) + + config = GlobalConfig() + config.api_type = "claude-chat" + config.stream_mode = False + response = FakeResponse( + { + "content": [{"type": "text", "text": "partial"}], + "stop_reason": "error", + } + ) + api_client = APIClient(config, Mock()) + + reasoning_content, content, usage = api_client.handle_non_stream_request( + FakeClient(response), + {"json": {"messages": []}, "name": "chat"}, + time.perf_counter(), + ) + + assert reasoning_content == "" + assert content == "" + assert usage["completion_tokens"] == 0 + response.failure.assert_called_once() + response.success.assert_not_called() + assert "stop_reason is error" in response.failure.call_args.args[0] diff --git a/st_engine/utils/config.py b/st_engine/utils/config.py index 2b06112..daacc09 100644 --- a/st_engine/utils/config.py +++ b/st_engine/utils/config.py @@ -84,5 +84,5 @@ TOKEN_COUNT_CACHE_SIZE = 8192 # === DATA VALIDATION === -MAX_QUEUE_SIZE = 10000 +MAX_QUEUE_SIZE = 20000 MIN_PROMPT_LENGTH = 1 diff --git a/st_engine/utils/dataset_loader.py b/st_engine/utils/dataset_loader.py index c96b66f..c1e92fd 100644 --- a/st_engine/utils/dataset_loader.py +++ b/st_engine/utils/dataset_loader.py @@ -38,6 +38,7 @@ def __init__( image_url: str = "", image_path: str = "", messages: Optional[List[Dict[str, Any]]] = None, + raw_data: Optional[Dict[str, Any]] = None, ): """Initialize PromptData with prompt information and optional image data. @@ -48,6 +49,7 @@ def __init__( image_url: URL to image (optional) image_path: Local file path for lazy encoding (optional, memory-efficient) messages: Optional list of messages for chat formats + raw_data: The full original JSON object (optional) """ self.id = prompt_id self.prompt = prompt @@ -55,6 +57,7 @@ def __init__( self.image_url = image_url self.image_path = image_path self.messages = messages or [] + self.raw_data = raw_data or {} def to_dict(self) -> Dict[str, Any]: """Convert to dictionary format.""" @@ -67,6 +70,8 @@ def to_dict(self) -> Dict[str, Any]: result["image_path"] = self.image_path if self.messages: result["messages"] = self.messages + if self.raw_data: + result["raw_data"] = self.raw_data return result @classmethod @@ -79,6 +84,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "PromptData": image_url=data.get("image_url", ""), image_path=data.get("image_path", ""), messages=data.get("messages", []), + raw_data=data.get("raw_data", {}), ) @@ -174,7 +180,9 @@ def extract_prompt_from_messages(messages: List[Dict[str, str]]) -> str: # === LINE PARSING === -def parse_data_line(line: str, line_num: int, task_logger=None) -> Optional[PromptData]: +def parse_data_line( + line: str, line_num: int, api_type: str = "", task_logger=None +) -> Optional[PromptData]: """Parse a single data line (JSONL or JSON object) into PromptData. Supports both standard JSONL format and ShareGPT format. @@ -220,8 +228,14 @@ def parse_data_line(line: str, line_num: int, task_logger=None) -> Optional[Prom prompt = extract_prompt_from_messages(messages_list) if not prompt and not messages_list: - # Skip silently without error for missing prompt and messages - return None + # For embeddings and custom-chat, the prompt might be in another field (e.g. "input"). + # We shouldn't skip it if the json object has data. + if api_type in ("embeddings", "custom-chat"): + if not json_obj: + return None + else: + # For chat APIs, skip silently without error for missing prompt and messages + return None # Handle images - unified image field processing # Support both "image" and "image_path" fields @@ -229,30 +243,38 @@ def parse_data_line(line: str, line_num: int, task_logger=None) -> Optional[Prom image_url = "" image_path = "" - # Try "image" field first, then "image_path" as fallback - image_field_value = json_obj.get("image") or json_obj.get("image_path") - - if image_field_value: - # Extract image value (string or list) - image_value = normalize_image_path(image_field_value) - - if image_value: - # Check if it's a URL - if is_url(image_value): - image_url = image_value - else: - # Store file path for lazy encoding at request time - # This avoids loading large base64 image data into the prompt queue, - # which is critical for memory efficiency in multiprocess mode. - if os.path.exists(image_value): - image_path = image_value + # Skip image parsing for embeddings and custom-chat + if api_type not in ("embeddings", "custom-chat"): + # Try "image" field first, then "image_path" as fallback + image_field_value = json_obj.get("image") or json_obj.get("image_path") + + if image_field_value: + # Extract image value (string or list) + image_value = normalize_image_path(image_field_value) + + if image_value: + # Check if it's a URL + if is_url(image_value): + image_url = image_value else: - effective_logger.warning( - f"Image file not found in dataset: {image_value}" - ) + # Store file path for lazy encoding at request time + # This avoids loading large base64 image data into the prompt queue, + # which is critical for memory efficiency in multiprocess mode. + if os.path.exists(image_value): + image_path = image_value + else: + effective_logger.warning( + f"Image file not found in dataset: {image_value}" + ) return PromptData( - prompt_id, prompt, image_base64, image_url, image_path, messages_list + prompt_id, + prompt, + image_base64, + image_url, + image_path, + messages_list, + json_obj, ) except json.JSONDecodeError as e: @@ -266,7 +288,9 @@ def parse_data_line(line: str, line_num: int, task_logger=None) -> Optional[Prom # === FILE LOADING === -def load_dataset_file(data_file: str, task_logger=None) -> List[Dict[str, Any]]: +def load_dataset_file( + data_file: str, api_type: str = "", task_logger=None +) -> List[Dict[str, Any]]: """Load all stress test data from file. Supports both JSONL format (one JSON object per line) and JSON array format (ShareGPT). @@ -313,7 +337,7 @@ def load_dataset_file(data_file: str, task_logger=None) -> List[Dict[str, Any]]: # Convert dict to JSON string for parse_data_line line = json.dumps(json_obj, ensure_ascii=False) - prompt_data = parse_data_line(line, idx, task_logger) + prompt_data = parse_data_line(line, idx, api_type, task_logger) if prompt_data: prompts.append(prompt_data.to_dict()) @@ -329,7 +353,7 @@ def load_dataset_file(data_file: str, task_logger=None) -> List[Dict[str, Any]]: if not line.strip(): continue - prompt_data = parse_data_line(line, line_num, task_logger) + prompt_data = parse_data_line(line, line_num, api_type, task_logger) if prompt_data: prompts.append(prompt_data.to_dict()) @@ -341,7 +365,9 @@ def load_dataset_file(data_file: str, task_logger=None) -> List[Dict[str, Any]]: return prompts -def load_dataset_string(content: str, task_logger=None) -> List[Dict[str, Any]]: +def load_dataset_string( + content: str, api_type: str = "", task_logger=None +) -> List[Dict[str, Any]]: """Load dataset from string content. Supports both JSONL format and JSON array format (ShareGPT). @@ -382,7 +408,7 @@ def load_dataset_string(content: str, task_logger=None) -> List[Dict[str, Any]]: # Convert dict to JSON string for parse_data_line line = json.dumps(json_obj, ensure_ascii=False) - prompt_data = parse_data_line(line, idx, task_logger) + prompt_data = parse_data_line(line, idx, api_type, task_logger) if prompt_data: prompts.append(prompt_data.to_dict()) @@ -396,7 +422,7 @@ def load_dataset_string(content: str, task_logger=None) -> List[Dict[str, Any]]: if not line.strip(): continue - prompt_data = parse_data_line(line, line_num, task_logger) + prompt_data = parse_data_line(line, line_num, api_type, task_logger) if prompt_data: prompts.append(prompt_data.to_dict()) @@ -407,7 +433,9 @@ def load_dataset_string(content: str, task_logger=None) -> List[Dict[str, Any]]: # === QUEUE INITIALIZATION === -def init_prompt_queue_from_string(content: str, task_logger=None) -> queue.Queue: +def init_prompt_queue_from_string( + content: str, api_type: str = "", task_logger=None +) -> queue.Queue: """Initializes the test data queue from JSONL or JSON array string content. Supports both JSONL format (one JSON object per line) and JSON array format (ShareGPT). @@ -429,7 +457,7 @@ def init_prompt_queue_from_string(content: str, task_logger=None) -> queue.Queue raise ValueError("Empty content provided") try: - prompts = load_dataset_string(content, task_logger) + prompts = load_dataset_string(content, api_type, task_logger) if not prompts: raise ValueError("No valid prompts were parsed from the content") @@ -451,7 +479,9 @@ def init_prompt_queue_from_string(content: str, task_logger=None) -> queue.Queue raise RuntimeError(f"Failed to initialize prompt queue from content: {e}") -def init_prompt_queue_from_file(file_path: str, task_logger=None) -> queue.Queue: +def init_prompt_queue_from_file( + file_path: str, api_type: str = "", task_logger=None +) -> queue.Queue: """Initializes the test data queue from a custom file. Supports both JSONL format and JSON array format (ShareGPT). @@ -473,7 +503,7 @@ def init_prompt_queue_from_file(file_path: str, task_logger=None) -> queue.Queue raise ValueError(f"Custom data file not found: {file_path}") try: - prompts = load_dataset_file(file_path, task_logger) + prompts = load_dataset_file(file_path, api_type, task_logger) if not prompts: raise ValueError("No prompts were loaded from the custom data file") @@ -495,6 +525,7 @@ def init_prompt_queue_from_file(file_path: str, task_logger=None) -> queue.Queue def init_prompt_queue( chat_type: int = 0, test_data: str = "", + api_type: str = "", task_logger=None, ) -> queue.Queue: """Initializes the test data queue based on the chat type and custom test data. @@ -542,25 +573,79 @@ def init_prompt_queue( if not os.path.exists(data_file): raise ValueError(f"Default data file not found: {data_file}") - return init_prompt_queue_from_file(data_file, task_logger) + return init_prompt_queue_from_file(data_file, api_type, task_logger) # Case 3: test_data is JSONL content string (starts with "{") or JSON array (starts with "[") if test_data.strip().startswith("{") or test_data.strip().startswith("["): - return init_prompt_queue_from_string(test_data, task_logger) + return init_prompt_queue_from_string(test_data, api_type, task_logger) # Case 4: test_data is a file path - handle both absolute and relative paths # Try to resolve the path using FilePathUtils for upload files try: - return init_prompt_queue_from_file(test_data, task_logger) + return init_prompt_queue_from_file(test_data, api_type, task_logger) except (ValueError, FileNotFoundError) as e: effective_logger.warning(f"Failed to resolve as upload file path: {e}") # Fallback: try as direct file path for backward compatibility if os.path.exists(test_data): - return init_prompt_queue_from_file(test_data, task_logger) + return init_prompt_queue_from_file(test_data, api_type, task_logger) # Invalid test_data provided raise ValueError( f"Invalid test_data provided: '{test_data}'. " f"Expected empty string, 'default', JSONL/JSON content string, or valid file path." ) + + +def init_shared_dataset( + chat_type: int = 0, + test_data: str = "", + api_type: str = "", + task_logger=None, +): + """Initialize dataset as a shared mmap reader for multiprocess mode. + + Returns a SharedDatasetReader instance, or None if no dataset is configured + or if creation fails (caller should use queue-based fallback). + """ + from utils.shared_dataset import SharedDatasetReader + + effective_logger = task_logger or logger + + if not test_data or test_data.strip() == "": + return None + + try: + if test_data.strip().lower() == "default": + dataset_index = DEFAULT_CHAT_TYPE + try: + dataset_index = int(chat_type) + except (TypeError, ValueError): + pass + dataset_filename = BUILTIN_DATASET_FILES.get( + dataset_index, BUILTIN_DATASET_FILES[DEFAULT_CHAT_TYPE] + ) + data_file = os.path.join(DATA_DIR, dataset_filename) + items = load_dataset_file(data_file, api_type, task_logger) + elif test_data.strip().startswith("{") or test_data.strip().startswith("["): + items = load_dataset_string(test_data, api_type, task_logger) + else: + items = load_dataset_file(test_data, api_type, task_logger) + if not items and os.path.exists(test_data): + items = load_dataset_file(test_data, api_type, task_logger) + + if not items: + return None + + if len(items) > MAX_QUEUE_SIZE: + effective_logger.warning( + f"Dataset ({len(items)} items) exceeds MAX_QUEUE_SIZE={MAX_QUEUE_SIZE}; truncating." + ) + items = items[:MAX_QUEUE_SIZE] + + return SharedDatasetReader.from_items(items, task_logger) + except Exception as e: + effective_logger.warning( + f"Failed to create shared dataset reader: {e}. Falling back to queue." + ) + return None diff --git a/st_engine/utils/error_handler.py b/st_engine/utils/error_handler.py index 6a5a750..31a0a34 100644 --- a/st_engine/utils/error_handler.py +++ b/st_engine/utils/error_handler.py @@ -10,6 +10,30 @@ from engine.core import GlobalConfig from utils.event_handler import EventManager +_PAYLOAD_LOG_LIMIT = 500 + + +def _safe_repr_truncate(obj: Any, limit: int = _PAYLOAD_LOG_LIMIT) -> str: + """Return a truncated repr without allocating the full string first. + + For large objects (e.g. dicts containing base64 images), calling repr() + then slicing still creates the entire multi-MB string in memory. Instead + we convert to str with a hard size cap using iterative key inspection for + dicts and a direct str() fallback for other types. + """ + try: + if isinstance(obj, dict): + preview = str(obj)[:limit] + elif isinstance(obj, (list, tuple)): + preview = str(obj)[:limit] + else: + preview = repr(obj)[:limit] + except Exception: + preview = "" + if len(preview) >= limit: + return preview[:limit] + "... (truncated)" + return preview + # === ERROR HANDLING === class ErrorResponse: @@ -79,9 +103,7 @@ def _handle_general_exception_event( if req_id: log_msg = f"[{req_id}] {log_msg}" if payload_data is not None: - payload_str = repr(payload_data) - if len(payload_str) > 500: - payload_str = payload_str[:500] + "... (truncated)" + payload_str = _safe_repr_truncate(payload_data, 500) log_msg += f" | Payload: {payload_str}" self.task_logger.error(log_msg) diff --git a/st_engine/utils/shared_dataset.py b/st_engine/utils/shared_dataset.py new file mode 100644 index 0000000..d3e4630 --- /dev/null +++ b/st_engine/utils/shared_dataset.py @@ -0,0 +1,159 @@ +""" +Author: Charm +Copyright (c) 2025, All Rights Reserved. + +Memory-efficient shared dataset reader using mmap. + +Serializes the dataset to a temporary file, then mmap's it. +After fork (Locust --processes N), all child processes share the same +physical memory pages (copy-on-write semantics, but data is never modified). +Each process maintains its own read index for round-robin access. +""" + +import mmap +import os +import struct +import tempfile +from typing import Any, Dict, List + +import orjson + +from utils.logger import logger + +# 4 bytes for length prefix per item +_LENGTH_PREFIX_SIZE = 4 +_LENGTH_STRUCT = struct.Struct(" "SharedDatasetReader": + """Create a SharedDatasetReader from a list of dict items. + + Serializes items using orjson (fast, compact), writes to a temp file, + then mmap's it for shared read access across forked processes. + """ + effective_logger = task_logger or logger + + if not items: + raise ValueError("Cannot create SharedDatasetReader from empty items") + + tmp_fd, tmp_path = tempfile.mkstemp(prefix="lmeterx_dataset_", suffix=".mmap") + try: + offsets: List[int] = [] + current_offset = 0 + + with os.fdopen(tmp_fd, "wb") as f: + for item in items: + blob = orjson.dumps(item) + length_bytes = _LENGTH_STRUCT.pack(len(blob)) + offsets.append(current_offset) + f.write(length_bytes) + f.write(blob) + current_offset += _LENGTH_PREFIX_SIZE + len(blob) + + fd = os.open(tmp_path, os.O_RDONLY) + try: + mm = mmap.mmap(fd, 0, access=mmap.ACCESS_READ) + finally: + os.close(fd) + + effective_logger.info( + f"SharedDatasetReader created: {len(items)} items, " + f"{current_offset / 1024:.1f} KB mmap'd" + ) + + return cls(mm, offsets, len(items)) + finally: + # Unlink immediately - the mmap file descriptor keeps the data alive. + # After fork, children inherit the mmap mapping, not the path reference. + try: + os.unlink(tmp_path) + except OSError: + pass + + def next(self) -> Dict[str, Any]: + """Get next item in round-robin order.""" + idx = self._index % self._total + self._index += 1 + return self._read_item(idx) + + def _read_item(self, idx: int) -> Dict[str, Any]: + """Deserialize item at given index from mmap.""" + offset = self._offsets[idx] + length = _LENGTH_STRUCT.unpack_from(self._mm, offset)[0] + data = self._mm[ + offset + _LENGTH_PREFIX_SIZE : offset + _LENGTH_PREFIX_SIZE + length + ] + return orjson.loads(data) + + def __len__(self) -> int: + """Return total number of items in the dataset.""" + return self._total + + @property + def empty(self) -> bool: + """Check if the dataset is empty.""" + return self._total == 0 + + def close(self) -> None: + """Release mmap resources.""" + try: + self._mm.close() + except Exception as e: + logger.debug(f"Ignored error closing mmap: {e}") + + +class DatasetQueueAdapter: + """Adapter that makes SharedDatasetReader behave like a queue.Queue. + + Provides get_nowait()/put_nowait()/empty()/qsize() interface for backward + compatibility with existing code that uses queue.Queue for datasets. + """ + + def __init__(self, reader: SharedDatasetReader): + """Initialize adapter with a SharedDatasetReader instance.""" + self._reader = reader + + def get_nowait(self) -> Dict[str, Any]: + """Get next item (round-robin). Never raises Empty.""" + return self._reader.next() + + def put_nowait(self, item) -> None: + """No-op: data is read-only in shared memory.""" + pass + + def put(self, item) -> None: + """No-op: data is read-only in shared memory.""" + pass + + def empty(self) -> bool: + """Always False if dataset has items (infinite round-robin).""" + return self._reader.empty + + def qsize(self) -> int: + """Return the size of the underlying dataset.""" + return len(self._reader) + + def close(self) -> None: + """Release underlying mmap resources.""" + self._reader.close() diff --git a/st_engine/uv.lock b/st_engine/uv.lock new file mode 100644 index 0000000..9431a63 --- /dev/null +++ b/st_engine/uv.lock @@ -0,0 +1,3 @@ +version = 1 +revision = 3 +requires-python = ">=3.11"