Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public DataProtectionStore(
_memoryCache = memoryCache;
_dataProtector = provider.CreateProtector(nameof(KeyMaterial)); ;
}
public Task Store(KeyMaterial securityParameters)
public Task<KeyMaterial> Store(KeyMaterial securityParameters)
{
var possiblyEncryptedKeyElement = _dataProtector.Protect(JsonSerializer.Serialize(securityParameters));

Expand All @@ -74,16 +74,16 @@ public Task Store(KeyMaterial securityParameters)
KeyRepository.StoreElement(keyElement, friendlyName);
ClearCache();

return Task.CompletedTask;
return Task.FromResult(securityParameters);
}



public async Task<KeyMaterial> GetCurrent(JwtKeyType jwtKeyType = JwtKeyType.Jws)
public async Task<KeyMaterial> GetCurrent(JwtKeyType jwtKeyType = JwtKeyType.Jws, bool bypassCache = false)
{
var cacheKey = JwkContants.CurrentJwkCache + jwtKeyType;

if (!_memoryCache.TryGetValue(cacheKey, out KeyMaterial keyMaterial))
if (bypassCache || !_memoryCache.TryGetValue(cacheKey, out KeyMaterial keyMaterial))
{
var keys = await GetLastKeys(1, jwtKeyType);
keyMaterial = keys.FirstOrDefault();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@ internal class InMemoryStore : IJsonWebKeyStore
private readonly SemaphoreSlim _slim = new(1);
internal const string DefaultRevocationReason = "Revoked";

public Task Store(KeyMaterial keyMaterial)
public Task<KeyMaterial> Store(KeyMaterial keyMaterial)
{
if (keyMaterial is null) throw new InvalidOperationException("Can't store empty value.");

_slim.Wait();
_store.Add(keyMaterial);
_slim.Release();

return Task.CompletedTask;
return Task.FromResult(keyMaterial);
}

public Task<KeyMaterial> GetCurrent(JwtKeyType jwtKeyType = JwtKeyType.Jws)
public Task<KeyMaterial> GetCurrent(JwtKeyType jwtKeyType = JwtKeyType.Jws, bool bypassCache = false)
{
return Task.FromResult(_store.Where(s => s.Use == (jwtKeyType == JwtKeyType.Jws ? "sig" : "enc")).OrderByDescending(s => s.CreationDate).FirstOrDefault());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ namespace NetDevPack.Security.Jwt.Core.Interfaces;

public interface IJsonWebKeyStore
{
Task Store(KeyMaterial keyMaterial);
Task<KeyMaterial> GetCurrent(JwtKeyType jwtKeyType = JwtKeyType.Jws);
Task<KeyMaterial> Store(KeyMaterial keyMaterial);
Task<KeyMaterial> GetCurrent(JwtKeyType jwtKeyType = JwtKeyType.Jws, bool bypassCache = false);
Task Revoke(KeyMaterial keyMaterial, string reason=default);
Task<ReadOnlyCollection<KeyMaterial>> GetLastKeys(int quantity, JwtKeyType? jwtKeyType = null);
Task<KeyMaterial> Get(string keyId);
Expand Down
49 changes: 39 additions & 10 deletions src/NetDevPack.Security.Jwt.Core/Jwt/JwtService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,34 @@ internal class JwtService : IJwtService
{
private readonly IJsonWebKeyStore _store;
private readonly IOptions<JwtOptions> _options;
// Process-wide lock so a scoped service across concurrent requests rotates once, not once per request.
private static readonly SemaphoreSlim RotationLock = new(1, 1);

public JwtService(IJsonWebKeyStore store, IOptions<JwtOptions> options)
{
_store = store;
_options = options;
}
public async Task<SecurityKey> GenerateKey(JwtKeyType jwtKeyType = JwtKeyType.Jws)
{
var current = await _store.GetCurrent(jwtKeyType, bypassCache: true);
// if current is null, get the highest version ever created (manually revoked/first-run)
current ??= (await _store.GetLastKeys(1, jwtKeyType)).FirstOrDefault();
return await GenerateKey(jwtKeyType, current);
}

private async Task<SecurityKey> GenerateKey(JwtKeyType jwtKeyType, KeyMaterial previous)
{
var key = new CryptographicKey(jwtKeyType == JwtKeyType.Jws ? _options.Value.Jws : _options.Value.Jwe);

var model = new KeyMaterial(key);
await _store.Store(model);
// Next rotation version. Seeded at 1 on cold start and for pre-column rows (Version defaults to 0).
model.Version = (previous?.Version ?? 0) + 1;
// Store returns the persisted key when it wins the insert, or the key another replica already stored
// for this version, so we never sign with a key that isn't published.
var persisted = await _store.Store(model);

return model.GetSecurityKey();
return persisted.GetSecurityKey();
}

public async Task<SecurityKey> GetCurrentSecurityKey(JwtKeyType jwtKeyType = JwtKeyType.Jws)
Expand All @@ -33,10 +47,24 @@ public async Task<SecurityKey> GetCurrentSecurityKey(JwtKeyType jwtKeyType = Jwt

if (NeedsUpdate(current))
{
// According NIST - https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-57pt1r4.pdf - Private key should be removed when no longer needs
await _store.Revoke(current);
var newKey = await GenerateKey(jwtKeyType);
return newKey;
await RotationLock.WaitAsync();
try
{
// Re-check under the lock, bypassing the cache
current = await _store.GetCurrent(jwtKeyType, bypassCache: true);
// No active key: fall back to the newest key including revoked.
current ??= (await _store.GetLastKeys(1, jwtKeyType)).FirstOrDefault();
if (NeedsUpdate(current))
{
// According NIST - https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-57pt1r4.pdf - Private key should be removed when no longer needs
await _store.Revoke(current);
return await GenerateKey(jwtKeyType, current);
}
}
finally
{
RotationLock.Release();
}
}

// options has change. Change current key
Expand Down Expand Up @@ -74,7 +102,7 @@ private async Task<bool> CheckCompatibility(KeyMaterial currentKey, JwtKeyType j
if (jwtKeyType == JwtKeyType.Jws && currentKey.Type != _options.Value.Jws.Kty()
|| jwtKeyType == JwtKeyType.Jwe && currentKey.Type != _options.Value.Jwe.Kty())
{
await GenerateKey(jwtKeyType);
await GenerateKey(jwtKeyType, currentKey);
return false;
}
return true;
Expand All @@ -89,10 +117,11 @@ public async Task RevokeKey(string keyId, string reason = null)

public async Task<SecurityKey> GenerateNewKey(JwtKeyType jwtKeyType = JwtKeyType.Jws)
{
var oldCurrent = await _store.GetCurrent(jwtKeyType);
var oldCurrent = await _store.GetCurrent(jwtKeyType, bypassCache: true);
// if current is null, get the highest version ever created (manually revoked/first-run)
oldCurrent ??= (await _store.GetLastKeys(1, jwtKeyType)).FirstOrDefault();
await _store.Revoke(oldCurrent);
return await GenerateKey(jwtKeyType);

return await GenerateKey(jwtKeyType, oldCurrent);
}

private bool NeedsUpdate(KeyMaterial current)
Expand Down
1 change: 1 addition & 0 deletions src/NetDevPack.Security.Jwt.Core/Model/Key.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public KeyMaterial(CryptographicKey cryptographicKey)
}

public Guid Id { get; set; } = Guid.NewGuid();
public long Version { get; set; }
public string KeyId { get; set; }
public string Type { get; set; }
public string Use { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
Expand Down Expand Up @@ -32,20 +34,50 @@ public DatabaseJsonWebKeyStore(TContext context, ILogger<DatabaseJsonWebKeyStore
_logger = logger;
}

public async Task Store(KeyMaterial securityParamteres)
public async Task<KeyMaterial> Store(KeyMaterial securityParamteres)
{
await _context.SecurityKeys.AddAsync(securityParamteres);
// Deterministic Id per use + kty + rotation version
// every replica replacing the same key computes the same Id
// Concurrent inserts collide on the primary key
securityParamteres.Id = DeterministicId(securityParamteres.Use, securityParamteres.Type, securityParamteres.Version);

_logger.LogInformation($"Saving new SecurityKeyWithPrivate {securityParamteres.Id}", typeof(TContext).Name);
await _context.SaveChangesAsync();
try
{
await _context.SecurityKeys.AddAsync(securityParamteres);
await _context.SaveChangesAsync();
}
catch
{
// Lost the race or a transient fault.
_context.Entry(securityParamteres).State = EntityState.Detached;
var winner = await _context.SecurityKeys.AsNoTracking()
.FirstOrDefaultAsync(k => k.Id == securityParamteres.Id);
if (winner == null)
throw;

// Return the persisted winner so the caller signs with the published key, not our orphan.
ClearCache();
return winner;
}
ClearCache();
return securityParamteres;
}

private static Guid DeterministicId(string use, string kty, long version)
{
using var sha = SHA256.Create();
var hash = sha.ComputeHash(Encoding.UTF8.GetBytes($"{use}:{kty}:{version}"));
var guidBytes = new byte[16];
Array.Copy(hash, guidBytes, 16);
return new Guid(guidBytes);
}

public async Task<KeyMaterial> GetCurrent(JwtKeyType jwtKeyType = JwtKeyType.Jws)
public async Task<KeyMaterial> GetCurrent(JwtKeyType jwtKeyType = JwtKeyType.Jws, bool bypassCache = false)
{
var cacheKey = JwkContants.CurrentJwkCache + jwtKeyType;

if (!_memoryCache.TryGetValue(cacheKey, out KeyMaterial credentials))
if (bypassCache || !_memoryCache.TryGetValue(cacheKey, out KeyMaterial credentials))
{
var keyType = (jwtKeyType == JwtKeyType.Jws ? "sig" : "enc");
#if NET5_0_OR_GREATER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ private string GetCurrentFile(JwtKeyType jwtKeyType)
return Path.Combine(KeysPath.FullName, $"{_options.Value.KeyPrefix}current.{jwtKeyType}.key");
}

public async Task Store(KeyMaterial securityParamteres)
public async Task<KeyMaterial> Store(KeyMaterial securityParamteres)
{
if (!KeysPath.Exists)
KeysPath.Create();
Expand All @@ -48,6 +48,7 @@ public async Task Store(KeyMaterial securityParamteres)

await File.WriteAllTextAsync(Path.Combine(KeysPath.FullName, $"{_options.Value.KeyPrefix}current-{securityParamteres.KeyId}.{keyType}.key"), JsonSerializer.Serialize(securityParamteres, new JsonSerializerOptions() { DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull }));
ClearCache();
return securityParamteres;
}

public bool NeedsUpdate(KeyMaterial current)
Expand All @@ -74,11 +75,11 @@ public async Task Revoke(KeyMaterial securityKeyWithPrivate, string reason = nul
}


public Task<KeyMaterial?> GetCurrent(JwtKeyType jwtKeyType = JwtKeyType.Jws)
public Task<KeyMaterial?> GetCurrent(JwtKeyType jwtKeyType = JwtKeyType.Jws, bool bypassCache = false)
{
var cacheKey = JwkContants.CurrentJwkCache + jwtKeyType;

if (!_memoryCache.TryGetValue(cacheKey, out KeyMaterial credentials))
if (bypassCache || !_memoryCache.TryGetValue(cacheKey, out KeyMaterial credentials))
{
credentials = GetKey(GetCurrentFile(jwtKeyType));
// Set cache options.
Expand Down