diff --git a/runtime/datamate-python/app/core/exception.py b/runtime/datamate-python/app/core/exception.py deleted file mode 100644 index 0c4b09600..000000000 --- a/runtime/datamate-python/app/core/exception.py +++ /dev/null @@ -1,17 +0,0 @@ -from enum import Enum - -class BusinessErrorCode: - def __init__(self, message: str, error_code: str): - self.message = message - self.error_code = error_code - - -class BusinessException(RuntimeError): - def __init__(self, business_error_code: BusinessErrorCode): - self.message = business_error_code.message - self.error_code = business_error_code.error_code - super().__init__(self.message) - - -class BusinessErrorCodeEnum(Enum): - TASK_TYPE_ERROR = BusinessErrorCode("任务类型错误", "evaluation.0001") diff --git a/runtime/datamate-python/app/core/exception/__init__.py b/runtime/datamate-python/app/core/exception/__init__.py new file mode 100644 index 000000000..439e6bcb6 --- /dev/null +++ b/runtime/datamate-python/app/core/exception/__init__.py @@ -0,0 +1,240 @@ +""" +核心异常处理模块 + +为应用程序提供统一、优雅的异常处理系统,确保所有 API 响应都符合标准化格式。 + +## 快速开始 + +### 1. 注册异常处理器(在 main.py 中) + +```python +from app.core.exception import register_exception_handlers, ExceptionHandlingMiddleware + +app = FastAPI() + +# 注册全局异常捕获中间件(最外层) +app.add_middleware(ExceptionHandlingMiddleware) + +# 注册异常处理器 +register_exception_handlers(app) +``` + +### 2. 在代码中抛出业务异常 + +```python +from app.core.exception import ErrorCodes, BusinessError + +# 资源不存在 +async def get_user(user_id: str): + user = await db.get_user(user_id) + if not user: + raise BusinessError(ErrorCodes.NOT_FOUND) + return user + +# 参数验证失败 +async def create_user(name: str): + if not name: + raise BusinessError(ErrorCodes.BAD_REQUEST) + # ... +``` + +### 3. 返回成功响应 + +```python +from app.core.exception import SuccessResponse + +@router.get("/users/{user_id}") +async def get_user(user_id: str): + user = await service.get_user(user_id) + return SuccessResponse(data=user) + # 返回格式: {"code": "0", "message": "success", "data": {...}} +``` + +### 4. 使用事务管理 + +```python +from app.core.exception import transaction, BusinessError, ErrorCodes + +@router.post("/users") +async def create_user(request: CreateUserRequest, db: AsyncSession = Depends(get_db)): + async with transaction(db): + # 所有数据库操作 + user = User(name=request.name) + db.add(user) + await db.flush() + + # 如果抛出异常,事务会自动回滚 + if check_duplicate(user.name): + raise BusinessError(ErrorCodes.OPERATION_FAILED) + + # 事务已自动提交 + return SuccessResponse(data=user) +``` + +## 异常类型说明 + +### BusinessError(业务异常) +用于预期的业务错误,如资源不存在、权限不足、参数错误等。 + +特点: +- 不会记录完整的堆栈跟踪(因为这是预期内的错误) +- 返回对应的 HTTP 状态码(400、404 等) +- 客户端会收到标准化的错误响应 + +使用场景: +- 资源不存在 +- 参数验证失败 +- 权限不足 +- 业务规则违反 + +### SystemError(系统异常) +用于意外的系统错误,如数据库错误、网络错误、配置错误等。 + +特点: +- 记录完整的堆栈跟踪 +- 返回 HTTP 500 +- 不暴露敏感的系统信息给客户端 + +使用场景: +- 数据库连接失败 +- 网络超时 +- 配置错误 +- 编程错误 + +## 错误码定义 + +所有错误码在 `ErrorCodes` 类中集中定义,遵循规范:`{module}.{sequence}` + +### 通用错误码 +- `SUCCESS` (0): 操作成功 +- `BAD_REQUEST` (common.0001): 请求参数错误 +- `NOT_FOUND` (common.0002): 资源不存在 +- `FORBIDDEN` (common.0003): 权限不足 +- `UNAUTHORIZED` (common.0004): 未授权访问 +- `VALIDATION_ERROR` (common.0005): 数据验证失败 + +### 系统级错误码 +- `INTERNAL_ERROR` (system.0001): 服务器内部错误 +- `DATABASE_ERROR` (system.0002): 数据库错误 +- `NETWORK_ERROR` (system.0003): 网络错误 + +### 模块错误码 +- `annotation.*`: 标注模块相关错误 +- `collection.*`: 归集模块相关错误 +- `evaluation.*`: 评估模块相关错误 +- `generation.*`: 生成模块相关错误 +- `rag.*`: RAG 模块相关错误 +- `ratio.*`: 配比模块相关错误 + +## Result 类型(可选的函数式错误处理) + +如果你不喜欢使用异常,可以使用 Result 类型进行函数式错误处理: + +```python +from app.core.exception import Result, Ok, Err, ErrorCodes + +def get_user(user_id: str) -> Result[User]: + user = db.find_user(user_id) + if user: + return Ok(user) + return Err(ErrorCodes.NOT_FOUND) + +# 使用结果 +result = get_user("123") +if result.is_ok(): + user = result.unwrap() + print(f"User: {user.name}") +else: + error = result.unwrap_err() + print(f"Error: {error.message}") + +# 链式操作 +result = get_user("123") + .map(lambda user: user.name) + .and_then(validate_name) +``` + +## 响应格式 + +### 成功响应 +```json +{ + "code": "0", + "message": "success", + "data": { + "id": 123, + "name": "张三" + } +} +``` + +### 错误响应 +```json +{ + "code": "common.0002", + "message": "资源不存在", + "data": null +} +``` + +## 最佳实践 + +1. **始终使用业务异常**:对于可预见的业务错误,使用 `BusinessError` 而不是 HTTPException +2. **集中定义错误码**:所有错误码在 `ErrorCodes` 中定义,不要硬编码 +3. **提供有用的数据**:在抛出异常时,可以通过 `data` 参数传递额外的错误信息 +4. **使用事务管理**:涉及多个数据库操作时,使用 `transaction` 上下文管理器 +5. **不要捕获 SystemError**:让系统错误由全局处理器统一处理 + +## 迁移指南 + +如果你有旧的异常处理代码,迁移步骤: + +1. 删除旧的异常类定义 +2. 将 `raise OldException(...)` 替换为 `raise BusinessError(ErrorCodes.XXX)` +3. 移除 try-except 中的异常转换逻辑,让全局处理器处理 +4. 更新导入语句:`from app.core.exception import ErrorCodes, BusinessError` + +## 测试 + +可以使用测试端点验证异常处理: + +```bash +curl http://localhost:8000/test-success +curl http://localhost:8000/test-business-error +curl http://localhost:8000/test-system-error +``` +""" + +from .base import BaseError, ErrorCode, SystemError, BusinessError +from .codes import ErrorCodes +from .handlers import ( + register_exception_handlers, + ErrorResponse, + SuccessResponse +) +from .middleware import ExceptionHandlingMiddleware +from .result import Result, Ok, Err +from .transaction import transaction, ensure_transaction_rollback + +__all__ = [ + # 基础异常类 + 'BaseError', + 'ErrorCode', + 'SystemError', + 'BusinessError', + # 错误码 + 'ErrorCodes', + # 处理器 + 'register_exception_handlers', + 'ErrorResponse', + 'SuccessResponse', + # 中间件 + 'ExceptionHandlingMiddleware', + # Result 类型 + 'Result', + 'Ok', + 'Err', + # 事务管理 + 'transaction', + 'ensure_transaction_rollback', +] diff --git a/runtime/datamate-python/app/core/exception/base.py b/runtime/datamate-python/app/core/exception/base.py new file mode 100644 index 000000000..b1a415add --- /dev/null +++ b/runtime/datamate-python/app/core/exception/base.py @@ -0,0 +1,121 @@ +""" +基础异常类和错误码定义 +""" +from typing import Any +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ErrorCode: + """ + 不可变的错误码定义 + + 属性: + code: 错误码字符串(如 "annotation.0001") + message: 人类可读的错误消息 + http_status: HTTP状态码(业务错误默认为400) + """ + code: str + message: str + http_status: int = 400 + + def __post_init__(self): + """验证错误码格式""" + if not isinstance(self.code, str): + raise ValueError(f"错误码必须是字符串,实际类型: {type(self.code)}") + if not self.code: + raise ValueError("错误码不能为空") + + +class BaseError(Exception): + """ + 所有应用异常的基类 + + 所有自定义异常都应该继承此类。 + 提供自动的错误码和消息处理。 + + 使用示例: + raise BusinessError(ErrorCodes.TASK_NOT_FOUND) + + 客户端将收到: + { + "code": "annotation.0001", + "message": "任务不存在", + "data": null + } + """ + + def __init__( + self, + error_code: ErrorCode, + data: Any = None, + *args: Any + ): + """ + 使用错误码和可选数据初始化异常 + + Args: + error_code: ErrorCode 实例 + data: 附加错误数据(会包含在响应中) + *args: 额外参数(用于兼容性) + """ + self.error_code = error_code + self.data = data + super().__init__(error_code.message, *args) + + @property + def code(self) -> str: + """获取错误码字符串""" + return self.error_code.code + + @property + def message(self) -> str: + """获取错误消息""" + return self.error_code.message + + @property + def http_status(self) -> int: + """获取HTTP状态码""" + return self.error_code.http_status + + def to_dict(self) -> dict: + """将异常转换为响应字典""" + return { + "code": self.code, + "message": self.message, + "data": self.data + } + + +class SystemError(BaseError): + """ + 系统级异常(意外错误) + + 用于: + - 数据库错误 + - 网络错误 + - 配置错误 + - 编程错误(bug) + + 系统错误会记录完整的堆栈跟踪 + """ + + def __init__(self, error_code: ErrorCode, data: Any = None): + super().__init__(error_code, data) + + +class BusinessError(BaseError): + """ + 业务逻辑异常(预期内的错误) + + 用于: + - 验证失败 + - 资源不存在 + - 权限不足 + - 业务规则违反 + + 业务错误不记录堆栈跟踪(因为它们是预期的) + """ + + def __init__(self, error_code: ErrorCode, data: Any = None): + super().__init__(error_code, data) diff --git a/runtime/datamate-python/app/core/exception/codes.py b/runtime/datamate-python/app/core/exception/codes.py new file mode 100644 index 000000000..d741174b5 --- /dev/null +++ b/runtime/datamate-python/app/core/exception/codes.py @@ -0,0 +1,91 @@ +""" +集中式错误码定义 + +所有错误码都在这里定义,遵循规范:{module}.{sequence} + +模块代码: +- common: 通用错误 +- system: 系统级错误 +- annotation: 标注模块 +- collection: 归集模块 +- evaluation: 评估模块 +- generation: 生成模块 +- rag: RAG模块 +- ratio: 配比模块 +""" +from typing import Final + +from .base import ErrorCode + + +class ErrorCodes: + def __init__(self): + self.message = None + self.code = None + + """ + 集中式错误码定义 + + 所有错误码在此一次性定义,使用时直接通过类属性访问。 + + 使用示例: + from app.core.exception import ErrorCodes, BusinessError + + raise BusinessError(ErrorCodes.TASK_NOT_FOUND) + """ + + # ========== 通用错误码 ========== + SUCCESS: Final = ErrorCode("0", "Success", 200) + BAD_REQUEST: Final = ErrorCode("common.0001", "Bad request", 400) + NOT_FOUND: Final = ErrorCode("common.0002", "Resource not found", 404) + FORBIDDEN: Final = ErrorCode("common.0003", "Forbidden", 403) + UNAUTHORIZED: Final = ErrorCode("common.0004", "Unauthorized", 401) + VALIDATION_ERROR: Final = ErrorCode("common.0005", "Validation error", 422) + OPERATION_FAILED: Final = ErrorCode("common.0006", "Operation failed", 500) + + # ========== 系统级错误码 ========== + INTERNAL_ERROR: Final = ErrorCode("system.0001", "Internal server error", 500) + DATABASE_ERROR: Final = ErrorCode("system.0002", "Database error", 500) + NETWORK_ERROR: Final = ErrorCode("system.0003", "Network error", 500) + CONFIG_ERROR: Final = ErrorCode("system.0004", "Configuration error", 500) + SERVICE_UNAVAILABLE: Final = ErrorCode("system.0005", "Service unavailable", 503) + + # ========== 标注模块 ========== + ANNOTATION_TASK_NOT_FOUND: Final = ErrorCode("annotation.0001", "Annotation task not found", 404) + ANNOTATION_PROJECT_NOT_FOUND: Final = ErrorCode("annotation.0002", "Annotation project not found", 404) + ANNOTATION_TEMPLATE_NOT_FOUND: Final = ErrorCode("annotation.0003", "Annotation template not found", 404) + ANNOTATION_FILE_NOT_FOUND: Final = ErrorCode("annotation.0004", "File not found", 404) + ANNOTATION_TAG_UPDATE_FAILED: Final = ErrorCode("annotation.0005", "Failed to update tags", 500) + + # ========== 归集模块 ========== + COLLECTION_TASK_NOT_FOUND: Final = ErrorCode("collection.0001", "Collection task not found", 404) + COLLECTION_TEMPLATE_NOT_FOUND: Final = ErrorCode("collection.0002", "Collection template not found", 404) + COLLECTION_EXECUTION_NOT_FOUND: Final = ErrorCode("collection.0003", "Execution record not found", 404) + COLLECTION_LOG_NOT_FOUND: Final = ErrorCode("collection.0004", "Log file not found", 404) + + # ========== 评估模块 ========== + EVALUATION_TASK_NOT_FOUND: Final = ErrorCode("evaluation.0001", "Evaluation task not found", 404) + EVALUATION_TASK_TYPE_ERROR: Final = ErrorCode("evaluation.0002", "Invalid task type", 400) + EVALUATION_MODEL_NOT_FOUND: Final = ErrorCode("evaluation.0003", "Evaluation model not found", 404) + + # ========== 生成模块 ========== + GENERATION_TASK_NOT_FOUND: Final = ErrorCode("generation.0001", "Generation task not found", 404) + GENERATION_FILE_NOT_FOUND: Final = ErrorCode("generation.0002", "Generation file not found", 404) + GENERATION_CHUNK_NOT_FOUND: Final = ErrorCode("generation.0003", "Data chunk not found", 404) + GENERATION_DATA_NOT_FOUND: Final = ErrorCode("generation.0004", "Generation data not found", 404) + + # ========== RAG 模块 ========== + RAG_CONFIG_ERROR: Final = ErrorCode("rag.0001", "RAG configuration error", 400) + RAG_KNOWLEDGE_BASE_NOT_FOUND: Final = ErrorCode("rag.0002", "Knowledge base not found", 404) + RAG_MODEL_NOT_FOUND: Final = ErrorCode("rag.0003", "RAG model not found", 404) + RAG_QUERY_FAILED: Final = ErrorCode("rag.0004", "RAG query failed", 500) + + # ========== 配比模块 ========== + RATIO_TASK_NOT_FOUND: Final = ErrorCode("ratio.0001", "Ratio task not found", 404) + RATIO_NAME_REQUIRED: Final = ErrorCode("ratio.0002", "Task name is required", 400) + RATIO_ALREADY_EXISTS: Final = ErrorCode("ratio.0003", "Task already exists", 400) + RATIO_DELETE_FAILED: Final = ErrorCode("ratio.0004", "Failed to delete task", 500) + + # ========== 系统模块 ========== + SYSTEM_MODEL_NOT_FOUND: Final = ErrorCode("system.0006", "Model configuration not found", 404) + SYSTEM_MODEL_HEALTH_CHECK_FAILED: Final = ErrorCode("system.0007", "Model health check failed", 500) diff --git a/runtime/datamate-python/app/core/exception/handlers.py b/runtime/datamate-python/app/core/exception/handlers.py new file mode 100644 index 000000000..0ee4d05fe --- /dev/null +++ b/runtime/datamate-python/app/core/exception/handlers.py @@ -0,0 +1,203 @@ +""" +异常处理器和响应构建器 + +提供异常的自动转换,将异常转换为标准化的 JSON 响应。 +""" +from typing import Any, Optional + +from fastapi import Request, Response +from fastapi.responses import JSONResponse +from fastapi.exceptions import RequestValidationError +from starlette.exceptions import HTTPException as StarletteHTTPException + +from .base import BaseError, BusinessError, SystemError +from .codes import ErrorCodes +from ..logging import get_logger + +logger = get_logger(__name__) + + +def SuccessResponse(data: Any = None, message: str = "success", code: str = "0") -> dict: + """ + 构建成功响应 + + Args: + data: 响应数据 + message: 成功消息 + code: 自定义成功码(默认为 "0") + + Returns: + 响应字典,符合 StandardResponse 格式 + + 使用示例: + return SuccessResponse(data={"id": 123}) + # 或 + return SuccessResponse({"id": 123}) + """ + return { + "code": code, + "message": message, + "data": data + } + + +def ErrorResponse( + error_code: ErrorCodes, + data: Any = None, + custom_message: Optional[str] = None +) -> dict: + """ + 构建错误响应 + + Args: + error_code: 错误码(来自 ErrorCodes) + data: 附加错误数据 + custom_message: 覆盖默认错误消息 + + Returns: + 响应字典,符合 StandardResponse 格式 + + 使用示例: + return ErrorResponse(ErrorCodes.TASK_NOT_FOUND) + """ + return { + "code": error_code.code, + "message": custom_message or error_code.message, + "data": data + } + + +async def business_error_handler(request: Request, exc: BusinessError) -> JSONResponse: + """ + 处理业务逻辑异常 + + 业务异常是预期的错误,不需要记录堆栈跟踪。 + 返回对应的 HTTP 状态码(400、404 等),错误信息在响应体的 code 字段中。 + """ + return JSONResponse( + status_code=exc.http_status, + content=exc.to_dict() + ) + + +async def system_error_handler(request: Request, exc: SystemError) -> JSONResponse: + """ + 处理系统异常 + + 系统异常是意外的错误,需要记录完整的堆栈跟踪。 + 返回 HTTP 500 和经过净化的错误消息。 + """ + logger.error( + f"System error occurred at {request.method} {request.url.path}: {exc.message}", + exc_info=True + ) + + return JSONResponse( + status_code=exc.http_status, + content={ + "code": exc.code, + "message": exc.message, + "data": None # 绝不暴露系统错误详情 + } + ) + + +async def validation_error_handler( + request: Request, + exc: RequestValidationError +) -> JSONResponse: + """处理请求验证错误""" + errors = exc.errors() + simplified_errors = [ + err.get("msg", "Validation error") + for err in errors + ] + + return JSONResponse( + status_code=422, + content={ + "code": ErrorCodes.VALIDATION_ERROR.code, + "message": ErrorCodes.VALIDATION_ERROR.message, + "data": { + "detail": "请求验证失败", + "errors": simplified_errors + } + } + ) + + +async def http_exception_handler( + request: Request, + exc: StarletteHTTPException +) -> JSONResponse: + """处理 HTTP 异常(404、500 等)""" + # 根据 HTTP 状态码映射到错误码 + error_code = ErrorCodes.INTERNAL_ERROR + if exc.status_code == 404: + error_code = ErrorCodes.NOT_FOUND + elif exc.status_code == 401: + error_code = ErrorCodes.UNAUTHORIZED + elif exc.status_code == 403: + error_code = ErrorCodes.FORBIDDEN + elif exc.status_code == 422: + error_code = ErrorCodes.VALIDATION_ERROR + elif exc.status_code == 400: + error_code = ErrorCodes.BAD_REQUEST + + return JSONResponse( + status_code=exc.status_code, + content={ + "code": error_code.code, + "message": "error", + "data": { + "detail": str(exc.detail) + } + } + ) + + +async def generic_exception_handler(request: Request, exc: Exception) -> JSONResponse: + """ + 通用异常处理器(最后的兜底处理) + + 这是最后一道防线 - 记录完整的堆栈跟踪 + 并返回经过净化的错误给客户端。 + """ + logger.error( + f"Unhandled exception occurred at {request.method} {request.url.path}: {str(exc)}", + exc_info=True + ) + + return JSONResponse( + status_code=500, + content={ + "code": ErrorCodes.INTERNAL_ERROR.code, + "message": ErrorCodes.INTERNAL_ERROR.message, + "data": { + "detail": "服务器内部错误" + } + } + ) + + +def register_exception_handlers(app) -> None: + """ + 注册所有异常处理器到 FastAPI 应用 + + Args: + app: FastAPI 应用实例 + + 使用示例: + from app.core.exception import register_exception_handlers + + app = FastAPI() + register_exception_handlers(app) + """ + from .base import BusinessError, SystemError + + # 按照特异性顺序注册异常处理器(最具体的在前) + app.add_exception_handler(BusinessError, business_error_handler) + app.add_exception_handler(SystemError, system_error_handler) + app.add_exception_handler(RequestValidationError, validation_error_handler) + app.add_exception_handler(StarletteHTTPException, http_exception_handler) + app.add_exception_handler(Exception, generic_exception_handler) diff --git a/runtime/datamate-python/app/core/exception/middleware.py b/runtime/datamate-python/app/core/exception/middleware.py new file mode 100644 index 000000000..82b03ca24 --- /dev/null +++ b/runtime/datamate-python/app/core/exception/middleware.py @@ -0,0 +1,96 @@ +""" +全局异常处理中间件,支持自动响应包装 + +此中间件提供自动地返回值和异常转换,确保所有响应都符合标准化的 JSON 格式。 +关键特性: +- 自动包装 dict/list 响应为 StandardResponse 格式 +- 异常处理,并记录适当的日志 +- 堆栈跟踪隔离(永远不会暴露给客户端) +""" + +from fastapi import Request +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import Response as StarletteResponse + +from .base import BaseError, BusinessError, SystemError +from .codes import ErrorCodes +from ..logging import get_logger + +logger = get_logger(__name__) + + +class ExceptionHandlingMiddleware(BaseHTTPMiddleware): + """ + 全局异常捕获中间件 + + 这是最外层的中间件,捕获所有未处理的异常, + 并将它们转换为标准化的响应格式。 + + 堆栈跟踪信息只记录到日志文件,永远不会发送给客户端。 + """ + + async def dispatch(self, request: Request, call_next): + """ + 处理请求并处理任何异常 + + 异常会被捕获、记录(带堆栈跟踪),并转换为标准化响应。 + 业务异常会被重新抛出,由专门的处理器处理。 + """ + try: + response = await call_next(request) + return response + + except BusinessError: + # 让业务异常处理器处理 + raise + + except SystemError as exc: + # 记录系统错误及其完整堆栈跟踪 + logger.error( + f"System error occurred at {request.method} {request.url.path}", + exc_info=True + ) + return self._error_response( + code=exc.code, + message=exc.message, + http_status=exc.http_status + ) + + except BaseError as exc: + # 处理其他自定义错误 + logger.warning(f"BaseError occurred at {request.url.path}: {exc.message}") + return self._error_response( + code=exc.code, + message=exc.message, + http_status=exc.http_status + ) + + except Exception as exc: + # 捕获所有未处理的异常 + logger.error( + f"Unhandled exception occurred at {request.method} {request.url.path}", + exc_info=True + ) + return self._error_response( + code=ErrorCodes.INTERNAL_ERROR.code, + message=ErrorCodes.INTERNAL_ERROR.message, + http_status=500 + ) + + @staticmethod + def _error_response( + code: str, + message: str, + http_status: int + ) -> StarletteResponse: + """构建错误响应""" + + return JSONResponse( + status_code=http_status, + content={ + "code": code, + "message": message, + "data": None + } + ) diff --git a/runtime/datamate-python/app/core/exception/result.py b/runtime/datamate-python/app/core/exception/result.py new file mode 100644 index 000000000..d9dffc91f --- /dev/null +++ b/runtime/datamate-python/app/core/exception/result.py @@ -0,0 +1,170 @@ +""" +Result 类型用于优雅的错误处理 + +提供受 Rust 启发的 Result 类型,用于处理可能失败的操作而无需使用异常。 + +使用示例: + # 成功情况 + def get_user(user_id: str) -> Result[User]: + user = db.find_user(user_id) + if user: + return Ok(user) + return Err(ErrorCodes.USER_NOT_FOUND) + + # 使用结果 + result = get_user("123") + if result.is_ok(): + user = result.unwrap() + print(f"User: {user.name}") + else: + error = result.unwrap_err() + print(f"Error: {error.message}") +""" +from typing import Generic, TypeVar, Optional, Any + +from .base import ErrorCode +from .codes import ErrorCodes + +T = TypeVar('T') # 成功类型 +E = TypeVar('E', bound=ErrorCode) # 错误类型 + + +class Result(Generic[T, E]): + """ + 表示成功(Ok)或失败(Err)的 Result 类型 + + 此类型允许在不需要异常的情况下进行显式错误处理 + """ + + def __init__(self, value: Optional[T], error: Optional[E], is_ok: bool): + self._value = value + self._error = error + self._is_ok = is_ok + + @staticmethod + def ok(value: T) -> 'Result[T, E]': + """创建一个包含值的成功结果""" + return Result(value, None, True) + + @staticmethod + def err(error: E) -> 'Result[T, E]': + """创建一个包含错误码的失败结果""" + return Result(None, error, False) + + @property + def is_ok(self) -> bool: + """检查结果是否成功""" + return self._is_ok + + @property + def is_err(self) -> bool: + """检查结果是否失败""" + return not self._is_ok + + def unwrap(self) -> T: + """ + 获取成功值 + + Returns: + 成功值 + + Raises: + ValueError: 如果结果是错误 + """ + if self._is_ok: + return self._value + raise ValueError( + f"Cannot unwrap error result: {self._error.message}" + ) + + def unwrap_err(self) -> E: + """ + 获取错误码 + + Returns: + 错误码 + + Raises: + ValueError: 如果结果是成功的 + """ + if not self._is_ok: + return self._error + raise ValueError("Cannot unwrap error from successful result") + + def unwrap_or(self, default: T) -> T: + """ + 获取成功值,如果出错则返回默认值 + + Args: + default: 出错时返回的默认值 + + Returns: + 成功值或默认值 + """ + return self._value if self._is_ok else default + + def map(self, func) -> 'Result[Any, E]': + """ + 如果存在成功值,则应用函数 + + Args: + func: 要应用的函数 + + Returns: + 包含映射值的新结果或相同的错误 + """ + if self._is_ok: + try: + return Result.ok(func(self._value)) + except Exception: + # 如果映射失败,转换为错误 + return Result.err(ErrorCodes.INTERNAL_ERROR) + return self + + def and_then(self, func) -> 'Result[Any, E]': + """ + 链式调用返回 Result 的操作 + + Args: + func: 接收成功值并返回新 Result 的函数 + + Returns: + 来自函数的新结果或相同的错误 + """ + if self._is_ok: + return func(self._value) + return self + + def or_else(self, func) -> 'Result[T, Any]': + """ + 在出错时提供备用结果 + + Args: + func: 接收错误并返回新 Result 的函数 + + Returns: + 相同的结果或来自函数的新结果 + """ + if not self._is_ok: + return func(self._error) + return self + + +def Ok(value: T) -> Result[T, ErrorCode]: + """ + 创建一个成功的结果 + + 使用示例: + return Ok(user_data) + """ + return Result.ok(value) + + +def Err(error_code: ErrorCode) -> Result[Any, ErrorCode]: + """ + 创建一个失败的结果 + + 使用示例: + return Err(ErrorCodes.USER_NOT_FOUND) + """ + return Result.err(error_code) diff --git a/runtime/datamate-python/app/core/exception/transaction.py b/runtime/datamate-python/app/core/exception/transaction.py new file mode 100644 index 000000000..97a934adf --- /dev/null +++ b/runtime/datamate-python/app/core/exception/transaction.py @@ -0,0 +1,76 @@ +""" +数据库事务管理工具 + +提供自动事务管理的上下文管理器和依赖注入函数, +确保在异常情况下正确回滚事务。 +""" +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.logging import get_logger + +logger = get_logger(__name__) + + +@asynccontextmanager +async def transaction(db: AsyncSession) -> AsyncGenerator[None, None]: + """ + 数据库事务上下文管理器 + + 自动处理事务的提交和回滚: + - 如果代码块正常执行完成,则自动提交 + - 如果代码块抛出异常,则自动回滚 + + 使用示例: + @router.post("/tasks") + async def create_task( + request: CreateTaskRequest, + db: AsyncSession = Depends(get_db) + ): + async with transaction(db): + # 数据库操作 + db.add(task) + await db.flush() + + # 事务已自动提交 + return SuccessResponse(data=task) + + Args: + db: SQLAlchemy 异步会话 + + Yields: + None + + Raises: + Exception: 重新抛出代码块中的任何异常 + """ + try: + yield + # 如果没有异常,提交事务 + await db.commit() + logger.debug("Transaction committed successfully") + except Exception as e: + # 发生异常,回滚事务 + try: + await db.rollback() + logger.warning(f"Transaction rolled back due to exception: {e}") + except Exception as rollback_error: + logger.error(f"Failed to rollback transaction: {rollback_error}") + # 重新抛出异常,让上层异常处理器处理 + raise + + +async def ensure_transaction_rollback(db: AsyncSession) -> None: + """ + 确保事务回滚(用于错误处理) + + Args: + db: SQLAlchemy 异步会话 + """ + try: + await db.rollback() + logger.debug("Transaction rolled back") + except Exception as e: + logger.error(f"Failed to rollback transaction: {e}") diff --git a/runtime/datamate-python/app/exception.py b/runtime/datamate-python/app/exception.py deleted file mode 100644 index f776d8036..000000000 --- a/runtime/datamate-python/app/exception.py +++ /dev/null @@ -1,101 +0,0 @@ -""" -全局自定义异常类定义 -""" -from fastapi.responses import JSONResponse -from fastapi.exceptions import RequestValidationError -from starlette.exceptions import HTTPException as StarletteHTTPException -from fastapi import FastAPI, Request, HTTPException, status - -from .core.logging import setup_logging, get_logger - -logger = get_logger(__name__) - -# 自定义异常处理器:StarletteHTTPException (包括404等) -async def starlette_http_exception_handler(request: Request, exc: StarletteHTTPException): - """将Starlette的HTTPException转换为标准响应格式""" - return JSONResponse( - status_code=exc.status_code, - content={ - "code": exc.status_code, - "message": "error", - "data": { - "detail": exc.detail - } - } - ) - -# 自定义异常处理器:FastAPI HTTPException -async def fastapi_http_exception_handler(request: Request, exc: HTTPException): - """将FastAPI的HTTPException转换为标准响应格式""" - return JSONResponse( - status_code=exc.status_code, - content={ - "code": exc.status_code, - "message": "error", - "data": { - "detail": exc.detail - } - } - ) - -# 自定义异常处理器:RequestValidationError -async def validation_exception_handler(request: Request, exc: RequestValidationError): - """将请求验证错误转换为标准响应格式""" - # 仅返回每个错误的简要 detail 文本(来自 Pydantic 错误的 `msg` 字段),不返回整个错误对象 - raw_errors = exc.errors() or [] - errors = [err.get("msg", "Validation error") for err in raw_errors] - - return JSONResponse( - status_code=422, - content={ - "code": 422, - "message": "error", - "data": { - "detail": "Validation error", - "errors": errors, - }, - }, - ) - -# 自定义异常处理器:未捕获的异常 -async def general_exception_handler(request: Request, exc: Exception): - """将未捕获的异常转换为标准响应格式""" - logger.error(f"Unhandled exception: {exc}", exc_info=True) - return JSONResponse( - status_code=500, - content={ - "code": 500, - "message": "error", - "data": { - "detail": "Internal server error" - } - } - ) - -class LabelStudioAdapterException(Exception): - """Label Studio Adapter 基础异常类""" - pass - -class DatasetMappingNotFoundError(LabelStudioAdapterException): - """数据集映射未找到异常""" - def __init__(self, mapping_id: str): - self.mapping_id = mapping_id - super().__init__(f"Dataset mapping not found: {mapping_id}") - -class NoDatasetInfoFoundError(LabelStudioAdapterException): - """无法获取数据集信息异常""" - def __init__(self, dataset_uuid: str): - self.dataset_uuid = dataset_uuid - super().__init__(f"Failed to get dataset info: {dataset_uuid}") - -class LabelStudioClientError(LabelStudioAdapterException): - """Label Studio 客户端错误""" - pass - -class DMServiceClientError(LabelStudioAdapterException): - """DM 服务客户端错误""" - pass - -class SyncServiceError(LabelStudioAdapterException): - """同步服务错误""" - pass diff --git a/runtime/datamate-python/app/main.py b/runtime/datamate-python/app/main.py index 5b08df569..00dcbd0d8 100644 --- a/runtime/datamate-python/app/main.py +++ b/runtime/datamate-python/app/main.py @@ -1,27 +1,25 @@ from contextlib import asynccontextmanager -from typing import Dict, Any, Literal +from typing import Literal from urllib.parse import urlparse, urlunparse -from fastapi import FastAPI, HTTPException -from fastapi.exceptions import RequestValidationError +from fastapi import FastAPI from fastapi_mcp import FastApiMCP from sqlalchemy import text -from starlette.exceptions import HTTPException as StarletteHTTPException -from app.middleware import UserContextMiddleware -from .core.config import settings -from .core.logging import setup_logging, get_logger -from .db.session import AsyncSessionLocal -from .exception import ( - starlette_http_exception_handler, - fastapi_http_exception_handler, - validation_exception_handler, - general_exception_handler +from app.core.config import settings +from app.core.exception import ( + register_exception_handlers, + SuccessResponse, + ExceptionHandlingMiddleware, + ErrorCodes, + BusinessError, ) -from .module import router -from .module.shared.schema import StandardResponse -from .module.collection.schedule import load_scheduled_collection_tasks, set_collection_scheduler -from .module.shared.schedule import Scheduler +from app.core.logging import setup_logging, get_logger +from app.db.session import AsyncSessionLocal +from app.middleware import UserContextMiddleware +from app.module import router +from app.module.collection.schedule import load_scheduled_collection_tasks, set_collection_scheduler +from app.module.shared.schedule import Scheduler setup_logging() logger = get_logger(__name__) @@ -79,46 +77,40 @@ def mask_db_url(url: str) -> Literal[b""] | str: lifespan=lifespan ) +# 注册全局异常捕获中间件(最外层,确保捕获所有异常) +# 这样即使 debug=True,也不会泄露堆栈信息给客户端 +app.add_middleware(ExceptionHandlingMiddleware) + +# 注册用户上下文中间件 app.add_middleware(UserContextMiddleware) -# CORS Middleware -# app.add_middleware( -# CORSMiddleware, -# allow_origins=settings.allowed_origins, -# allow_credentials=True, -# allow_methods=settings.allowed_methods, -# allow_headers=settings.allowed_headers, -# ) # 注册路由 app.include_router(router) -# 输出注册的路由(每行一个) -logger.debug(f"Registered routes refer to http://localhost:{settings.port}/redoc") - # 注册全局异常处理器 -app.add_exception_handler(StarletteHTTPException, starlette_http_exception_handler) # type: ignore -app.add_exception_handler(HTTPException, fastapi_http_exception_handler) # type: ignore -app.add_exception_handler(RequestValidationError, validation_exception_handler) # type: ignore -app.add_exception_handler(Exception, general_exception_handler) +register_exception_handlers(app) # 测试端点:验证异常处理 -@app.get("/test-404", include_in_schema=False) -async def test_404(): - """测试404异常处理""" - raise HTTPException(status_code=404, detail="Test 404 error") - -@app.get("/test-500", include_in_schema=False) -async def test_500(): - """测试500异常处理""" - raise Exception("Test uncaught exception") - -# 根路径重定向到文档 -@app.get("/", response_model=StandardResponse[Dict[str, Any]], include_in_schema=False) +@app.get("/test-success", include_in_schema=False) +async def test_success(): + """测试成功响应""" + return SuccessResponse(data={"message": "Test successful"}) + +@app.get("/test-business-error", include_in_schema=False) +async def test_business_error(): + """测试业务错误响应""" + raise BusinessError(ErrorCodes.ANNOTATION_TASK_NOT_FOUND) + +@app.get("/test-system-error", include_in_schema=False) +async def test_system_error(): + """测试系统错误响应""" + raise SystemError(ErrorCodes.DATABASE_ERROR) + +# 根路径 +@app.get("/", include_in_schema=False) async def root(): """根路径,返回服务信息""" - return StandardResponse( - code=200, - message="success", + return SuccessResponse( data={ "message": f"{settings.app_name} is running", "version": settings.app_version, diff --git a/runtime/datamate-python/app/middleware.py b/runtime/datamate-python/app/middleware.py index 06e228ed4..5a6aac53f 100644 --- a/runtime/datamate-python/app/middleware.py +++ b/runtime/datamate-python/app/middleware.py @@ -1,8 +1,15 @@ +""" +Application middleware. + +Note: Exception handling has been moved to app.core.exception.ExceptionHandlingMiddleware +This file now only contains the UserContextMiddleware. +""" +import json +from typing import Optional + from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.status import HTTP_401_UNAUTHORIZED -import json -from typing import Optional from app.core.config import settings from app.core.logging import get_logger @@ -10,10 +17,11 @@ logger = get_logger(__name__) + class UserContextMiddleware(BaseHTTPMiddleware): """ - FastAPI middleware that reads `User` header and sets DataScopeHandle. - If `jwt_enable` is True, missing header returns 401. + 读取 `User` 请求头并设置 DataScopeHandle 的 FastAPI 中间件。 + 如果 `jwt_enable` 为 True,缺少请求头将返回 401。 """ def __init__(self, app): @@ -24,8 +32,12 @@ async def dispatch(self, request: Request, call_next): user: Optional[str] = request.headers.get("User") logger.info(f"start filter, current user: {user}, need filter: {self.jwt_enable}") if self.jwt_enable and (user is None or user.strip() == ""): - payload = {"code": HTTP_401_UNAUTHORIZED, "message": "unauthorized"} - return Response(content=json.dumps(payload), status_code=HTTP_401_UNAUTHORIZED, media_type="application/json") + payload = {"code": "common.401", "message": "unauthorized", "data": None} + return Response( + content=json.dumps(payload), + status_code=HTTP_401_UNAUTHORIZED, + media_type="application/json" + ) DataScopeHandle.set_user_info(user) try: @@ -33,3 +45,9 @@ async def dispatch(self, request: Request, call_next): return response finally: DataScopeHandle.remove_user_info() + + +# Re-export ExceptionHandlingMiddleware for backward compatibility + +__all__ = ['UserContextMiddleware'] + diff --git a/runtime/datamate-python/app/module/annotation/interface/auto.py b/runtime/datamate-python/app/module/annotation/interface/auto.py index 05fb3d1d9..e4a7fc07f 100644 --- a/runtime/datamate-python/app/module/annotation/interface/auto.py +++ b/runtime/datamate-python/app/module/annotation/interface/auto.py @@ -418,7 +418,7 @@ async def list_auto_annotation_tasks( tasks = await service.list_tasks(db) return StandardResponse( - code=200, + code="0", message="success", data=tasks, ) @@ -486,7 +486,7 @@ async def create_auto_annotation_task( ) return StandardResponse( - code=200, + code="0", message="success", data=task, ) @@ -571,7 +571,7 @@ async def update_auto_annotation_task_files( ) return StandardResponse( - code=200, + code="0", message="success", data=updated, ) @@ -592,7 +592,7 @@ async def get_auto_annotation_task_status( raise HTTPException(status_code=404, detail="Task not found") return StandardResponse( - code=200, + code="0", message="success", data=task, ) @@ -638,7 +638,7 @@ async def get_auto_annotation_task_files( else: # 未显式记录 file_ids 时,回退为主数据集下所有 ACTIVE 文件 if not dataset_id: - return StandardResponse(code=200, message="success", data=[]) + return StandardResponse(code="0", message="success", data=[]) files_query = select(DatasetFiles).where( DatasetFiles.dataset_id == dataset_id, DatasetFiles.status == "ACTIVE", @@ -671,7 +671,7 @@ async def get_auto_annotation_task_files( } data.append(item) - return StandardResponse(code=200, message="success", data=data) + return StandardResponse(code="0", message="success", data=data) @router.get("/{task_id}/label-studio-project", response_model=StandardResponse[Dict[str, str]]) @@ -773,7 +773,7 @@ async def get_auto_annotation_label_studio_project( "datasetId": str(target.dataset_id), } - return StandardResponse(code=200, message="success", data=data) + return StandardResponse(code="0", message="success", data=data) @router.delete("/{task_id}", response_model=StandardResponse[bool]) @@ -861,7 +861,7 @@ async def delete_auto_annotation_task( ) return StandardResponse( - code=200, + code="0", message="success", data=True, ) @@ -1029,7 +1029,7 @@ async def sync_auto_annotation_to_label_studio( ) return StandardResponse( - code=200, + code="0", message="success", data=created_count, ) @@ -1190,7 +1190,7 @@ def _sanitize_base_name(raw: str) -> str: except Exception: pass - return StandardResponse(code=200, message="success", data=True) + return StandardResponse(code="0", message="success", data=True) @router.post("/{task_id}/sync-db", response_model=StandardResponse[int]) @@ -1272,4 +1272,4 @@ async def sync_auto_task_annotations_to_database( updated = await sync_service.sync_project_annotations_to_dm(project_id=str(project_id)) - return StandardResponse(code=200, message="success", data=updated) + return StandardResponse(code="0", message="success", data=updated) diff --git a/runtime/datamate-python/app/module/annotation/interface/config.py b/runtime/datamate-python/app/module/annotation/interface/config.py index 78501714d..5c70e0240 100644 --- a/runtime/datamate-python/app/module/annotation/interface/config.py +++ b/runtime/datamate-python/app/module/annotation/interface/config.py @@ -20,7 +20,7 @@ async def get_config(): """获取配置信息(已废弃,请使用 /api/annotation/about)""" return StandardResponse( - code=200, + code="0", message="success", data=ConfigResponse( label_studio_url=settings.label_studio_base_url, @@ -39,9 +39,9 @@ async def get_tag_config(): if not config: logger.error("Failed to load tag configuration") return StandardResponse( - code=500, + code="common.500", message="Failed to load tag configuration", data={"objects": {}, "controls": {}} ) - return StandardResponse(code=200, message="success", data=config) + return StandardResponse(code="0", message="success", data=config) diff --git a/runtime/datamate-python/app/module/annotation/interface/project.py b/runtime/datamate-python/app/module/annotation/interface/project.py index 2a17c7211..9f00ca7ff 100644 --- a/runtime/datamate-python/app/module/annotation/interface/project.py +++ b/runtime/datamate-python/app/module/annotation/interface/project.py @@ -294,7 +294,7 @@ async def create_mapping( ) return StandardResponse( - code=201, + code="0", message="success", data=response_data, ) @@ -338,7 +338,7 @@ async def get_manual_mapping_files( ) file_ids = list(existing_mapping.keys()) if not file_ids: - return StandardResponse(code=200, message="success", data=[]) + return StandardResponse(code="0", message="success", data=[]) files_result = await db.execute( _select(DatasetFiles).where(DatasetFiles.id.in_(file_ids)) @@ -377,7 +377,7 @@ async def get_manual_mapping_files( } data.append(item) - return StandardResponse(code=200, message="success", data=data) + return StandardResponse(code="0", message="success", data=data) @router.put("/{mapping_id}/files", response_model=StandardResponse[bool]) @@ -413,7 +413,7 @@ async def update_manual_mapping_files( requested_ids = {str(fid) for fid in (body.file_ids or [])} if not requested_ids: # 不做任何变更,但认为成功 - return StandardResponse(code=200, message="success", data=True) + return StandardResponse(code="0", message="success", data=True) existing_mapping = await sync_service.get_existing_dm_file_mapping( mapping.labeling_project_id @@ -423,7 +423,7 @@ async def update_manual_mapping_files( # 仅对新增文件创建任务 new_ids = sorted(requested_ids - existing_ids) if not new_ids: - return StandardResponse(code=200, message="success", data=True) + return StandardResponse(code="0", message="success", data=True) stmt = ( _select(DatasetFiles.dataset_id, DatasetFiles.id) @@ -498,7 +498,7 @@ async def update_manual_mapping_files( delete_orphans=False, ) - return StandardResponse(code=200, message="success", data=True) + return StandardResponse(code="0", message="success", data=True) @router.post("/{mapping_id}/sync-label-studio-back", response_model=StandardResponse[bool]) @@ -602,7 +602,7 @@ def _sanitize_base_name(raw: str) -> str: except Exception: pass - return StandardResponse(code=200, message="success", data=True) + return StandardResponse(code="0", message="success", data=True) @router.post("/{mapping_id}/sync-db", response_model=StandardResponse[int]) @@ -637,7 +637,7 @@ async def sync_manual_annotations_to_database( project_id=str(mapping.labeling_project_id), ) - return StandardResponse(code=200, message="success", data=updated) + return StandardResponse(code="0", message="success", data=updated) @router.get("", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]]) async def list_mappings( @@ -688,7 +688,7 @@ async def list_mappings( logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}") return StandardResponse( - code=200, + code="0", message="success", data=paginated_data ) @@ -728,7 +728,7 @@ async def get_mapping( logger.info(f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}") return StandardResponse( - code=200, + code="0", message="success", data=mapping ) @@ -790,7 +790,7 @@ async def get_mappings_by_source( logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}") return StandardResponse( - code=200, + code="0", message="success", data=paginated_data ) @@ -864,7 +864,7 @@ async def delete_mapping( logger.info(f"Successfully deleted mapping: {id}, Label Studio project: {labeling_project_id}") return StandardResponse( - code=200, + code="0", message="success", data=DeleteDatasetResponse( id=id, diff --git a/runtime/datamate-python/app/module/annotation/interface/task.py b/runtime/datamate-python/app/module/annotation/interface/task.py index 13b667448..99639147f 100644 --- a/runtime/datamate-python/app/module/annotation/interface/task.py +++ b/runtime/datamate-python/app/module/annotation/interface/task.py @@ -1,15 +1,15 @@ -from fastapi import APIRouter, Depends, HTTPException, Query, Path +from fastapi import APIRouter, Depends, Query, Path, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from typing import List, Optional, Dict, Any from datetime import datetime from pydantic import BaseModel, Field, ConfigDict +from app.core.exception import ErrorCodes, BusinessError, SuccessResponse from app.db.session import get_db from app.module.shared.schema import StandardResponse from app.module.dataset import DatasetManagementService from app.core.logging import get_logger from app.core.config import settings -from app.exception import NoDatasetInfoFoundError, DatasetMappingNotFoundError from ..client import LabelStudioClient from ..service.sync import SyncService @@ -39,7 +39,7 @@ async def sync_dataset_content( ): """ Sync Dataset Content (Files and Annotations) - + 根据指定的mapping ID,同步DM程序数据集中的内容到Label Studio数据集中。 默认同时同步文件和标注数据。 """ @@ -59,14 +59,14 @@ async def sync_dataset_content( status_code=404, detail=f"Mapping not found: {request.id}" ) - + # Sync dataset files result = await sync_service.sync_dataset_files(request.id, request.batch_size) - + # Sync annotations if requested if request.sync_annotations: logger.info(f"Syncing annotations: direction={request.annotation_direction}") - + # 根据方向执行标注同步 if request.annotation_direction == "ls_to_dm": await sync_service.sync_annotations_from_ls_to_dm( @@ -87,26 +87,23 @@ async def sync_dataset_content( request.overwrite, request.overwrite_labeling_project ) - + logger.info(f"Sync completed: {result.synced_files}/{result.total_files} files") - + return StandardResponse( - code=200, + code="0", message="success", data=result ) - + except HTTPException: raise - except NoDatasetInfoFoundError as e: - logger.error(f"Failed to get dataset info: {e}") - raise HTTPException(status_code=404, detail=str(e)) - except DatasetMappingNotFoundError as e: - logger.error(f"Mapping not found: {e}") - raise HTTPException(status_code=404, detail=str(e)) + except BusinessError as e: + # 业务异常已经由全局异常处理器处理 + raise except Exception as e: logger.error(f"Error syncing dataset content: {e}") - raise HTTPException(status_code=500, detail="Internal server error") + raise @router.post("/annotation/sync", response_model=StandardResponse[SyncAnnotationsResponse]) @@ -136,7 +133,7 @@ async def sync_annotations( status_code=404, detail=f"Mapping not found: {request.id}" ) - + # 根据方向执行同步 if request.direction == "ls_to_dm": result = await sync_service.sync_annotations_from_ls_to_dm( @@ -162,15 +159,15 @@ async def sync_annotations( status_code=400, detail=f"Invalid direction: {request.direction}" ) - + logger.info(f"Annotation sync completed: synced_to_dm={result.synced_to_dm}, synced_to_ls={result.synced_to_ls}, conflicts_resolved={result.conflicts_resolved}") - + return StandardResponse( - code=200, + code="0", message="success", data=result ) - + except HTTPException: raise except Exception as e: @@ -190,17 +187,17 @@ async def check_label_studio_connection(): base_url=settings.label_studio_base_url, token=settings.label_studio_user_token ) - + # 尝试获取项目列表来测试连接 try: response = await ls_client.client.get("/api/projects") response.raise_for_status() projects = response.json() - + token_display = settings.label_studio_user_token[:10] + "..." if settings.label_studio_user_token else "None" - + return StandardResponse( - code=200, + code="0", message="success", data={ "status": "connected", @@ -212,9 +209,9 @@ async def check_label_studio_connection(): ) except Exception as e: token_display = settings.label_studio_user_token[:10] + "..." if settings.label_studio_user_token else "None" - + return StandardResponse( - code=500, + code="common.500", message="error", data={ "status": "disconnected", @@ -247,74 +244,74 @@ async def update_file_tags( Update File Tags (Partial Update with Auto Format Conversion) 接收部分标签更新并合并到指定文件(只修改提交的标签,其余保持不变),并更新 `tags_updated_at`。 - + 支持两种标签格式: 1. 简化格式(外部用户提交): [{"from_name": "label", "to_name": "image", "values": ["cat", "dog"]}] - + 2. 完整格式(内部存储): - [{"id": "...", "from_name": "label", "to_name": "image", "type": "choices", + [{"id": "...", "from_name": "label", "to_name": "image", "type": "choices", "value": {"choices": ["cat", "dog"]}}] - + 系统会自动根据数据集关联的模板将简化格式转换为完整格式。 请求与响应使用 Pydantic 模型 `UpdateFileTagsRequest` / `UpdateFileTagsResponse`。 """ service = DatasetManagementService(db) - + # 首先获取文件所属的数据集 from sqlalchemy.future import select from app.db.models import DatasetFiles - + result = await db.execute( select(DatasetFiles).where(DatasetFiles.id == file_id) ) file_record = result.scalar_one_or_none() - + if not file_record: raise HTTPException(status_code=404, detail=f"File not found: {file_id}") - + dataset_id = str(file_record.dataset_id) # type: ignore - Convert Column to str - + # 查找数据集关联的模板ID from ..service.mapping import DatasetMappingService - + mapping_service = DatasetMappingService(db) template_id = await mapping_service.get_template_id_by_dataset_id(dataset_id) - + if template_id: logger.info(f"Found template {template_id} for dataset {dataset_id}, will auto-convert tag format") else: logger.warning(f"No template found for dataset {dataset_id}, tags must be in full format") - + # 更新标签(如果有模板ID则自动转换格式) success, error_msg, updated_at = await service.update_file_tags_partial( file_id=file_id, new_tags=request.tags, template_id=template_id # 传递模板ID以启用自动转换 ) - + if not success: if "not found" in (error_msg or "").lower(): raise HTTPException(status_code=404, detail=error_msg) raise HTTPException(status_code=500, detail=error_msg or "更新标签失败") - + # 重新获取更新后的文件记录(获取完整标签列表) result = await db.execute( select(DatasetFiles).where(DatasetFiles.id == file_id) ) file_record = result.scalar_one_or_none() - + if not file_record: raise HTTPException(status_code=404, detail=f"File not found: {file_id}") - + response_data = UpdateFileTagsResponse( fileId=file_id, tags=file_record.tags or [], # type: ignore tagsUpdatedAt=updated_at or datetime.now() ) - + return StandardResponse( - code=200, + code="0", message="标签更新成功", data=response_data ) diff --git a/runtime/datamate-python/app/module/annotation/interface/template.py b/runtime/datamate-python/app/module/annotation/interface/template.py index 6ab8a7655..795c10558 100644 --- a/runtime/datamate-python/app/module/annotation/interface/template.py +++ b/runtime/datamate-python/app/module/annotation/interface/template.py @@ -40,7 +40,7 @@ async def create_template( - **category**: 模板分类(默认custom) """ template = await template_service.create_template(db, request) - return StandardResponse(code=200, message="success", data=template) + return StandardResponse(code="0", message="success", data=template) @router.get( @@ -57,7 +57,7 @@ async def get_template( template = await template_service.get_template(db, template_id) if not template: raise HTTPException(status_code=404, detail="Template not found") - return StandardResponse(code=200, message="success", data=template) + return StandardResponse(code="0", message="success", data=template) @router.get( @@ -92,7 +92,7 @@ async def list_template( labeling_type=labelingType, built_in=builtIn ) - return StandardResponse(code=200, message="success", data=templates) + return StandardResponse(code="0", message="success", data=templates) @router.put( @@ -112,7 +112,7 @@ async def update_template( template = await template_service.update_template(db, template_id, request) if not template: raise HTTPException(status_code=404, detail="Template not found") - return StandardResponse(code=200, message="success", data=template) + return StandardResponse(code="0", message="success", data=template) @router.delete( @@ -129,4 +129,4 @@ async def delete_template( success = await template_service.delete_template(db, template_id) if not success: raise HTTPException(status_code=404, detail="Template not found") - return StandardResponse(code=200, message="success", data=True) + return StandardResponse(code="0", message="success", data=True) diff --git a/runtime/datamate-python/app/module/annotation/service/sync.py b/runtime/datamate-python/app/module/annotation/service/sync.py index 4f59ac5ed..2660fd1e4 100644 --- a/runtime/datamate-python/app/module/annotation/service/sync.py +++ b/runtime/datamate-python/app/module/annotation/service/sync.py @@ -7,7 +7,7 @@ from app.core.logging import get_logger from app.core.config import settings -from app.exception import NoDatasetInfoFoundError +from app.core.exception import ErrorCodes, BusinessError from ..client import LabelStudioClient from ..schema import ( @@ -388,7 +388,7 @@ async def sync_files( # 获取DM数据集信息 dataset_info = await self.dm_client.get_dataset(effective_dataset_id) if not dataset_info: - raise NoDatasetInfoFoundError(mapping.dataset_id) + raise BusinessError(ErrorCodes.NOT_FOUND, data={"dataset_id": mapping.dataset_id}) total_files = dataset_info.fileCount logger.debug(f"Total files in DM dataset: {total_files}") diff --git a/runtime/datamate-python/app/module/collection/interface/collection.py b/runtime/datamate-python/app/module/collection/interface/collection.py index 42a83f38f..ba660b268 100644 --- a/runtime/datamate-python/app/module/collection/interface/collection.py +++ b/runtime/datamate-python/app/module/collection/interface/collection.py @@ -4,10 +4,11 @@ import os from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession +from app.core.exception import ErrorCodes, BusinessError, SuccessResponse, transaction from app.core.logging import get_logger from app.db.models import Dataset from app.db.models.data_collection import CollectionTask, TaskExecution, CollectionTemplate @@ -32,18 +33,18 @@ async def create_task( db: AsyncSession = Depends(get_db) ): """创建归集任务""" - try: - template = await db.execute(select(CollectionTemplate).where(CollectionTemplate.id == request.template_id)) - template = template.scalar_one_or_none() - if not template: - raise HTTPException(status_code=400, detail="Template not found") - - task_id = str(uuid.uuid4()) - DataxClient.generate_datx_config(request.config, template, f"/dataset/local/{task_id}") - task = convert_for_create(request, task_id) - task.template_name = template.name - dataset = None - + template = await db.execute(select(CollectionTemplate).where(CollectionTemplate.id == request.template_id)) + template = template.scalar_one_or_none() + if not template: + raise BusinessError(ErrorCodes.COLLECTION_TEMPLATE_NOT_FOUND, data={"template_id": request.template_id}) + + task_id = str(uuid.uuid4()) + DataxClient.generate_datx_config(request.config, template, f"/dataset/local/{task_id}") + task = convert_for_create(request, task_id) + task.template_name = template.name + dataset = None + + async with transaction(db): if request.dataset_name: target_dataset_id = uuid.uuid4() dataset = Dataset( @@ -61,80 +62,63 @@ async def create_task( task = await db.execute(select(CollectionTask).where(CollectionTask.id == task.id)) task = task.scalar_one_or_none() - await db.commit() - if task and task.sync_mode == SyncMode.SCHEDULED.value and task.schedule_expression: - schedule_collection_task(task.id, task.schedule_expression) - - return StandardResponse( - code=200, - message="Success", - data=converter_to_response(task) - ) - except HTTPException: - await db.rollback() - raise - except ValueError as e: - await db.rollback() - raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - await db.rollback() - logger.error(f"Failed to create collection task: {str(e)}", e) - raise HTTPException(status_code=500, detail="Internal server error") + + # 事务已提交,执行调度 + if task and task.sync_mode == SyncMode.SCHEDULED.value and task.schedule_expression: + schedule_collection_task(task.id, task.schedule_expression) + + return SuccessResponse( + data=converter_to_response(task), + message="Success" + ) @router.get("", response_model=StandardResponse[PaginatedData[CollectionTaskBase]]) async def list_tasks( page: int = 1, size: int = 20, - name: Optional[str] = Query(None, description="任务名称模糊查询"), + name: Optional[str] = Query(None, description="Fuzzy search by task name"), db: AsyncSession = Depends(get_db) ): """分页查询归集任务""" - try: - # 构建查询条件 - page = page if page > 0 else 1 - size = size if size > 0 else 20 - query = select(CollectionTask) - - if name: - query = query.where(CollectionTask.name.ilike(f"%{name}%")) - - # 获取总数 - count_query = select(func.count()).select_from(query.subquery()) - total = (await db.execute(count_query)).scalar_one() - - # 分页查询 - offset = (page - 1) * size - tasks = (await db.execute( - query.order_by(CollectionTask.created_at.desc()) - .offset(offset) - .limit(size) - )).scalars().all() - - # 转换为响应模型 - items = [converter_to_response(task) for task in tasks] - total_pages = math.ceil(total / size) if total > 0 else 0 - - return StandardResponse( - code=200, - message="Success", - data=PaginatedData( - content=items, - total_elements=total, - total_pages=total_pages, - page=page, - size=size, - ) + # 构建查询条件 + page = page if page > 0 else 1 + size = size if size > 0 else 20 + query = select(CollectionTask) + + if name: + query = query.where(CollectionTask.name.ilike(f"%{name}%")) + + # 获取总数 + count_query = select(func.count()).select_from(query.subquery()) + total = (await db.execute(count_query)).scalar_one() + + # 分页查询 + offset = (page - 1) * size + tasks = (await db.execute( + query.order_by(CollectionTask.created_at.desc()) + .offset(offset) + .limit(size) + )).scalars().all() + + # 转换为响应模型 + items = [converter_to_response(task) for task in tasks] + total_pages = math.ceil(total / size) if total > 0 else 0 + + return SuccessResponse( + data=PaginatedData( + content=items, + total_elements=total, + total_pages=total_pages, + page=page, + size=size, ) - - except Exception as e: - logger.error(f"Failed to list evaluation tasks: {str(e)}", e) - raise HTTPException(status_code=500, detail="Internal server error") + ) @router.delete("", response_model=StandardResponse[str], status_code=200) async def delete_collection_tasks( - ids: list[str] = Query(..., description="要删除的任务ID列表"), + ids: list[str] = Query(..., description="List of task IDs to delete"), db: AsyncSession = Depends(get_db), ): """ @@ -147,44 +131,36 @@ async def delete_collection_tasks( Returns: StandardResponse[str]: 删除结果 """ - try: - # 检查任务是否存在 - task_id = ids[0] - task = await db.get(CollectionTask, task_id) - if not task: - raise HTTPException(status_code=404, detail="Collection task not found") - - # 删除任务执行记录 + # 检查任务是否存在 + task_id = ids[0] + task = await db.get(CollectionTask, task_id) + if not task: + raise BusinessError(ErrorCodes.COLLECTION_TASK_NOT_FOUND, data={"task_id": task_id}) + + # 删除任务执行记录(在事务内) + async with transaction(db): await db.execute( TaskExecution.__table__.delete() .where(TaskExecution.task_id == task_id) ) - remove_collection_task(task_id) - - target_path = f"/dataset/local/{task_id}" - if os.path.exists(target_path): - shutil.rmtree(target_path) - job_path = f"/flow/data-collection/{task_id}" - if os.path.exists(job_path): - shutil.rmtree(job_path) # 删除任务 await db.delete(task) - await db.commit() - return StandardResponse( - code=200, - message="Collection task deleted successfully", - data="success" - ) + # 事务提交后,删除文件系统和调度 + remove_collection_task(task_id) + + target_path = f"/dataset/local/{task_id}" + if os.path.exists(target_path): + shutil.rmtree(target_path) + job_path = f"/flow/data-collection/{task_id}" + if os.path.exists(job_path): + shutil.rmtree(job_path) - except HTTPException: - await db.rollback() - raise - except Exception as e: - await db.rollback() - logger.error(f"Failed to delete collection task: {str(e)}") - raise HTTPException(status_code=500, detail="Internal server error") + return SuccessResponse( + data="success", + message="Collection task deleted successfully" + ) @router.get("/{task_id}", response_model=StandardResponse[CollectionTaskBase]) async def get_task( @@ -192,22 +168,12 @@ async def get_task( db: AsyncSession = Depends(get_db) ): """获取归集任务详情""" - try: - # Query the task by ID - task = await db.get(CollectionTask, task_id) - if not task: - raise HTTPException( - status_code=404, - detail=f"Task with ID {task_id} not found" - ) - - return StandardResponse( - code=200, - message="Success", - data=converter_to_response(task) - ) - except HTTPException: - raise - except Exception as e: - logger.error(f"Failed to get task {task_id}: {str(e)}", e) - raise HTTPException(status_code=500, detail="Internal server error") + # 根据ID查询任务 + task = await db.get(CollectionTask, task_id) + if not task: + raise BusinessError(ErrorCodes.COLLECTION_TASK_NOT_FOUND, data={"task_id": task_id}) + + return SuccessResponse( + data=converter_to_response(task), + message="Success" + ) diff --git a/runtime/datamate-python/app/module/collection/interface/execution.py b/runtime/datamate-python/app/module/collection/interface/execution.py index 29e879d32..2ec0c9eb6 100644 --- a/runtime/datamate-python/app/module/collection/interface/execution.py +++ b/runtime/datamate-python/app/module/collection/interface/execution.py @@ -27,10 +27,10 @@ async def list_executions( page: int = 1, size: int = 20, - task_id: Optional[str] = Query(None, description="任务ID"), - task_name: Optional[str] = Query(None, description="任务名称模糊查询"), - start_time: Optional[datetime] = Query(None, description="开始执行时间范围-起(started_at >= start_time)"), - end_time: Optional[datetime] = Query(None, description="开始执行时间范围-止(started_at <= end_time)"), + task_id: Optional[str] = Query(None, description="Task ID"), + task_name: Optional[str] = Query(None, description="Fuzzy search by task name"), + start_time: Optional[datetime] = Query(None, description="Start time range from (started_at >= start_time)"), + end_time: Optional[datetime] = Query(None, description="Start time range to (started_at <= end_time)"), db: AsyncSession = Depends(get_db) ): """分页查询归集任务执行记录""" @@ -63,7 +63,7 @@ async def list_executions( total_pages = math.ceil(total / size) if total > 0 else 0 return StandardResponse( - code=200, + code="0", message="Success", data=PaginatedData( content=items, @@ -84,7 +84,7 @@ async def get_execution_log( execution_id: str, db: AsyncSession = Depends(get_db) ): - """获取执行记录对应的日志文件内容""" + """Get log file content for execution record""" try: execution = await db.get(TaskExecution, execution_id) if not execution: diff --git a/runtime/datamate-python/app/module/collection/interface/template.py b/runtime/datamate-python/app/module/collection/interface/template.py index 83d6125c4..a42aabdd2 100644 --- a/runtime/datamate-python/app/module/collection/interface/template.py +++ b/runtime/datamate-python/app/module/collection/interface/template.py @@ -23,8 +23,8 @@ async def list_templates( page: int = 1, size: int = 20, - name: Optional[str] = Query(None, description="模板名称模糊查询"), - built_in: Optional[bool] = Query(None, description="是否系统内置模板"), + name: Optional[str] = Query(None, description="Fuzzy search by template name"), + built_in: Optional[bool] = Query(None, description="Filter by system built-in template"), db: AsyncSession = Depends(get_db) ): """分页查询归集任务模板""" @@ -51,7 +51,7 @@ async def list_templates( total_pages = math.ceil(total / size) if total > 0 else 0 return StandardResponse( - code=200, + code="0", message="Success", data=PaginatedData( content=items, diff --git a/runtime/datamate-python/app/module/evaluation/interface/evaluation.py b/runtime/datamate-python/app/module/evaluation/interface/evaluation.py index a718b4194..203a8d080 100644 --- a/runtime/datamate-python/app/module/evaluation/interface/evaluation.py +++ b/runtime/datamate-python/app/module/evaluation/interface/evaluation.py @@ -3,12 +3,12 @@ import math import json from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks +from fastapi import APIRouter, Depends, Query from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select, func, or_, text, and_ -from pydantic import ValidationError +from sqlalchemy import select, func from app.core.logging import get_logger +from app.core.exception import ErrorCodes, BusinessError, SuccessResponse from app.db.models.data_evaluation import EvaluationFile from app.db.session import get_db from app.db.models import EvaluationTask, EvaluationItem, DatasetFiles @@ -37,24 +37,13 @@ @router.get("/prompt-templates", response_model=StandardResponse[PromptTemplateResponse]) async def get_prompt_templates(): """ - Get all available evaluation prompt templates + 获取所有可用的评估提示模板 Returns: StandardResponse with list of prompt templates """ - try: - templates = PromptTemplateService.get_prompt_templates() - return StandardResponse( - code=200, - message="Success", - data=templates - ) - except Exception as e: - logger.error(f"Failed to get prompt templates: {str(e)}") - raise HTTPException( - status_code=500, - detail="Failed to retrieve prompt templates" - ) + templates = PromptTemplateService.get_prompt_templates() + return SuccessResponse(data=templates) @router.post("/tasks", response_model=StandardResponse[EvaluationTaskDetailResponse], status_code=201) @@ -72,62 +61,51 @@ async def create_evaluation_task( Returns: StandardResponse[EvaluationTaskDetailResponse]: 创建的任务详情 """ - try: - # 检查任务名称是否已存在 - existing_task = await db.execute( - select(EvaluationTask).where(EvaluationTask.name == request.name) - ) - if existing_task.scalar_one_or_none(): - raise HTTPException(status_code=400, detail=f"Evaluation task with name '{request.name}' already exists") - - models = await get_model_by_id(db, request.eval_config.model_id) - if not models: - raise HTTPException(status_code=400, detail=f"Model with id '{request.eval_config.model_id}' not found") - - # 创建评估任务 - task = EvaluationTask( - id=str(uuid.uuid4()), - name=request.name, - description=request.description, - task_type=request.task_type, - source_type=request.source_type, - source_id=request.source_id, - source_name=request.source_name, - eval_prompt=request.eval_prompt, - eval_config=json.dumps({ - "modelId": request.eval_config.model_id, - "modelName": models.model_name, - "dimensions": request.eval_config.dimensions, - }), - status=TaskStatus.PENDING.value, - eval_process=0.0, - ) + # 检查任务名称是否已存在 + existing_task = await db.execute( + select(EvaluationTask).where(EvaluationTask.name == request.name) + ) + if existing_task.scalar_one_or_none(): + raise BusinessError(ErrorCodes.EVALUATION_MODEL_NOT_FOUND, data={"name": request.name}) + + models = await get_model_by_id(db, request.eval_config.model_id) + if not models: + raise BusinessError(ErrorCodes.EVALUATION_MODEL_NOT_FOUND, data={"model_id": request.eval_config.model_id}) + + # 创建评估任务 + task = EvaluationTask( + id=str(uuid.uuid4()), + name=request.name, + description=request.description, + task_type=request.task_type, + source_type=request.source_type, + source_id=request.source_id, + source_name=request.source_name, + eval_prompt=request.eval_prompt, + eval_config=json.dumps({ + "modelId": request.eval_config.model_id, + "modelName": models.model_name, + "dimensions": request.eval_config.dimensions, + }), + status=TaskStatus.PENDING.value, + eval_process=0.0, + ) - db.add(task) - # Commit first to persist the task before scheduling background work - await db.commit() - # Schedule background execution without blocking the current request - asyncio.create_task(EvaluationTaskService.run_evaluation_task(task.id)) - - # Refresh the task to return latest state - await db.refresh(task) - - # 转换响应模型 - response = _map_to_task_detail_response(task) - return StandardResponse( - code=200, - message="Evaluation task created successfully", - data=response - ) + db.add(task) + # Commit first to persist the task before scheduling background work + await db.commit() + # Schedule background execution without blocking the current request + asyncio.create_task(EvaluationTaskService.run_evaluation_task(task.id)) + + # Refresh the task to return latest state + await db.refresh(task) - except ValidationError as e: - await db.rollback() - logger.error(f"Validation error: {str(e)}") - raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - await db.rollback() - logger.error(f"Failed to create evaluation task: {str(e)}") - raise HTTPException(status_code=500, detail="Internal server error") + # 转换响应模型 + response = _map_to_task_detail_response(task) + return SuccessResponse( + data=response, + message="Evaluation task created successfully" + ) @router.get("/tasks", response_model=StandardResponse[PagedEvaluationTaskResponse]) @@ -153,51 +131,44 @@ async def list_evaluation_tasks( Returns: StandardResponse[PagedEvaluationTaskResponse]: 分页的评估任务列表 """ - try: - # 构建查询条件 - query = select(EvaluationTask) - - if name: - query = query.where(EvaluationTask.name.ilike(f"%{name}%")) - if status: - query = query.where(EvaluationTask.status == status) - if task_type: - query = query.where(EvaluationTask.task_type == task_type) - - # 获取总数 - count_query = select(func.count()).select_from(query.subquery()) - total = (await db.execute(count_query)).scalar_one() - - # 分页查询 - offset = (page - 1) * size - tasks = (await db.execute( - query.order_by(EvaluationTask.created_at.desc()) - .offset(offset) - .limit(size) - )).scalars().all() - - # 转换为响应模型 - items = [_map_to_task_detail_response(task) for task in tasks] - total_pages = math.ceil(total / size) if total > 0 else 0 - - return StandardResponse( - code=200, - message="Success", - data=PagedEvaluationTaskResponse( - content=items, - totalElements=total, - totalPages=total_pages, - page=page, - size=size, - ) + # 构建查询条件 + query = select(EvaluationTask) + + if name: + query = query.where(EvaluationTask.name.ilike(f"%{name}%")) + if status: + query = query.where(EvaluationTask.status == status) + if task_type: + query = query.where(EvaluationTask.task_type == task_type) + + # 获取总数 + count_query = select(func.count()).select_from(query.subquery()) + total = (await db.execute(count_query)).scalar_one() + + # 分页查询 + offset = (page - 1) * size + tasks = (await db.execute( + query.order_by(EvaluationTask.created_at.desc()) + .offset(offset) + .limit(size) + )).scalars().all() + + # 转换为响应模型 + items = [_map_to_task_detail_response(task) for task in tasks] + total_pages = math.ceil(total / size) if total > 0 else 0 + + return SuccessResponse( + data=PagedEvaluationTaskResponse( + content=items, + totalElements=total, + totalPages=total_pages, + page=page, + size=size, ) - - except Exception as e: - logger.error(f"Failed to list evaluation tasks: {str(e)}") - raise HTTPException(status_code=500, detail="Internal server error") + ) @router.get("/tasks/{task_id}/files", response_model=StandardResponse[PagedEvaluationFilesResponse]) -async def list_evaluation_items( +async def list_evaluation_files( task_id: str, page: int = Query(1, ge=1, description="页码,从1开始"), size: int = Query(10, ge=1, le=100, description="每页数量"), @@ -215,41 +186,36 @@ async def list_evaluation_items( Returns: StandardResponse[PagedEvaluationFilesResponse]: 分页的评估文件列表 """ - try: - task = await db.get(EvaluationTask, task_id) - if not task: - raise HTTPException(status_code=404, detail="Evaluation task not found") - offset = (page - 1) * size - query = select(EvaluationFile).where(EvaluationFile.task_id == task_id) - count_query = select(func.count()).select_from(query.subquery()) - total = (await db.execute(count_query)).scalar_one() - files = (await db.execute(query.offset(offset).limit(size))).scalars().all() - total_pages = math.ceil(total / size) if total > 0 else 0 - file_responses = [ - EvaluationFileResponse( - taskId=file.task_id, - fileId=file.file_id, - fileName=file.file_name, - totalCount=file.total_count, - evaluatedCount=file.evaluated_count, - pendingCount=file.total_count - file.evaluated_count - ) - for file in files - ] - return StandardResponse( - code=200, - message="Success", - data=PagedEvaluationFilesResponse( - content=file_responses, - totalElements=total, - totalPages=total_pages, - page=page, - size=size, - ) + task = await db.get(EvaluationTask, task_id) + if not task: + raise BusinessError(ErrorCodes.EVALUATION_TASK_NOT_FOUND, data={"task_id": task_id}) + + offset = (page - 1) * size + query = select(EvaluationFile).where(EvaluationFile.task_id == task_id) + count_query = select(func.count()).select_from(query.subquery()) + total = (await db.execute(count_query)).scalar_one() + files = (await db.execute(query.offset(offset).limit(size))).scalars().all() + total_pages = math.ceil(total / size) if total > 0 else 0 + file_responses = [ + EvaluationFileResponse( + taskId=file.task_id, + fileId=file.file_id, + fileName=file.file_name, + totalCount=file.total_count, + evaluatedCount=file.evaluated_count, + pendingCount=file.total_count - file.evaluated_count ) - except Exception as e: - logger.error(f"Failed to list evaluation items: {str(e)}") - raise HTTPException(status_code=500, detail="Internal server error") + for file in files + ] + return SuccessResponse( + data=PagedEvaluationFilesResponse( + content=file_responses, + totalElements=total, + totalPages=total_pages, + page=page, + size=size, + ) + ) @router.get("/tasks/{task_id}/items", response_model=StandardResponse[PagedEvaluationItemsResponse]) @@ -275,60 +241,54 @@ async def list_evaluation_items( Returns: StandardResponse[PagedEvaluationItemsResponse]: 分页的评估条目列表 """ - try: - # 检查任务是否存在 - task = await db.get(EvaluationTask, task_id) - if not task: - raise HTTPException(status_code=404, detail="Evaluation task not found") - - # 构建查询条件 - query = select(EvaluationItem).where(EvaluationItem.task_id == task_id) - - if status: - query = query.where(EvaluationItem.status == status) - - if file_id: - query = query.where(EvaluationItem.file_id == file_id) - - # 获取总数 - count_query = select(func.count()).select_from(query.subquery()) - total = (await db.execute(count_query)).scalar_one() - - # 分页查询 - offset = (page - 1) * size - items = (await db.execute(query.offset(offset).limit(size))).scalars().all() - - # 转换为响应模型 - item_responses = [ - EvaluationItemResponse( - id=item.id, - taskId=item.task_id, - itemId=item.item_id, - fileId=item.file_id, - evalContent=json.loads(item.eval_content) if item.eval_content else None, - evalScore=float(item.eval_score) if item.eval_score else None, - evalResult=json.loads(item.eval_result), - status=item.status - ) - for item in items - ] - - total_pages = math.ceil(total / size) if total > 0 else 0 - - return StandardResponse( - code=200, - message="Success", - data=PagedEvaluationItemsResponse( - content=item_responses, - totalElements=total, - totalPages=total_pages, - page=page, - size=size, - ) + # 检查任务是否存在 + task = await db.get(EvaluationTask, task_id) + if not task: + raise BusinessError(ErrorCodes.EVALUATION_TASK_NOT_FOUND, data={"task_id": task_id}) + + # 构建查询条件 + query = select(EvaluationItem).where(EvaluationItem.task_id == task_id) + + if status: + query = query.where(EvaluationItem.status == status) + + if file_id: + query = query.where(EvaluationItem.file_id == file_id) + + # 获取总数 + count_query = select(func.count()).select_from(query.subquery()) + total = (await db.execute(count_query)).scalar_one() + + # 分页查询 + offset = (page - 1) * size + items = (await db.execute(query.offset(offset).limit(size))).scalars().all() + + # 转换为响应模型 + item_responses = [ + EvaluationItemResponse( + id=item.id, + taskId=item.task_id, + itemId=item.item_id, + fileId=item.file_id, + evalContent=json.loads(item.eval_content) if item.eval_content else None, + evalScore=float(item.eval_score) if item.eval_score else None, + evalResult=json.loads(item.eval_result), + status=item.status + ) + for item in items + ] + + total_pages = math.ceil(total / size) if total > 0 else 0 + + return SuccessResponse( + data=PagedEvaluationItemsResponse( + content=item_responses, + totalElements=total, + totalPages=total_pages, + page=page, + size=size, ) - except Exception as e: - logger.error(f"Failed to list evaluation items: {str(e)}") - raise HTTPException(status_code=500, detail="Internal server error") + ) @router.get("/tasks/{task_id}", response_model=StandardResponse[EvaluationTaskDetailResponse]) @@ -346,24 +306,13 @@ async def get_evaluation_task( Returns: StandardResponse[EvaluationTaskDetailResponse]: 评估任务详情 """ - try: - task = await db.get(EvaluationTask, task_id) - if not task: - raise HTTPException(status_code=404, detail="Evaluation task not found") - - # 转换为响应模型 - response = _map_to_task_detail_response(task) - return StandardResponse( - code=200, - message="Success", - data=response - ) + task = await db.get(EvaluationTask, task_id) + if not task: + raise BusinessError(ErrorCodes.EVALUATION_TASK_NOT_FOUND, data={"task_id": task_id}) - except HTTPException: - raise - except Exception as e: - logger.error(f"Failed to get evaluation task: {str(e)}") - raise HTTPException(status_code=500, detail="Internal server error") + # 转换为响应模型 + response = _map_to_task_detail_response(task) + return SuccessResponse(data=response) @router.delete("/tasks", response_model=StandardResponse[str], status_code=200) @@ -381,42 +330,32 @@ async def delete_eval_tasks( Returns: StandardResponse[str]: 删除结果 """ - try: - # 检查任务是否存在 - task_id = ids[0] - task = await db.get(EvaluationTask, task_id) - if not task: - raise HTTPException(status_code=404, detail="Evaluation task not found") - - # 删除评估项 - await db.execute( - EvaluationItem.__table__.delete() - .where(EvaluationItem.task_id == task_id) - ) - - # 删除评估文件 - await db.execute( - EvaluationFile.__table__.delete() - .where(EvaluationFile.task_id == task_id) - ) + # 检查任务是否存在 + task_id = ids[0] + task = await db.get(EvaluationTask, task_id) + if not task: + raise BusinessError(ErrorCodes.EVALUATION_TASK_NOT_FOUND, data={"task_id": task_id}) + + # 删除评估项 + await db.execute( + EvaluationItem.__table__.delete() + .where(EvaluationItem.task_id == task_id) + ) - # 删除任务 - await db.delete(task) - await db.commit() + # 删除评估文件 + await db.execute( + EvaluationFile.__table__.delete() + .where(EvaluationFile.task_id == task_id) + ) - return StandardResponse( - code=200, - message="Evaluation task deleted successfully", - data="success" - ) + # 删除任务 + await db.delete(task) + await db.commit() - except HTTPException: - await db.rollback() - raise - except Exception as e: - await db.rollback() - logger.error(f"Failed to delete evaluation task: {str(e)}") - raise HTTPException(status_code=500, detail="Internal server error") + return SuccessResponse( + data="success", + message="Evaluation task deleted successfully" + ) def _map_to_task_detail_response( diff --git a/runtime/datamate-python/app/module/evaluation/service/evaluation.py b/runtime/datamate-python/app/module/evaluation/service/evaluation.py index 02d617a62..9271f9c9f 100644 --- a/runtime/datamate-python/app/module/evaluation/service/evaluation.py +++ b/runtime/datamate-python/app/module/evaluation/service/evaluation.py @@ -5,7 +5,7 @@ from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession -from app.core.exception import BusinessErrorCodeEnum, BusinessException +from app.core.exception import ErrorCodes, BusinessError from app.core.logging import get_logger from app.db.models import EvaluationItem, EvaluationTask, DatasetFiles from app.db.models.data_evaluation import EvaluationFile @@ -185,7 +185,7 @@ def get_executor(self, source_type: str) -> EvaluationExecutor: for executor in self.executors: if executor.get_source_type().value == source_type: return executor - raise BusinessException(BusinessErrorCodeEnum.TASK_TYPE_ERROR.value) + raise BusinessError(ErrorCodes.EVALUATION_TASK_TYPE_ERROR) class EvaluationTaskService: diff --git a/runtime/datamate-python/app/module/generation/interface/generation_api.py b/runtime/datamate-python/app/module/generation/interface/generation_api.py index bee0fba97..67ca917ae 100644 --- a/runtime/datamate-python/app/module/generation/interface/generation_api.py +++ b/runtime/datamate-python/app/module/generation/interface/generation_api.py @@ -1,10 +1,11 @@ import uuid from typing import cast -from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks +from fastapi import APIRouter, Depends, BackgroundTasks from sqlalchemy import select, func, delete from sqlalchemy.ext.asyncio import AsyncSession +from app.core.exception import ErrorCodes, BusinessError, SuccessResponse, transaction from app.core.logging import get_logger from app.db.models.data_synthesis import ( save_synthesis_task, @@ -56,30 +57,32 @@ async def create_synthesis_task( ) dataset_files = ds_result.scalars().all() - # 保存任务到数据库 + # 保存任务到数据库(在事务中) request.source_file_id = [str(f.id) for f in dataset_files] - synthesis_task = await save_synthesis_task(db, request) - - # 将已有的 DatasetFiles 记录保存到 t_data_synthesis_file_instances - synth_files = [] - for f in dataset_files: - file_instance = DataSynthesisFileInstance( - id=str(uuid.uuid4()), # 使用新的 UUID 作为文件任务记录的主键,避免与 DatasetFiles 主键冲突 - synthesis_instance_id=synthesis_task.id, - file_name=f.file_name, - source_file_id=str(f.id), - status="pending", - total_chunks=0, - processed_chunks=0, - created_by="system", - updated_by="system", - ) - synth_files.append(file_instance) - if dataset_files: - db.add_all(synth_files) - await db.commit() + async with transaction(db): + synthesis_task = await save_synthesis_task(db, request) + + # 将已有的 DatasetFiles 记录保存到 t_data_synthesis_file_instances + synth_files = [] + for f in dataset_files: + file_instance = DataSynthesisFileInstance( + id=str(uuid.uuid4()), # 使用新的 UUID 作为文件任务记录的主键,避免与 DatasetFiles 主键冲突 + synthesis_instance_id=synthesis_task.id, + file_name=f.file_name, + source_file_id=str(f.id), + status="pending", + total_chunks=0, + processed_chunks=0, + created_by="system", + updated_by="system", + ) + synth_files.append(file_instance) + + if dataset_files: + db.add_all(synth_files) + # 事务已提交,启动后台任务 generation_service = GenerationService(db) # 异步处理任务:只传任务 ID,后台任务中使用新的 DB 会话重新加载任务对象 background_tasks.add_task(generation_service.process_task, synthesis_task.id) @@ -99,11 +102,7 @@ async def create_synthesis_task( updated_by=synthesis_task.updated_by, ) - return StandardResponse( - code=200, - message="success", - data=task_item, - ) + return SuccessResponse(data=task_item) @router.get("/task/{task_id}", response_model=StandardResponse[DataSynthesisTaskItem]) @@ -114,7 +113,7 @@ async def get_synthesis_task( """获取数据合成任务详情""" synthesis_task = await db.get(DataSynthInstance, task_id) if not synthesis_task: - raise HTTPException(status_code=404, detail="Synthesis task not found") + raise BusinessError(ErrorCodes.GENERATION_TASK_NOT_FOUND, data={"task_id": task_id}) task_item = DataSynthesisTaskItem( id=synthesis_task.id, @@ -128,11 +127,7 @@ async def get_synthesis_task( created_by=synthesis_task.created_by, updated_by=synthesis_task.updated_by, ) - return StandardResponse( - code=200, - message="success", - data=task_item, - ) + return SuccessResponse(data=task_item) @router.get("/tasks", response_model=StandardResponse[PagedDataSynthesisTaskResponse], status_code=200) @@ -209,11 +204,7 @@ async def list_synthesis_tasks( size=page_size, ) - return StandardResponse( - code=200, - message="Success", - data=paged, - ) + return SuccessResponse(data=paged) @router.delete("/task/{task_id}", response_model=StandardResponse) @@ -224,7 +215,7 @@ async def delete_synthesis_task( """删除数据合成任务""" task = await db.get(DataSynthInstance, task_id) if not task: - raise HTTPException(status_code=404, detail="Synthesis task not found") + raise BusinessError(ErrorCodes.GENERATION_TASK_NOT_FOUND, data={"task_id": task_id}) # 1. 删除与该任务相关的 SynthesisData、Chunk、File 记录 # 先查出所有文件任务 ID @@ -258,11 +249,7 @@ async def delete_synthesis_task( await db.delete(task) await db.commit() - return StandardResponse( - code=200, - message="success", - data=None, - ) + return SuccessResponse(data=None) @router.delete("/task/{task_id}/{file_id}", response_model=StandardResponse) @@ -275,11 +262,11 @@ async def delete_synthesis_file_task( # 先获取任务和文件任务记录 task = await db.get(DataSynthInstance, task_id) if not task: - raise HTTPException(status_code=404, detail="Synthesis task not found") + raise BusinessError(ErrorCodes.GENERATION_TASK_NOT_FOUND, data={"task_id": task_id}) file_task = await db.get(DataSynthesisFileInstance, file_id) if not file_task: - raise HTTPException(status_code=404, detail="Synthesis file task not found") + raise BusinessError(ErrorCodes.GENERATION_FILE_NOT_FOUND, data={"file_id": file_id}) # 删除 SynthesisData(根据文件任务ID) await db.execute( @@ -310,11 +297,7 @@ async def delete_synthesis_file_task( await db.commit() await db.refresh(task) - return StandardResponse( - code=200, - message="success", - data=None, - ) + return SuccessResponse(data=None) @router.get("/prompt", response_model=StandardResponse[str]) @@ -322,11 +305,7 @@ async def get_prompt_by_type( synth_type: SynthesisType, ): prompt = get_prompt(synth_type) - return StandardResponse( - code=200, - message="Success", - data=prompt, - ) + return SuccessResponse(data=prompt) @router.get("/task/{task_id}/files", response_model=StandardResponse[PagedDataSynthesisFileTaskResponse]) @@ -340,7 +319,7 @@ async def list_synthesis_file_tasks( # 先校验任务是否存在 task = await db.get(DataSynthInstance, task_id) if not task: - raise HTTPException(status_code=404, detail="Synthesis task not found") + raise BusinessError(ErrorCodes.GENERATION_TASK_NOT_FOUND, data={"task_id": task_id}) base_query = select(DataSynthesisFileInstance).where( DataSynthesisFileInstance.synthesis_instance_id == task_id @@ -384,11 +363,7 @@ async def list_synthesis_file_tasks( size=page_size, ) - return StandardResponse( - code=200, - message="Success", - data=paged, - ) + return SuccessResponse(data=paged) @router.get("/file/{file_id}/chunks", response_model=StandardResponse[PagedDataSynthesisChunkResponse]) @@ -402,7 +377,7 @@ async def list_chunks_by_file( # 校验文件任务是否存在 file_task = await db.get(DataSynthesisFileInstance, file_id) if not file_task: - raise HTTPException(status_code=404, detail="Synthesis file task not found") + raise BusinessError(ErrorCodes.GENERATION_FILE_NOT_FOUND, data={"file_id": file_id}) base_query = select(DataSynthesisChunkInstance).where( DataSynthesisChunkInstance.synthesis_file_instance_id == file_id @@ -442,11 +417,7 @@ async def list_chunks_by_file( size=page_size, ) - return StandardResponse( - code=200, - message="Success", - data=paged, - ) + return SuccessResponse(data=paged) @router.get("/chunk/{chunk_id}/data", response_model=StandardResponse[list[SynthesisDataItem]]) @@ -458,7 +429,7 @@ async def list_synthesis_data_by_chunk( # 可选:校验 chunk 是否存在 chunk = await db.get(DataSynthesisChunkInstance, chunk_id) if not chunk: - raise HTTPException(status_code=404, detail="Chunk not found") + raise BusinessError(ErrorCodes.GENERATION_CHUNK_NOT_FOUND, data={"chunk_id": chunk_id}) result = await db.execute( select(SynthesisData).where(SynthesisData.chunk_instance_id == chunk_id) @@ -475,11 +446,7 @@ async def list_synthesis_data_by_chunk( for row in rows ] - return StandardResponse( - code=200, - message="Success", - data=items, - ) + return SuccessResponse(data=items) @router.post("/task/{task_id}/export-dataset/{dataset_id}", response_model=StandardResponse[str]) @@ -505,13 +472,9 @@ async def export_synthesis_task_to_dataset( dataset_id, e, ) - raise HTTPException(status_code=400, detail=str(e)) + raise BusinessError(ErrorCodes.OPERATION_FAILED, data={"error": str(e)}) - return StandardResponse( - code=200, - message="success", - data=dataset.id, - ) + return SuccessResponse(data=dataset.id) @router.delete("/chunk/{chunk_id}", response_model=StandardResponse) @@ -522,7 +485,7 @@ async def delete_chunk_with_data( """删除单条 t_data_synthesis_chunk_instances 记录及其关联的所有 t_data_synthesis_data""" chunk = await db.get(DataSynthesisChunkInstance, chunk_id) if not chunk: - raise HTTPException(status_code=404, detail="Chunk not found") + raise BusinessError(ErrorCodes.GENERATION_CHUNK_NOT_FOUND, data={"chunk_id": chunk_id}) # 先删除与该 chunk 关联的合成数据 await db.execute( @@ -538,7 +501,7 @@ async def delete_chunk_with_data( await db.commit() - return StandardResponse(code=200, message="success", data=None) + return SuccessResponse(data=None) @router.delete("/chunk/{chunk_id}/data", response_model=StandardResponse) @@ -549,7 +512,7 @@ async def delete_synthesis_data_by_chunk( """仅删除指定 chunk 下的全部 t_data_synthesis_data 记录,返回删除条数""" chunk = await db.get(DataSynthesisChunkInstance, chunk_id) if not chunk: - raise HTTPException(status_code=404, detail="Chunk not found") + raise BusinessError(ErrorCodes.GENERATION_CHUNK_NOT_FOUND, data={"chunk_id": chunk_id}) result = await db.execute( delete(SynthesisData).where(SynthesisData.chunk_instance_id == chunk_id) @@ -558,7 +521,7 @@ async def delete_synthesis_data_by_chunk( await db.commit() - return StandardResponse(code=200, message="success", data=deleted) + return SuccessResponse(data=deleted) @router.delete("/data/batch", response_model=StandardResponse) @@ -568,7 +531,7 @@ async def batch_delete_synthesis_data( ): """批量删除 t_data_synthesis_data 记录""" if not request.ids: - return StandardResponse(code=200, message="success", data=0) + return SuccessResponse(data=0) result = await db.execute( delete(SynthesisData).where(SynthesisData.id.in_(request.ids)) @@ -576,7 +539,7 @@ async def batch_delete_synthesis_data( deleted = int(getattr(result, "rowcount", 0) or 0) await db.commit() - return StandardResponse(code=200, message="success", data=deleted) + return SuccessResponse(data=deleted) @router.patch("/data/{data_id}", response_model=StandardResponse) @@ -591,7 +554,7 @@ async def update_synthesis_data_field( """ record = await db.get(SynthesisData, data_id) if not record: - raise HTTPException(status_code=404, detail="Synthesis data not found") + raise BusinessError(ErrorCodes.GENERATION_DATA_NOT_FOUND, data={"data_id": data_id}) # 直接整体覆盖 data 字段 record.data = body.data @@ -600,7 +563,7 @@ async def update_synthesis_data_field( await db.refresh(record) return StandardResponse( - code=200, + code="0", message="success", data=SynthesisDataItem( id=record.id, diff --git a/runtime/datamate-python/app/module/rag/interface/rag_interface.py b/runtime/datamate-python/app/module/rag/interface/rag_interface.py index b66af6463..40265bf76 100644 --- a/runtime/datamate-python/app/module/rag/interface/rag_interface.py +++ b/runtime/datamate-python/app/module/rag/interface/rag_interface.py @@ -1,6 +1,6 @@ -from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy.ext.asyncio import AsyncSession +from fastapi import APIRouter, Depends +from app.core.exception import ErrorCodes, BusinessError, SuccessResponse from app.db.session import get_db from app.module.rag.service.rag_service import RAGService from app.module.shared.schema import StandardResponse @@ -11,27 +11,18 @@ @router.post("/process/{knowledge_base_id}") async def process_knowledge_base(knowledge_base_id: str, rag_service: RAGService = Depends()): """ - Process all unprocessed files in a knowledge base. + 处理知识库中所有未处理的文件 """ - try: - await rag_service.init_graph_rag(knowledge_base_id) - return StandardResponse( - code=200, - message="Processing started for knowledge base.", - data=None - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + await rag_service.init_graph_rag(knowledge_base_id) + return SuccessResponse( + data=None, + message="Processing started for knowledge base." + ) @router.post("/query") async def query_knowledge_graph(payload: QueryRequest, rag_service: RAGService = Depends()): """ - Query the knowledge graph with the given query text and knowledge base ID. + 使用给定的查询文本和知识库 ID 查询知识图谱 """ - try: - result = await rag_service.query_rag(payload.query, payload.knowledge_base_id) - return StandardResponse(code=200, message="success", data=result) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + result = await rag_service.query_rag(payload.query, payload.knowledge_base_id) + return SuccessResponse(data=result) diff --git a/runtime/datamate-python/app/module/ratio/interface/ratio_task.py b/runtime/datamate-python/app/module/ratio/interface/ratio_task.py index 8a3e13f00..b020804e0 100644 --- a/runtime/datamate-python/app/module/ratio/interface/ratio_task.py +++ b/runtime/datamate-python/app/module/ratio/interface/ratio_task.py @@ -3,11 +3,12 @@ from typing import Set from datetime import datetime -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from pydantic import BaseModel, Field, field_validator from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import or_, func, delete, select +from app.core.exception import ErrorCodes, BusinessError, SuccessResponse, transaction from app.core.logging import get_logger from app.db.models import Dataset from app.db.session import get_db @@ -42,45 +43,34 @@ async def create_ratio_task( Path: /api/synthesis/ratio-task """ - try: - # 校验 config 中的 dataset_id 是否存在 - dm_service = DatasetManagementService(db) - source_types = await get_dataset_types(dm_service, req) + # 校验 config 中的 dataset_id 是否存在 + dm_service = DatasetManagementService(db) + source_types = await get_dataset_types(dm_service, req) - await valid_exists(db, req) + await valid_exists(db, req) + async with transaction(db): target_dataset = await create_target_dataset(db, req, source_types) - instance = await create_ratio_instance(db, req, target_dataset) - asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id)) - - response_data = CreateRatioTaskResponse( - id=instance.id, - name=instance.name, - description=instance.description, - totals=instance.totals or 0, - status=instance.status or TaskStatus.PENDING.name, - config=req.config, - targetDataset=TargetDatasetInfo( - id=str(target_dataset.id), - name=str(target_dataset.name), - datasetType=str(target_dataset.dataset_type), - status=str(target_dataset.status), - ) + # 事务已提交,启动后台任务 + asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id)) + + response_data = CreateRatioTaskResponse( + id=instance.id, + name=instance.name, + description=instance.description, + totals=instance.totals or 0, + status=instance.status or TaskStatus.PENDING.name, + config=req.config, + targetDataset=TargetDatasetInfo( + id=str(target_dataset.id), + name=str(target_dataset.name), + datasetType=str(target_dataset.dataset_type), + status=str(target_dataset.status), ) - return StandardResponse( - code=200, - message="success", - data=response_data ) - except HTTPException: - await db.rollback() - raise - except Exception as e: - await db.rollback() - logger.error(f"Failed to create ratio task: {e}") - raise HTTPException(status_code=500, detail="Internal server error") + return SuccessResponse(data=response_data) async def create_ratio_instance(db, req: CreateRatioTaskRequest, target_dataset: Dataset) -> RatioInstance: @@ -104,7 +94,12 @@ async def create_ratio_instance(db, req: CreateRatioTaskRequest, target_dataset: async def create_target_dataset(db, req: CreateRatioTaskRequest, source_types: set[str]) -> Dataset: - # 创建目标数据集:名称使用“<任务名称>-时间戳” + """ + 创建目标数据集 + + 注意:此函数必须在 transaction 上下文中调用 + """ + # 创建目标数据集:名称使用"<任务名称>-时间戳" target_dataset_name = f"{req.name}-{datetime.now().strftime('%Y%m%d%H%M%S')}" target_type = get_target_dataset_type(source_types) @@ -119,7 +114,7 @@ async def create_target_dataset(db, req: CreateRatioTaskRequest, source_types: s path=f"/dataset/{target_dataset_id}", ) db.add(target_dataset) - await db.flush() # 获取 target_dataset.id + # 不需要 flush,事务会在结束时自动提交 return target_dataset @@ -132,66 +127,60 @@ async def list_ratio_tasks( db: AsyncSession = Depends(get_db), ): """分页查询配比任务,支持名称与状态过滤""" - try: - query = select(RatioInstance) - # filters - if name: - # simple contains filter - query = query.where(RatioInstance.name.like(f"%{name}%")) - if status: - query = query.where(RatioInstance.status == status) - - # count - count_q = select(func.count()).select_from(query.subquery()) - total = (await db.execute(count_q)).scalar_one() - - # page (1-based) - page_index = max(page, 1) - 1 - query = query.order_by(RatioInstance.created_at.desc()).offset(page_index * size).limit(size) - result = await db.execute(query) - items = result.scalars().all() - - # map to DTOs and attach dataset name - # preload datasets - ds_ids = {i.target_dataset_id for i in items if i.target_dataset_id} - ds_map = {} - if ds_ids: - ds_res = await db.execute(select(Dataset).where(Dataset.id.in_(list(ds_ids)))) - for d in ds_res.scalars().all(): - ds_map[d.id] = d - - content: list[RatioTaskItem] = [] - for i in items: - ds = ds_map.get(i.target_dataset_id) if i.target_dataset_id else None - content.append( - RatioTaskItem( - id=i.id, - name=i.name or "", - description=i.description, - status=i.status, - totals=i.totals, - target_dataset_id=i.target_dataset_id, - target_dataset_name=(ds.name if ds else None), - created_at=str(i.created_at) if getattr(i, "created_at", None) else None, - updated_at=str(i.updated_at) if getattr(i, "updated_at", None) else None, - ) + query = select(RatioInstance) + # filters + if name: + # simple contains filter + query = query.where(RatioInstance.name.like(f"%{name}%")) + if status: + query = query.where(RatioInstance.status == status) + + # count + count_q = select(func.count()).select_from(query.subquery()) + total = (await db.execute(count_q)).scalar_one() + + # page (1-based) + page_index = max(page, 1) - 1 + query = query.order_by(RatioInstance.created_at.desc()).offset(page_index * size).limit(size) + result = await db.execute(query) + items = result.scalars().all() + + # map to DTOs and attach dataset name + # preload datasets + ds_ids = {i.target_dataset_id for i in items if i.target_dataset_id} + ds_map = {} + if ds_ids: + ds_res = await db.execute(select(Dataset).where(Dataset.id.in_(list(ds_ids)))) + for d in ds_res.scalars().all(): + ds_map[d.id] = d + + content: list[RatioTaskItem] = [] + for i in items: + ds = ds_map.get(i.target_dataset_id) if i.target_dataset_id else None + content.append( + RatioTaskItem( + id=i.id, + name=i.name or "", + description=i.description, + status=i.status, + totals=i.totals, + target_dataset_id=i.target_dataset_id, + target_dataset_name=(ds.name if ds else None), + created_at=str(i.created_at) if getattr(i, "created_at", None) else None, + updated_at=str(i.updated_at) if getattr(i, "updated_at", None) else None, ) - - total_pages = (total + size - 1) // size if size > 0 else 0 - return StandardResponse( - code=200, - message="success", - data=PagedRatioTaskResponse( - content=content, - totalElements=total, - totalPages=total_pages, - page=page, - size=size, - ), ) - except Exception as e: - logger.error(f"Failed to list ratio tasks: {e}") - raise HTTPException(status_code=500, detail="Internal server error") + + total_pages = (total + size - 1) // size if size > 0 else 0 + return SuccessResponse( + data=PagedRatioTaskResponse( + content=content, + totalElements=total, + totalPages=total_pages, + page=page, + size=size, + ), + ) @router.delete("", response_model=StandardResponse[str], status_code=200) @@ -200,10 +189,10 @@ async def delete_ratio_tasks( db: AsyncSession = Depends(get_db), ): """删除配比任务,返回简单结果字符串。""" - try: - if not ids: - raise HTTPException(status_code=400, detail="ids is required") + if not ids: + raise BusinessError(ErrorCodes.BAD_REQUEST, data={"detail": "ids is required"}) + async with transaction(db): # 先删除关联关系 await db.execute( delete(RatioRelation).where(RatioRelation.ratio_instance_id.in_(ids)) @@ -212,30 +201,22 @@ async def delete_ratio_tasks( await db.execute( delete(RatioInstance).where(RatioInstance.id.in_(ids)) ) - await db.commit() - return StandardResponse(code=200, message="success", data="success") - except HTTPException: - await db.rollback() - raise - except Exception as e: - await db.rollback() - logger.error(f"Failed to delete ratio tasks: {e}") - raise HTTPException(status_code=500, detail=f"Fail to delete ratio task: {e}") + return SuccessResponse(data="success") async def valid_exists(db: AsyncSession, req: CreateRatioTaskRequest) -> None: """校验配比任务名称不能重复(精确匹配,去除首尾空格)。""" name = (req.name or "").strip() if not name: - raise HTTPException(status_code=400, detail="ratio task name is required") + raise BusinessError(ErrorCodes.RATIO_NAME_REQUIRED) # 查询是否已存在同名任务 ratio_task = await db.execute(select(RatioInstance.id).where(RatioInstance.name == name)) exists = ratio_task.scalar_one_or_none() if exists is not None: logger.error(f"create ratio task failed: ratio task '{name}' already exists (id={exists})") - raise HTTPException(status_code=400, detail=f"ratio task '{name}' already exists") + raise BusinessError(ErrorCodes.RATIO_ALREADY_EXISTS, data={"name": name}) async def get_dataset_types(dm_service: DatasetManagementService, req: CreateRatioTaskRequest) -> Set[str]: @@ -243,7 +224,7 @@ async def get_dataset_types(dm_service: DatasetManagementService, req: CreateRat for item in req.config: dataset = await dm_service.get_dataset(item.dataset_id) if not dataset: - raise HTTPException(status_code=400, detail=f"dataset_id not found: {item.dataset_id}") + raise BusinessError(ErrorCodes.NOT_FOUND, data={"dataset_id": item.dataset_id}) else: dtype = getattr(dataset, "dataset_type", None) or getattr(dataset, "datasetType", None) source_types.add(str(dtype).upper()) @@ -278,65 +259,57 @@ async def get_ratio_task( Path: /api/synthesis/ratio-task/{task_id} """ - try: - # 查询任务实例 - instance_res = await db.execute( - select(RatioInstance).where(RatioInstance.id == task_id) - ) - instance = instance_res.scalar_one_or_none() - if not instance: - raise HTTPException(status_code=404, detail="Ratio task not found") - - # 查询关联的配比关系 - relations_res = await db.execute( - select(RatioRelationModel).where(RatioRelationModel.ratio_instance_id == task_id) - ) - relations = list(relations_res.scalars().all()) + # 查询任务实例 + instance_res = await db.execute( + select(RatioInstance).where(RatioInstance.id == task_id) + ) + instance = instance_res.scalar_one_or_none() + if not instance: + raise BusinessError(ErrorCodes.RATIO_TASK_NOT_FOUND, data={"task_id": task_id}) - # 查询目标数据集 - target_ds = None - if instance.target_dataset_id: - ds_res = await db.execute( - select(Dataset).where(Dataset.id == instance.target_dataset_id) - ) - target_ds = ds_res.scalar_one_or_none() + # 查询关联的配比关系 + relations_res = await db.execute( + select(RatioRelationModel).where(RatioRelationModel.ratio_instance_id == task_id) + ) + relations = list(relations_res.scalars().all()) - # 构建响应 - config = [ - { - "dataset_id": rel.source_dataset_id, - "counts": str(rel.counts) if rel.counts is not None else "0", - "filter_conditions": rel.filter_conditions or "", - } - for rel in relations - ] - - target_dataset_info = { - "id": str(target_ds.id) if target_ds else None, - "name": target_ds.name if target_ds else None, - "type": target_ds.dataset_type if target_ds else None, - "status": target_ds.status if target_ds else None, - "file_count": target_ds.file_count if target_ds else 0, - "size_bytes": target_ds.size_bytes if target_ds else 0, + # 查询目标数据集 + target_ds = None + if instance.target_dataset_id: + ds_res = await db.execute( + select(Dataset).where(Dataset.id == instance.target_dataset_id) + ) + target_ds = ds_res.scalar_one_or_none() + + # 构建响应 + config = [ + { + "dataset_id": rel.source_dataset_id, + "counts": str(rel.counts) if rel.counts is not None else "0", + "filter_conditions": rel.filter_conditions or "", } - - return StandardResponse( - code=200, - message="success", - data=RatioTaskDetailResponse( - id=instance.id, - name=instance.name or "", - description=instance.description, - status=instance.status or "UNKNOWN", - totals=instance.totals or 0, - config=config, - target_dataset=target_dataset_info, - created_at=instance.created_at, - updated_at=instance.updated_at, - ) + for rel in relations + ] + + target_dataset_info = { + "id": str(target_ds.id) if target_ds else None, + "name": target_ds.name if target_ds else None, + "type": target_ds.dataset_type if target_ds else None, + "status": target_ds.status if target_ds else None, + "file_count": target_ds.file_count if target_ds else 0, + "size_bytes": target_ds.size_bytes if target_ds else 0, + } + + return SuccessResponse( + data=RatioTaskDetailResponse( + id=instance.id, + name=instance.name or "", + description=instance.description, + status=instance.status or "UNKNOWN", + totals=instance.totals or 0, + config=config, + target_dataset=target_dataset_info, + created_at=instance.created_at, + updated_at=instance.updated_at, ) - except HTTPException: - raise - except Exception as e: - logger.error(f"Failed to get ratio task {task_id}: {e}") - raise HTTPException(status_code=500, detail="Internal server error") + ) diff --git a/runtime/datamate-python/app/module/shared/schema/common.py b/runtime/datamate-python/app/module/shared/schema/common.py index beaf1c7ec..7d2bf7e14 100644 --- a/runtime/datamate-python/app/module/shared/schema/common.py +++ b/runtime/datamate-python/app/module/shared/schema/common.py @@ -26,16 +26,32 @@ class StandardResponse(BaseResponseModel, Generic[T]): """ 标准API响应格式 - 所有API端点应返回此格式,确保响应的一致性 + 所有API端点(包括错误响应)都应返回此格式 """ - code: int = Field(..., description="HTTP状态码") + code: str = Field(..., description="业务状态码(字符串)。成功使用 '0',错误使用 '{module}.{sequence}' 格式(如 'rag.001')") message: str = Field(..., description="响应消息") - data: T = Field(..., description="响应数据") + data: T = Field(default=None, description="响应数据") class Config: populate_by_name = True alias_generator = to_camel + +class ResponseCode(str, Enum): + """通用响应码枚举""" + + # 成功响应 + SUCCESS = "0" # 操作成功 + + # 通用错误 + BAD_REQUEST = "common.400" # 错误的请求 + UNAUTHORIZED = "common.401" # 未授权 + FORBIDDEN = "common.403" # 禁止访问 + NOT_FOUND = "common.404" # 资源未找到 + VALIDATION_ERROR = "common.422" # 验证错误 + INTERNAL_ERROR = "common.500" # 服务器内部错误 + SERVICE_UNAVAILABLE = "common.503" # 服务不可用 + class PaginatedData(BaseResponseModel, Generic[T]): """分页数据容器""" page: int = Field(..., description="当前页码(从1开始)") diff --git a/runtime/datamate-python/app/module/system/interface/about.py b/runtime/datamate-python/app/module/system/interface/about.py index 1ad157fe9..69784af53 100644 --- a/runtime/datamate-python/app/module/system/interface/about.py +++ b/runtime/datamate-python/app/module/system/interface/about.py @@ -12,7 +12,7 @@ async def health_check(): """健康检查端点""" return StandardResponse( - code=200, + code="0", message="success", data=HealthResponse( status="healthy", diff --git a/runtime/datamate-python/app/module/system/interface/models.py b/runtime/datamate-python/app/module/system/interface/models.py index 7979eb3c5..0b85cf341 100644 --- a/runtime/datamate-python/app/module/system/interface/models.py +++ b/runtime/datamate-python/app/module/system/interface/models.py @@ -1,5 +1,6 @@ from fastapi import APIRouter, Depends, Query +from app.core.exception import SuccessResponse from app.module.shared.schema import StandardResponse, PaginatedData from app.module.system.schema.models import ( CreateModelRequest, @@ -17,7 +18,7 @@ async def get_providers(svc: ModelsService = Depends()): """获取厂商列表,与 Java GET /models/providers 一致。""" data = await svc.get_providers() - return StandardResponse(code=200, message="success", data=data) + return SuccessResponse(data=data) @router.get("/list", response_model=StandardResponse[PaginatedData[ModelsResponse]]) @@ -40,21 +41,21 @@ async def get_models( isDefault=isDefault, ) data = await svc.get_models(q) - return StandardResponse(code=200, message="success", data=data) + return SuccessResponse(data=data) @router.post("/create", response_model=StandardResponse[ModelsResponse]) async def create_model(req: CreateModelRequest, svc: ModelsService = Depends()): """创建模型配置,与 Java POST /models/create 一致。""" data = await svc.create_model(req) - return StandardResponse(code=200, message="success", data=data) + return SuccessResponse(data=data) @router.get("/{model_id}", response_model=StandardResponse[ModelsResponse]) async def get_model_detail(model_id: str, svc: ModelsService = Depends()): """获取模型详情,与 Java GET /models/{modelId} 一致。""" data = await svc.get_model_detail(model_id) - return StandardResponse(code=200, message="success", data=data) + return SuccessResponse(data=data) @router.put("/{model_id}", response_model=StandardResponse[ModelsResponse]) @@ -65,11 +66,11 @@ async def update_model( ): """更新模型配置,与 Java PUT /models/{modelId} 一致。""" data = await svc.update_model(model_id, req) - return StandardResponse(code=200, message="success", data=data) + return SuccessResponse(data=data) @router.delete("/{model_id}", response_model=StandardResponse[None]) async def delete_model(model_id: str, svc: ModelsService = Depends()): """删除模型配置,与 Java DELETE /models/{modelId} 一致。""" await svc.delete_model(model_id) - return StandardResponse(code=200, message="success", data=None) + return SuccessResponse(data=None)