Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 61 additions & 74 deletions lib/Service/DiscoveryService.php
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class DiscoveryService {
public const INVALIDATE_JWKS_CACHE_AFTER_SECONDS = 3600;

/**
*
* @var string[]
*/
private const SUPPORTED_JWK_ALGS = [
Expand Down Expand Up @@ -52,63 +51,54 @@ public function __construct(
public function obtainDiscovery(Provider $provider): array {
$cacheKey = 'discovery-2-' . $provider->getDiscoveryEndpoint();
$cachedDiscovery = $this->cache->get($cacheKey);

if ($cachedDiscovery === null) {
$url = $provider->getDiscoveryEndpoint();
$this->logger->debug('Obtaining discovery endpoint: ' . $url);

$freshDiscovery = $this->clientService->get($url);
$parsedDiscovery = json_decode($freshDiscovery, true, 512, JSON_THROW_ON_ERROR);
$this->cache->set($cacheKey, $freshDiscovery, self::INVALIDATE_DISCOVERY_CACHE_AFTER_SECONDS);
} else {
$parsedDiscovery = json_decode($cachedDiscovery, true, 512, JSON_THROW_ON_ERROR);

return json_decode($freshDiscovery, true, 512, JSON_THROW_ON_ERROR);
}

return $parsedDiscovery;
return json_decode($cachedDiscovery, true, 512, JSON_THROW_ON_ERROR);
}

/**
* @param Provider $provider
* @param string $tokenToDecode This is used to potentially fix the missing alg in
* @param bool $useCache
* @return array
* @throws \JsonException
*/
public function obtainJWK(Provider $provider, string $tokenToDecode, bool $useCache = true): array {
$lastJwksRefresh = $this->providerService->getSetting($provider->getId(), ProviderService::SETTING_JWKS_CACHE_TIMESTAMP);
if ($lastJwksRefresh !== '' && $useCache && (int)$lastJwksRefresh > time() - self::INVALIDATE_JWKS_CACHE_AFTER_SECONDS) {
$rawJwks = $this->providerService->getSetting($provider->getId(), ProviderService::SETTING_JWKS_CACHE);
$rawJwks = json_decode($rawJwks, true);
$providerId = $provider->getId();
$lastJwksRefresh = $this->providerService->getSetting($providerId, ProviderService::SETTING_JWKS_CACHE_TIMESTAMP);

if ($useCache && $lastJwksRefresh !== '' && (int)$lastJwksRefresh > time() - self::INVALIDATE_JWKS_CACHE_AFTER_SECONDS) {
$rawJwks = $this->providerService->getSetting($providerId, ProviderService::SETTING_JWKS_CACHE);
$rawJwks = json_decode($rawJwks, true, 512, JSON_THROW_ON_ERROR);
$this->logger->debug('[obtainJWK] jwks cache content', ['jwks_cache' => $rawJwks]);
} else {
$discovery = $this->obtainDiscovery($provider);
$responseBody = $this->clientService->get($discovery['jwks_uri']);
$rawJwks = json_decode($responseBody, true);
$rawJwks = json_decode($responseBody, true, 512, JSON_THROW_ON_ERROR);

$this->logger->debug('[obtainJWK] getting fresh jwks', ['jwks' => $rawJwks]);
// cache jwks
$this->providerService->setSetting($provider->getId(), ProviderService::SETTING_JWKS_CACHE, $responseBody);
$this->logger->debug('[obtainJWK] setting cache', ['jwks_cache' => $responseBody]);
$this->providerService->setSetting($provider->getId(), ProviderService::SETTING_JWKS_CACHE_TIMESTAMP, strval(time()));

$this->providerService->setSetting($providerId, ProviderService::SETTING_JWKS_CACHE, $responseBody);
$this->providerService->setSetting($providerId, ProviderService::SETTING_JWKS_CACHE_TIMESTAMP, (string)time());
}

$fixedJwks = $this->fixJwksAlg($rawJwks, $tokenToDecode);
$this->logger->debug('[obtainJWK] fixed jwks', ['fixed_jwks' => $fixedJwks]);
$jwks = JWK::parseKeySet($fixedJwks, 'RS256');
$this->logger->debug('Parsed the jwks');
return $jwks;

return JWK::parseKeySet($fixedJwks, 'RS256');
}

/**
* @param string $authorizationEndpoint
* @param array $extraGetParameters
* @return string
*/
public function buildAuthorizationUrl(string $authorizationEndpoint, array $extraGetParameters = []): string {
$parsedUrl = parse_url($authorizationEndpoint);

$urlWithoutParams
= (isset($parsedUrl['scheme']) ? $parsedUrl['scheme'] . '://' : '')
$urlWithoutParams = ($parsedUrl['scheme'] ?? 'https') . '://'
. ($parsedUrl['host'] ?? '')
. (isset($parsedUrl['port']) ? ':' . strval($parsedUrl['port']) : '')
. (isset($parsedUrl['port']) ? ':' . $parsedUrl['port'] : '')
. ($parsedUrl['path'] ?? '');

$queryParams = $extraGetParameters;
Expand All @@ -117,19 +107,18 @@ public function buildAuthorizationUrl(string $authorizationEndpoint, array $extr
$queryParams = array_merge($queryParams, $parsedQueryParams);
}

// sanitize everything before the query parameters
// and trust http_build_query to sanitize the query parameters
return htmlentities(filter_var($urlWithoutParams, FILTER_SANITIZE_URL), ENT_QUOTES)
. (empty($queryParams) ? '' : '?' . http_build_query($queryParams));
$finalUrl = $urlWithoutParams . ($queryParams ? '?' . http_build_query($queryParams) : '');

return htmlspecialchars($finalUrl, ENT_QUOTES, 'UTF-8');
}

/**
* Validates the strength and correctness of a cryptographic key.
*
* This method checks:
* - RSA keys have a modulus of at least 2048 bits.
* - EC keys use one of the allowed curves: P-256, P-384, P-521.
* - OKP (EdDSA) keys use the Ed25519 curve.
* - RSA keys have a modulus of at least 2048 bits.
* - EC keys use one of the allowed curves: P-256, P-384, P-521.
* - OKP (EdDSA) keys use the Ed25519 curve.
*
* @param array $key The key data as an associative array (JWK format).
* @param string $alg The algorithm intended to be used with this key (e.g., 'RS256', 'ES256').
Expand All @@ -140,44 +129,37 @@ public function buildAuthorizationUrl(string $authorizationEndpoint, array $extr
private function validateKeyStrength(array $key, string $alg): void {
$kty = $key['kty'] ?? throw new \RuntimeException('Key missing kty');

switch ($kty) {
case 'RSA':
if (empty($key['n'])) {
throw new \RuntimeException('RSA key missing modulus (n)');
}
$modulus = JWT::urlsafeB64Decode($key['n']);
match ($kty) {
'RSA' => (function () use ($key) {
$modulus = JWT::urlsafeB64Decode($key['n'] ?? throw new \RuntimeException('RSA key missing modulus (n)'));
$modulusBits = strlen($modulus) * 8;

if ($modulusBits < 2048) {
throw new \RuntimeException('RSA key too short: ' . $modulusBits . ' bits');
}
break;
})(),

case 'EC':
'EC' => (function () use ($key) {
$curve = $key['crv'] ?? throw new \RuntimeException('EC key missing crv');
$allowedCurves = ['P-256', 'P-384', 'P-521'];
if (!in_array($curve, $allowedCurves, true)) {

if (!in_array($curve, ['P-256', 'P-384', 'P-521'], true)) {
throw new \RuntimeException('Unsupported EC curve: ' . $curve);
}
break;
})(),

case 'OKP':
'OKP' => (function () use ($key) {
$curve = $key['crv'] ?? throw new \RuntimeException('OKP key missing crv');

if ($curve !== 'Ed25519') {
throw new \RuntimeException('Unsupported OKP curve: ' . $curve);
}
break;
})(),

default:
throw new \RuntimeException('Unsupported key type: ' . $kty);
}
default => throw new \RuntimeException('Unsupported key type: ' . $kty)
};
}

/**
* Inspired by https://github.com/snake/moodle/compare/880462a1685...MDL-77077-master
*
* @param array $jwks The JSON Web Key Set
* @param string $jwt The JWT token
* @return array The modified JWKS
* @throws \RuntimeException if no matching key is found or algorithm is unsupported
*/
public function fixJwksAlg(array $jwks, string $jwt): array {
Expand All @@ -196,54 +178,59 @@ public function fixJwksAlg(array $jwks, string $jwt): array {
throw new \RuntimeException('Invalid JWKS: missing "keys" array');
}

// Check validation config once before loop
$oidcSystemConfig = $this->config->getSystemValue('user_oidc', []);
$shouldValidateStrength = !isset($oidcSystemConfig['validate_jwk_strength'])
|| !($oidcSystemConfig['validate_jwk_strength'] === false
|| $oidcSystemConfig['validate_jwk_strength'] === 'false'
|| $oidcSystemConfig['validate_jwk_strength'] === 0
|| $oidcSystemConfig['validate_jwk_strength'] === '0');

$matchingIndex = null;

foreach ($keys as $index => $key) {
$keyKty = $key['kty'] ?? null;
$keyUse = $key['use'] ?? null;

// Skip keys with incompatible type
if ($keyKty !== $expectedKty) {
if (($key['kty'] ?? null) !== $expectedKty) {
continue;
}

// Skip keys not intended for signature
$keyUse = $key['use'] ?? null;
if ($keyUse !== null && $keyUse !== 'sig') {
continue;
}

$oidcSystemConfig = $this->config->getSystemValue('user_oidc', []);
if (!isset($oidcSystemConfig['validate_jwk_strength'])
|| !in_array($oidcSystemConfig['validate_jwk_strength'], [false, 'false', 0, '0'], true)) {
// Validate key strength
$this->validateKeyStrength($key, $alg);
}

// If JWT has a kid, match strictly
if ($kid !== null) {
if (($key['kid'] ?? null) !== $kid) {
continue;
}

if ($shouldValidateStrength) {
$this->validateKeyStrength($key, $alg);
}

$matchingIndex = $index;
break;
}

// If no kid, select the first compatible key
if ($matchingIndex === null) {
if ($shouldValidateStrength) {
$this->validateKeyStrength($key, $alg);
}
$matchingIndex = $index;
}
}

if ($matchingIndex === null) {
throw new \RuntimeException(sprintf(
'No matching key found in JWKS (alg=%s, kid=%s)',
$alg ?? 'unknown',
$kid ?? 'none'
));
throw new \RuntimeException(
'No matching key found in JWKS (alg=' . ($alg ?? 'unknown') . ', kid=' . ($kid ?? 'none') . ')'
);
}

// Set 'alg' field if missing
if (empty($jwks['keys'][$matchingIndex]['alg'])) {
if (!isset($jwks['keys'][$matchingIndex]['alg'])) {
$jwks['keys'][$matchingIndex]['alg'] = $alg;
}

Expand Down
Loading