diff --git a/backend/app/api/tools.py b/backend/app/api/tools.py index 3250a25b..4aa2580c 100644 --- a/backend/app/api/tools.py +++ b/backend/app/api/tools.py @@ -252,6 +252,29 @@ async def create_tool( return {"id": str(tool.id), "name": tool.name} +class BulkToolUpdateItem(BaseModel): + tool_id: str + enabled: bool + +@router.put("/bulk") +async def update_tools_bulk( + updates: list[BulkToolUpdateItem], + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """Bulk update the enabled status of multiple tools.""" + tool_ids = [uuid.UUID(u.tool_id) for u in updates] + result = await db.execute(select(Tool).where(Tool.id.in_(tool_ids))) + tools_map = {str(t.id): t for t in result.scalars().all()} + + for update in updates: + if update.tool_id in tools_map: + tools_map[update.tool_id].enabled = update.enabled + + await db.commit() + return {"ok": True} + + @router.put("/{tool_id}") async def update_tool( tool_id: uuid.UUID, @@ -276,29 +299,6 @@ async def update_tool( return {"ok": True} -class BulkToolUpdateItem(BaseModel): - tool_id: str - enabled: bool - -@router.put("/bulk") -async def update_tools_bulk( - updates: list[BulkToolUpdateItem], - current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db), -): - """Bulk update the enabled status of multiple tools.""" - tool_ids = [uuid.UUID(u.tool_id) for u in updates] - result = await db.execute(select(Tool).where(Tool.id.in_(tool_ids))) - tools_map = {str(t.id): t for t in result.scalars().all()} - - for update in updates: - if update.tool_id in tools_map: - tools_map[update.tool_id].enabled = update.enabled - - await db.commit() - return {"ok": True} - - @router.delete("/{tool_id}") async def delete_tool( tool_id: uuid.UUID, diff --git a/backend/tests/test_tools_bulk_api.py b/backend/tests/test_tools_bulk_api.py new file mode 100644 index 00000000..19ee6e4d --- /dev/null +++ b/backend/tests/test_tools_bulk_api.py @@ -0,0 +1,70 @@ +import uuid +from types import SimpleNamespace + +import httpx +import pytest + +from app.api import tools as tools_api +from app.core.security import get_current_user +from app.main import app + + +class FakeToolsResult: + def __init__(self, values): + self._values = list(values) + + def scalars(self): + return self + + def all(self): + return list(self._values) + + +class FakeDB: + def __init__(self, tools): + self._tools = tools + self.committed = False + + async def execute(self, _statement): + return FakeToolsResult(self._tools) + + async def commit(self): + self.committed = True + + +@pytest.mark.asyncio +async def test_put_tools_bulk_hits_bulk_route_and_updates_tools(): + tool_a = SimpleNamespace(id=uuid.uuid4(), enabled=False) + tool_b = SimpleNamespace(id=uuid.uuid4(), enabled=True) + db = FakeDB([tool_a, tool_b]) + + async def override_db(): + yield db + + user = SimpleNamespace( + id=uuid.uuid4(), + role="platform_admin", + tenant_id=uuid.uuid4(), + is_active=True, + ) + + app.dependency_overrides[get_current_user] = lambda: user + app.dependency_overrides[tools_api.get_db] = override_db + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as ac: + response = await ac.put( + "/api/tools/bulk", + json=[ + {"tool_id": str(tool_a.id), "enabled": True}, + {"tool_id": str(tool_b.id), "enabled": False}, + ], + ) + + app.dependency_overrides.clear() + + assert response.status_code == 200 + assert response.json() == {"ok": True} + assert tool_a.enabled is True + assert tool_b.enabled is False + assert db.committed is True