Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 55 additions & 14 deletions litellm/proxy/management_endpoints/internal_user_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,37 +101,77 @@ def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> d
return data_json


async def _check_duplicate_user_email(
user_email: Optional[str], prisma_client: Any
async def _check_duplicate_user_field(
field_name: str,
field_value: Optional[str],
prisma_client: Any,
*,
case_insensitive: bool = False,
label: Optional[str] = None,
) -> None:
"""
Helper function to check if a user email already exists in the database.
Helper function to check if a field already exists in the user table.

Args:
user_email (Optional[str]): Email to check
prisma_client (Any): Database client instance
field_name (str): Database field name to check.
field_value (Optional[str]): Value to check for duplicates.
prisma_client (Any): Database client instance.
case_insensitive (bool): Whether to use case-insensitive comparison.
label (Optional[str]): Human readable label for error messages.

Raises:
Exception: If database is not connected
HTTPException: If user with email already exists
Exception: If database is not connected.
HTTPException: If a user with the given field value already exists.
"""
if user_email:
if field_value:
if prisma_client is None:
raise Exception("Database not connected")

value = field_value.strip()
where_clause = {field_name: {"equals": value}}
if case_insensitive:
where_clause[field_name]["mode"] = "insensitive"

existing_user = await prisma_client.db.litellm_usertable.find_first(
where={"user_email": {"equals": user_email.strip(), "mode": "insensitive"}}
where=where_clause
)

if existing_user is not None:
existing_value = getattr(existing_user, field_name, value)
error_label = label or field_name
raise HTTPException(
status_code=400,
detail={
"error": f"User with email {existing_user.user_email} already exists"
},
status_code=409,
detail={"error": f"User with {error_label} {existing_value} already exists"},
)


async def _check_duplicate_user_email(
user_email: Optional[str], prisma_client: Any
) -> None:
"""
Helper function to check if a user email already exists in the database.
"""
await _check_duplicate_user_field(
field_name="user_email",
field_value=user_email,
prisma_client=prisma_client,
case_insensitive=True,
label="email",
)


async def _check_duplicate_user_id(user_id: Optional[str], prisma_client: Any) -> None:
"""
Helper function to check if a user id already exists in the database.
"""
await _check_duplicate_user_field(
field_name="user_id",
field_value=user_id,
prisma_client=prisma_client,
label="id",
)


async def _add_user_to_organizations(
user_id: str,
organizations: List[str],
Expand Down Expand Up @@ -361,7 +401,8 @@ async def new_user(
status_code=500,
detail=CommonProxyErrors.db_not_connected_error.value,
)
# Check for duplicate email
# Check for duplicate user_id or email
await _check_duplicate_user_id(data.user_id, prisma_client)
await _check_duplicate_user_email(data.user_email, prisma_client)

# Check if license is over limit
Expand Down
2 changes: 1 addition & 1 deletion tests/proxy_admin_ui_tests/test_key_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ async def test_get_users(prisma_client):
# Create some test users
test_users = [
NewUserRequest(
user_id=f"test_user_{i}",
user_id=f"test_user_{i}_{uuid.uuid4()}",
user_role=(
LitellmUserRoles.INTERNAL_USER.value
if i % 2 == 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,21 @@ async def mock_count(*args, **kwargs):

mock_prisma_client.db.litellm_usertable.count = mock_count

# Mock check_duplicate_user_email to pass
# Mock duplicate checks to pass
async def mock_check_duplicate_user_email(*args, **kwargs):
return None # No duplicate found

async def mock_check_duplicate_user_id(*args, **kwargs):
return None # No duplicate found

mocker.patch(
"litellm.proxy.management_endpoints.internal_user_endpoints._check_duplicate_user_email",
mock_check_duplicate_user_email,
)
mocker.patch(
"litellm.proxy.management_endpoints.internal_user_endpoints._check_duplicate_user_id",
mock_check_duplicate_user_id,
)

# Mock the license check to return True (over limit)
mock_license_check = mocker.MagicMock()
Expand Down Expand Up @@ -449,14 +456,21 @@ async def mock_count(*args, **kwargs):

mock_prisma_client.db.litellm_usertable.count = mock_count

# Mock check_duplicate_user_email to pass
# Mock duplicate checks to pass
async def mock_check_duplicate_user_email(*args, **kwargs):
return None # No duplicate found

async def mock_check_duplicate_user_id(*args, **kwargs):
return None # No duplicate found

mocker.patch(
"litellm.proxy.management_endpoints.internal_user_endpoints._check_duplicate_user_email",
mock_check_duplicate_user_email,
)
mocker.patch(
"litellm.proxy.management_endpoints.internal_user_endpoints._check_duplicate_user_id",
mock_check_duplicate_user_id,
)

# Mock the license check to return False (under limit)
mock_license_check = mocker.MagicMock()
Expand Down Expand Up @@ -737,7 +751,7 @@ async def mock_find_first_duplicate(*args, **kwargs):
with pytest.raises(HTTPException) as exc_info:
await _check_duplicate_user_email("user@example.com", mock_prisma_client)

assert exc_info.value.status_code == 400
assert exc_info.value.status_code == 409
assert "User with email User@Example.com already exists" in str(
exc_info.value.detail
)
Expand Down Expand Up @@ -770,6 +784,56 @@ async def mock_find_first_no_duplicate(*args, **kwargs):
) # Should not raise exception


@pytest.mark.asyncio
async def test_check_duplicate_user_id(mocker):
"""
Test that _check_duplicate_user_id detects duplicates and does not use case insensitive matching.
"""
from fastapi import HTTPException

from litellm.proxy.management_endpoints.internal_user_endpoints import (
_check_duplicate_user_id,
)

mock_prisma_client = mocker.MagicMock()

# Duplicate user_id should raise
mock_existing_user = mocker.MagicMock()
mock_existing_user.user_id = "existing-user-id"

async def mock_find_first_duplicate(*args, **kwargs):
where_clause = kwargs.get("where", {})
user_id_clause = where_clause.get("user_id", {})
assert user_id_clause.get("equals") == "existing-user-id"
assert "mode" not in user_id_clause
return mock_existing_user

mock_prisma_client.db.litellm_usertable.find_first = mock_find_first_duplicate

with pytest.raises(HTTPException) as exc_info:
await _check_duplicate_user_id("existing-user-id", mock_prisma_client)

assert exc_info.value.status_code == 409
assert "User with id existing-user-id already exists" in str(
exc_info.value.detail
)

# No duplicate should pass
async def mock_find_first_no_duplicate(*args, **kwargs):
where_clause = kwargs.get("where", {})
user_id_clause = where_clause.get("user_id", {})
assert user_id_clause.get("equals") == "new-user-id"
assert "mode" not in user_id_clause
return None

mock_prisma_client.db.litellm_usertable.find_first = mock_find_first_no_duplicate

await _check_duplicate_user_id("new-user-id", mock_prisma_client)

# None user_id should no-op
await _check_duplicate_user_id(None, mock_prisma_client)


def test_process_keys_for_user_info_filters_dashboard_keys(monkeypatch):
"""
Test that _process_keys_for_user_info filters out keys with team_id='litellm-dashboard'
Expand Down
10 changes: 6 additions & 4 deletions tests/test_team.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,26 +683,28 @@ async def test_team_alias():
@pytest.mark.asyncio
async def test_users_in_team_budget():
"""
- Create Team
- Create User
- Create Team with User
- Add User to team with budget = 0.0000001
- Make Call 1 -> pass
- Make Call 2 -> fail
"""
get_user = f"krrish_{time.time()}@berri.ai"
async with aiohttp.ClientSession() as session:
team = await new_team(session, 0, user_id=get_user)
print("New team=", team)
# Create user first to avoid user_id collision when creating team
key_gen = await new_user(
session,
0,
user_id=get_user,
budget=10,
budget_duration="5s",
team_id=team["team_id"],
models=["fake-openai-endpoint"],
)
key = key_gen["key"]

# Create team with the user (user already exists, so it will just add them)
team = await new_team(session, 0, user_id=get_user)
print("New team=", team)

# update user to have budget = 0.0000001
await update_member(
Expand Down
Loading