From 6227802a74525666377693532ae2f353841cf0ad Mon Sep 17 00:00:00 2001 From: Carlos Monastyrski Date: Fri, 19 Jun 2026 00:49:19 -0300 Subject: [PATCH] Add PKI HSM option --- .../workflows/release_build_infisical_cli.yml | 7 + .goreleaser.yaml | 57 ++- go.mod | 1 + go.sum | 2 + packages/api/api.go | 3 +- packages/api/model.go | 4 + packages/cmd/gateway.go | 25 +- packages/gateway-v2/gateway.go | 72 ++- packages/gateway-v2/pkcs11.go | 53 ++ packages/gateway-v2/pkcs11_disabled.go | 10 + packages/gateway-v2/pkcs11_enabled.go | 452 ++++++++++++++++++ packages/gateway-v2/pkcs11_handler.go | 353 ++++++++++++++ 12 files changed, 1025 insertions(+), 14 deletions(-) create mode 100644 packages/gateway-v2/pkcs11.go create mode 100644 packages/gateway-v2/pkcs11_disabled.go create mode 100644 packages/gateway-v2/pkcs11_enabled.go create mode 100644 packages/gateway-v2/pkcs11_handler.go diff --git a/.github/workflows/release_build_infisical_cli.yml b/.github/workflows/release_build_infisical_cli.yml index 22226804..3669e210 100644 --- a/.github/workflows/release_build_infisical_cli.yml +++ b/.github/workflows/release_build_infisical_cli.yml @@ -137,6 +137,13 @@ jobs: sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 3B4FE6ACC0B21F32 sudo apt update sudo apt-get install -y libssl1.0-dev + - name: Install glibc cross-compilers for PKCS#11 (HSM) builds + run: | + set -euo pipefail + # PKCS#11 driver loading uses dlopen; the artifact must be dynamically + # linked against glibc. We use the system gcc for amd64 (native) and + # gcc-aarch64-linux-gnu for arm64. + sudo apt-get install -y gcc-aarch64-linux-gnu - name: Install cross-compile toolchains for RDP tier run: | set -euo pipefail diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 1b0a5362..9fde2d64 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -117,6 +117,41 @@ builds: goarm: - "7" + # PKCS#11-enabled HSM build. Loads the vendor's PKCS#11 driver via dlopen at + # runtime, so it MUST be dynamically linked (no -extldflags "-static") and + # must use a glibc toolchain. Ships as a separate `infisical-pkcs11` artifact. + - id: linux-amd64-pkcs11 + binary: infisical-pkcs11 + ldflags: + - -X github.com/Infisical/infisical-merge/packages/util.CLI_VERSION={{ .Version }} + - -X github.com/Infisical/infisical-merge/packages/telemetry.POSTHOG_API_KEY_FOR_CLI={{ .Env.POSTHOG_API_KEY_FOR_CLI }} + flags: + - -trimpath + - -tags=pkcs11 + env: + - CGO_ENABLED=1 + - CC=x86_64-linux-gnu-gcc + goos: + - linux + goarch: + - amd64 + + - id: linux-arm64-pkcs11 + binary: infisical-pkcs11 + ldflags: + - -X github.com/Infisical/infisical-merge/packages/util.CLI_VERSION={{ .Version }} + - -X github.com/Infisical/infisical-merge/packages/telemetry.POSTHOG_API_KEY_FOR_CLI={{ .Env.POSTHOG_API_KEY_FOR_CLI }} + flags: + - -trimpath + - -tags=pkcs11 + env: + - CGO_ENABLED=1 + - CC=aarch64-linux-gnu-gcc + goos: + - linux + goarch: + - arm64 + # BSDs and windows/arm64 stay on CGO=0 stub; see build-rdp-bridge.yml. - id: all-other-builds env: @@ -151,7 +186,18 @@ builds: goarch: arm archives: - - format_overrides: + - id: default + builds_info: + group: default + builds: + - linux-amd64-rdp + - linux-arm64-rdp + - linux-386-rdp + - linux-armv6-rdp + - linux-armv7-rdp + - windows-amd64-rdp + - all-other-builds + format_overrides: - goos: windows format: zip files: @@ -160,6 +206,15 @@ archives: - manpages/* - completions/* + - id: pkcs11 + builds: + - linux-amd64-pkcs11 + - linux-arm64-pkcs11 + name_template: "{{ .ProjectName }}-pkcs11_{{ .Version }}_{{ .Os }}_{{ .Arch }}" + files: + - README* + - LICENSE* + release: mode: append diff --git a/go.mod b/go.mod index eb6d44de..dd08a0ed 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/jackc/pgx/v5 v5.9.2 github.com/jcmturner/gokrb5/v8 v8.4.4 github.com/mattn/go-isatty v0.0.20 + github.com/miekg/pkcs11 v1.1.1 github.com/muesli/ansi v0.0.0-20221106050444-61f0cd9a192a github.com/muesli/mango-cobra v1.2.0 github.com/muesli/reflow v0.3.0 diff --git a/go.sum b/go.sum index 63fd6ff6..15176da8 100644 --- a/go.sum +++ b/go.sum @@ -444,6 +444,8 @@ github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRC github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= +github.com/miekg/pkcs11 v1.1.1 h1:Ugu9pdy6vAYku5DEpVWVFPYnzV+bxB+iRdbuFSu7TvU= +github.com/miekg/pkcs11 v1.1.1/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= diff --git a/packages/api/api.go b/packages/api/api.go index 25c72961..325c1c2e 100644 --- a/packages/api/api.go +++ b/packages/api/api.go @@ -812,10 +812,11 @@ func CallGatewayHeartBeatV1(httpClient *resty.Client) error { return nil } -func CallGatewayHeartBeatV2(httpClient *resty.Client) error { +func CallGatewayHeartBeatV2(httpClient *resty.Client, request GatewayHeartbeatRequest) error { response, err := httpClient. R(). SetHeader("User-Agent", USER_AGENT). + SetBody(request). Post(fmt.Sprintf("%v/v2/gateways/heartbeat", config.INFISICAL_URL)) if err != nil { diff --git a/packages/api/model.go b/packages/api/model.go index 92879031..f7d116c0 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -1008,6 +1008,10 @@ type RelayHeartbeatRequest struct { Name string `json:"name"` } +type GatewayHeartbeatRequest struct { + Capabilities map[string]any `json:"capabilities,omitempty"` +} + type RelayLoginRequest struct { Method string `json:"method"` Token string `json:"token,omitempty"` diff --git a/packages/cmd/gateway.go b/packages/cmd/gateway.go index 4cef610b..6ee646ab 100644 --- a/packages/cmd/gateway.go +++ b/packages/cmd/gateway.go @@ -6,6 +6,7 @@ import ( "os" "os/exec" "os/signal" + "path/filepath" "runtime" "sync/atomic" "syscall" @@ -401,11 +402,26 @@ var gatewayStartCmd = &cobra.Command{ } } + pkcs11ModulePath, _ := cmd.Flags().GetString("pkcs11-module") + if pkcs11ModulePath != "" { + if !filepath.IsAbs(pkcs11ModulePath) { + util.HandleError(fmt.Errorf("--pkcs11-module must be an absolute path (got %q)", pkcs11ModulePath), "unable to load PKCS#11 driver") + } + info, statErr := os.Stat(pkcs11ModulePath) + if statErr != nil { + util.HandleError(fmt.Errorf("PKCS#11 driver not found at %q: %w", pkcs11ModulePath, statErr), "unable to load PKCS#11 driver") + } + if info.IsDir() { + util.HandleError(fmt.Errorf("--pkcs11-module path is a directory, expected a driver file: %q", pkcs11ModulePath), "unable to load PKCS#11 driver") + } + } + gatewayInstance, err := gatewayv2.NewGateway(&gatewayv2.GatewayConfig{ - Name: gatewayName, - RelayName: relayName, - ReconnectDelay: 10 * time.Second, - UseV3Connect: runningWithStoredToken, + Name: gatewayName, + RelayName: relayName, + ReconnectDelay: 10 * time.Second, + UseV3Connect: runningWithStoredToken, + Pkcs11ModulePath: pkcs11ModulePath, }) if err != nil { @@ -759,6 +775,7 @@ func init() { gatewayStartCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") gatewayStartCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") gatewayStartCmd.Flags().String("pam-session-recording-path", "", "directory path for PAM session recordings (defaults to /var/lib/infisical/session_recordings)") + gatewayStartCmd.Flags().String("pkcs11-module", "", "absolute path to a PKCS#11 driver (e.g. /opt/fortanix/pkcs11/fortanix_pkcs11.so). When set, the gateway loads the driver, advertises pkcs11 capability on heartbeat, and serves HSM operations.") // Legacy install command flags (v1) gatewayInstallCmd.Flags().String("token", "", "Connect with Infisical using machine identity access token") diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index de87d37b..0539e5d6 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -40,6 +40,7 @@ const ( ForwardModePAMCapabilities ForwardMode = "PAM_CAPABILITIES" ForwardModePing ForwardMode = "PING" ForwardModeHealth ForwardMode = "HEALTH" + ForwardModePkcs11 ForwardMode = "PKCS11" ) type ActorType string @@ -82,12 +83,13 @@ type ActorDetails struct { } type GatewayConfig struct { - Name string - RelayName string - IdentityToken string - SSHPort int - ReconnectDelay time.Duration - UseV3Connect bool // Use V3 /connect endpoint instead of V2 /gateways for cert refresh + Name string + RelayName string + IdentityToken string + SSHPort int + ReconnectDelay time.Duration + UseV3Connect bool // Use V3 /connect endpoint instead of V2 /gateways for cert refresh + Pkcs11ModulePath string } type pamSessionEntry struct { @@ -132,6 +134,7 @@ type Gateway struct { // MongoDB proxy registry: one topology per session, shared across connections mongoProxies map[string]*mongoProxyEntry mongoProxiesMu sync.Mutex + pkcs11Module Pkcs11Module } // mongoProxyEntry holds a session-level MongoDB proxy with a ready signal. @@ -160,6 +163,17 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) { pamCredentialsManager := session.NewCredentialsManager(httpClient) + var pkcs11Module Pkcs11Module + if config.Pkcs11ModulePath != "" { + mod, err := LoadPkcs11Module(config.Pkcs11ModulePath) + if err != nil { + cancel() + return nil, fmt.Errorf("failed to load PKCS#11 module: %w", err) + } + pkcs11Module = mod + log.Info().Str("path", config.Pkcs11ModulePath).Msg("PKCS#11 module loaded; Gateway will serve HSM requests") + } + return &Gateway{ httpClient: httpClient, config: config, @@ -169,6 +183,7 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) { pamSessionUploader: session.NewSessionUploader(httpClient, pamCredentialsManager), pamSessions: make(map[string][]*pamSessionEntry), mongoProxies: make(map[string]*mongoProxyEntry), + pkcs11Module: pkcs11Module, }, nil } @@ -366,7 +381,12 @@ func (g *Gateway) reapIdleSessions() { func (g *Gateway) registerHeartBeat(ctx context.Context, errCh chan error) { sendHeartbeat := func() error { - if err := api.CallGatewayHeartBeatV2(g.httpClient); err != nil { + capabilities := map[string]any{} + if g.pkcs11Module != nil { + capabilities[CapabilityPkcs11] = true + } + req := api.GatewayHeartbeatRequest{Capabilities: capabilities} + if err := api.CallGatewayHeartBeatV2(g.httpClient, req); err != nil { log.Warn().Msgf("Heartbeat failed: %v", err) select { case errCh <- err: @@ -502,6 +522,13 @@ func (g *Gateway) Stop() { if g.pamCredentialsManager != nil { g.pamCredentialsManager.Shutdown() } + + if g.pkcs11Module != nil { + if err := g.pkcs11Module.Finalize(); err != nil { + log.Warn().Err(err).Msg("PKCS#11 module Finalize returned an error") + } + g.pkcs11Module = nil + } } func (g *Gateway) startHeartbeatOnce(ctx context.Context, errCh chan error) { @@ -707,7 +734,7 @@ func (g *Gateway) setupTLSConfig() error { ClientCAs: clientCAPool, ClientAuth: tls.RequireAndVerifyClientCert, MinVersion: tls.VersionTLS12, - NextProtos: []string{"infisical-http-proxy", "infisical-tcp-proxy", "infisical-health", "infisical-ping", "infisical-pam-proxy", "infisical-pam-rdp-browser", "infisical-pam-session-cancellation", "infisical-pam-capabilities"}, + NextProtos: nextProtosForGateway(g.pkcs11Module != nil), } return nil @@ -921,6 +948,14 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { log.Info().Msg("Health handler completed") } return + } else if forwardConfig.Mode == ForwardModePkcs11 { + log.Info().Msg("Starting PKCS#11 handler") + if err := servePkcs11OverTLS(g.ctx, tlsConn, reader, g.pkcs11Module); err != nil { + log.Error().Err(err).Msg("PKCS#11 handler ended with error") + } else { + log.Info().Msg("PKCS#11 handler completed") + } + return } } @@ -975,6 +1010,10 @@ func (g *Gateway) parseForwardConfigFromALPN(tlsConn *tls.Conn, reader *bufio.Re config.Mode = ForwardModeHealth return config, nil + case "infisical-pkcs11": + config.Mode = ForwardModePkcs11 + return config, nil + default: return nil, fmt.Errorf("unsupported ALPN protocol: %s", negotiatedProtocol) } @@ -1137,3 +1176,20 @@ func (g *Gateway) renewCertificates() error { return nil } + +func nextProtosForGateway(pkcs11Loaded bool) []string { + base := []string{ + "infisical-http-proxy", + "infisical-tcp-proxy", + "infisical-health", + "infisical-ping", + "infisical-pam-proxy", + "infisical-pam-rdp-browser", + "infisical-pam-session-cancellation", + "infisical-pam-capabilities", + } + if pkcs11Loaded { + base = append(base, "infisical-pkcs11") + } + return base +} diff --git a/packages/gateway-v2/pkcs11.go b/packages/gateway-v2/pkcs11.go new file mode 100644 index 00000000..59077a01 --- /dev/null +++ b/packages/gateway-v2/pkcs11.go @@ -0,0 +1,53 @@ +package gatewayv2 + +type Pkcs11Module interface { + Test(slotLabel string, pin []byte) (SlotInfo, error) + + GenerateKeyPair(slotLabel string, pin []byte, keyLabel string, keyAlgorithm string) ([]byte, error) + + GetPublicKey(slotLabel string, pin []byte, keyLabel string) ([]byte, error) + + Sign(slotLabel string, pin []byte, keyLabel string, mechanism string, data []byte, isDigest bool) ([]byte, error) + + Finalize() error +} + +type SlotInfo struct { + Manufacturer string `json:"manufacturer"` + Model string `json:"model"` + Firmware string `json:"firmware"` +} + +type Pkcs11ErrorCode string + +const ( + Pkcs11ErrPinIncorrect Pkcs11ErrorCode = "pin_incorrect" + Pkcs11ErrPinLocked Pkcs11ErrorCode = "pin_locked" + Pkcs11ErrSlotNotFound Pkcs11ErrorCode = "slot_not_found" + Pkcs11ErrKeyNotFound Pkcs11ErrorCode = "key_not_found" + Pkcs11ErrMechanismInvalid Pkcs11ErrorCode = "mechanism_invalid" + Pkcs11ErrDriverUnavailable Pkcs11ErrorCode = "driver_unavailable" + Pkcs11ErrLoginFailed Pkcs11ErrorCode = "login_failed" + Pkcs11ErrNotSupported Pkcs11ErrorCode = "pkcs11_not_supported" + Pkcs11ErrBadRequest Pkcs11ErrorCode = "bad_request" + Pkcs11ErrInternal Pkcs11ErrorCode = "internal" +) + +type Pkcs11Error struct { + Code Pkcs11ErrorCode + Message string +} + +func (e *Pkcs11Error) Error() string { + return string(e.Code) + ": " + e.Message +} + +// Supported keyAlgorithm values. +const ( + KeyAlgorithmRSA2048 = "RSA_2048" + KeyAlgorithmRSA4096 = "RSA_4096" + KeyAlgorithmECCP256 = "ECC_P256" + KeyAlgorithmECCP384 = "ECC_P384" +) + +const CapabilityPkcs11 = "pkcs11" diff --git a/packages/gateway-v2/pkcs11_disabled.go b/packages/gateway-v2/pkcs11_disabled.go new file mode 100644 index 00000000..9479390e --- /dev/null +++ b/packages/gateway-v2/pkcs11_disabled.go @@ -0,0 +1,10 @@ +//go:build !pkcs11 + +package gatewayv2 + +func LoadPkcs11Module(_ string) (Pkcs11Module, error) { + return nil, &Pkcs11Error{ + Code: Pkcs11ErrNotSupported, + Message: "This Gateway build was compiled without PKCS#11 support. Use the infisical-pkcs11 release artifact, or build from source with `go build -tags pkcs11` (cgo + dynamic linking required).", + } +} diff --git a/packages/gateway-v2/pkcs11_enabled.go b/packages/gateway-v2/pkcs11_enabled.go new file mode 100644 index 00000000..5808a3ea --- /dev/null +++ b/packages/gateway-v2/pkcs11_enabled.go @@ -0,0 +1,452 @@ +//go:build pkcs11 + +package gatewayv2 + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509" + "encoding/asn1" + "encoding/binary" + "fmt" + "math/big" + "strings" + "sync" + + "github.com/miekg/pkcs11" + "github.com/rs/zerolog/log" +) + +var ( + ecParamsP256 = []byte{0x06, 0x08, 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07} + ecParamsP384 = []byte{0x06, 0x05, 0x2B, 0x81, 0x04, 0x00, 0x22} +) + +type pkcs11ModuleImpl struct { + mu sync.Mutex + ctx *pkcs11.Ctx +} + +func LoadPkcs11Module(path string) (Pkcs11Module, error) { + if strings.TrimSpace(path) == "" { + return nil, &Pkcs11Error{ + Code: Pkcs11ErrDriverUnavailable, + Message: "Empty --pkcs11-module path", + } + } + ctx := pkcs11.New(path) + if ctx == nil { + return nil, &Pkcs11Error{ + Code: Pkcs11ErrDriverUnavailable, + Message: fmt.Sprintf("Failed to dlopen PKCS#11 driver at %q", path), + } + } + if err := ctx.Initialize(); err != nil { + if e, ok := err.(pkcs11.Error); !ok || e != pkcs11.CKR_CRYPTOKI_ALREADY_INITIALIZED { + ctx.Destroy() + return nil, &Pkcs11Error{ + Code: Pkcs11ErrDriverUnavailable, + Message: fmt.Sprintf("PKCS#11 C_Initialize failed: %v", err), + } + } + } + if _, err := ctx.GetSlotList(true); err != nil { + _ = ctx.Finalize() + ctx.Destroy() + return nil, &Pkcs11Error{ + Code: Pkcs11ErrDriverUnavailable, + Message: fmt.Sprintf("PKCS#11 C_GetSlotList failed: %v", err), + } + } + return &pkcs11ModuleImpl{ctx: ctx}, nil +} + +func (m *pkcs11ModuleImpl) Finalize() error { + m.mu.Lock() + defer m.mu.Unlock() + if m.ctx == nil { + return nil + } + err := m.ctx.Finalize() + m.ctx.Destroy() + m.ctx = nil + return err +} + +type sessionFn func(slot uint, sh pkcs11.SessionHandle) error + +func (m *pkcs11ModuleImpl) withSession(slotLabel string, pin []byte, fn sessionFn) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.ctx == nil { + return &Pkcs11Error{Code: Pkcs11ErrDriverUnavailable, Message: "Module is not loaded"} + } + slots, err := m.ctx.GetSlotList(true) + if err != nil { + return &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "GetSlotList failed"} + } + slot, ok := findSlotByLabel(m.ctx, slots, slotLabel) + if !ok { + return &Pkcs11Error{Code: Pkcs11ErrSlotNotFound, Message: fmt.Sprintf("Slot %q not found on this HSM", slotLabel)} + } + session, err := m.ctx.OpenSession(slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) + if err != nil { + return &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "OpenSession failed"} + } + defer func() { + if closeErr := m.ctx.CloseSession(session); closeErr != nil { + log.Warn().Err(closeErr).Msg("pkcs11: CloseSession failed") + } + }() + + loggedIn := false + loginErr := m.ctx.Login(session, pkcs11.CKU_USER, string(pin)) + if loginErr != nil { + if e, ok := loginErr.(pkcs11.Error); !ok || e != pkcs11.CKR_USER_ALREADY_LOGGED_IN { + return mapPkcs11LoginError(loginErr) + } + } else { + loggedIn = true + } + if loggedIn { + defer func() { + if logoutErr := m.ctx.Logout(session); logoutErr != nil { + log.Warn().Err(logoutErr).Msg("pkcs11: Logout failed") + } + }() + } + + return fn(slot, session) +} + +func findSlotByLabel(ctx *pkcs11.Ctx, slots []uint, label string) (uint, bool) { + for _, slot := range slots { + ti, err := ctx.GetTokenInfo(slot) + if err != nil { + continue + } + if strings.TrimRight(ti.Label, " \x00") == label { + return slot, true + } + } + return 0, false +} + +func mapPkcs11LoginError(err error) error { + if e, ok := err.(pkcs11.Error); ok { + switch e { + case pkcs11.CKR_PIN_INCORRECT: + return &Pkcs11Error{Code: Pkcs11ErrPinIncorrect, Message: "The HSM rejected the PIN"} + case pkcs11.CKR_PIN_LOCKED: + return &Pkcs11Error{Code: Pkcs11ErrPinLocked, Message: "The HSM has locked the slot"} + case pkcs11.CKR_TOKEN_NOT_PRESENT, pkcs11.CKR_DEVICE_REMOVED, pkcs11.CKR_DEVICE_ERROR: + return &Pkcs11Error{Code: Pkcs11ErrDriverUnavailable, Message: "Driver unavailable"} + } + } + return &Pkcs11Error{Code: Pkcs11ErrLoginFailed, Message: "The HSM rejected the login"} +} + +func (m *pkcs11ModuleImpl) Test(slotLabel string, pin []byte) (SlotInfo, error) { + var info SlotInfo + err := m.withSession(slotLabel, pin, func(slot uint, _ pkcs11.SessionHandle) error { + ti, err := m.ctx.GetTokenInfo(slot) + if err != nil { + return &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "GetTokenInfo failed"} + } + info = SlotInfo{ + Manufacturer: strings.TrimRight(ti.ManufacturerID, " \x00"), + Model: strings.TrimRight(ti.Model, " \x00"), + Firmware: fmt.Sprintf("%d.%d", ti.FirmwareVersion.Major, ti.FirmwareVersion.Minor), + } + return nil + }) + return info, err +} + +func (m *pkcs11ModuleImpl) GenerateKeyPair(slotLabel string, pin []byte, keyLabel, keyAlgorithm string) ([]byte, error) { + var spkiDer []byte + err := m.withSession(slotLabel, pin, func(_ uint, session pkcs11.SessionHandle) error { + mech, pubTpl, privTpl, err := generateKeyPairTemplates(keyLabel, keyAlgorithm) + if err != nil { + return err + } + pubHandle, _, err := m.ctx.GenerateKeyPair(session, mech, pubTpl, privTpl) + if err != nil { + return &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "GenerateKeyPair failed"} + } + der, err := buildSpkiFromHandle(m.ctx, session, pubHandle, keyAlgorithm) + if err != nil { + return err + } + spkiDer = der + return nil + }) + return spkiDer, err +} + +func generateKeyPairTemplates(keyLabel, keyAlgorithm string) ([]*pkcs11.Mechanism, []*pkcs11.Attribute, []*pkcs11.Attribute, error) { + commonPriv := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_LABEL, []byte(keyLabel)), + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), + pkcs11.NewAttribute(pkcs11.CKA_PRIVATE, true), + pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, true), + pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, false), + pkcs11.NewAttribute(pkcs11.CKA_SIGN, true), + } + commonPub := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_LABEL, []byte(keyLabel)), + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), + pkcs11.NewAttribute(pkcs11.CKA_VERIFY, true), + } + switch keyAlgorithm { + case KeyAlgorithmRSA2048, KeyAlgorithmRSA4096: + modulusBits := 2048 + if keyAlgorithm == KeyAlgorithmRSA4096 { + modulusBits = 4096 + } + pubTpl := append(commonPub, + pkcs11.NewAttribute(pkcs11.CKA_MODULUS_BITS, modulusBits), + pkcs11.NewAttribute(pkcs11.CKA_PUBLIC_EXPONENT, []byte{0x01, 0x00, 0x01}), + ) + return []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_RSA_PKCS_KEY_PAIR_GEN, nil)}, pubTpl, commonPriv, nil + case KeyAlgorithmECCP256: + pubTpl := append(commonPub, pkcs11.NewAttribute(pkcs11.CKA_EC_PARAMS, ecParamsP256)) + return []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_EC_KEY_PAIR_GEN, nil)}, pubTpl, commonPriv, nil + case KeyAlgorithmECCP384: + pubTpl := append(commonPub, pkcs11.NewAttribute(pkcs11.CKA_EC_PARAMS, ecParamsP384)) + return []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_EC_KEY_PAIR_GEN, nil)}, pubTpl, commonPriv, nil + default: + return nil, nil, nil, &Pkcs11Error{Code: Pkcs11ErrMechanismInvalid, Message: fmt.Sprintf("Unsupported keyAlgorithm %q", keyAlgorithm)} + } +} + +func buildSpkiFromHandle(ctx *pkcs11.Ctx, session pkcs11.SessionHandle, pubHandle pkcs11.ObjectHandle, keyAlgorithm string) ([]byte, error) { + switch keyAlgorithm { + case KeyAlgorithmRSA2048, KeyAlgorithmRSA4096: + attrs, err := ctx.GetAttributeValue(session, pubHandle, []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_MODULUS, nil), + pkcs11.NewAttribute(pkcs11.CKA_PUBLIC_EXPONENT, nil), + }) + if err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "RSA GetAttributeValue failed"} + } + var modulus, exp []byte + for _, a := range attrs { + switch a.Type { + case pkcs11.CKA_MODULUS: + modulus = a.Value + case pkcs11.CKA_PUBLIC_EXPONENT: + exp = a.Value + } + } + pub := &rsa.PublicKey{ + N: new(big.Int).SetBytes(modulus), + E: int(new(big.Int).SetBytes(exp).Int64()), + } + der, err := x509.MarshalPKIXPublicKey(pub) + if err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "MarshalPKIXPublicKey failed"} + } + return der, nil + + case KeyAlgorithmECCP256, KeyAlgorithmECCP384: + attrs, err := ctx.GetAttributeValue(session, pubHandle, []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, nil), + }) + if err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "EC GetAttributeValue failed"} + } + if len(attrs) == 0 { + return nil, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "CKA_EC_POINT missing from response"} + } + // CKA_EC_POINT is DER OCTET STRING wrapping the raw point. + var raw []byte + if _, err := asn1.Unmarshal(attrs[0].Value, &raw); err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "Unmarshal CKA_EC_POINT failed"} + } + var curve elliptic.Curve + if keyAlgorithm == KeyAlgorithmECCP256 { + curve = elliptic.P256() + } else { + curve = elliptic.P384() + } + // Parse the uncompressed point format (RFC 5480 Section 2.2): 0x04 || X || Y + // with each coordinate padded to (BitSize + 7) / 8 bytes. Stdlib's + // elliptic.Unmarshal is deprecated and there is no ECDSA-specific replacement, + // so do the parse inline. + byteLen := (curve.Params().BitSize + 7) / 8 + if len(raw) != 1+2*byteLen || raw[0] != 0x04 { + return nil, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "Failed to unmarshal EC point"} + } + x := new(big.Int).SetBytes(raw[1 : 1+byteLen]) + y := new(big.Int).SetBytes(raw[1+byteLen:]) + pub := &ecdsa.PublicKey{Curve: curve, X: x, Y: y} + der, err := x509.MarshalPKIXPublicKey(pub) + if err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "MarshalPKIXPublicKey failed"} + } + return der, nil + } + return nil, &Pkcs11Error{Code: Pkcs11ErrMechanismInvalid, Message: "Unsupported keyAlgorithm for SPKI build"} +} + +func (m *pkcs11ModuleImpl) GetPublicKey(slotLabel string, pin []byte, keyLabel string) ([]byte, error) { + var spkiDer []byte + err := m.withSession(slotLabel, pin, func(_ uint, session pkcs11.SessionHandle) error { + handle, found, err := findObject(m.ctx, session, keyLabel, pkcs11.CKO_PUBLIC_KEY) + if err != nil { + return err + } + if !found { + return &Pkcs11Error{Code: Pkcs11ErrKeyNotFound, Message: fmt.Sprintf("Public key with label %q not found", keyLabel)} + } + alg, err := detectKeyAlgorithm(m.ctx, session, handle) + if err != nil { + return err + } + der, err := buildSpkiFromHandle(m.ctx, session, handle, alg) + if err != nil { + return err + } + spkiDer = der + return nil + }) + return spkiDer, err +} + +func detectKeyAlgorithm(ctx *pkcs11.Ctx, session pkcs11.SessionHandle, handle pkcs11.ObjectHandle) (string, error) { + attrs, err := ctx.GetAttributeValue(session, handle, []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, nil), + }) + if err != nil || len(attrs) == 0 { + return "", &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "Failed to read CKA_KEY_TYPE"} + } + raw := make([]byte, 8) + copy(raw, attrs[0].Value) + keyType := uint(binary.LittleEndian.Uint64(raw)) + switch keyType { + case pkcs11.CKK_RSA: + modAttrs, err := ctx.GetAttributeValue(session, handle, []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_MODULUS, nil)}) + if err != nil || len(modAttrs) == 0 { + return "", &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "Failed to read CKA_MODULUS"} + } + switch len(modAttrs[0].Value) { + case 256: + return KeyAlgorithmRSA2048, nil + case 512: + return KeyAlgorithmRSA4096, nil + } + return "", &Pkcs11Error{Code: Pkcs11ErrMechanismInvalid, Message: fmt.Sprintf("Unsupported RSA modulus length: %d bits", len(modAttrs[0].Value)*8)} + case pkcs11.CKK_EC: + paramsAttrs, err := ctx.GetAttributeValue(session, handle, []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_EC_PARAMS, nil)}) + if err != nil || len(paramsAttrs) == 0 { + return "", &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "Failed to read CKA_EC_PARAMS"} + } + if bytes.Equal(paramsAttrs[0].Value, ecParamsP256) { + return KeyAlgorithmECCP256, nil + } + if bytes.Equal(paramsAttrs[0].Value, ecParamsP384) { + return KeyAlgorithmECCP384, nil + } + return "", &Pkcs11Error{Code: Pkcs11ErrMechanismInvalid, Message: "Unsupported EC curve"} + } + return "", &Pkcs11Error{Code: Pkcs11ErrMechanismInvalid, Message: fmt.Sprintf("Unsupported PKCS#11 key type: %d", keyType)} +} + +func findObject(ctx *pkcs11.Ctx, session pkcs11.SessionHandle, label string, class uint) (pkcs11.ObjectHandle, bool, error) { + tpl := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_LABEL, []byte(label)), + pkcs11.NewAttribute(pkcs11.CKA_CLASS, class), + } + if err := ctx.FindObjectsInit(session, tpl); err != nil { + return 0, false, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "FindObjectsInit failed"} + } + defer func() { + if finalErr := ctx.FindObjectsFinal(session); finalErr != nil { + log.Warn().Err(finalErr).Msg("pkcs11: FindObjectsFinal failed") + } + }() + objs, _, err := ctx.FindObjects(session, 2) + if err != nil { + return 0, false, &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "FindObjects failed"} + } + if len(objs) == 0 { + return 0, false, nil + } + if len(objs) > 1 { + return 0, false, &Pkcs11Error{ + Code: Pkcs11ErrBadRequest, + Message: fmt.Sprintf("Multiple objects on the HSM share label %q. Resolve the duplicate before proceeding.", label), + } + } + return objs[0], true, nil +} + +func (m *pkcs11ModuleImpl) Sign(slotLabel string, pin []byte, keyLabel, mechanism string, data []byte, isDigest bool) ([]byte, error) { + log.Debug().Str("keyLabel", keyLabel).Str("mech", mechanism).Int("dataLen", len(data)).Msg("pkcs11.Sign: enter") + var sig []byte + err := m.withSession(slotLabel, pin, func(_ uint, session pkcs11.SessionHandle) error { + mechCode, params, err := resolveMechanism(mechanism, isDigest) + if err != nil { + return err + } + handle, found, err := findObject(m.ctx, session, keyLabel, pkcs11.CKO_PRIVATE_KEY) + if err != nil { + return err + } + if !found { + return &Pkcs11Error{Code: Pkcs11ErrKeyNotFound, Message: fmt.Sprintf("Private key with label %q not found", keyLabel)} + } + if err := m.ctx.SignInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(mechCode, params)}, handle); err != nil { + return mapPkcs11SignError(err) + } + out, err := m.ctx.Sign(session, data) + if err != nil { + return mapPkcs11SignError(err) + } + sig = out + return nil + }) + log.Debug().Bool("ok", err == nil).Int("sigLen", len(sig)).Msg("pkcs11.Sign: done") + return sig, err +} + +func resolveMechanism(name string, isDigest bool) (uint, []byte, error) { + switch name { + case "CKM_SHA256_RSA_PKCS": + return pkcs11.CKM_SHA256_RSA_PKCS, nil, nil + case "CKM_SHA384_RSA_PKCS": + return pkcs11.CKM_SHA384_RSA_PKCS, nil, nil + case "CKM_SHA512_RSA_PKCS": + return pkcs11.CKM_SHA512_RSA_PKCS, nil, nil + case "CKM_ECDSA_SHA256": + if isDigest { + return pkcs11.CKM_ECDSA, nil, nil + } + return pkcs11.CKM_ECDSA_SHA256, nil, nil + case "CKM_ECDSA_SHA384": + if isDigest { + return pkcs11.CKM_ECDSA, nil, nil + } + return pkcs11.CKM_ECDSA_SHA384, nil, nil + case "CKM_ECDSA": + return pkcs11.CKM_ECDSA, nil, nil + } + return 0, nil, &Pkcs11Error{Code: Pkcs11ErrMechanismInvalid, Message: fmt.Sprintf("Unsupported mechanism %q", name)} +} + +func mapPkcs11SignError(err error) error { + if e, ok := err.(pkcs11.Error); ok { + switch e { + case pkcs11.CKR_KEY_HANDLE_INVALID, pkcs11.CKR_OBJECT_HANDLE_INVALID: + return &Pkcs11Error{Code: Pkcs11ErrKeyNotFound, Message: "The HSM rejected the key handle"} + case pkcs11.CKR_MECHANISM_INVALID, pkcs11.CKR_KEY_TYPE_INCONSISTENT: + return &Pkcs11Error{Code: Pkcs11ErrMechanismInvalid, Message: "The HSM does not support the requested signing algorithm"} + case pkcs11.CKR_TOKEN_NOT_PRESENT, pkcs11.CKR_DEVICE_REMOVED, pkcs11.CKR_DEVICE_ERROR: + return &Pkcs11Error{Code: Pkcs11ErrDriverUnavailable, Message: "Driver unavailable"} + } + } + return &Pkcs11Error{Code: Pkcs11ErrInternal, Message: "Sign operation failed"} +} diff --git a/packages/gateway-v2/pkcs11_handler.go b/packages/gateway-v2/pkcs11_handler.go new file mode 100644 index 00000000..6a0eb5ec --- /dev/null +++ b/packages/gateway-v2/pkcs11_handler.go @@ -0,0 +1,353 @@ +package gatewayv2 + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/rs/zerolog/log" +) + +type pkcs11RequestEnvelope struct { + SlotLabel string `json:"slotLabel"` + PIN []byte `json:"-"` + Params json.RawMessage `json:"params"` +} + +func (e *pkcs11RequestEnvelope) UnmarshalJSON(data []byte) error { + var raw struct { + SlotLabel string `json:"slotLabel"` + PIN string `json:"pin"` + Params json.RawMessage `json:"params"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + e.SlotLabel = raw.SlotLabel + e.PIN = []byte(raw.PIN) + e.Params = raw.Params + return nil +} + +type pkcs11Response struct { + Result json.RawMessage `json:"result"` +} + +type pkcs11ErrorResponse struct { + Error pkcs11ErrorBody `json:"error"` +} + +type pkcs11ErrorBody struct { + Code Pkcs11ErrorCode `json:"code"` + Message string `json:"message"` +} + +const pkcs11RequestDeadline = 30 * time.Second + +func servePkcs11OverTLS(ctx context.Context, conn *tls.Conn, reader *bufio.Reader, module Pkcs11Module) error { + _ = conn.SetDeadline(time.Now().Add(pkcs11RequestDeadline)) + + if module == nil { + writeErrorResponse(conn, http.StatusServiceUnavailable, Pkcs11ErrNotSupported, "PKCS#11 module not loaded") + return errors.New("PKCS#11 module is nil") + } + + reqCh := make(chan *http.Request, 1) + errCh := make(chan error, 1) + go func() { + req, err := http.ReadRequest(reader) + if err != nil { + errCh <- err + return + } + reqCh <- req + }() + + var req *http.Request + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return fmt.Errorf("failed to read HTTP request: %w", err) + case req = <-reqCh: + } + + log.Debug().Str("path", req.URL.Path).Int64("contentLength", req.ContentLength).Msg("pkcs11: request received") + + rw := newBufferedResponseWriter() + servePkcs11Mux(module).ServeHTTP(rw, req) + if err := rw.writeTo(conn); err != nil { + return fmt.Errorf("failed to write response: %w", err) + } + log.Debug().Int("status", rw.status).Msg("pkcs11: response written") + return nil +} + +type bufferedResponseWriter struct { + header http.Header + body bytes.Buffer + status int + wroteStart bool +} + +func newBufferedResponseWriter() *bufferedResponseWriter { + return &bufferedResponseWriter{header: http.Header{}, status: http.StatusOK} +} +func (b *bufferedResponseWriter) Header() http.Header { return b.header } +func (b *bufferedResponseWriter) WriteHeader(s int) { + if b.wroteStart { + return + } + b.status = s + b.wroteStart = true +} +func (b *bufferedResponseWriter) Write(p []byte) (int, error) { + if !b.wroteStart { + b.WriteHeader(http.StatusOK) + } + return b.body.Write(p) +} +func (b *bufferedResponseWriter) writeTo(conn *tls.Conn) error { + body := b.body.Bytes() + if b.header.Get("Content-Length") == "" { + b.header.Set("Content-Length", strconv.Itoa(len(body))) + } + if b.header.Get("Connection") == "" { + b.header.Set("Connection", "close") + } + var sb strings.Builder + fmt.Fprintf(&sb, "HTTP/1.1 %d %s\r\n", b.status, http.StatusText(b.status)) + for k, vs := range b.header { + for _, v := range vs { + sb.WriteString(k) + sb.WriteString(": ") + sb.WriteString(v) + sb.WriteString("\r\n") + } + } + sb.WriteString("\r\n") + if _, err := conn.Write([]byte(sb.String())); err != nil { + return err + } + if _, err := conn.Write(body); err != nil { + return err + } + return nil +} + +func servePkcs11Mux(module Pkcs11Module) *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("/v1/test", wrapPkcs11(module, handleTest)) + mux.HandleFunc("/v1/generate-key-pair", wrapPkcs11(module, handleGenerateKeyPair)) + mux.HandleFunc("/v1/sign", wrapPkcs11(module, handleSign)) + mux.HandleFunc("/v1/get-public-key", wrapPkcs11(module, handleGetPublicKey)) + return mux +} + +type pkcs11Handler func(module Pkcs11Module, env *pkcs11RequestEnvelope) (any, error) + +const maxPkcs11RequestBodyBytes = 256 * 1024 + +func zeroBytes(b []byte) { + for i := range b { + b[i] = 0 + } +} + +func safeMessageForCode(code Pkcs11ErrorCode) string { + switch code { + case Pkcs11ErrPinIncorrect: + return "The HSM rejected the PIN" + case Pkcs11ErrPinLocked: + return "The HSM has locked the slot" + case Pkcs11ErrLoginFailed: + return "The HSM rejected the login" + case Pkcs11ErrSlotNotFound: + return "Slot not found on this HSM" + case Pkcs11ErrKeyNotFound: + return "Key not found on this HSM" + case Pkcs11ErrMechanismInvalid: + return "Mechanism not supported by this HSM" + case Pkcs11ErrDriverUnavailable: + return "Driver unavailable" + case Pkcs11ErrNotSupported: + return "Operation not supported" + case Pkcs11ErrBadRequest: + return "Invalid request" + } + return "Operation failed" +} + +func wrapPkcs11(module Pkcs11Module, fn pkcs11Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + defer func() { + if r.Body != nil { + _ = r.Body.Close() + } + }() + log.Debug().Str("path", r.URL.Path).Str("method", r.Method).Int64("contentLength", r.ContentLength).Msg("pkcs11: handler received request") + if r.Method != http.MethodPost { + writeErrorResponse(w, http.StatusMethodNotAllowed, Pkcs11ErrBadRequest, "Only POST is supported") + return + } + if r.ContentLength > maxPkcs11RequestBodyBytes { + log.Error().Int64("contentLength", r.ContentLength).Msg("pkcs11: request body too large") + writeErrorResponse(w, http.StatusRequestEntityTooLarge, Pkcs11ErrBadRequest, "Request body too large") + return + } + r.Body = http.MaxBytesReader(w, r.Body, maxPkcs11RequestBodyBytes) + var env pkcs11RequestEnvelope + if err := json.NewDecoder(r.Body).Decode(&env); err != nil { + log.Warn().Err(err).Msg("pkcs11: body decode failed") + writeErrorResponse(w, http.StatusBadRequest, Pkcs11ErrBadRequest, "Malformed request body") + return + } + defer zeroBytes(env.PIN) + log.Debug().Bool("hasPin", len(env.PIN) > 0).Msg("pkcs11: body decoded, dispatching to op handler") + result, err := fn(module, &env) + log.Debug().Bool("ok", err == nil).Msg("pkcs11: op handler returned") + if err != nil { + var p11Err *Pkcs11Error + if errors.As(err, &p11Err) { + log.Error().Str("code", string(p11Err.Code)).Str("errorMessage", p11Err.Message).Msg("pkcs11: op handler returned typed error") + writeErrorResponse(w, statusForCode(p11Err.Code), p11Err.Code, safeMessageForCode(p11Err.Code)) + return + } + log.Error().Err(err).Msg("pkcs11: op handler returned untyped error") + writeErrorResponse(w, http.StatusInternalServerError, Pkcs11ErrInternal, safeMessageForCode(Pkcs11ErrInternal)) + return + } + raw, err := json.Marshal(result) + if err != nil { + log.Error().Err(err).Msg("pkcs11: failed to marshal result") + writeErrorResponse(w, http.StatusInternalServerError, Pkcs11ErrInternal, "Failed to marshal result") + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(pkcs11Response{Result: raw}); err != nil { + log.Warn().Err(err).Msg("pkcs11: failed to encode response") + } + } +} + +type generateKeyPairParams struct { + KeyLabel string `json:"keyLabel"` + KeyAlgorithm string `json:"keyAlgorithm"` +} + +type signParams struct { + KeyLabel string `json:"keyLabel"` + Mechanism string `json:"mechanism"` + Data string `json:"data"` + IsDigest bool `json:"isDigest"` +} + +type singleKeyLabelParams struct { + KeyLabel string `json:"keyLabel"` +} + +const ( + maxKeyLabelLen = 256 + maxSignDataBytes = 64 * 1024 +) + +func handleTest(module Pkcs11Module, env *pkcs11RequestEnvelope) (any, error) { + info, err := module.Test(env.SlotLabel, env.PIN) + if err != nil { + return nil, err + } + return map[string]any{"slotInfo": info}, nil +} + +func handleGenerateKeyPair(module Pkcs11Module, env *pkcs11RequestEnvelope) (any, error) { + var p generateKeyPairParams + if err := json.Unmarshal(env.Params, &p); err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "Malformed params for generate-key-pair"} + } + if len(p.KeyLabel) > maxKeyLabelLen { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "keyLabel too long"} + } + spki, err := module.GenerateKeyPair(env.SlotLabel, env.PIN, p.KeyLabel, p.KeyAlgorithm) + if err != nil { + return nil, err + } + return map[string]any{"publicKey": base64.StdEncoding.EncodeToString(spki)}, nil +} + +func handleSign(module Pkcs11Module, env *pkcs11RequestEnvelope) (any, error) { + var p signParams + if err := json.Unmarshal(env.Params, &p); err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "Malformed params for sign"} + } + if len(p.KeyLabel) > maxKeyLabelLen { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "keyLabel too long"} + } + data, err := base64.StdEncoding.DecodeString(p.Data) + if err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "data is not valid base64"} + } + if len(data) == 0 { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "data is empty"} + } + if len(data) > maxSignDataBytes { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "Data too large for signing"} + } + sig, err := module.Sign(env.SlotLabel, env.PIN, p.KeyLabel, p.Mechanism, data, p.IsDigest) + if err != nil { + return nil, err + } + return map[string]any{"signature": base64.StdEncoding.EncodeToString(sig)}, nil +} + +func handleGetPublicKey(module Pkcs11Module, env *pkcs11RequestEnvelope) (any, error) { + var p singleKeyLabelParams + if err := json.Unmarshal(env.Params, &p); err != nil { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "Malformed params for get-public-key"} + } + if len(p.KeyLabel) > maxKeyLabelLen { + return nil, &Pkcs11Error{Code: Pkcs11ErrBadRequest, Message: "keyLabel too long"} + } + spki, err := module.GetPublicKey(env.SlotLabel, env.PIN, p.KeyLabel) + if err != nil { + return nil, err + } + return map[string]any{"publicKey": base64.StdEncoding.EncodeToString(spki)}, nil +} + +func statusForCode(code Pkcs11ErrorCode) int { + switch code { + case Pkcs11ErrPinIncorrect, Pkcs11ErrPinLocked, Pkcs11ErrLoginFailed, Pkcs11ErrSlotNotFound, Pkcs11ErrKeyNotFound, Pkcs11ErrMechanismInvalid, Pkcs11ErrBadRequest: + return http.StatusBadRequest + case Pkcs11ErrDriverUnavailable, Pkcs11ErrInternal: + return http.StatusBadGateway + case Pkcs11ErrNotSupported: + return http.StatusServiceUnavailable + } + return http.StatusBadGateway +} + +func writeErrorResponse(w any, status int, code Pkcs11ErrorCode, message string) { + body, _ := json.Marshal(pkcs11ErrorResponse{Error: pkcs11ErrorBody{Code: code, Message: message}}) + switch sink := w.(type) { + case http.ResponseWriter: + sink.Header().Set("Content-Type", "application/json") + sink.WriteHeader(status) + _, _ = sink.Write(body) + case *tls.Conn: + resp := fmt.Sprintf("HTTP/1.1 %d %s\r\nContent-Type: application/json\r\nContent-Length: %d\r\nConnection: close\r\n\r\n%s", + status, http.StatusText(status), len(body), body) + _, _ = sink.Write([]byte(resp)) + default: + log.Warn().Msg("writeErrorResponse called with unsupported sink type") + } +}