diff --git a/src/CertManager.sol b/src/CertManager.sol index 1a5c0e2..47354a8 100644 --- a/src/CertManager.sol +++ b/src/CertManager.sol @@ -381,8 +381,11 @@ contract CertManager is ICertManager { "invalid cert algo param" ); - uint256 end = subjectPubKeyPtr.content() + subjectPubKeyPtr.length(); - subjectPubKey = certificate.slice(end - 96, 96); + uint256 keyStart = subjectPubKeyPtr.content(); + uint256 keyLength = subjectPubKeyPtr.length(); + require(keyLength == 97 && keyStart + keyLength <= certificate.length, "invalid subject public key length"); + require(certificate[keyStart] == 0x04, "invalid subject public key format"); + subjectPubKey = certificate.slice(keyStart + 1, 96); } function _verifyValidity(bytes memory certificate, Asn1Ptr validityPtr) internal view returns (uint64 notAfter) { diff --git a/test/CertManager.t.sol b/test/CertManager.t.sol index 687e222..ba838f5 100644 --- a/test/CertManager.t.sol +++ b/test/CertManager.t.sol @@ -21,11 +21,31 @@ contract Asn1DecodeHarness { } } +contract CertManagerHarness is CertManager { + using Asn1Decode for bytes; + + constructor() CertManager(new P384Verifier()) {} + + function parsePubKey(bytes memory subjectPublicKeyInfo) external pure returns (bytes memory) { + return _parsePubKey(subjectPublicKeyInfo, subjectPublicKeyInfo.root()); + } + + function parsePubKeyAt(bytes memory certificate, uint256 header, uint256 content, uint256 length) + external + pure + returns (bytes memory) + { + return _parsePubKey(certificate, LibAsn1Ptr.toAsn1Ptr(header, content, length)); + } +} + contract CertManagerTest is Test { Asn1DecodeHarness public harness; + CertManagerHarness public certManagerHarness; function setUp() public { harness = new Asn1DecodeHarness(); + certManagerHarness = new CertManagerHarness(); } // 's' INTEGER from cabundle[3] (2026-04-02 attestation): DER-encoded with a 0x00 @@ -41,6 +61,46 @@ contract CertManagerTest is Test { assertEq(lo, 0xa2eda9c549dc01460f5fe650814ebe0e7ee855d3bcffde95afd2e82e21df0eac); } + function test_ParsePubKeyAcceptsUncompressedP384Point() public view { + bytes memory pubKey = _patternBytes(96); + bytes memory spki = abi.encodePacked(hex"3076301006072a8648ce3d020106052b8104002203620004", pubKey); + + assertEq(certManagerHarness.parsePubKey(spki), pubKey); + } + + function test_ParsePubKeyRejectsCompressedP384Point() public { + bytes memory compressedKey = _patternBytes(48); + bytes memory spki = abi.encodePacked(hex"3046301006072a8648ce3d020106052b8104002203320002", compressedKey); + bytes memory paddedCertificate = abi.encodePacked(new bytes(128), spki); + + vm.expectRevert("invalid subject public key length"); + certManagerHarness.parsePubKeyAt(paddedCertificate, 128, 130, 0x46); + } + + function test_ParsePubKeyRejectsOversizedP384Point() public { + bytes memory oversizedKey = _patternBytes(97); + bytes memory spki = abi.encodePacked(hex"3077301006072a8648ce3d020106052b8104002203630004", oversizedKey); + + vm.expectRevert("invalid subject public key length"); + certManagerHarness.parsePubKey(spki); + } + + function test_ParsePubKeyRejectsTruncatedP384Point() public { + bytes memory truncatedKey = _patternBytes(95); + bytes memory spki = abi.encodePacked(hex"3076301006072a8648ce3d020106052b8104002203620004", truncatedKey); + + vm.expectRevert("invalid subject public key length"); + certManagerHarness.parsePubKey(spki); + } + + function test_ParsePubKeyRejectsMissingUncompressedPrefix() public { + bytes memory pubKey = _patternBytes(96); + bytes memory spki = abi.encodePacked(hex"3076301006072a8648ce3d020106052b8104002203620002", pubKey); + + vm.expectRevert("invalid subject public key format"); + certManagerHarness.parsePubKey(spki); + } + // Cert chain from the 2026-04-02 ~15:35 UTC dev attestation that produced the live revert. // CB0 is the AWS Nitro root (keccak256(CB0) == CertManager.ROOT_CA_CERT_HASH, pinned in the // constructor), so the chain is verified starting from CB1. @@ -155,4 +215,11 @@ contract CertManagerTest is Test { return der; } + + function _patternBytes(uint256 len) internal pure returns (bytes memory out) { + out = new bytes(len); + for (uint256 i = 0; i < len; i++) { + out[i] = bytes1(uint8(i + 1)); + } + } }