Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions backend/app/api/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,29 @@ async def create_tool(
return {"id": str(tool.id), "name": tool.name}


class BulkToolUpdateItem(BaseModel):
tool_id: str
enabled: bool

@router.put("/bulk")
async def update_tools_bulk(
updates: list[BulkToolUpdateItem],
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Bulk update the enabled status of multiple tools."""
tool_ids = [uuid.UUID(u.tool_id) for u in updates]
result = await db.execute(select(Tool).where(Tool.id.in_(tool_ids)))
tools_map = {str(t.id): t for t in result.scalars().all()}

for update in updates:
if update.tool_id in tools_map:
tools_map[update.tool_id].enabled = update.enabled

await db.commit()
return {"ok": True}


@router.put("/{tool_id}")
async def update_tool(
tool_id: uuid.UUID,
Expand All @@ -276,29 +299,6 @@ async def update_tool(
return {"ok": True}


class BulkToolUpdateItem(BaseModel):
tool_id: str
enabled: bool

@router.put("/bulk")
async def update_tools_bulk(
updates: list[BulkToolUpdateItem],
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Bulk update the enabled status of multiple tools."""
tool_ids = [uuid.UUID(u.tool_id) for u in updates]
result = await db.execute(select(Tool).where(Tool.id.in_(tool_ids)))
tools_map = {str(t.id): t for t in result.scalars().all()}

for update in updates:
if update.tool_id in tools_map:
tools_map[update.tool_id].enabled = update.enabled

await db.commit()
return {"ok": True}


@router.delete("/{tool_id}")
async def delete_tool(
tool_id: uuid.UUID,
Expand Down
70 changes: 70 additions & 0 deletions backend/tests/test_tools_bulk_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import uuid
from types import SimpleNamespace

import httpx
import pytest

from app.api import tools as tools_api
from app.core.security import get_current_user
from app.main import app


class FakeToolsResult:
def __init__(self, values):
self._values = list(values)

def scalars(self):
return self

def all(self):
return list(self._values)


class FakeDB:
def __init__(self, tools):
self._tools = tools
self.committed = False

async def execute(self, _statement):
return FakeToolsResult(self._tools)

async def commit(self):
self.committed = True


@pytest.mark.asyncio
async def test_put_tools_bulk_hits_bulk_route_and_updates_tools():
tool_a = SimpleNamespace(id=uuid.uuid4(), enabled=False)
tool_b = SimpleNamespace(id=uuid.uuid4(), enabled=True)
db = FakeDB([tool_a, tool_b])

async def override_db():
yield db

user = SimpleNamespace(
id=uuid.uuid4(),
role="platform_admin",
tenant_id=uuid.uuid4(),
is_active=True,
)

app.dependency_overrides[get_current_user] = lambda: user
app.dependency_overrides[tools_api.get_db] = override_db

transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as ac:
response = await ac.put(
"/api/tools/bulk",
json=[
{"tool_id": str(tool_a.id), "enabled": True},
{"tool_id": str(tool_b.id), "enabled": False},
],
)

app.dependency_overrides.clear()

assert response.status_code == 200
assert response.json() == {"ok": True}
assert tool_a.enabled is True
assert tool_b.enabled is False
assert db.committed is True