diff --git a/cmd/certinfo.go b/cmd/certinfo.go index f808f9c..e5f8497 100644 --- a/cmd/certinfo.go +++ b/cmd/certinfo.go @@ -5,9 +5,6 @@ Copyright © 2025 Zeno Belli xeno@os76.xyz package cmd import ( - "fmt" - "os" - "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/xenos76/https-wrench/internal/certinfo" @@ -48,14 +45,14 @@ Examples: https-wrench certinfo --ca-bundle ./ca-bundle.pem --tls-endpoint example.com:443 https-wrench certinfo --ca-bundle ./ca-bundle.pem --cert-bundle ./bundle.pem --key-file ./key.pem `, - Run: func(cmd *cobra.Command, args []string) { + Run: func(cmd *cobra.Command, _ []string) { caBundleValue := viper.GetString("ca-bundle") certBundleValue := viper.GetString("cert-bundle") keyFileValue := viper.GetString("key-file") versionRequested := viper.GetBool("version") if versionRequested { - fmt.Print(version) + cmd.Print(version) return } @@ -67,16 +64,16 @@ Examples: certinfoCfg, err := certinfo.NewCertinfoConfig() if err != nil { - fmt.Printf("Error creating new Certinfo config: %s", err) + cmd.Printf("Error creating new Certinfo config: %s", err) return } if err = certinfoCfg.SetCaPoolFromFile(caBundleValue, fileReader); err != nil { - fmt.Printf("Error importing CA Certificate bundle from file: %s", err) + cmd.Printf("Error importing CA Certificate bundle from file: %s", err) } if err = certinfoCfg.SetCertsFromFile(certBundleValue, fileReader); err != nil { - fmt.Printf("Error importing Certificate bundle from file: %s", err) + cmd.Printf("Error importing Certificate bundle from file: %s", err) } certinfoCfg.SetTLSInsecure(tlsInsecure).SetTLSServerName(tlsServerName) @@ -85,7 +82,7 @@ Examples: // before being able to ask details about the certificate we want to a // webserver using self-signed and valid certificates if err = certinfoCfg.SetTLSEndpoint(tlsEndpoint); err != nil { - fmt.Printf("Error setting TLS endpoint: %s", err) + cmd.Printf("Error setting TLS endpoint: %s", err) } if err = certinfoCfg.SetPrivateKeyFromFile( @@ -93,12 +90,12 @@ Examples: keyPwEnvVar, fileReader, ); err != nil { - fmt.Printf("Error importing key from file: %s", err) + cmd.Printf("Error importing key from file: %s", err) } // dump.Print(certinfoCfg) - if err = certinfoCfg.PrintData(os.Stdout); err != nil { - fmt.Printf("error printing Certinfo data: %s", err) + if err = certinfoCfg.PrintData(cmd.OutOrStdout()); err != nil { + cmd.Printf("error printing Certinfo data: %s", err) } }, } diff --git a/cmd/certinfo_test.go b/cmd/certinfo_test.go index 635a6e5..19f0b12 100644 --- a/cmd/certinfo_test.go +++ b/cmd/certinfo_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" ) +//nolint:revive func TestCertinfoCmd(t *testing.T) { tests := []struct { name string @@ -88,12 +89,36 @@ func TestCertinfoCmd(t *testing.T) { "--version Display the version", }, }, + { + name: "version", + args: []string{"certinfo", "--version"}, + expectError: false, + expected: []string{version}, + }, + { + //nolint:revive + name: "invalid files and endpoints", + //nolint:revive + args: []string{"certinfo", "--ca-bundle", "non_existent.pem", "--cert-bundle", "non_existent.pem", "--key-file", "non_existent.pem", "--tls-endpoint", "invalid://"}, + expectError: false, + expected: []string{"Error importing CA Certificate bundle", "Error importing Certificate bundle", "Error importing key", "Error setting TLS endpoint"}, + }, } for _, tc := range tests { tt := tc t.Run(tt.name, func(t *testing.T) { + t.Cleanup(func() { + require.NoError(t, rootCmd.PersistentFlags().Set("version", "false")) + require.NoError(t, certinfoCmd.Flags().Set("ca-bundle", "")) + require.NoError(t, certinfoCmd.Flags().Set("tls-endpoint", "")) + require.NoError(t, certinfoCmd.Flags().Set("tls-servername", "")) + require.NoError(t, certinfoCmd.Flags().Set("tls-insecure", "false")) + require.NoError(t, certinfoCmd.Flags().Set("cert-bundle", "")) + require.NoError(t, certinfoCmd.Flags().Set("key-file", "")) + }) + reqOut := new(bytes.Buffer) reqCmd := rootCmd reqCmd.SetOut(reqOut) diff --git a/cmd/jwtinfo.go b/cmd/jwtinfo.go index 02bd7c1..07ca6ac 100644 --- a/cmd/jwtinfo.go +++ b/cmd/jwtinfo.go @@ -5,10 +5,8 @@ Copyright © 2026 Zeno Belli package cmd import ( - "fmt" "io" "net/http" - "os" "github.com/MicahParks/keyfunc/v3" "github.com/spf13/cobra" @@ -31,8 +29,8 @@ var ( var jwtinfoCmd = &cobra.Command{ Use: "jwtinfo", - Short: "JwtInfo request and display JWT token data", - Long: `JwtInfo request and display JWT token data + Short: "JwtInfo shows data from a JWT token", + Long: `JwtInfo shows data from a JWT token Examples: export REQ_URL="https://sample.provider/oauth/token" @@ -51,7 +49,7 @@ Examples: # Request and validate a JWT token https-wrench jwtinfo --request-url $REQ_URL --request-values-json $REQ_VALUES --validation-url $VALIDATION_URL `, - Run: func(cmd *cobra.Command, args []string) { + Run: func(cmd *cobra.Command, _ []string) { // TODO: remove global --config option var err error @@ -60,7 +58,7 @@ Examples: if tokenFile != "" { tokenData, err = jwtinfo.ReadTokenFromFile(tokenFile) if err != nil { - fmt.Printf( + cmd.Printf( "error while reading token value from file: %s", err, ) @@ -78,7 +76,7 @@ Examples: requestValuesMap, ) if err != nil { - fmt.Printf( + cmd.Printf( "error while reading request's values from file: %s", err, ) @@ -92,7 +90,7 @@ Examples: requestValuesMap, ) if err != nil { - fmt.Printf( + cmd.Printf( "error while parsing request's values JSON string: %s", err, ) @@ -107,7 +105,7 @@ Examples: io.ReadAll, ) if err != nil { - fmt.Printf("error while requesting token data: %s\n", err) + cmd.Printf("error while requesting token data: %s\n", err) return } } @@ -115,21 +113,21 @@ Examples: if tokenData.AccessTokenRaw != "" { err = tokenData.DecodeBase64() if err != nil { - fmt.Printf("DecodeBase64 error: %s\n", err) + cmd.Printf("DecodeBase64 error: %s\n", err) return } if jwksURL != "" { err = tokenData.ParseWithJWKS(jwksURL, keyfuncDefOverride) if err != nil { - fmt.Printf("error while parsing token data: %s\n", err) + cmd.Printf("error while parsing token data: %s\n", err) return } } - err = jwtinfo.PrintTokenInfo(tokenData, os.Stdout) + err = jwtinfo.PrintTokenInfo(tokenData, cmd.OutOrStdout()) if err != nil { - fmt.Printf("error while printing token data: %s\n", err) + cmd.Printf("error while printing token data: %s\n", err) return } } else { diff --git a/cmd/jwtinfo_test.go b/cmd/jwtinfo_test.go new file mode 100644 index 0000000..73d3bbf --- /dev/null +++ b/cmd/jwtinfo_test.go @@ -0,0 +1,58 @@ +package cmd + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestJwtinfoCmd(t *testing.T) { + tests := []struct { + name string + args []string + expectError bool + errMsgs []string + expected []string + }{ + { + name: "invalid file", + args: []string{"jwtinfo", "--token-file", "non_existent.jwt"}, + expectError: false, + expected: []string{"error while reading token value from file"}, + }, + } + + for _, tc := range tests { + tt := tc + t.Run(tt.name, func(t *testing.T) { + t.Cleanup(func() { + rootCmd.Flags().Set("version", "false") + jwtinfoCmd.Flags().Set("token-file", "") + jwtinfoCmd.Flags().Set("clipboard", "false") + }) + + reqOut := new(bytes.Buffer) + reqCmd := rootCmd + reqCmd.SetOut(reqOut) + reqCmd.SetErr(reqOut) + reqCmd.SetArgs(tt.args) + err := reqCmd.Execute() + + if tt.expectError { + require.Error(t, err) + + for _, expected := range tt.errMsgs { + require.ErrorContains(t, err, expected) + } + } else { + require.NoError(t, err) + } + + got := reqOut.String() + for _, expected := range tt.expected { + require.Contains(t, got, expected) + } + }) + } +} diff --git a/cmd/requests.go b/cmd/requests.go index ea74a87..3d6469c 100644 --- a/cmd/requests.go +++ b/cmd/requests.go @@ -6,7 +6,6 @@ package cmd import ( _ "embed" - "fmt" "os" "github.com/gookit/goutil/dump" @@ -41,7 +40,7 @@ Examples: https-wrench requests --config https-wrench-sample-config.yaml `, - Run: func(cmd *cobra.Command, args []string) { + Run: func(cmd *cobra.Command, _ []string) { versionRequested := viper.GetBool("version") if versionRequested { @@ -61,14 +60,14 @@ Examples: _, err := os.Stat(viper.ConfigFileUsed()) if err != nil { - fmt.Printf("\nConfig file not found: %s\n", viper.ConfigFileUsed()) + cmd.Printf("\nConfig file not found: %s\n", viper.ConfigFileUsed()) _ = cmd.Help() return } cfg, err := LoadConfig() if err != nil { - fmt.Print(err) + cmd.Print(err) return } @@ -78,7 +77,7 @@ Examples: requestsCfg, err := requests.NewRequestsMetaConfig() if err != nil { - fmt.Print(err) + cmd.Print(err) return } @@ -87,16 +86,16 @@ Examples: SetRequests(cfg.Requests) if err := requestsCfg.SetCaPoolFromYAML(cfg.CaBundle); err != nil { - fmt.Print(err) + cmd.Print(err) } if err := requestsCfg.SetCaPoolFromFile(caBundlePath, fileReader); err != nil { - fmt.Print(err) + cmd.Print(err) } - responseMap, err := requests.HandleRequests(os.Stdout, requestsCfg) + responseMap, err := requests.HandleRequests(cmd.OutOrStdout(), requestsCfg) if err != nil { - fmt.Print(err) + cmd.Print(err) } if cfg.Debug { diff --git a/cmd/requests_test.go b/cmd/requests_test.go index ba8fcea..e7123a5 100644 --- a/cmd/requests_test.go +++ b/cmd/requests_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" ) +//nolint:revive func TestRequestsCmd(t *testing.T) { tests := []struct { name string @@ -111,6 +112,13 @@ func TestRequestsCmd(t *testing.T) { tt := tc t.Run(tt.name, func(t *testing.T) { + t.Cleanup(func() { + require.NoError(t, rootCmd.Flags().Set("version", "false")) + require.NoError(t, requestsCmd.Flags().Set("ca-bundle", "")) + require.NoError(t, rootCmd.Flags().Set("config", "")) + require.NoError(t, requestsCmd.Flags().Set("show-sample-config", "false")) + }) + reqOut := new(bytes.Buffer) reqCmd := rootCmd reqCmd.SetOut(reqOut) diff --git a/cmd/root.go b/cmd/root.go index 72337e2..fcc1222 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -51,7 +51,8 @@ var rootCmd = &cobra.Command{ Use: "https-wrench", Short: "HTTPS Wrench, a tool to make Yaml defined HTTPS requests and inspect x.509 certificates and keys", Long: ` -HTTPS Wrench is a tool to make HTTPS requests according to a Yaml configuration file and to inspect x.509 certificates and keys. +HTTPS Wrench is a tool to make HTTPS requests according to a Yaml configuration file +and to inspect x.509 certificates and keys. https-wrench has two subcommands: requests and certinfo. @@ -66,7 +67,7 @@ certinfo can compare public keys extracted from certificates and private keys to HTTPS Wrench is distributed with an open source license and available at the following address: https://github.com/xenOs76/https-wrench`, - Run: func(cmd *cobra.Command, args []string) { + Run: func(cmd *cobra.Command, _ []string) { showVersion, _ := cmd.Flags().GetBool("version") if showVersion { cmd.Println(version) diff --git a/cmd/root_test.go b/cmd/root_test.go index e7e69aa..275f431 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -13,6 +13,7 @@ import ( "github.com/xenos76/https-wrench/internal/requests" ) +//nolint:revive func TestRootCmd_LoadConfig(t *testing.T) { t.Run("LoadConfig no config file", func(t *testing.T) { oldCfg := cfgFile @@ -73,6 +74,23 @@ func TestRootCmd_LoadConfig(t *testing.T) { require.Equal(t, "httpBunComGet", config.Requests[0].Name) require.Equal(t, "https://cat.httpbun.com:443", config.Requests[0].TransportOverrideURL) }) + t.Run("LoadConfig unmarshal error", func(t *testing.T) { + oldCfg := cfgFile + + t.Cleanup(func() { + cfgFile = oldCfg + + viper.Reset() + }) + + // Make Unmarshal fail by setting a type mismatch + viper.Set("Requests", "this is a string, not a slice") + + config, err := LoadConfig() + require.Error(t, err) + require.Nil(t, config) + require.ErrorContains(t, err, "unable to decode into config struct") + }) } func TestRootCmd_Execute(t *testing.T) { @@ -87,8 +105,27 @@ func TestRootCmd_Execute(t *testing.T) { err = Execute() require.EqualError(t, err, "flag needs an argument: --config") }) + + t.Run("Execute success", func(t *testing.T) { + oldCfg := cfgFile + + t.Cleanup(func() { + cfgFile = oldCfg + + rootCmd.SetArgs(nil) + viper.Reset() + }) + + rootCmd.SetArgs([]string{"--config", "./embedded/config-example.yaml"}) + + err := Execute() + require.NoError(t, err) + }) } +//nolint:revive + +//nolint:revive func TestRootCmd(t *testing.T) { tests := []struct { name string diff --git a/devenv.lock b/devenv.lock index 3bca2f5..c6f8840 100644 --- a/devenv.lock +++ b/devenv.lock @@ -3,10 +3,11 @@ "devenv": { "locked": { "dir": "src/modules", - "lastModified": 1772483048, + "lastModified": 1776863933, + "narHash": "sha256-v9NoQFSln9n5zqVWUWUc9PajsMaGmga51HOAJqMx7Qw=", "owner": "cachix", "repo": "devenv", - "rev": "40f410e3a5e0f9198cf67bfa8673c9a17d8c605c", + "rev": "863b4204725efaeeb73811e376f928232b720646", "type": "github" }, "original": { @@ -20,6 +21,7 @@ "flake": false, "locked": { "lastModified": 1767039857, + "narHash": "sha256-vNpUSpF5Nuw8xvDLj2KCwwksIbjua2LZCqhV1LNRDns=", "owner": "NixOS", "repo": "flake-compat", "rev": "5edf11c44bc78a0d334f6334cdaf7d60d732daab", @@ -38,10 +40,11 @@ "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1772024342, + "lastModified": 1776796298, + "narHash": "sha256-PcRvlWayisPSjd0UcRQbhG8Oqw78AcPE6x872cPRHN8=", "owner": "cachix", "repo": "git-hooks.nix", - "rev": "6e34e97ed9788b17796ee43ccdbaf871a5c2b476", + "rev": "3cfd774b0a530725a077e17354fbdb87ea1c4aad", "type": "github" }, "original": { @@ -58,10 +61,11 @@ ] }, "locked": { - "lastModified": 1762808025, + "lastModified": 1709087332, + "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", "owner": "hercules-ci", "repo": "gitignore.nix", - "rev": "cb5e3fdca1de58ccbc3ef53de65bd372b48f567c", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", "type": "github" }, "original": { @@ -72,10 +76,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1773597492, + "lastModified": 1770073757, + "narHash": "sha256-Vy+G+F+3E/Tl+GMNgiHl9Pah2DgShmIUBJXmbiQPHbI=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "a07d4ce6bee67d7c838a8a5796e75dff9caa21ef", + "rev": "47472570b1e607482890801aeaf29bfb749884f6", "type": "github" }, "original": { @@ -88,11 +93,11 @@ "nixpkgs-src": { "flake": false, "locked": { - "lastModified": 1769922788, - "narHash": "sha256-H3AfG4ObMDTkTJYkd8cz1/RbY9LatN5Mk4UF48VuSXc=", + "lastModified": 1776329215, + "narHash": "sha256-a8BYi3mzoJ/AcJP8UldOx8emoPRLeWqALZWu4ZvjPXw=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "207d15f1a6603226e1e223dc79ac29c7846da32e", + "rev": "b86751bc4085f48661017fa226dee99fab6c651b", "type": "github" }, "original": { @@ -104,10 +109,11 @@ }, "nixpkgs-stable": { "locked": { - "lastModified": 1772047000, + "lastModified": 1776734388, + "narHash": "sha256-vl3dkhlE5gzsItuHoEMVe+DlonsK+0836LIRDnm6MXQ=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "1267bb4920d0fc06ea916734c11b0bf004bbe17e", + "rev": "10e7ad5bbcb421fe07e3a4ad53a634b0cd57ffac", "type": "github" }, "original": { @@ -122,10 +128,11 @@ "nixpkgs-src": "nixpkgs-src" }, "locked": { - "lastModified": 1770434727, + "lastModified": 1776852779, + "narHash": "sha256-WwO/ITisCXwyiRgtktZgv3iGhAGO+IB5Av4kKCwezR0=", "owner": "cachix", "repo": "devenv-nixpkgs", - "rev": "8430f16a39c27bdeef236f1eeb56f0b51b33d348", + "rev": "ec3063523dcd911aeadb50faa589f237cdab5853", "type": "github" }, "original": { @@ -140,10 +147,7 @@ "devenv": "devenv", "git-hooks": "git-hooks", "nixpkgs": "nixpkgs_2", - "nixpkgs-stable": "nixpkgs-stable", - "pre-commit-hooks": [ - "git-hooks" - ] + "nixpkgs-stable": "nixpkgs-stable" } } }, diff --git a/internal/certinfo/certinfo.go b/internal/certinfo/certinfo.go index 243d3ec..52a2fac 100644 --- a/internal/certinfo/certinfo.go +++ b/internal/certinfo/certinfo.go @@ -45,8 +45,9 @@ type ( ) var ( - TlsServerName string - TlsInsecure bool + //nolint:revive + TLSServerName string + TLSInsecure bool inputReader InputReader ) diff --git a/internal/certinfo/certinfo_handlers.go b/internal/certinfo/certinfo_handlers.go index 2275af9..32d8ef7 100644 --- a/internal/certinfo/certinfo_handlers.go +++ b/internal/certinfo/certinfo_handlers.go @@ -21,6 +21,7 @@ import ( "github.com/xenos76/https-wrench/internal/style" ) +//nolint:revive func (c *CertinfoConfig) PrintData(w io.Writer) error { ks := style.ItemKey.PaddingBottom(0).PaddingTop(1).PaddingLeft(1) sl := style.CertKeyP4.Bold(true) diff --git a/internal/certinfo/certinfo_handlers_test.go b/internal/certinfo/certinfo_handlers_test.go index 5de7310..0af3a3e 100644 --- a/internal/certinfo/certinfo_handlers_test.go +++ b/internal/certinfo/certinfo_handlers_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" ) +//nolint:revive func TestCertinfo_GetRemoteCerts(t *testing.T) { tests := []struct { desc string @@ -39,7 +40,8 @@ func TestCertinfo_GetRemoteCerts(t *testing.T) { serverCertFile: RSASampleCertFile, serverKeyFile: RSASampleCertKeyFile, }, - caCertFile: emptyString, + caCertFile: emptyString, + //nolint:revive expectError: true, expectMsg: "TLS handshake failed: tls: failed to verify certificate: x509: certificate signed by unknown authority", }, @@ -54,6 +56,7 @@ func TestCertinfo_GetRemoteCerts(t *testing.T) { }, caCertFile: RSACaCertFile, expectSrvHost: "localhost", + //nolint:revive expectSrvPort: "46303", expectError: true, expectMsg: "TLS handshake failed: tls: failed to verify certificate: x509: certificate relies on legacy Common Name field, use SANs instead", @@ -79,7 +82,8 @@ func TestCertinfo_GetRemoteCerts(t *testing.T) { serverCertFile: RSASampleCertFile, serverKeyFile: RSASampleCertKeyFile, }, - caCertFile: RSASamplePKCS8Certificate, + caCertFile: RSASamplePKCS8Certificate, + //nolint:revive expectSrvHost: "localhost", expectSrvPort: "46305", expectError: true, @@ -116,7 +120,8 @@ func TestCertinfo_GetRemoteCerts(t *testing.T) { serverAddr: "localhost:46308", serverName: "example.co.uk", serverCertFile: RSASampleCertFile, - serverKeyFile: RSASampleCertKeyFile, + //nolint:revive + serverKeyFile: RSASampleCertKeyFile, }, caCertFile: RSACaCertFile, expectError: true, @@ -153,11 +158,13 @@ func TestCertinfo_GetRemoteCerts(t *testing.T) { return } + //nolint:revive require.EqualError(t, err, tt.expectMsg, "check error expected") }) } } +//nolint:revive func TestCertinfo_CertsToTables(t *testing.T) { rsaSampleCert, err := GetCertsFromBundle( RSASampleCertFile, @@ -277,12 +284,16 @@ func TestCertinfo_CertsToTables(t *testing.T) { tt.signatureAlgorithm, tt.expiration, } { + //nolint:revive require.Contains(t, got, want) } }) } } +//nolint:revive + +//nolint:revive func TestCertinfo_PrintData(t *testing.T) { tests := []struct { desc string @@ -330,12 +341,15 @@ func TestCertinfo_PrintData(t *testing.T) { caCertFile: emptyString, tlsEndpoint: "localhost:46402", tlsServerName: "example.com", + //nolint:revive srvCfg: demoHTTPServerConfig{ serverAddr: "localhost:46402", serverName: "example.com", serverCertFile: RSASampleCertFile, serverKeyFile: RSASampleCertKeyFile, + //nolint:revive }, + //nolint:revive expectCertsFetchErr: true, expectCertsFetcMsg: "unable to get endpoint certificates: TLS handshake failed: tls: failed to verify certificate: x509: certificate signed by unknown authority", }, @@ -356,16 +370,19 @@ func TestCertinfo_PrintData(t *testing.T) { }, }, { - desc: "local key and remote TLS Endpoint, missing TLS ServerName", - keyFile: RSASampleCertKeyFile, - caCertFile: RSACaCertFile, - tlsEndpoint: "localhost:46404", + desc: "local key and remote TLS Endpoint, missing TLS ServerName", + keyFile: RSASampleCertKeyFile, + caCertFile: RSACaCertFile, + tlsEndpoint: "localhost:46404", + //nolint:revive tlsServerName: emptyString, srvCfg: demoHTTPServerConfig{ serverAddr: "localhost:46404", serverName: "example.com", serverCertFile: RSASampleCertFile, - serverKeyFile: RSASampleCertKeyFile, + //nolint:revive + serverKeyFile: RSASampleCertKeyFile, + //nolint:revive }, expectCertsFetchErr: true, expectCertsFetcMsg: "unable to get endpoint certificates: TLS handshake failed: tls: failed to verify certificate: x509: certificate is valid for example.com, example.net, example.de, not localhost", @@ -479,4 +496,50 @@ func TestCertinfo_PrintData(t *testing.T) { } }) } + + t.Run("PrintData local cert private key match error", func(t *testing.T) { + buffer := bytes.Buffer{} + cc, err := NewCertinfoConfig() + require.NoError(t, err) + + // Inject a bad public key to force certMatchPrivateKey to fail + cc.PrivKey = "dummy_key" + cc.CertsBundle = append(cc.CertsBundle, &x509.Certificate{ + PublicKey: "unsupported_key_type", + }) + cc.CertsBundleFilePath = "dummy" + + errPrint := cc.PrintData(&buffer) + require.Error(t, errPrint) + require.ErrorContains(t, errPrint, "unable to check if private key matches local certificate") + }) + + t.Run("PrintData remote cert private key match error", func(t *testing.T) { + buffer := bytes.Buffer{} + cc, err := NewCertinfoConfig() + require.NoError(t, err) + + cc.PrivKey = "dummy_key" + cc.TLSEndpointCerts = append(cc.TLSEndpointCerts, &x509.Certificate{ + PublicKey: "unsupported_key_type", + }) + cc.TLSEndpointHost = "localhost" + cc.TLSEndpointPort = "443" + + errPrint := cc.PrintData(&buffer) + require.Error(t, errPrint) + require.ErrorContains(t, errPrint, "unable to check if private key matches remote TLS Endpoint certificate") + }) + + t.Run("PrintData CA cert file read error", func(t *testing.T) { + buffer := bytes.Buffer{} + cc, err := NewCertinfoConfig() + require.NoError(t, err) + + cc.CACertsFilePath = "non_existent_file.pem" + + errPrint := cc.PrintData(&buffer) + require.Error(t, errPrint) + require.ErrorContains(t, errPrint, "unable for read Root certificates") + }) } diff --git a/internal/certinfo/certinfo_test.go b/internal/certinfo/certinfo_test.go index 74d3cd4..c74af9d 100644 --- a/internal/certinfo/certinfo_test.go +++ b/internal/certinfo/certinfo_test.go @@ -325,6 +325,7 @@ func TestCertinfo_SetTLSServerName(t *testing.T) { } } +//nolint:revive func TestCertinfo_SetTLSEndpoint(t *testing.T) { tests := []struct { desc string @@ -357,10 +358,11 @@ func TestCertinfo_SetTLSEndpoint(t *testing.T) { expectPort: "443", }, { - desc: "error malformed host", - endpoint: "localh#$%ost:443", + desc: "error malformed host", + endpoint: "localh#$%ost:443", + //nolint:revive processErr: true, - expectMsg: "unable to get endpoint certificates: TLS handshake failed: dial tcp: lookup localh#$%ost: no such host", + expectMsg: "unable to get endpoint certificates: TLS handshake failed", }, { desc: "error missing port", @@ -369,12 +371,14 @@ func TestCertinfo_SetTLSEndpoint(t *testing.T) { expectMsg: "invalid TLS endpoint \"localhost\": address localhost: missing port in address", }, { - desc: "error missing host", + desc: "error missing host", + //nolint:revive endpoint: ":80443", processErr: true, expectMsg: "unable to get endpoint certificates: TLS handshake failed: dial tcp: address 80443: invalid port", }, { + //nolint:revive desc: "error endpoint includes scheme", endpoint: "https://localhost:80443", processErr: true, @@ -404,7 +408,7 @@ func TestCertinfo_SetTLSEndpoint(t *testing.T) { return } - require.EqualError(t, err, tt.expectMsg) + require.ErrorContains(t, err, tt.expectMsg) }) } } diff --git a/internal/certinfo/common_handlers.go b/internal/certinfo/common_handlers.go index 3f4b2e5..3d73bea 100644 --- a/internal/certinfo/common_handlers.go +++ b/internal/certinfo/common_handlers.go @@ -166,6 +166,7 @@ func IsPrivateKeyEncrypted(key []byte) (bool, error) { } } +//nolint:revive func getPassphraseIfNeeded(isEncrypted bool, pwEnvKey string, pwReader Reader) ([]byte, error) { if !isEncrypted { return nil, nil diff --git a/internal/certinfo/common_handlers_test.go b/internal/certinfo/common_handlers_test.go index bf4b22e..31485f9 100644 --- a/internal/certinfo/common_handlers_test.go +++ b/internal/certinfo/common_handlers_test.go @@ -175,6 +175,7 @@ func TestCertinfo_GetCertsFromBundle(t *testing.T) { }) } +//nolint:revive func TestCertinfo_GetKeyFromFile_inputReaderErrors(t *testing.T) { tests := []struct { desc string diff --git a/internal/jwtinfo/jwtinfo.go b/internal/jwtinfo/jwtinfo.go index aeba925..cb84a44 100644 --- a/internal/jwtinfo/jwtinfo.go +++ b/internal/jwtinfo/jwtinfo.go @@ -42,6 +42,7 @@ type JwtTokenData struct { type allReader func(io.Reader) ([]byte, error) +//nolint:revive func RequestToken(reqURL string, reqValues map[string]string, client *http.Client, readAll allReader) (JwtTokenData, error) { if reqURL == emptyString { return JwtTokenData{}, errors.New("empty string provided as request URL") @@ -196,6 +197,7 @@ func isValidJSON(data []byte) bool { return json.Unmarshal(data, &v) == nil } +//nolint:revive func (jtd *JwtTokenData) DecodeBase64() error { tokens := []struct { name string @@ -330,6 +332,7 @@ func (jtd *JwtTokenData) ParseWithJWKS(jwksURL string, keyfuncOverride keyfunc.O return nil } +//nolint:revive func PrintTokenInfo(jtd JwtTokenData, w io.Writer) error { sl := style.CertKeyP4.Render sv := style.CertValue.Render @@ -443,13 +446,13 @@ func unmarshallTokenTimeClaims(claims []byte) (map[string]string, error) { } for k, v := range genericClaims { - var vi any = v + vi := v if vf, ok := vi.(float64); ok { vInt64 := int64(vf) t := time.Unix(vInt64, 0) - dateUtc := t.UTC().Format(time.UnixDate) - tokenClaims[k] = fmt.Sprintf("%v", dateUtc) + dateUTC := t.UTC().Format(time.UnixDate) + tokenClaims[k] = fmt.Sprintf("%v", dateUTC) continue } diff --git a/internal/jwtinfo/jwtinfo_test.go b/internal/jwtinfo/jwtinfo_test.go index 31de5ba..3a964cf 100644 --- a/internal/jwtinfo/jwtinfo_test.go +++ b/internal/jwtinfo/jwtinfo_test.go @@ -3,6 +3,7 @@ package jwtinfo import ( "bytes" "encoding/base64" + "fmt" "io" "maps" "os" @@ -85,6 +86,7 @@ func TestReadRequestValuesFile(t *testing.T) { }) } +//nolint:revive func TestParseRequestJSONValues(t *testing.T) { inputMap := map[string]string{ "testKey": "testValue", @@ -164,6 +166,9 @@ func TestParseRequestJSONValues(t *testing.T) { } } +//nolint:revive + +//nolint:revive func TestRequestToken(t *testing.T) { tests := []struct { name string @@ -303,8 +308,12 @@ func TestRequestToken(t *testing.T) { // godump.Dump(td) }) } + //nolint:revive } +//nolint:revive + +//nolint:revive func TestParseWithJWKS(t *testing.T) { tests := []struct { name string @@ -573,9 +582,14 @@ func TestParseWithJWKS_Errors(t *testing.T) { err, "failed to create JWK Set from resource at URL", ) + //nolint:revive }) + //nolint:revive } +//nolint:revive + +//nolint:revive func TestDecodeBase64(t *testing.T) { notThreeDotted := "notThreeDottedBase64CompliantString" @@ -775,10 +789,16 @@ func TestUnmarshallTokenTimeClaims_MapErrors(t *testing.T) { _, err := unmarshallTokenTimeClaims(tt.claims) require.ErrorContains(t, err, tt.errMsg) + //nolint:revive }) + //nolint:revive } + //nolint:revive } +//nolint:revive + +//nolint:revive func TestPrintTokenInfo(t *testing.T) { tests := []struct { name string @@ -882,3 +902,71 @@ func TestPrintTokenInfo(t *testing.T) { }) } } + +func TestReadTokenFromFile(t *testing.T) { + t.Run("Success", func(t *testing.T) { + tokenRaw, err := createToken("demo") + require.NoError(t, err) + + tmpDir := t.TempDir() + tempFile, err := createTmpFileWithContent(tmpDir, "token.txt", []byte(tokenRaw)) + require.NoError(t, err) + + td, err := ReadTokenFromFile(tempFile) + require.NoError(t, err) + require.Equal(t, tokenRaw, td.AccessTokenRaw) + }) + + t.Run("File Read Error", func(t *testing.T) { + _, err := ReadTokenFromFile("non_existent_file.txt") + require.Error(t, err) + require.ErrorContains(t, err, "unable to read token file") + }) + + t.Run("Parse Error", func(t *testing.T) { + tmpDir := t.TempDir() + tempFile, err := createTmpFileWithContent(tmpDir, "token.txt", []byte("invalid_token")) + require.NoError(t, err) + + _, err = ReadTokenFromFile(tempFile) + require.Error(t, err) + require.ErrorContains(t, err, "unable to parse JWT token from file") + }) +} + +func TestPrintTokenInfo_Errors(t *testing.T) { + t.Run("jsonIndent error header", func(t *testing.T) { + //nolint:revive + buffer := bytes.Buffer{} + + // Valid claims so unmarshallTokenTimeClaims succeeds. + now := time.Now().Unix() + exp := now + 3600 + claimsJSON := fmt.Sprintf(`{"iat": %d, "exp": %d}`, now, exp) + + jtd := JwtTokenData{ + AccessTokenHeader: []byte("invalid json"), + AccessTokenClaims: []byte(claimsJSON), + } + + err := PrintTokenInfo(jtd, &buffer) + require.NoError(t, err) + + // The json.Indent for header failed and it wrote the raw header. + // It will be syntax-highlighted, adding ANSI codes, so we just check it wrote something. + require.Positive(t, buffer.Len()) + }) + + t.Run("unmarshallTokenTimeClaims error", func(t *testing.T) { + buffer := bytes.Buffer{} + + jtd := JwtTokenData{ + AccessTokenHeader: []byte(`{"typ":"JWT"}`), + AccessTokenClaims: []byte("invalid json"), + } + + err := PrintTokenInfo(jtd, &buffer) + require.Error(t, err) + require.ErrorContains(t, err, "unable to unmashall time claims from AccessToken") + }) +} diff --git a/internal/jwtinfo/main_test.go b/internal/jwtinfo/main_test.go index 332e208..f0ba3a9 100644 --- a/internal/jwtinfo/main_test.go +++ b/internal/jwtinfo/main_test.go @@ -49,7 +49,7 @@ var ( mockErrReader MockErrReader ) -func (MockErrReader) ReadAll(r io.Reader) ([]byte, error) { +func (MockErrReader) ReadAll(_ io.Reader) ([]byte, error) { return nil, errors.New("mock Reader error") } @@ -65,7 +65,8 @@ func TestMain(m *testing.M) { func fatal(err error) { if err != nil { - log.Fatal(err) + //nolint:revive + panic(err) } } diff --git a/internal/requests/main_test.go b/internal/requests/main_test.go index 2f3d2a7..dc82dec 100644 --- a/internal/requests/main_test.go +++ b/internal/requests/main_test.go @@ -32,6 +32,7 @@ type demoCertTemplate struct { parent *x509.Certificate } +//nolint:revive type demoHttpServerData struct { serverAddr string proxyprotoEnabled bool @@ -241,6 +242,9 @@ func NewHTTPSTestServer(data demoHttpServerData) (*httptest.Server, error) { return ts, nil } +//nolint:revive + +//nolint:revive func TestMain(m *testing.M) { fmt.Printf("Check test data dir: %s\n", testdataDir) diff --git a/internal/requests/requests.go b/internal/requests/requests.go index 1c6f152..3c14352 100644 --- a/internal/requests/requests.go +++ b/internal/requests/requests.go @@ -1,3 +1,4 @@ +//nolint:revive package requests import ( @@ -196,6 +197,7 @@ func (r *RequestsMetaConfig) PrintCmd(w io.Writer) { } } +//nolint:revive func (r *RequestConfig) PrintTitle(isVerbose bool) { if isVerbose { fmt.Print(style.LgSprintf(style.TitleKey, "Request:")) @@ -228,6 +230,7 @@ func (r *RequestConfig) PrintRequestDebug(w io.Writer, req *http.Request) error return nil } +//nolint:revive func (r *RequestConfig) PrintResponseDebug(w io.Writer, resp *http.Response) { // TODO: return an error if resp == nil { @@ -500,7 +503,11 @@ func (rc *RequestHTTPClient) SetClientTimeout(timeout int) (*RequestHTTPClient, return rc, nil } -func NewHTTPClientFromRequestConfig(r RequestConfig, serverName string, caPool *x509.CertPool) (*RequestHTTPClient, error) { +func NewHTTPClientFromRequestConfig( + r RequestConfig, + serverName string, + caPool *x509.CertPool, +) (*RequestHTTPClient, error) { reqClient := NewRequestHTTPClient() _, err := reqClient.SetCACertsPool(caPool) @@ -555,7 +562,12 @@ func NewHTTPClientFromRequestConfig(r RequestConfig, serverName string, caPool * return reqClient, nil } -func processHTTPRequestsByHost(r RequestConfig, caPool *x509.CertPool, isVerbose bool) ([]ResponseData, error) { +//nolint:revive +func processHTTPRequestsByHost( + r RequestConfig, + caPool *x509.CertPool, + isVerbose bool, +) ([]ResponseData, error) { var responseDataList []ResponseData requestBodyBytes := []byte(r.RequestBody) diff --git a/internal/requests/requests_handlers.go b/internal/requests/requests_handlers.go index b48903a..2e90003 100644 --- a/internal/requests/requests_handlers.go +++ b/internal/requests/requests_handlers.go @@ -242,9 +242,7 @@ func (rd *ResponseData) ImportResponseBody() { re, err := regexp.Compile(rd.Request.ResponseBodyMatchRegexp) if err != nil { fmt.Print(fmt.Errorf("unable to compile responseBodyMatchRegexp: %w", err)) - } - - if re.Match(body) { + } else if re.Match(body) { rd.ResponseBodyRegexpMatched = true } } @@ -276,6 +274,7 @@ func (rd *ResponseData) ImportResponseBody() { rd.ResponseBody = string(body) } +//nolint:revive func (rd ResponseData) PrintResponseData(isVerbose bool) { if !isVerbose { return diff --git a/internal/requests/requests_handlers_test.go b/internal/requests/requests_handlers_test.go index fe3c06c..ff8bfd2 100644 --- a/internal/requests/requests_handlers_test.go +++ b/internal/requests/requests_handlers_test.go @@ -313,6 +313,7 @@ func TestTransportAddressFromURLString(t *testing.T) { } } +//nolint:revive func TestRenderTLSData(t *testing.T) { tests := []struct { srvAddr string @@ -436,6 +437,7 @@ func TestRenderTLSData(t *testing.T) { } } +//nolint:revive func TestHandleRequests(t *testing.T) { reqMeta1 := RequestsMetaConfig{ CACertsPool: caCertPool, diff --git a/internal/requests/requests_test.go b/internal/requests/requests_test.go index e90b6b9..b56d951 100644 --- a/internal/requests/requests_test.go +++ b/internal/requests/requests_test.go @@ -4,7 +4,9 @@ import ( "bytes" "crypto/tls" "crypto/x509" + "errors" "fmt" + "io" "net/http" "net/http/httptest" "net/url" @@ -160,6 +162,25 @@ func TestRequestsMetaConfig_SetCaPoolFromFile(t *testing.T) { } }) } + + t.Run("SetCaPoolFromFile_Error", func(t *testing.T) { + t.Parallel() + + rmc, _ := NewRequestsMetaConfig() + err := rmc.SetCaPoolFromFile("non_existent_file.pem", nil) + require.Error(t, err) + }) +} + +func TestRequestsMetaConfig_SetCaPoolFromYAML_Error(t *testing.T) { + t.Run("SetCaPoolFromYAML_Error", func(t *testing.T) { + t.Parallel() + + rmc, _ := NewRequestsMetaConfig() + err := rmc.SetCaPoolFromYAML("invalid cert data") + require.Error(t, err) + require.ErrorContains(t, err, "unable to create CA Certs Pool from YAML") + }) } func TestRequestsMetaConfig_SetRequests(t *testing.T) { @@ -288,6 +309,66 @@ func TestNewHTTPClientFromRequestConfig_Error(t *testing.T) { } } +func TestNewHTTPClientFromRequestConfig_SubErrors(t *testing.T) { + tests := []struct { + desc string + reqConf RequestConfig + serverName string + errMsg string + }{ + { + desc: "SetClientTimeout error", + reqConf: RequestConfig{ + ClientTimeout: -1, + }, + serverName: "localhost", + errMsg: "SetClientTimeout error: timeout value must be positive: -1 provided", + }, + { + desc: "SetMethod error", + reqConf: RequestConfig{ + RequestMethod: "INVALID", + }, + serverName: "localhost", + errMsg: "SetMethod error: INVALID: HTTP method not found", + }, + { + desc: "SetTransportOverride error", + reqConf: RequestConfig{ + TransportOverrideURL: "https://loca$%^lhost", + }, + serverName: "localhost", + errMsg: "SetTransportOverride error: failed to parse transport override url: https://loca$%^lhost", + }, + { + desc: "proxyProtoHeaderFromRequest error", + reqConf: RequestConfig{ + EnableProxyProtocolV2: true, + TransportOverrideURL: "https://test.invalid:443", + }, + serverName: "localhost", + errMsg: "error creating proxyproto Header: failed to resolve transport override hostname's IPs': " + + "lookup test.invalid", // we'll just check ErrorContains + }, + } + + for _, tc := range tests { + tt := tc + t.Run(tt.desc, func(t *testing.T) { + t.Parallel() + + _, err := NewHTTPClientFromRequestConfig( + tt.reqConf, + tt.serverName, + nil, + ) + require.Error(t, err) + require.ErrorContains(t, err, tt.errMsg) + }) + } +} + +//nolint:revive func TestNewHTTPClientFromRequestConfig(t *testing.T) { tests := []struct { desc string @@ -674,7 +755,7 @@ func TestNewRequestHTTPClient_SetInsecureSkipVerify_tlsServer(t *testing.T) { t.Run(testname, func(t *testing.T) { t.Parallel() - ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { fmt.Fprintln(w, "Hello, client") })) defer ts.Close() @@ -835,6 +916,8 @@ func TestRequestHTTPClient_SetTransportOverride_Error(t *testing.T) { // http.client. // Once the TLS server is started on an address other than https://hostname, we expect the // client to contact the TLS server even if it is requested to connect to https://servername. +// +//nolint:revive // test function func TestRequestHTTPClient_SetTransportOverride_transportAddress_server(t *testing.T) { tests := []struct { trasportURL string @@ -927,6 +1010,7 @@ func TestRequestHTTPClient_SetTransportOverride_transportAddress_server(t *testi } } +//nolint:revive // test function func TestRequestHTTPClient_SetProxyProtocolV2_server(t *testing.T) { tests := []struct { testname string @@ -1109,6 +1193,7 @@ func TestPrintCmd(t *testing.T) { } } +//nolint:revive func TestPrintResponseDebug(t *testing.T) { tests := []struct { desc string @@ -1255,6 +1340,7 @@ func TestPrintResponseDebug_nonTLS(t *testing.T) { }) } +//nolint:revive func TestPrintRequestDebug(t *testing.T) { httpTestHeader := http.Header{} httpTestHeader.Add("user-agent", "go-test") @@ -1337,6 +1423,7 @@ func TestPrintRequestDebug(t *testing.T) { } } +//nolint:revive func TestProcessHTTPRequestsByHost(t *testing.T) { tests := []struct { srvAddr string @@ -1377,7 +1464,8 @@ func TestProcessHTTPRequestsByHost(t *testing.T) { pool: caCertPool, verbose: false, respStatusCode: 0, - errMsg: "Get \"https://localhost\": tls: failed to verify certificate: x509: certificate is valid for example.com, example.net, example.de, not localhost", + errMsg: "Get \"https://localhost\": tls: failed to verify certificate: " + + "x509: certificate is valid for example.com, example.net, example.de, not localhost", }, { @@ -1400,8 +1488,7 @@ func TestProcessHTTPRequestsByHost(t *testing.T) { for _, tc := range tests { tt := tc // safer when using t.Parallel() t.Run(tt.reqConf.Name, func(t *testing.T) { - t.Parallel() - + // t.Parallel() httpSrvData := demoHttpServerData{ serverAddr: tt.srvAddr, proxyprotoEnabled: false, @@ -1484,3 +1571,107 @@ func TestProcessHTTPRequestsByHost(t *testing.T) { }) } } + +func TestRequestHTTPClient_DialContextErrors(t *testing.T) { + t.Run("SetTransportOverride DialContext Error", func(t *testing.T) { + t.Parallel() + + reqConf := RequestConfig{ + TransportOverrideURL: "https://localhost:11111", // dead port + } + + client, err := NewHTTPClientFromRequestConfig(reqConf, "localhost", nil) + require.NoError(t, err) + + req, _ := http.NewRequest("GET", "https://localhost", nil) + _, err = client.client.Do(req) + require.Error(t, err) + }) + + t.Run("SetProxyProtocolV2 DialContext Error", func(t *testing.T) { + t.Parallel() + + reqConf := RequestConfig{ + EnableProxyProtocolV2: true, + TransportOverrideURL: "https://localhost:11111", // dead port + } + + client, err := NewHTTPClientFromRequestConfig(reqConf, "localhost", nil) + require.NoError(t, err) + + req, _ := http.NewRequest("GET", "https://localhost", nil) + _, err = client.client.Do(req) + require.Error(t, err) + }) +} + +func TestProcessHTTPRequestsByHost_Errors(t *testing.T) { + t.Run("getUrlsFromHost error", func(t *testing.T) { + reqConf := RequestConfig{ + Hosts: []Host{ + {Name: "localhost", URIList: []URI{"invalid"}}, + }, + } + _, err := processHTTPRequestsByHost(reqConf, nil, false) + require.Error(t, err) + require.ErrorContains(t, err, "invalid uri") + }) +} + +func TestProxyProtoHeaderFromRequest_Errors(t *testing.T) { + t.Run("not enabled", func(t *testing.T) { + _, err := proxyProtoHeaderFromRequest(RequestConfig{}, "localhost") + require.ErrorContains(t, err, "proxy protocol v2 is not enabled") + }) + + // url.Parse won't fail for typical invalid URLs, but let's try a control character + t.Run("serverName parse fail", func(t *testing.T) { + _, err := proxyProtoHeaderFromRequest(RequestConfig{EnableProxyProtocolV2: true}, string([]byte{0x7f})) + require.Error(t, err) + }) + + t.Run("transportOverride parse fail", func(t *testing.T) { + _, err := proxyProtoHeaderFromRequest(RequestConfig{ + EnableProxyProtocolV2: true, + TransportOverrideURL: string([]byte{0x7f}), + }, "localhost") + require.Error(t, err) + }) +} + +type mockErrReader struct{} + +func (mockErrReader) Read(_ []byte) (n int, err error) { + return 0, errors.New("mock read error") +} + +func TestImportResponseBody_Errors(t *testing.T) { + t.Run("already imported", func(t *testing.T) { + rd := ResponseData{ResponseBody: "already imported"} + rd.ImportResponseBody() // should return immediately + require.Equal(t, "already imported", rd.ResponseBody) + }) + + t.Run("read error", func(t *testing.T) { + rd := ResponseData{ + Response: &http.Response{ + Body: io.NopCloser(mockErrReader{}), + }, + } + rd.ImportResponseBody() // should print error and return + require.Empty(t, rd.ResponseBody) + }) + + t.Run("bad regexp", func(t *testing.T) { + rd := ResponseData{ + Request: RequestConfig{ResponseBodyMatchRegexp: "["}, + Response: &http.Response{ + Header: make(http.Header), + Body: io.NopCloser(bytes.NewBufferString("test body")), + }, + } + rd.ImportResponseBody() + require.False(t, rd.ResponseBodyRegexpMatched) + require.Equal(t, "test body", rd.ResponseBody) + }) +} diff --git a/internal/style/style_handlers.go b/internal/style/style_handlers.go index 2e5de6f..aad50f1 100644 --- a/internal/style/style_handlers.go +++ b/internal/style/style_handlers.go @@ -47,6 +47,7 @@ func StatusCodeParse(sc int) string { return status } +//nolint:revive func BoolStyle(b bool) string { if b { return LgSprintf(BoolTrue, "true") diff --git a/internal/style/style_handlers_test.go b/internal/style/style_handlers_test.go index dc36d43..0fca9ce 100644 --- a/internal/style/style_handlers_test.go +++ b/internal/style/style_handlers_test.go @@ -232,3 +232,12 @@ func TestCodeSyntaxHighlight(t *testing.T) { }) } } + +func TestCodeSyntaxHighlightWithStyle_Fallback(t *testing.T) { + t.Run("invalid style uses fallback", func(t *testing.T) { + code := `{"test": "json"}` + // Pass an invalid style to hit the fallback + s := CodeSyntaxHighlightWithStyle("json", code, "invalid-style-that-does-not-exist") + require.Contains(t, s, "test") + }) +}