diff --git a/github/admin_users.go b/github/admin_users.go index c2f331deeb7..d4077694541 100644 --- a/github/admin_users.go +++ b/github/admin_users.go @@ -31,13 +31,13 @@ func (s *AdminService) CreateUser(ctx context.Context, userReq CreateUserRequest return nil, nil, err } - var user User + var user *User resp, err := s.client.Do(req, &user) if err != nil { return nil, resp, err } - return &user, resp, nil + return user, resp, nil } // DeleteUser deletes a user in GitHub Enterprise. diff --git a/github/git_commits_test.go b/github/git_commits_test.go index 92934d1d2e9..be9607556b3 100644 --- a/github/git_commits_test.go +++ b/github/git_commits_test.go @@ -315,14 +315,14 @@ Commit Message.` Parents: []*Commit{{SHA: Ptr("p")}}, Author: &author, } - wantBody := createCommit{ + wantBody := &createCommit{ Message: input.Message, Tree: Ptr("t"), Parents: []string{"p"}, Author: &author, Signature: &signature, } - var gotBody createCommit + var gotBody *createCommit mux.HandleFunc("/repos/o/r/git/commits", func(w http.ResponseWriter, r *http.Request) { assertNilError(t, json.NewDecoder(r.Body).Decode(&gotBody)) testMethod(t, r, "POST") diff --git a/github/issue_import.go b/github/issue_import.go index 31761632c20..1907f02f12b 100644 --- a/github/issue_import.go +++ b/github/issue_import.go @@ -83,20 +83,20 @@ func (s *IssueImportService) Create(ctx context.Context, owner, repo string, iss req.Header.Set("Accept", mediaTypeIssueImportAPI) - var i IssueImportResponse + var i *IssueImportResponse resp, err := s.client.Do(req, &i) if err != nil { var aerr *AcceptedError if errors.As(err, &aerr) { if err := json.Unmarshal(aerr.Raw, &i); err != nil { - return &i, resp, err + return i, resp, err } - return &i, resp, err + return i, resp, err } return nil, resp, err } - return &i, resp, nil + return i, resp, nil } // CheckStatus checks the status of an imported issue. diff --git a/github/private_registries.go b/github/private_registries.go index d1899f4f9a8..e22f1e47f17 100644 --- a/github/private_registries.go +++ b/github/private_registries.go @@ -249,12 +249,12 @@ func (s *PrivateRegistriesService) ListOrganizationPrivateRegistries(ctx context return nil, nil, err } - var privateRegistries PrivateRegistries + var privateRegistries *PrivateRegistries resp, err := s.client.Do(req, &privateRegistries) if err != nil { return nil, resp, err } - return &privateRegistries, resp, nil + return privateRegistries, resp, nil } // CreateOrganizationPrivateRegistry creates a private registry configuration with an encrypted value for an organization. @@ -270,12 +270,12 @@ func (s *PrivateRegistriesService) CreateOrganizationPrivateRegistry(ctx context return nil, nil, err } - var result PrivateRegistry + var result *PrivateRegistry resp, err := s.client.Do(req, &result) if err != nil { return nil, resp, err } - return &result, resp, nil + return result, resp, nil } // GetOrganizationPrivateRegistriesPublicKey retrieves the public key for encrypting secrets for an organization's private registries. @@ -291,12 +291,12 @@ func (s *PrivateRegistriesService) GetOrganizationPrivateRegistriesPublicKey(ctx return nil, nil, err } - var publicKey PublicKey + var publicKey *PublicKey resp, err := s.client.Do(req, &publicKey) if err != nil { return nil, resp, err } - return &publicKey, resp, nil + return publicKey, resp, nil } // GetOrganizationPrivateRegistry gets a specific private registry for an organization. @@ -313,13 +313,13 @@ func (s *PrivateRegistriesService) GetOrganizationPrivateRegistry(ctx context.Co return nil, nil, err } - var privateRegistry PrivateRegistry + var privateRegistry *PrivateRegistry resp, err := s.client.Do(req, &privateRegistry) if err != nil { return nil, resp, err } - return &privateRegistry, resp, nil + return privateRegistry, resp, nil } // UpdateOrganizationPrivateRegistry updates a specific private registry for an organization. diff --git a/github/repos_forks.go b/github/repos_forks.go index 2e4ec16d765..dca0b7f7c88 100644 --- a/github/repos_forks.go +++ b/github/repos_forks.go @@ -79,20 +79,20 @@ func (s *RepositoriesService) CreateFork(ctx context.Context, owner, repo string return nil, nil, err } - var fork Repository + var fork *Repository resp, err := s.client.Do(req, &fork) if err != nil { // Persist AcceptedError's metadata to the Repository object. var aerr *AcceptedError if errors.As(err, &aerr) { if err := json.Unmarshal(aerr.Raw, &fork); err != nil { - return &fork, resp, err + return fork, resp, err } - return &fork, resp, err + return fork, resp, err } return nil, resp, err } - return &fork, resp, nil + return fork, resp, nil } diff --git a/github/repos_forks_test.go b/github/repos_forks_test.go index f3115cc15ae..393e4d669b6 100644 --- a/github/repos_forks_test.go +++ b/github/repos_forks_test.go @@ -131,6 +131,29 @@ func TestRepositoriesService_CreateFork_deferred(t *testing.T) { } } +func TestRepositoriesService_CreateFork_deferred_badBody(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + + opt := &RepositoryCreateForkOptions{Organization: "o", Name: "n", DefaultBranchOnly: true} + + mux.HandleFunc("/repos/o/r/forks", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "POST") + testJSONBody(t, r, opt) + w.WriteHeader(http.StatusAccepted) + fmt.Fprint(w, `{invalid json`) + }) + + ctx := t.Context() + repo, _, err := client.Repositories.CreateFork(ctx, "o", "r", opt) + if err == nil { + t.Fatal("Repositories.CreateFork returned nil error") + } + if repo != nil { + t.Errorf("Repositories.CreateFork returned non-nil repo: %+v", repo) + } +} + func TestRepositoriesService_CreateFork_invalidOwner(t *testing.T) { t.Parallel() client, _, _ := setup(t) diff --git a/github/security_advisories.go b/github/security_advisories.go index c749f0b1b45..f705d1adc5e 100644 --- a/github/security_advisories.go +++ b/github/security_advisories.go @@ -168,21 +168,21 @@ func (s *SecurityAdvisoriesService) CreateTemporaryPrivateFork(ctx context.Conte return nil, nil, err } - var fork Repository + var fork *Repository resp, err := s.client.Do(req, &fork) if err != nil { var aerr *AcceptedError if errors.As(err, &aerr) { if err := json.Unmarshal(aerr.Raw, &fork); err != nil { - return &fork, resp, err + return fork, resp, err } - return &fork, resp, err + return fork, resp, err } return nil, resp, err } - return &fork, resp, nil + return fork, resp, nil } // ListRepositorySecurityAdvisoriesForOrg lists the repository security advisories for an organization. diff --git a/github/security_advisories_test.go b/github/security_advisories_test.go index 5e4e187f979..a5d36e84f77 100644 --- a/github/security_advisories_test.go +++ b/github/security_advisories_test.go @@ -513,6 +513,26 @@ func TestSecurityAdvisoriesService_CreateTemporaryPrivateFork_deferred(t *testin } } +func TestSecurityAdvisoriesService_CreateTemporaryPrivateFork_deferred_badBody(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + + mux.HandleFunc("/repos/o/r/security-advisories/ghsa_id/forks", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "POST") + w.WriteHeader(http.StatusAccepted) + fmt.Fprint(w, `{invalid json`) + }) + + ctx := t.Context() + fork, _, err := client.SecurityAdvisories.CreateTemporaryPrivateFork(ctx, "o", "r", "ghsa_id") + if err == nil { + t.Fatal("SecurityAdvisories.CreateTemporaryPrivateFork returned nil error") + } + if fork != nil { + t.Errorf("SecurityAdvisories.CreateTemporaryPrivateFork returned non-nil fork: %+v", fork) + } +} + func TestSecurityAdvisoriesService_CreateTemporaryPrivateFork_invalidOwner(t *testing.T) { t.Parallel() client, _, _ := setup(t) diff --git a/tools/extraneousnew/extraneousnew.go b/tools/extraneousnew/extraneousnew.go index c28a700c315..c48ae0cf38e 100644 --- a/tools/extraneousnew/extraneousnew.go +++ b/tools/extraneousnew/extraneousnew.go @@ -119,11 +119,12 @@ func inspectAllBlocks(pass *analysis.Pass, root ast.Node) { } func inspectBlock(pass *analysis.Pass, block *ast.BlockStmt) { - // Track pointers that are currently nil. nilPointers := make(map[string]*ast.Ident) + valueVars := make(map[string]*valueVarInfo) for i, stmt := range block.List { - // 1. Check for `var v *T` or `var v *struct{...}` + // 1. Check for `var v *T` or `var v *struct{...}` (nil pointer declarations) + // and `var v T` (value-type declarations that should use a pointer instead). if decl, ok := stmt.(*ast.DeclStmt); ok { if gen, ok := decl.Decl.(*ast.GenDecl); ok && gen.Tok == token.VAR { for _, spec := range gen.Specs { @@ -132,6 +133,14 @@ func inspectBlock(pass *analysis.Pass, block *ast.BlockStmt) { for _, name := range vSpec.Names { nilPointers[name.Name] = name } + } else if typeIdent, ok := vSpec.Type.(*ast.Ident); ok && len(vSpec.Values) == 0 && token.IsExported(typeIdent.Name) { + for _, name := range vSpec.Names { + valueVars[name.Name] = &valueVarInfo{ + ident: name, + typeName: typeIdent.Name, + stmtIndex: i, + } + } } } } @@ -146,8 +155,9 @@ func inspectBlock(pass *analysis.Pass, block *ast.BlockStmt) { if assign, ok := stmt.(*ast.AssignStmt); ok && len(assign.Lhs) == 1 && len(assign.Rhs) == 1 { if lhs, ok := assign.Lhs[0].(*ast.Ident); ok { assignLHS = lhs - // Any assignment to v means it's no longer a "nil pointer" for our simple tracking. + // Any assignment to v means it's no longer tracked as a nil pointer or value var. delete(nilPointers, lhs.Name) + delete(valueVars, lhs.Name) // Check for v := new(T) or v := &T{} if call, ok := assign.Rhs[0].(*ast.CallExpr); ok { @@ -174,11 +184,13 @@ func inspectBlock(pass *analysis.Pass, block *ast.BlockStmt) { continue } - // If it's a regular assignment (possibly with multiple variables), it might initialize a nil pointer. + // If it's a regular assignment (possibly with multiple variables), remove tracked variables + // since they are no longer nil pointers or uninitialized value vars. if assign, ok := stmt.(*ast.AssignStmt); ok { for _, lhs := range assign.Lhs { if ident, ok := lhs.(*ast.Ident); ok { delete(nilPointers, ident.Name) + delete(valueVars, ident.Name) } } } @@ -211,6 +223,50 @@ func inspectBlock(pass *analysis.Pass, block *ast.BlockStmt) { } return true }) + + // 4. Check if a value-type var is passed to Do/Decode and could use a pointer instead. + ast.Inspect(stmt, func(n ast.Node) bool { + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + fnName := getFunctionName(call.Fun) + if fnName == "" { + return true + } + + var targetArg ast.Expr + if fnName == "Do" && len(call.Args) == 2 { + targetArg = call.Args[1] + } else if fnName == "Decode" && len(call.Args) == 1 { + targetArg = call.Args[0] + } + + if targetArg == nil { + return true + } + + if ident, ok := targetArg.(*ast.Ident); ok { + if info, isValue := valueVars[ident.Name]; isValue { + delete(valueVars, ident.Name) + if !isUsedElsewhere(block, info.stmtIndex, i, info.ident.Name) && !isReadAfterCall(block, i, info.ident.Name) { + pass.Reportf(ident.Pos(), "use 'var %v *%v' and pass '&%v' instead", ident.Name, info.typeName, ident.Name) + } + } + } + if unary, ok := targetArg.(*ast.UnaryExpr); ok && unary.Op == token.AND { + if ident, ok := unary.X.(*ast.Ident); ok { + if info, isValue := valueVars[ident.Name]; isValue { + delete(valueVars, ident.Name) + if !isUsedElsewhere(block, info.stmtIndex, i, info.ident.Name) && !isReadAfterCall(block, i, info.ident.Name) { + pass.Reportf(ident.Pos(), "use 'var %v *%v' instead", ident.Name, info.typeName) + } + } + } + } + return true + }) } } @@ -284,6 +340,87 @@ func lookAhead(pass *analysis.Pass, block *ast.BlockStmt, startIndex int, lhsIde } } +type valueVarInfo struct { + ident *ast.Ident + typeName string + stmtIndex int +} + +// isUsedElsewhere checks if the variable named name is used for anything other than +// being a target argument in a Do/Decode call, between declIndex and useIndex (exclusive). +func isUsedElsewhere(block *ast.BlockStmt, declIndex, useIndex int, name string) bool { + for j := declIndex + 1; j < useIndex; j++ { + stmt := block.List[j] + var foundOtherUse bool + ast.Inspect(stmt, func(un ast.Node) bool { + if foundOtherUse { + return false + } + ident, ok := un.(*ast.Ident) + if !ok || ident.Name != name { + return true + } + + isTargetArg := false + ast.Inspect(stmt, func(n ast.Node) bool { + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + fnName := getFunctionName(call.Fun) + var targetArg ast.Expr + if fnName == "Do" && len(call.Args) == 2 { + targetArg = call.Args[1] + } else if fnName == "Decode" && len(call.Args) == 1 { + targetArg = call.Args[0] + } + if targetArg != nil && isIdentOrAddressOfIdent(targetArg, name) { + isTargetArg = true + return false + } + return true + }) + + if !isTargetArg { + foundOtherUse = true + } + return false + }) + if foundOtherUse { + return true + } + } + return false +} + +// isReadAfterCall checks if the variable is read (accessed via selector like v.Field) +// in any statement after callIndex. This indicates correct zero-value usage where +// Do/Decode fills the value and the caller reads its fields afterward. +func isReadAfterCall(block *ast.BlockStmt, callIndex int, name string) bool { + for j := callIndex + 1; j < len(block.List); j++ { + stmt := block.List[j] + var found bool + ast.Inspect(stmt, func(un ast.Node) bool { + if found { + return false + } + sel, ok := un.(*ast.SelectorExpr) + if !ok { + return true + } + if ident, ok := sel.X.(*ast.Ident); ok && ident.Name == name { + found = true + return false + } + return true + }) + if found { + return true + } + } + return false +} + func isIdentOrAddressOfIdent(expr ast.Expr, name string) bool { if ident, ok := expr.(*ast.Ident); ok { return ident.Name == name diff --git a/tools/extraneousnew/testdata/src/has-warnings/main.go b/tools/extraneousnew/testdata/src/has-warnings/main.go index cef0441c89d..53354c33b3b 100644 --- a/tools/extraneousnew/testdata/src/has-warnings/main.go +++ b/tools/extraneousnew/testdata/src/has-warnings/main.go @@ -73,4 +73,18 @@ func (s *Service) TestMethod(req any, r *http.Request, t *testing.T) { var v13 *T s.client.Do(req, v13) // want "pass '&v13' instead" + + // Unnecessary use of value + var v14 T + s.client.Do(req, &v14) // want "use 'var v14 [*]T' instead" + + var v15 T + s.client.Do(req, v15) // want "use 'var v15 [*]T' and pass '&v15' instead" + + // Value-type var with Decode + var v16 T + json.NewDecoder(r.Body).Decode(&v16) // want "use 'var v16 [*]T' instead" + + var v17 T + json.NewDecoder(r.Body).Decode(v17) // want "use 'var v17 [*]T' and pass '&v17' instead" } diff --git a/tools/extraneousnew/testdata/src/no-warnings/main.go b/tools/extraneousnew/testdata/src/no-warnings/main.go index b69bce2f7fa..e1adcb84ac9 100644 --- a/tools/extraneousnew/testdata/src/no-warnings/main.go +++ b/tools/extraneousnew/testdata/src/no-warnings/main.go @@ -9,6 +9,10 @@ type T struct { Field string } +type CheckPrivateReporting struct { + Enabled bool +} + type Client struct{} func (c *Client) Do(req any, v any) (any, error) { @@ -49,3 +53,19 @@ func (s *Receiver) unexportedMethod(req any) { v := new(T) s.client.Do(req, v) // Should be ignored because unexported. } + +func (s *Receiver) ValueVarMethod(req any) { + // Value-type var used elsewhere before Do — no warning + var v T + v.Field = "set" + s.client.Do(req, &v) + + // Value-type var with non-struct type — no warning + var data any + s.client.Do(req, &data) + + // Value-type var read after Do via selector — correct zero-value usage, no warning + var reporting CheckPrivateReporting + s.client.Do(req, &reporting) + _ = reporting.Enabled +}