-
Notifications
You must be signed in to change notification settings - Fork 1
Finish implementation of stubbed roles APIs #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,13 +12,15 @@ | |
| from api.core.database import get_osm_session, get_task_session | ||
| from api.core.jwt import validate_and_decode_token | ||
| from api.core.logging import get_logger | ||
| from api.src.workspaces.schemas import WorkspaceUserRoleType | ||
| from api.src.users.schemas import WorkspaceUserRoleType | ||
|
|
||
| # Set up logger for this module | ||
| logger = get_logger(__name__) | ||
|
|
||
| # TTL cache for token validation (1 hour TTL, max 1000 entries) | ||
| _token_cache: cachetools.TTLCache[str, "UserInfo"] = cachetools.TTLCache( | ||
| # TTL cache keyed by a user's OIDC subject. Evict entries when roles change. We | ||
| # still validate the JWT signature and expiry on every request before reading a | ||
| # cached record. | ||
| _user_info_cache: cachetools.TTLCache[UUID, "UserInfo"] = cachetools.TTLCache( | ||
| maxsize=1000, ttl=60 * 60 | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
|
|
@@ -41,6 +43,17 @@ async def close_tdei_client() -> None: | |
| _tdei_client = None | ||
|
|
||
|
|
||
| def evict_user_from_cache(auth_uid: UUID) -> None: | ||
| """ | ||
| Evict a user's cached UserInfo object so that their next request re-fetches | ||
| permissions. | ||
|
|
||
| Call this after modifying a user's roles in the OSM DB to ensure the change | ||
| takes effect on their next request rather than after the cache TTL expires. | ||
| """ | ||
| _user_info_cache.pop(auth_uid, None) | ||
|
|
||
|
|
||
| security = HTTPBearer() | ||
|
|
||
|
|
||
|
|
@@ -72,6 +85,7 @@ class UserInfo: | |
| credentials: str | ||
| user_uuid: UUID | ||
| user_name: str | ||
| token_jti: str # JWT ID used to detect token rotation on cache hits | ||
|
|
||
| # workspaceId, role from OSM DB | ||
| osmWorkspaceRoles: dict[int, list[WorkspaceUserRoleType]] | ||
|
|
@@ -125,6 +139,13 @@ def isWorkspaceContributor(self, workspaceId: int) -> bool: | |
| return True | ||
| return False | ||
|
|
||
| def effective_role(self, workspaceId: int) -> WorkspaceUserRoleType: | ||
| if self.isWorkspaceLead(workspaceId): | ||
| return WorkspaceUserRoleType.LEAD | ||
| if self.isWorkspaceValidator(workspaceId): | ||
| return WorkspaceUserRoleType.VALIDATOR | ||
| return WorkspaceUserRoleType.CONTRIBUTOR | ||
|
|
||
|
|
||
| # can't use the ORM here since the ORM uses us! (circular dependency) | ||
| def get_osm_db_session( | ||
|
|
@@ -144,9 +165,13 @@ async def validate_token( | |
| osm_db_session: AsyncSession = Depends(get_osm_db_session), | ||
| task_db_session: AsyncSession = Depends(get_task_db_session), | ||
| ) -> UserInfo: | ||
| """Dependency to get current authenticated user from TDEI/KeyCloak token and APIs. | ||
| """ | ||
| Dependency that gets the current authenticated user from the TDEI/KeyCloak | ||
| access token and fetches permissions from TDEI APIs. | ||
|
|
||
| Results are cached by token for 1 hour to avoid repeated validation calls. | ||
| We validate the JWT's signature and expiry on every request. The expensive | ||
| TDEI API and DB lookups are cached for 1 hour and should be evicted when a | ||
| user's role changes via evict_user_from_cache(). | ||
| """ | ||
| token = credentials.credentials | ||
|
|
||
|
|
@@ -161,27 +186,39 @@ async def validate_token( | |
| except Exception: | ||
| raise credentials_exception | ||
|
|
||
| user_id: str | None = payload.get("sub") | ||
| if user_id is None: | ||
| user_id_str: str | None = payload.get("sub") | ||
| if user_id_str is None: | ||
| raise credentials_exception | ||
|
|
||
| # Check cache first | ||
| if token in _token_cache: | ||
| logger.info("Token validation cache hit") | ||
| return _token_cache[token] | ||
| try: | ||
| user_uuid = UUID(user_id_str) | ||
| except ValueError: | ||
| raise credentials_exception from None | ||
|
|
||
| # Cache miss - perform full validation | ||
| # Cache keyed by user UUID. If the token rotated (new "jti") since we | ||
| # created the cache entry, evict it so we fetch fresh claims: | ||
| # | ||
| if user_uuid in _user_info_cache: | ||
| cached = _user_info_cache[user_uuid] | ||
| current_jti = payload.get("jti", "") | ||
| if cached.token_jti == current_jti: | ||
| logger.info("Token validation cache hit") | ||
| return cached | ||
| logger.info("Token validation cache miss: token rotated") | ||
| del _user_info_cache[user_uuid] | ||
|
|
||
| # Cache miss: fetch TDEI roles and DB data: | ||
| user_info = await _validate_token_uncached( | ||
| token, user_id, payload, osm_db_session, task_db_session | ||
| token, user_uuid, payload, osm_db_session, task_db_session | ||
| ) | ||
| _token_cache[token] = user_info | ||
| _user_info_cache[user_uuid] = user_info | ||
|
|
||
| return user_info | ||
|
|
||
|
|
||
| async def _validate_token_uncached( | ||
| token: str, | ||
| user_id: str, | ||
| user_uuid: UUID, | ||
| payload: dict, | ||
| osm_db_session: AsyncSession, | ||
| task_db_session: AsyncSession, | ||
|
|
@@ -200,21 +237,17 @@ async def _validate_token_uncached( | |
| } | ||
|
|
||
| r = UserInfo() | ||
|
|
||
| try: | ||
| r.user_uuid = UUID(user_id) | ||
| except ValueError: | ||
| raise credentials_exception from None | ||
|
|
||
| r.user_uuid = user_uuid | ||
| r.credentials = token | ||
| r.token_jti = payload.get("jti", "") | ||
| r.user_name = payload.get("preferred_username", "unknown") | ||
|
|
||
| # get user's project groups and roles from TDEI | ||
| pgs = [] | ||
|
|
||
| try: | ||
| response = await _tdei_client.get( | ||
| f"project-group-roles/{user_id}", | ||
| f"project-group-roles/{user_uuid}", | ||
| headers=headers, | ||
| params={"page_no": 1, "page_size": 1000}, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reminder I think it caps at 50, may still need pagination?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't see a limit in the code. The portal API doesn't seem to have upper bounds on these (unlike the TDEI API). |
||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also shorten the TTL? Your call... the big advantage is those rapid fire Rapid tile requests, even a short period would help with those.