diff --git a/.gitignore b/.gitignore index ada7a33..b03b315 100644 --- a/.gitignore +++ b/.gitignore @@ -139,3 +139,6 @@ devenv.local.nix # pre-commit .pre-commit-config.yaml + +.vscode +.scannerwork diff --git a/internal/certinfo/certinfo_handlers_test.go b/internal/certinfo/certinfo_handlers_test.go index 0af3a3e..62d46a8 100644 --- a/internal/certinfo/certinfo_handlers_test.go +++ b/internal/certinfo/certinfo_handlers_test.go @@ -406,94 +406,7 @@ func TestCertinfo_PrintData(t *testing.T) { for _, tc := range tests { tt := tc t.Run("No errors test - "+tt.desc, func(t *testing.T) { - t.Parallel() - - buffer := bytes.Buffer{} - - cc, err := NewCertinfoConfig() - require.NoError(t, err) - - cc.SetPrivateKeyFromFile(tt.keyFile, "notSet", inputReader) - cc.SetCertsFromFile(tt.certFile, inputReader) - cc.SetCaPoolFromFile(tt.caCertFile, inputReader) - - if tt.tlsEndpoint != emptyString { - ts, errSrv := NewHTTPSTestServer(tt.srvCfg) - require.NoError(t, errSrv) - - defer ts.Close() - - cc.SetTLSServerName(tt.tlsServerName) - cc.SetTLSInsecure(tt.tlsInsecure) - - // in most of these test cases SetTLSEndpoint depends - // on SetTLSServerName and/or SetTLSInsecure to be set - // before being able to fetch certificates from the TLS - // endpoint. - // The dependency is addressed with the order of method calls - // in cmd/certinfo.go. - // In this test, we call SetTLSServerName and SetTLSInsecure before - // SetTLSEndpoint to be sure the dependency is being addressed the - // same way. - err = cc.SetTLSEndpoint(tt.tlsEndpoint) - if !tt.expectCertsFetchErr { - require.NoError(t, err, "SetTLSEndpoint require NoError") - } - - if tt.expectCertsFetchErr { - require.EqualError(t, err, tt.expectCertsFetcMsg) - } - } - - errPrint := cc.PrintData(&buffer) - require.NoError(t, errPrint) - - got := buffer.String() - - if tt.keyFile != emptyString { - require.Contains(t, got, "PrivateKey file: "+tt.keyFile) - } - - if tt.certFile != emptyString { - require.Contains(t, got, "Certificate bundle file: "+tt.certFile) - } - - if tt.caCertFile != emptyString { - require.Contains(t, got, "CA Certificates file: "+tt.caCertFile) - } - - if !tt.expectCertsFetchErr { - for _, want := range []string{ - "Certinfo", - "Certificate", - "Subject", - "Issuer", - "NotBefore", - "NotAfter", - "Expiration", - "IsCA", - "AuthorityKeyId", - "SubjectKeyId", - "PublicKeyAlgorithm", - "SignatureAlgorithm", - "SerialNumber", - "Fingerprint SHA-256", - } { - require.Contains(t, got, want) - } - - if tt.keyFile != emptyString && tt.keyCertMatch { - require.Contains(t, got, "PrivateKey match: true") - } else { - require.Contains(t, got, "PrivateKey match: false") - } - - if tt.tlsEndpoint != emptyString { - require.Contains(t, got, "TLSEndpoint Certificates") - require.Contains(t, got, "Endpoint: "+tt.tlsEndpoint) - require.Contains(t, got, "ServerName: "+tt.tlsServerName) - } - } + runPrintDataSubtest(t, tt) }) } @@ -543,3 +456,93 @@ func TestCertinfo_PrintData(t *testing.T) { require.ErrorContains(t, errPrint, "unable for read Root certificates") }) } + +type printDataTestCase struct { + desc string + keyFile string + certFile string + caCertFile string + keyCertMatch bool + tlsEndpoint string + tlsInsecure bool + tlsServerName string + srvCfg demoHTTPServerConfig + expectCertsFetchErr bool + expectCertsFetcMsg string +} + +func runPrintDataSubtest(t *testing.T, tt printDataTestCase) { + t.Parallel() + + buffer := bytes.Buffer{} + + cc, err := NewCertinfoConfig() + require.NoError(t, err) + + require.NoError(t, cc.SetPrivateKeyFromFile(tt.keyFile, "notSet", inputReader)) + require.NoError(t, cc.SetCertsFromFile(tt.certFile, inputReader)) + require.NoError(t, cc.SetCaPoolFromFile(tt.caCertFile, inputReader)) + + if tt.tlsEndpoint != emptyString { + ts, errSrv := NewHTTPSTestServer(tt.srvCfg) + require.NoError(t, errSrv) + + defer ts.Close() + + cc.SetTLSServerName(tt.tlsServerName) + cc.SetTLSInsecure(tt.tlsInsecure) + + err = cc.SetTLSEndpoint(tt.tlsEndpoint) + if tt.expectCertsFetchErr { + require.EqualError(t, err, tt.expectCertsFetcMsg) + } else { + require.NoError(t, err, "SetTLSEndpoint require NoError") + } + } + + errPrint := cc.PrintData(&buffer) + require.NoError(t, errPrint) + + got := buffer.String() + verifyPrintDataOutput(t, got, tt) +} + +func verifyPrintDataOutput(t *testing.T, got string, tt printDataTestCase) { + if tt.keyFile != emptyString { + require.Contains(t, got, "PrivateKey file: "+tt.keyFile) + } + + if tt.certFile != emptyString { + require.Contains(t, got, "Certificate bundle file: "+tt.certFile) + } + + if tt.caCertFile != emptyString { + require.Contains(t, got, "CA Certificates file: "+tt.caCertFile) + } + + if tt.expectCertsFetchErr { + return + } + + for _, want := range []string{ + "Certinfo", "Certificate", "Subject", "Issuer", "NotBefore", "NotAfter", + "Expiration", "IsCA", "AuthorityKeyId", "SubjectKeyId", "PublicKeyAlgorithm", + "SignatureAlgorithm", "SerialNumber", "Fingerprint SHA-256", + } { + require.Contains(t, got, want) + } + + if tt.keyFile != emptyString { + if tt.keyCertMatch { + require.Contains(t, got, "PrivateKey match: true") + } else { + require.Contains(t, got, "PrivateKey match: false") + } + } + + if tt.tlsEndpoint != emptyString { + require.Contains(t, got, "TLSEndpoint Certificates") + require.Contains(t, got, "Endpoint: "+tt.tlsEndpoint) + require.Contains(t, got, "ServerName: "+tt.tlsServerName) + } +} diff --git a/internal/jwtinfo/jwtinfo_test.go b/internal/jwtinfo/jwtinfo_test.go index 2dab137..90e08e9 100644 --- a/internal/jwtinfo/jwtinfo_test.go +++ b/internal/jwtinfo/jwtinfo_test.go @@ -249,68 +249,71 @@ func TestRequestToken(t *testing.T) { for _, tc := range tests { tt := tc t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - server, err := NewJwtTestServer() - - require.NoError(t, err) - - defer server.Close() + runRequestTokenSubtest(t, tt) + }) + } + //nolint:revive +} - client := server.Client() - serverRoot := server.URL +type requestTokenTestCase struct { + name string + user string + pass string + scope string + expError bool +} - serverJwtEndpoint := serverRoot + "/jwt" +func runRequestTokenSubtest(t *testing.T, tt requestTokenTestCase) { + t.Parallel() - if tt.scope == "emptyReqUrl" { - serverJwtEndpoint = "" - } + server, err := NewJwtTestServer() + require.NoError(t, err) - if tt.scope == "wrongReqUrl" { - serverJwtEndpoint = "https://does.not.exist/wrong" - } + defer server.Close() - if tt.scope == "wrongReqParam" { - serverJwtEndpoint = "https://local$%#@@&host/wrongUrl" - } + client := server.Client() + serverJwtEndpoint := getJwtEndpoint(server.URL, tt.scope) + reqValues := getReqValues(tt) - reqValues := make(map[string]string) - if tt.scope != "emptyValuesMap" { - reqValues["user"] = tt.user - reqValues["pass"] = tt.pass - reqValues["scope"] = tt.scope - } + _, err = RequestToken( + context.Background(), + serverJwtEndpoint, + reqValues, + client, + io.ReadAll, + ) - _, err = RequestToken( - context.Background(), - serverJwtEndpoint, - reqValues, - client, - io.ReadAll, - ) + if tt.expError { + require.Error(t, err, "RequestToken - expected error: %s", err) + return + } - if tt.expError { - require.Error( - t, - err, - "RequestToken - expected error: %s", - err, - ) + require.NoError(t, err, "RequestToken error: %s", err) +} - return - } +func getJwtEndpoint(serverURL, scope string) string { + switch scope { + case "emptyReqUrl": + return "" + case "wrongReqUrl": + return "https://does.not.exist/wrong" + case "wrongReqParam": + return "https://local$%#@@&host/wrongUrl" + default: + return serverURL + "/jwt" + } +} - require.NoError( - t, - err, - "RequestToken error: %s", - err, - ) +func getReqValues(tt requestTokenTestCase) map[string]string { + if tt.scope == "emptyValuesMap" { + return make(map[string]string) + } - // godump.Dump(td) - }) + return map[string]string{ + "user": tt.user, + "pass": tt.pass, + "scope": tt.scope, } - //nolint:revive } //nolint:revive diff --git a/internal/requests/main_test.go b/internal/requests/main_test.go index dc82dec..b190a20 100644 --- a/internal/requests/main_test.go +++ b/internal/requests/main_test.go @@ -159,6 +159,8 @@ func createTmpFileWithContent(tempDir string, filePattern string, fileContent [] } func printResponseBody(res *http.Response) { + defer res.Body.Close() + body, err := io.ReadAll(res.Body) if err != nil { fmt.Printf("Error reading response body: %v\n", err) @@ -409,6 +411,7 @@ func TestHTTPSTestServer(t *testing.T) { if err != nil { t.Fatal(err) } + defer res.Body.Close() fmt.Printf("Resp StatusCode was: %v\n", res.StatusCode) assert.Equal(t, http.StatusOK, res.StatusCode) diff --git a/internal/requests/requests_handlers_test.go b/internal/requests/requests_handlers_test.go index ff8bfd2..44a3f04 100644 --- a/internal/requests/requests_handlers_test.go +++ b/internal/requests/requests_handlers_test.go @@ -491,47 +491,61 @@ func TestHandleRequests(t *testing.T) { for _, tc := range tests { tt := tc // safer when using t.Parallel() t.Run(tt.desc, func(t *testing.T) { - t.Parallel() + runHandleRequestsSubtest(t, tt) + }) + } +} - httpSrvData := demoHttpServerData{ - serverAddr: tt.srvAddr, - serverName: "localhost", - } +type handleRequestsTestCase struct { + desc string + srvAddr string + reqMeta RequestsMetaConfig + expectErr bool +} - ts, err := NewHTTPSTestServer(httpSrvData) - require.NoError(t, err) +func runHandleRequestsSubtest(t *testing.T, tt handleRequestsTestCase) { + t.Parallel() - defer ts.Close() + httpSrvData := demoHttpServerData{ + serverAddr: tt.srvAddr, + serverName: "localhost", + } - buffer := bytes.Buffer{} - respMap, err := HandleRequests(&buffer, &tt.reqMeta) + ts, err := NewHTTPSTestServer(httpSrvData) + require.NoError(t, err) - if tt.expectErr { - require.Error(t, err) - } else { - require.NoError(t, err) - } + defer ts.Close() + + buffer := bytes.Buffer{} + respMap, err := HandleRequests(&buffer, &tt.reqMeta) - out := buffer.String() - assert.Contains(t, out, "Requests") + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } - for reqConfName, rdList := range respMap { - for _, rd := range rdList { - gotReqConfName := rd.Request.Name + out := buffer.String() + assert.Contains(t, out, "Requests") - assert.Equal(t, gotReqConfName, reqConfName) + verifyHandleRequestsResults(t, respMap, tt.reqMeta) +} - ua := tt.reqMeta.Requests[0].UserAgent - wantUa := httpUserAgent +func verifyHandleRequestsResults(t *testing.T, respMap map[string][]ResponseData, reqMeta RequestsMetaConfig) { + for reqConfName, rdList := range respMap { + for _, rd := range rdList { + gotReqConfName := rd.Request.Name + assert.Equal(t, gotReqConfName, reqConfName) - if ua != wantUa { - wantUa = ua - } + ua := reqMeta.Requests[0].UserAgent + wantUa := httpUserAgent - gotUa := rd.Response.Request.UserAgent() - assert.Equal(t, wantUa, gotUa) - } + if ua != wantUa { + wantUa = ua } - }) + + gotUa := rd.Response.Request.UserAgent() + assert.Equal(t, wantUa, gotUa) + } } } diff --git a/internal/requests/requests_test.go b/internal/requests/requests_test.go index b56d951..4777113 100644 --- a/internal/requests/requests_test.go +++ b/internal/requests/requests_test.go @@ -412,72 +412,7 @@ func TestNewHTTPClientFromRequestConfig(t *testing.T) { for _, tc := range tests { tt := tc // safer when using t.Parallel() t.Run(tt.desc, func(t *testing.T) { - t.Parallel() - - rcClient, err := NewHTTPClientFromRequestConfig( - tt.reqConf, - tt.serverName, - tt.pool, - ) - if err != nil { - t.Fatal(err) - } - - var i any = rcClient.client - - client, ok := i.(*http.Client) - if !ok { - t.Errorf("expecting *http.Client, got %T", client) - } - - assert.Equal(t, - time.Duration(tt.reqConf.ClientTimeout)*time.Second, - client.Timeout, - "check client Timeout", - ) - - assert.Equal(t, - tt.reqConf.RequestMethod, - rcClient.method, - "check client Method", - ) - - assert.Equal(t, - tt.reqConf.EnableProxyProtocolV2, - rcClient.enableProxyProtoV2, - "check proxy proto enabled", - ) - - if tt.transportAddress != emptyString { - assert.Equal(t, - tt.transportAddress, - rcClient.transportAddress, - "check transportAddress", - ) - } - - var ti any = rcClient.client.Transport - - transport, ok := ti.(*http.Transport) - if !ok { - t.Errorf("expecting *http.Transport, got %T", transport) - } - - assert.Equal(t, - tt.reqConf.Insecure, - transport.TLSClientConfig.InsecureSkipVerify, - "check Insecure", - ) - - currPool := systemCertPool - - if tt.pool != nil { - currPool = caCertPool - } - - if diff := cmp.Diff(currPool, transport.TLSClientConfig.RootCAs); diff != "" { - t.Errorf("Client CA Pool mismatch (-want +got):\n%s", diff) - } + runNewHTTPClientFromRequestConfigSubtest(t, tt) }) } } @@ -772,6 +707,9 @@ func TestNewRequestHTTPClient_SetInsecureSkipVerify_tlsServer(t *testing.T) { if tt { require.NoError(t, err) + + defer res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) } }) @@ -935,77 +873,7 @@ func TestRequestHTTPClient_SetTransportOverride_transportAddress_server(t *testi tt := tc // safer when using t.Parallel() testname := fmt.Sprintf("%v", tt.trasportURL) t.Run(testname, func(t *testing.T) { - t.Parallel() - - c := NewRequestHTTPClient() - - _, _ = c.SetTransportOverride(tt.trasportURL) - if c.transportAddress != tt.transportAddr { - t.Errorf("expected %s, got %s", tt.transportAddr, tt.trasportURL) - } - - fmt.Printf("c.transportAddress is %s\n", c.transportAddress) - - httpSrvData := demoHttpServerData{serverAddr: tt.transportAddr} - - ts, err := NewHTTPSTestServer(httpSrvData) - if err != nil { - t.Fatal(err) - } - defer ts.Close() - - // WARN: the following check does not pass, but: - // expected: "localhost:6455" - // actual : "127.0.0.1:6455" - // - // fmt.Printf("Server Addr is %s\n", ts.Config.Addr) - // assert.Equal(t, tt.transportAddr, ts.Listener.Addr().String()) - - // Extract the transport via type assertion - tr, ok := c.client.Transport.(*http.Transport) - if !ok { - t.Fatalf("expected *http.Transport, got %T", tr) - } - - tr.TLSClientConfig = &tls.Config{ - RootCAs: caCertPool, - // InsecureSkipVerify: true, - } - testClient := &http.Client{Transport: tr} - - clientURL := "https://" + tt.requestHost - - // req, err := testClient.Get(clientURL) - req, err := http.NewRequest("GET", clientURL, nil) - if err != nil { - fmt.Println("Error:", err) - return - } - - fmt.Println(ts.URL) - - uaString := "TestSetTrasportOverride" - req.Header.Set("User-Agent", uaString) - - // res, err := testClient.Get(clientURL) - res, err := testClient.Do(req) - if err != nil { - t.Fatal(err) - } - // - // fmt.Printf("Resp StatusCode was: %v\n", res.StatusCode) - assert.Equal(t, http.StatusOK, res.StatusCode) - // - // fmt.Printf("Req URL was: %v\n", res.Request.URL) - assert.Equal(t, res.Request.URL.Scheme+"://"+res.Request.URL.Host, - "https://"+tt.requestHost) - // - // fmt.Printf("User Agent was: %v\n", - // res.Request.Header.Values("user-agent")) - assert.Equal(t, []string{uaString}, - res.Request.Header.Values("User-Agent")) - - printResponseBody(res) + runSetTransportOverrideSubtest(t, tt) }) } } @@ -1032,83 +900,7 @@ func TestRequestHTTPClient_SetProxyProtocolV2_server(t *testing.T) { for _, tc := range tests { tt := tc // safer when using t.Parallel() t.Run(tt.testname, func(t *testing.T) { - t.Parallel() - - httpSrvData := demoHttpServerData{ - serverAddr: tt.addr, - proxyprotoEnabled: true, - } - - ts, err := NewHTTPSTestServer(httpSrvData) - if err != nil { - t.Fatal(err) - } - - defer ts.Close() - - // fmt.Println(tt.testname) - // fmt.Print("Client URL: ") - // fmt.Println(ts.URL) - // fmt.Print("Listener address: ") - // fmt.Println(ts.Listener.Addr()) - // - transportURL := "https://" + tt.addr - reqURL := "https://" + tt.serverName - - reqConf := RequestConfig{ - EnableProxyProtocolV2: true, - TransportOverrideURL: transportURL, - } - - header, err := proxyProtoHeaderFromRequest(reqConf, tt.serverName) - if err != nil { - t.Fatal(err) - } - - c := NewRequestHTTPClient() - c.SetTransportOverride(transportURL) - c.SetProxyProtocolHeader(header) - - // Extract the transport via type assertion - transport, ok := c.client.Transport.(*http.Transport) - if !ok { - t.Fatalf("expected *http.Transport, got %T", c.client.Transport) - } - - transport.TLSClientConfig = &tls.Config{ - RootCAs: caCertPool, - // InsecureSkipVerify: true, - } - - testClient := &http.Client{Transport: transport} - - req, err := http.NewRequest("GET", reqURL, nil) - if err != nil { - fmt.Println("Error:", err) - return - } - - uaString := "TestSetProxyProtocolV2" - req.Header.Set("User-Agent", uaString) - - res, err := testClient.Do(req) - if err != nil { - t.Fatal(err) - } - - // fmt.Printf("Resp StatusCode was: %v\n", res.StatusCode) - assert.Equal(t, http.StatusOK, res.StatusCode) - - // fmt.Printf("Req URL was: %v\n", res.Request.URL) - assert.Equal(t, res.Request.URL.Scheme+"://"+res.Request.URL.Host, - reqURL) - - // fmt.Printf("User Agent was: %v\n", - // res.Request.Header.Values("user-agent")) - assert.Equal(t, []string{uaString}, - res.Request.Header.Values("User-Agent")) - - printResponseBody(res) + runSetProxyProtocolV2Subtest(t, tt) }) } } @@ -1224,53 +1016,7 @@ func TestPrintResponseDebug(t *testing.T) { for _, tc := range tests { tt := tc // safer when using t.Parallel() t.Run(tt.desc, func(t *testing.T) { - t.Parallel() - - httpSrvData := demoHttpServerData{ - serverAddr: tt.srvAddr, - proxyprotoEnabled: false, - serverName: "localhost", - } - - ts, err := NewHTTPSTestServer(httpSrvData) - if err != nil { - t.Fatal(err) - } - defer ts.Close() - - tr := &http.Transport{TLSClientConfig: &tls.Config{ - RootCAs: caCertPool, - }} - - client := &http.Client{Transport: tr} - - res, err := client.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - - rc := RequestConfig{ResponseDebug: tt.verbose} - buffer := bytes.Buffer{} - rc.PrintResponseDebug(&buffer, res) - - got := buffer.String() - fmt.Printf("got:\n%s\n", got) - - if !tt.verbose && len(got) == 0 { - assert.Empty(t, - buffer.Bytes(), - "check PrintResponseDebug with verbose False", - ) - } - - if tt.verbose { - for _, output := range tt.outputs { - assert.True(t, - bytes.Contains(buffer.Bytes(), []byte(output)), - "check PrintResponseDebug contains: %s", output, - ) - } - } + runPrintResponseDebugSubtest(t, tt) }) } } @@ -1287,38 +1033,6 @@ func TestPrintResponseDebug_Error(t *testing.T) { "output should be empty when Response is nil", ) }) - - // TODO: trigger error on malformed Response - // - // t.Run("MalformedResponse", func(t *testing.T) { - // httpTestHeader := http.Header{} - // httpTestHeader.Add("user-agent", "go-test") - // - // httpTestRequest := http.Request{ - // // Method: http.MethodGet, - // // URL: &url.URL{Scheme: "https", Host: "localhost"}, - // ContentLength: 300, - // Body: nil, - // } - // - // response := http.Response{ - // // Status: "200 OK", - // // StatusCode: 200, - // // Header: httpTestHeader, - // Request: &httpTestRequest, - // } - // rc := RequestConfig{ResponseDebug: true} - // buffer := bytes.Buffer{} - // rc.PrintResponseDebug(&buffer, &response) - // - // got := buffer.String() - // fmt.Printf("got:\n%s\n", got) - // assert.Contains(t, - // got, - // "Warning: failed to dump response:", - // "check PrintResponseDebug: MalformedResponse", - // ) - // }) } func TestPrintResponseDebug_nonTLS(t *testing.T) { @@ -1488,86 +1202,7 @@ func TestProcessHTTPRequestsByHost(t *testing.T) { for _, tc := range tests { tt := tc // safer when using t.Parallel() t.Run(tt.reqConf.Name, func(t *testing.T) { - // t.Parallel() - httpSrvData := demoHttpServerData{ - serverAddr: tt.srvAddr, - proxyprotoEnabled: false, - serverName: "localhost", - } - - ts, err := NewHTTPSTestServer(httpSrvData) - if err != nil { - t.Fatal(err) - } - defer ts.Close() - - respList, err := processHTTPRequestsByHost( - tt.reqConf, - tt.pool, - tt.verbose, - ) - if err != nil { - t.Error(err) - } - - for _, r := range respList { - fmt.Printf("resp type: %T\n", r) - - assert.Equal(t, - tt.srvAddr, - r.TransportAddress, - "check TransportAddress", - ) - - if tt.respStatusCode == 0 { - assert.Equal(t, - tt.errMsg, - r.Error.Error(), - "check Response Error", - ) - } - - // if expecting and error from the request do not - // check values from the response - if tt.respStatusCode != 0 { - require.NoError(t, - r.Error, - "check NoError in ResponseData", - ) - - ua := httpUserAgent - - if tt.reqConf.UserAgent != emptyString { - ua = tt.reqConf.UserAgent - } - - assert.Equal(t, - ua, - r.Response.Request.Header.Get("user-agent"), - "check UserAgent", - ) - - assert.Equal(t, - len(tt.reqConf.ResponseBodyMatchRegexp) > 0, - r.ResponseBodyRegexpMatched, - "check body rex match", - ) - - assert.Equal(t, - tt.respStatusCode, - r.Response.StatusCode, - "check StatusCode", - ) - - for _, headers := range tt.reqConf.RequestHeaders { - assert.Equal(t, - headers.Value, - r.Response.Request.Header.Get(headers.Key), - "check RequestHeaders Key", - ) - } - } - } + runProcessHTTPRequestsByHostSubtest(t, tt) }) } } @@ -1675,3 +1310,336 @@ func TestImportResponseBody_Errors(t *testing.T) { require.Equal(t, "test body", rd.ResponseBody) }) } + +type newHTTPClientFromRequestConfigTestCase struct { + desc string + reqConf RequestConfig + serverName string + pool *x509.CertPool + transportAddress string +} + +func runNewHTTPClientFromRequestConfigSubtest(t *testing.T, tt newHTTPClientFromRequestConfigTestCase) { + t.Parallel() + + rcClient, err := NewHTTPClientFromRequestConfig( + tt.reqConf, + tt.serverName, + tt.pool, + ) + require.NoError(t, err) + + client := rcClient.client + + assert.Equal(t, + time.Duration(tt.reqConf.ClientTimeout)*time.Second, + client.Timeout, + "check client Timeout", + ) + + assert.Equal(t, + tt.reqConf.RequestMethod, + rcClient.method, + "check client Method", + ) + + assert.Equal(t, + tt.reqConf.EnableProxyProtocolV2, + rcClient.enableProxyProtoV2, + "check proxy proto enabled", + ) + + if tt.transportAddress != emptyString { + assert.Equal(t, + tt.transportAddress, + rcClient.transportAddress, + "check transportAddress", + ) + } + + transport, ok := rcClient.client.Transport.(*http.Transport) + require.True(t, ok, "expecting *http.Transport, got %T", rcClient.client.Transport) + + assert.Equal(t, + tt.reqConf.Insecure, + transport.TLSClientConfig.InsecureSkipVerify, + "check Insecure", + ) + + currPool := systemCertPool + if tt.pool != nil { + currPool = caCertPool + } + + if diff := cmp.Diff(currPool, transport.TLSClientConfig.RootCAs); diff != "" { + t.Errorf("Client CA Pool mismatch (-want +got):\n%s", diff) + } +} + +type setTransportOverrideTestCase struct { + trasportURL string + transportAddr string + requestHost string +} + +func runSetTransportOverrideSubtest(t *testing.T, tt setTransportOverrideTestCase) { + t.Parallel() + + c := NewRequestHTTPClient() + + _, err := c.SetTransportOverride(tt.trasportURL) + require.NoError(t, err) + + assert.Equal(t, tt.transportAddr, c.transportAddress) + + fmt.Printf("c.transportAddress is %s\n", c.transportAddress) + + httpSrvData := demoHttpServerData{serverAddr: tt.transportAddr} + + ts, err := NewHTTPSTestServer(httpSrvData) + require.NoError(t, err) + + defer ts.Close() + + // Extract the transport via type assertion + tr, ok := c.client.Transport.(*http.Transport) + require.True(t, ok, "expected *http.Transport, got %T", c.client.Transport) + + tr.TLSClientConfig = &tls.Config{ + RootCAs: caCertPool, + } + testClient := &http.Client{Transport: tr} + + clientURL := "https://" + tt.requestHost + + req, err := http.NewRequest("GET", clientURL, nil) + require.NoError(t, err) + + fmt.Println(ts.URL) + + uaString := "TestSetTrasportOverride" + req.Header.Set("User-Agent", uaString) + + res, err := testClient.Do(req) + require.NoError(t, err) + + defer res.Body.Close() + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, res.Request.URL.Scheme+"://"+res.Request.URL.Host, + "https://"+tt.requestHost) + assert.Equal(t, []string{uaString}, + res.Request.Header.Values("User-Agent")) + + printResponseBody(res) +} + +type setProxyProtocolV2TestCase struct { + testname string + addr string + serverName string +} + +func runSetProxyProtocolV2Subtest(t *testing.T, tt setProxyProtocolV2TestCase) { + t.Parallel() + + httpSrvData := demoHttpServerData{ + serverAddr: tt.addr, + proxyprotoEnabled: true, + } + + ts, err := NewHTTPSTestServer(httpSrvData) + require.NoError(t, err) + + defer ts.Close() + + transportURL := "https://" + tt.addr + reqURL := "https://" + tt.serverName + + reqConf := RequestConfig{ + EnableProxyProtocolV2: true, + TransportOverrideURL: transportURL, + } + + header, err := proxyProtoHeaderFromRequest(reqConf, tt.serverName) + require.NoError(t, err) + + c := NewRequestHTTPClient() + _, err = c.SetTransportOverride(transportURL) + require.NoError(t, err) + + _, err = c.SetProxyProtocolHeader(header) + require.NoError(t, err) + + // Extract the transport via type assertion + transport, ok := c.client.Transport.(*http.Transport) + require.True(t, ok, "expected *http.Transport, got %T", c.client.Transport) + + transport.TLSClientConfig = &tls.Config{ + RootCAs: caCertPool, + } + + testClient := &http.Client{Transport: transport} + + req, err := http.NewRequest("GET", reqURL, nil) + require.NoError(t, err) + + uaString := "TestSetProxyProtocolV2" + req.Header.Set("User-Agent", uaString) + + res, err := testClient.Do(req) + require.NoError(t, err) + + defer res.Body.Close() + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, res.Request.URL.Scheme+"://"+res.Request.URL.Host, + reqURL) + assert.Equal(t, []string{uaString}, + res.Request.Header.Values("User-Agent")) +} + +type printResponseDebugTestCase struct { + desc string + srvAddr string + verbose bool + outputs []string +} + +func runPrintResponseDebugSubtest(t *testing.T, tt printResponseDebugTestCase) { + t.Parallel() + + httpSrvData := demoHttpServerData{ + serverAddr: tt.srvAddr, + proxyprotoEnabled: false, + serverName: "localhost", + } + + ts, err := NewHTTPSTestServer(httpSrvData) + require.NoError(t, err) + + defer ts.Close() + + tr := &http.Transport{TLSClientConfig: &tls.Config{ + RootCAs: caCertPool, + }} + + client := &http.Client{Transport: tr} + + res, err := client.Get(ts.URL) + require.NoError(t, err) + + defer res.Body.Close() + + rc := RequestConfig{ResponseDebug: tt.verbose} + buffer := bytes.Buffer{} + rc.PrintResponseDebug(&buffer, res) + + got := buffer.String() + fmt.Printf("got:\n%s\n", got) + + if !tt.verbose { + assert.Empty(t, got, "check PrintResponseDebug with verbose False") + return + } + + for _, output := range tt.outputs { + assert.Contains(t, got, output, "check PrintResponseDebug contains: %s", output) + } +} + +type processHTTPRequestsByHostTestCase struct { + srvAddr string + reqConf RequestConfig + pool *x509.CertPool + verbose bool + respStatusCode int + errMsg string +} + +func runProcessHTTPRequestsByHostSubtest(t *testing.T, tt processHTTPRequestsByHostTestCase) { + // t.Parallel() + httpSrvData := demoHttpServerData{ + serverAddr: tt.srvAddr, + proxyprotoEnabled: false, + serverName: "localhost", + } + + ts, err := NewHTTPSTestServer(httpSrvData) + require.NoError(t, err) + + defer ts.Close() + + respList, err := processHTTPRequestsByHost( + tt.reqConf, + tt.pool, + tt.verbose, + ) + if err != nil { + t.Error(err) + } + + verifyProcessHTTPRequestsResults(t, tt, respList) +} + +func verifyProcessHTTPRequestsResults(t *testing.T, tt processHTTPRequestsByHostTestCase, respList []ResponseData) { + t.Helper() + + for _, r := range respList { + fmt.Printf("resp type: %T\n", r) + + assert.Equal(t, + tt.srvAddr, + r.TransportAddress, + "check TransportAddress", + ) + + if tt.respStatusCode == 0 { + assert.Equal(t, + tt.errMsg, + r.Error.Error(), + "check Response Error", + ) + + continue + } + + // if expecting and error from the request do not + // check values from the response + require.NoError(t, + r.Error, + "check NoError in ResponseData", + ) + + ua := httpUserAgent + if tt.reqConf.UserAgent != emptyString { + ua = tt.reqConf.UserAgent + } + + assert.Equal(t, + ua, + r.Response.Request.Header.Get("user-agent"), + "check UserAgent", + ) + + assert.Equal(t, + len(tt.reqConf.ResponseBodyMatchRegexp) > 0, + r.ResponseBodyRegexpMatched, + "check body rex match", + ) + + assert.Equal(t, + tt.respStatusCode, + r.Response.StatusCode, + "check StatusCode", + ) + + for _, headers := range tt.reqConf.RequestHeaders { + assert.Equal(t, + headers.Value, + r.Response.Request.Header.Get(headers.Key), + "check RequestHeaders Key", + ) + } + } +} diff --git a/sonar-scanner.sh b/sonar-scanner.sh new file mode 100755 index 0000000..8797cee --- /dev/null +++ b/sonar-scanner.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +sonar-scanner -Dsonar.organization=xenos76 \ + -Dsonar.projectKey=xenOs76_https-wrench \ + -Dsonar.go.coverage.reportPaths=cover.out \ + -Dsonar.exclusions=completions/**,.devenv/**,.direnv/** \ + -D"sonar.tests=." \ + -D"sonar.test.inclusions=*_test.go,**/*_test.go"