Skip to content

Commit 18f31d1

Browse files
feat: enhance domain handling and validation in ApiClient and caching mechanisms
1 parent ba10d5f commit 18f31d1

8 files changed

Lines changed: 333 additions & 33 deletions

File tree

docs/Caching.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ api_client = ApiClient(ApiClientOptions(
7171
))
7272
```
7373

74-
When a custom adapter is provided, both the discovery cache and JWKS cache use it. Cache keys are inherently distinct — discovery keys are normalized issuer URLs (e.g., `https://tenant.auth0.com/`) and JWKS keys are `jwks_uri` values (e.g., `https://tenant.auth0.com/.well-known/jwks.json`).
74+
When a custom adapter is provided, both the discovery cache and JWKS cache use the same adapter instance. Cache keys are inherently distinct — discovery keys are normalized issuer URLs (e.g., `https://tenant.auth0.com/`) and JWKS keys are `jwks_uri` values (e.g., `https://tenant.auth0.com/.well-known/jwks.json`).
75+
76+
**Note:** Because both caches share one adapter, entries share the same LRU eviction pool. A JWKS entry could evict a discovery entry (or vice versa) under memory pressure. Set `cache_max_entries` accordingly — recommended: `number_of_issuers × 3`. With the default `InMemoryCache`, discovery and JWKS caches are separate and each gets its own `max_entries` budget.
7577

7678
## Tuning Recommendations
7779

src/auth0_api_python/api_client.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import time
23
from collections.abc import Mapping, Sequence
34
from typing import Any, Optional, Union
@@ -63,6 +64,10 @@ def __init__(self, options: ApiClientOptions):
6364
# Static list validation
6465
if len(options.domains) == 0:
6566
raise ConfigurationError("domains list cannot be empty")
67+
if not all(isinstance(d, str) and d.strip() for d in options.domains):
68+
raise ConfigurationError(
69+
"domains list must contain only non-empty strings"
70+
)
6671
# Normalize and store domains
6772
self._allowed_domains = [normalize_domain(d) for d in options.domains]
6873
elif callable(options.domains):
@@ -145,9 +150,11 @@ async def _resolve_allowed_domains(
145150
'unverified_iss': unverified_iss
146151
}
147152

148-
# Invoke resolver
153+
# Invoke resolver (supports both sync and async resolvers)
149154
try:
150155
result = self._allowed_domains(context)
156+
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
157+
result = await result
151158
except Exception as e:
152159
raise DomainsResolverError(
153160
f"Domains resolver function failed: {str(e)}"
@@ -164,6 +171,11 @@ async def _resolve_allowed_domains(
164171
"Domains resolver returned an empty list"
165172
)
166173

174+
if not all(isinstance(d, str) and d.strip() for d in result):
175+
raise DomainsResolverError(
176+
"Domains resolver must return a list of non-empty strings"
177+
)
178+
167179
# Normalize domains from resolver
168180
allowed_domains = [normalize_domain(d) for d in result]
169181
else:
@@ -984,11 +996,11 @@ async def _discover(self, issuer: Optional[str] = None) -> dict[str, Any]:
984996
OIDC discovery metadata dictionary
985997
"""
986998
if issuer:
999+
cache_key = issuer # Already normalized by caller
9871000
domain = issuer.replace('https://', '').replace('http://', '').rstrip('/')
9881001
else:
9891002
domain = self.options.domain
990-
991-
cache_key = normalize_domain(f"https://{domain}")
1003+
cache_key = normalize_domain(f"https://{domain}")
9921004

9931005
cached = self._discovery_cache.get(cache_key)
9941006
if cached:

src/auth0_api_python/cache.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
import time
12
from abc import ABC, abstractmethod
2-
from datetime import datetime, timedelta
33
from typing import Any, Optional
44

55

@@ -74,8 +74,12 @@ class InMemoryCache(CacheAdapter):
7474
"""
7575
Default in-memory cache implementation with LRU eviction.
7676
77+
Designed for asyncio (single-threaded).
78+
For multi-threaded environments, implement a custom CacheAdapter
79+
with appropriate locking.
80+
7781
Features:
78-
- TTL (time-to-live) support per entry
82+
- TTL (time-to-live) support per entry using monotonic clock
7983
- LRU (Least Recently Used) eviction when max_entries reached
8084
- No external dependencies
8185
@@ -96,7 +100,7 @@ def __init__(self, max_entries: int = 100):
96100
Args:
97101
max_entries: Maximum number of cache entries (default: 100)
98102
"""
99-
self._cache: dict[str, tuple[Any, Optional[datetime]]] = {}
103+
self._cache: dict[str, tuple[Any, Optional[float]]] = {}
100104
self._max_entries = max_entries
101105

102106
def get(self, key: str) -> Optional[Any]:
@@ -116,7 +120,7 @@ def get(self, key: str) -> Optional[Any]:
116120

117121
value, expiry = self._cache[key]
118122

119-
if expiry and datetime.now() > expiry:
123+
if expiry is not None and time.monotonic() > expiry:
120124
del self._cache[key]
121125
return None
122126

@@ -145,8 +149,8 @@ def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None) -> None:
145149
del self._cache[oldest_key]
146150

147151
expiry = None
148-
if ttl_seconds:
149-
expiry = datetime.now() + timedelta(seconds=ttl_seconds)
152+
if ttl_seconds is not None:
153+
expiry = time.monotonic() + ttl_seconds
150154

151155
self._cache[key] = (value, expiry)
152156

src/auth0_api_python/types.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
Type definitions for auth0-api-python SDK
33
"""
44

5-
from typing import Callable, Optional, TypedDict
5+
from collections.abc import Awaitable, Callable
6+
from typing import Optional, TypedDict, Union
67

78

89
class DomainsResolverContext(TypedDict, total=False):
@@ -12,19 +13,20 @@ class DomainsResolverContext(TypedDict, total=False):
1213
Attributes:
1314
request_url: The URL the API request was made to (optional)
1415
request_headers: Request headers dict (e.g., Host, X-Forwarded-Host) (optional)
15-
unverified_iss: The issuer claim from the unverified token (required)
16+
unverified_iss: The issuer claim from the unverified token
1617
"""
1718
request_url: Optional[str]
1819
request_headers: Optional[dict]
19-
unverified_iss: str # This is required, others are optional
20+
unverified_iss: str
2021

21-
22-
DomainsResolver = Callable[[DomainsResolverContext], list[str]]
22+
DomainsResolver = Callable[
23+
[DomainsResolverContext], Union[list[str], Awaitable[list[str]]]
24+
]
2325
"""
2426
Type alias for domains resolver function.
2527
26-
A DomainsResolver is a function that receives a DomainsResolverContext and returns
27-
a list of allowed domain strings.
28+
A DomainsResolver is a sync or async function that receives a DomainsResolverContext
29+
and returns a list of allowed domain strings.
2830
2931
Args:
3032
context (DomainsResolverContext): Dictionary containing:
@@ -35,14 +37,17 @@ class DomainsResolverContext(TypedDict, total=False):
3537
Returns:
3638
list[str]: List of allowed domain strings (e.g., ['tenant.auth0.com'])
3739
38-
Example:
40+
Example (sync):
3941
from auth0_api_python import DomainsResolverContext
4042
4143
def my_resolver(context: DomainsResolverContext) -> list[str]:
42-
unverified_iss = context['unverified_iss']
43-
request_url = context.get('request_url')
44-
request_headers = context.get('request_headers')
45-
46-
# Fetch allowed domains based on context
47-
return ['tenant1.auth0.com', 'tenant2.auth0.com']
44+
host = (context.get('request_headers') or {}).get('host')
45+
if host == 'api.brand.com':
46+
return ['brand.custom-domain.com']
47+
return ['tenant.auth0.com']
48+
49+
Example (async):
50+
async def my_async_resolver(context: DomainsResolverContext) -> list[str]:
51+
domains = await db.lookup_domains(context['unverified_iss'])
52+
return domains
4853
"""

src/auth0_api_python/utils.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,33 @@ def normalize_domain(domain: str) -> str:
5353
Normalized issuer URL (e.g., "https://tenant.auth0.com/")
5454
5555
"""
56+
if not isinstance(domain, str) or not domain.strip():
57+
raise ValueError("domain must be a non-empty string")
58+
5659
domain = domain.strip().lower()
57-
domain = domain.replace('http://', '').replace('https://', '')
58-
domain = domain.rstrip('/')
59-
return f"https://{domain}/"
60+
61+
# Reject http:// explicitly
62+
if domain.startswith('http://'):
63+
raise ValueError("invalid domain URL (https required)")
64+
65+
# Strip https:// prefix
66+
domain = domain.replace('https://', '')
67+
68+
# Split host from any path/query/fragment
69+
host = domain.split('/')[0].split('?')[0].split('#')[0]
70+
71+
# Reject credentials
72+
if '@' in host:
73+
raise ValueError("invalid domain URL (credentials are not allowed)")
74+
75+
# Check for path segments, query, or fragment
76+
bare = domain.rstrip('/')
77+
if bare != host:
78+
raise ValueError(
79+
"invalid domain URL (path/query/fragment are not allowed)"
80+
)
81+
82+
return f"https://{host}/"
6083

6184

6285
async def fetch_oidc_metadata(

0 commit comments

Comments
 (0)