From 7ab4d5f1b711359549c16258634e427279d02e83 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 13:49:10 -0400 Subject: [PATCH 01/24] feat(doctor): add core types and Check interface --- pkg/doctor/doctor.go | 81 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 pkg/doctor/doctor.go diff --git a/pkg/doctor/doctor.go b/pkg/doctor/doctor.go new file mode 100644 index 0000000000..27467a323a --- /dev/null +++ b/pkg/doctor/doctor.go @@ -0,0 +1,81 @@ +package doctor + +import ( + "errors" + + sitter "github.com/smacker/go-tree-sitter" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/schema" +) + +// ErrNoAutoFix is returned by Fix() for detect-only checks. +var ErrNoAutoFix = errors.New("no auto-fix available for this check") + +// Severity of a finding. +type Severity int + +const ( + SeverityError Severity = iota // Must fix -- will cause build/predict failures + SeverityWarning // Should fix -- deprecated patterns, future breakage + SeverityInfo // Informational -- suggestions, best practices +) + +// String returns the human-readable name of the severity level. +func (s Severity) String() string { + switch s { + case SeverityError: + return "error" + case SeverityWarning: + return "warning" + case SeverityInfo: + return "info" + default: + return "unknown" + } +} + +// Group categorizes checks for display purposes. +type Group string + +const ( + GroupConfig Group = "Config" + GroupPython Group = "Python" + GroupEnvironment Group = "Environment" +) + +// Finding represents a single problem detected by a check. +type Finding struct { + Severity Severity + Message string // What's wrong + Remediation string // How to fix it + File string // Optional: file path where the issue was found + Line int // Optional: line number (1-indexed, 0 means unknown) +} + +// Check is the interface every doctor rule implements. +type Check interface { + Name() string + Group() Group + Description() string + Check(ctx *CheckContext) ([]Finding, error) + Fix(ctx *CheckContext, findings []Finding) error +} + +// ParsedFile holds tree-sitter parse results for a Python file. +type ParsedFile struct { + Path string // Relative path from project root + Source []byte // Raw file contents + Tree *sitter.Tree // Tree-sitter parse tree + Imports *schema.ImportContext // Collected imports +} + +// CheckContext provides checks with access to project state. +// Built once by the runner and passed to every check. +type CheckContext struct { + ProjectDir string + Config *config.Config // Parsed cog.yaml (nil if parsing failed) + ConfigFile []byte // Raw cog.yaml bytes (available even if parsing failed) + PythonFiles map[string]*ParsedFile // Pre-parsed Python files keyed by relative path + PythonPath string // Path to python binary (empty if not found) +} From d2f198054417e8fa77ceecb0c60494522db80374 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 13:51:11 -0400 Subject: [PATCH 02/24] refactor(schema): export tree-sitter helpers for reuse by doctor checks --- pkg/schema/python/parser.go | 196 ++++++++++++++++++------------------ 1 file changed, 98 insertions(+), 98 deletions(-) diff --git a/pkg/schema/python/parser.go b/pkg/schema/python/parser.go index 253de24d48..a425956909 100644 --- a/pkg/schema/python/parser.go +++ b/pkg/schema/python/parser.go @@ -38,7 +38,7 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi root := tree.RootNode() // 1. Collect imports - imports := collectImports(root, source) + imports := CollectImports(root, source) // 2. Collect module-level variable assignments moduleScope := collectModuleScope(root, source) @@ -101,8 +101,8 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi // Helpers // --------------------------------------------------------------------------- -// namedChildren returns all named children of a node. -func namedChildren(n *sitter.Node) []*sitter.Node { +// NamedChildren returns all named children of a node. +func NamedChildren(n *sitter.Node) []*sitter.Node { count := int(n.NamedChildCount()) result := make([]*sitter.Node, 0, count) for i := range count { @@ -111,8 +111,8 @@ func namedChildren(n *sitter.Node) []*sitter.Node { return result } -// allChildren returns all children (named and anonymous) of a node. -func allChildren(n *sitter.Node) []*sitter.Node { +// AllChildren returns all children (named and anonymous) of a node. +func AllChildren(n *sitter.Node) []*sitter.Node { count := int(n.ChildCount()) result := make([]*sitter.Node, 0, count) for i := range count { @@ -121,8 +121,8 @@ func allChildren(n *sitter.Node) []*sitter.Node { return result } -// content returns the source text for a node. -func content(n *sitter.Node, source []byte) string { +// Content returns the source text for a node. +func Content(n *sitter.Node, source []byte) string { return n.Content(source) } @@ -130,10 +130,10 @@ func content(n *sitter.Node, source []byte) string { // Import collection // --------------------------------------------------------------------------- -func collectImports(root *sitter.Node, source []byte) *schema.ImportContext { +func CollectImports(root *sitter.Node, source []byte) *schema.ImportContext { ctx := schema.NewImportContext() - for _, child := range namedChildren(root) { + for _, child := range NamedChildren(root) { if child.Type() == "import_from_statement" { parseImportFrom(child, source, ctx) } @@ -157,15 +157,15 @@ func parseImportFrom(node *sitter.Node, source []byte, ctx *schema.ImportContext if moduleNode == nil { return } - module := content(moduleNode, source) + module := Content(moduleNode, source) - for _, child := range allChildren(node) { + for _, child := range AllChildren(node) { switch child.Type() { case "dotted_name": // Single import: `from X import name` // Skip if this is the module_name itself if child.StartByte() != moduleNode.StartByte() { - name := content(child, source) + name := Content(child, source) ctx.Names.Set(name, schema.ImportEntry{Module: module, Original: name}) } case "aliased_import": @@ -174,29 +174,29 @@ func parseImportFrom(node *sitter.Node, source []byte, ctx *schema.ImportContext aliasNode := child.ChildByFieldName("alias") orig := "" if origNode != nil { - orig = content(origNode, source) + orig = Content(origNode, source) } alias := orig if aliasNode != nil { - alias = content(aliasNode, source) + alias = Content(aliasNode, source) } ctx.Names.Set(alias, schema.ImportEntry{Module: module, Original: orig}) case "import_list": - for _, importChild := range allChildren(child) { + for _, importChild := range AllChildren(child) { switch importChild.Type() { case "dotted_name": - name := content(importChild, source) + name := Content(importChild, source) ctx.Names.Set(name, schema.ImportEntry{Module: module, Original: name}) case "aliased_import": origNode := importChild.ChildByFieldName("name") aliasNode := importChild.ChildByFieldName("alias") orig := "" if origNode != nil { - orig = content(origNode, source) + orig = Content(origNode, source) } alias := orig if aliasNode != nil { - alias = content(aliasNode, source) + alias = Content(aliasNode, source) } ctx.Names.Set(alias, schema.ImportEntry{Module: module, Original: orig}) } @@ -213,7 +213,7 @@ type moduleScope map[string]schema.DefaultValue func collectModuleScope(root *sitter.Node, source []byte) moduleScope { scope := make(moduleScope) - for _, child := range namedChildren(root) { + for _, child := range NamedChildren(root) { var assign *sitter.Node if child.Type() == "expression_statement" { if child.NamedChildCount() == 1 { @@ -233,7 +233,7 @@ func collectModuleScope(root *sitter.Node, source []byte) moduleScope { if left == nil || left.Type() != "identifier" { continue } - name := content(left, source) + name := Content(left, source) right := assign.ChildByFieldName("right") if right == nil { @@ -253,7 +253,7 @@ func resolveDefaultExpr(node *sitter.Node, source []byte, scope moduleScope) (sc return val, true } if node.Type() == "identifier" { - name := content(node, source) + name := Content(node, source) if val, ok := scope[name]; ok { return val, true } @@ -268,7 +268,7 @@ func resolveChoicesExpr(node *sitter.Node, source []byte, scope moduleScope) ([] return parseListLiteral(node, source) case "identifier": - name := content(node, source) + name := Content(node, source) val, ok := scope[name] if !ok { return nil, false @@ -284,8 +284,8 @@ func resolveChoicesExpr(node *sitter.Node, source []byte, scope moduleScope) ([] case "binary_operator": // Only handle + (list concatenation) hasPlus := false - for _, c := range allChildren(node) { - if !c.IsNamed() && content(c, source) == "+" { + for _, c := range AllChildren(node) { + if !c.IsNamed() && Content(c, source) == "+" { hasPlus = true break } @@ -314,7 +314,7 @@ func resolveChoicesExpr(node *sitter.Node, source []byte, scope moduleScope) ([] // resolveChoicesCall resolves list(X.keys()) or list(X.values()). func resolveChoicesCall(node *sitter.Node, source []byte, scope moduleScope) ([]schema.DefaultValue, bool) { funcNode := node.ChildByFieldName("function") - if funcNode == nil || content(funcNode, source) != "list" { + if funcNode == nil || Content(funcNode, source) != "list" { return nil, false } @@ -325,7 +325,7 @@ func resolveChoicesCall(node *sitter.Node, source []byte, scope moduleScope) ([] // Find the single positional argument var arg *sitter.Node - for _, c := range namedChildren(args) { + for _, c := range NamedChildren(args) { arg = c break } @@ -344,8 +344,8 @@ func resolveChoicesCall(node *sitter.Node, source []byte, scope moduleScope) ([] return nil, false } - varName := content(obj, source) - methodName := content(attr, source) + varName := Content(obj, source) + methodName := Content(attr, source) dictVal, ok := scope[varName] if !ok || dictVal.Kind != schema.DefaultDict { @@ -368,8 +368,8 @@ func resolveChoicesCall(node *sitter.Node, source []byte, scope moduleScope) ([] func collectModelClasses(root *sitter.Node, source []byte, imports *schema.ImportContext) schema.ModelClassMap { models := schema.NewOrderedMap[string, []schema.ModelField]() - for _, child := range namedChildren(root) { - classNode := unwrapClass(child) + for _, child := range NamedChildren(root) { + classNode := UnwrapClass(child) if classNode == nil { continue } @@ -378,9 +378,9 @@ func collectModelClasses(root *sitter.Node, source []byte, imports *schema.Impor if nameNode == nil { continue } - className := content(nameNode, source) + className := Content(nameNode, source) - if !inheritsFromBaseModel(classNode, source, imports) { + if !InheritsFromBaseModel(classNode, source, imports) { continue } @@ -449,7 +449,7 @@ func resolveExternalModels(sourceDir string, imports *schema.ImportContext, mode return } - fileImports := collectImports(tree.RootNode(), source) + fileImports := CollectImports(tree.RootNode(), source) fileModels := collectModelClasses(tree.RootNode(), source, fileImports) // Merge discovered models into the caller's map. @@ -518,12 +518,12 @@ func isKnownExternalModule(module string) bool { return false } -func unwrapClass(node *sitter.Node) *sitter.Node { +func UnwrapClass(node *sitter.Node) *sitter.Node { if node.Type() == "class_definition" { return node } if node.Type() == "decorated_definition" { - for _, c := range namedChildren(node) { + for _, c := range NamedChildren(node) { if c.Type() == "class_definition" { return c } @@ -532,12 +532,12 @@ func unwrapClass(node *sitter.Node) *sitter.Node { return nil } -func unwrapFunction(node *sitter.Node) *sitter.Node { +func UnwrapFunction(node *sitter.Node) *sitter.Node { if node.Type() == "function_definition" { return node } if node.Type() == "decorated_definition" { - for _, c := range namedChildren(node) { + for _, c := range NamedChildren(node) { if c.Type() == "function_definition" { return c } @@ -546,21 +546,21 @@ func unwrapFunction(node *sitter.Node) *sitter.Node { return nil } -func inheritsFromBaseModel(classNode *sitter.Node, source []byte, imports *schema.ImportContext) bool { +func InheritsFromBaseModel(classNode *sitter.Node, source []byte, imports *schema.ImportContext) bool { supers := classNode.ChildByFieldName("superclasses") if supers == nil { return false } - for _, child := range allChildren(supers) { + for _, child := range AllChildren(supers) { switch child.Type() { case "identifier": - name := content(child, source) + name := Content(child, source) if imports.IsBaseModel(name) || name == "BaseModel" { return true } case "attribute": // Handle dotted access: pydantic.BaseModel, cog.BaseModel - text := content(child, source) + text := Content(child, source) if strings.HasSuffix(text, ".BaseModel") { return true } @@ -576,7 +576,7 @@ func extractClassAnnotations(classNode *sitter.Node, source []byte) []schema.Mod } var fields []schema.ModelField - for _, child := range namedChildren(body) { + for _, child := range NamedChildren(body) { node := child if child.Type() == "expression_statement" && child.NamedChildCount() == 1 { node = child.NamedChild(0) @@ -603,7 +603,7 @@ func parseAnnotatedAssignment(node *sitter.Node, source []byte) (schema.ModelFie return schema.ModelField{}, false } - name := content(left, source) + name := Content(left, source) typeAnn, err := parseTypeAnnotation(typeNode, source) if err != nil { return schema.ModelField{}, false @@ -620,7 +620,7 @@ func parseAnnotatedAssignment(node *sitter.Node, source []byte) (schema.ModelFie } func parseBareAnnotation(node *sitter.Node, source []byte) (schema.ModelField, bool) { - text := strings.TrimSpace(content(node, source)) + text := strings.TrimSpace(Content(node, source)) parts := strings.SplitN(text, ":", 2) if len(parts) != 2 { return schema.ModelField{}, false @@ -763,8 +763,8 @@ func newInputRegistry() *inputRegistry { func collectInputRegistry(root *sitter.Node, source []byte, imports *schema.ImportContext, scope moduleScope) *inputRegistry { registry := newInputRegistry() - for _, child := range namedChildren(root) { - classNode := unwrapClass(child) + for _, child := range NamedChildren(root) { + classNode := UnwrapClass(child) if classNode == nil { continue } @@ -772,14 +772,14 @@ func collectInputRegistry(root *sitter.Node, source []byte, imports *schema.Impo if nameNode == nil { continue } - className := content(nameNode, source) + className := Content(nameNode, source) body := classNode.ChildByFieldName("body") if body == nil { continue } - for _, stmt := range namedChildren(body) { + for _, stmt := range NamedChildren(body) { inner := stmt if stmt.Type() == "expression_statement" && stmt.NamedChildCount() == 1 { inner = stmt.NamedChild(0) @@ -789,7 +789,7 @@ func collectInputRegistry(root *sitter.Node, source []byte, imports *schema.Impo collectInputAttribute(className, inner, source, imports, scope, registry) } - if funcNode := unwrapFunction(inner); funcNode != nil { + if funcNode := UnwrapFunction(inner); funcNode != nil { collectInputMethod(className, funcNode, source, imports, scope, registry) } } @@ -803,7 +803,7 @@ func collectInputAttribute(className string, assignment *sitter.Node, source []b if left == nil || left.Type() != "identifier" { return } - attrName := content(left, source) + attrName := Content(left, source) right := assignment.ChildByFieldName("right") if right == nil || !isInputCall(right, source, imports) { @@ -823,7 +823,7 @@ func collectInputMethod(className string, funcNode *sitter.Node, source []byte, if nameNode == nil { return } - methodName := content(nameNode, source) + methodName := Content(nameNode, source) params := funcNode.ChildByFieldName("parameters") if params == nil { @@ -831,10 +831,10 @@ func collectInputMethod(className string, funcNode *sitter.Node, source []byte, } var paramNames []string - for _, param := range allChildren(params) { + for _, param := range AllChildren(params) { switch param.Type() { case "identifier": - name := content(param, source) + name := Content(param, source) if name != "self" && name != "cls" { paramNames = append(paramNames, name) } @@ -843,7 +843,7 @@ func collectInputMethod(className string, funcNode *sitter.Node, source []byte, for j := 0; j < int(param.NamedChildCount()); j++ { c := param.NamedChild(j) if c.Type() == "identifier" { - name := content(c, source) + name := Content(c, source) if name != "self" && name != "cls" { paramNames = append(paramNames, name) } @@ -852,7 +852,7 @@ func collectInputMethod(className string, funcNode *sitter.Node, source []byte, } case "typed_default_parameter", "default_parameter": if n := param.ChildByFieldName("name"); n != nil { - name := content(n, source) + name := Content(n, source) if name != "self" && name != "cls" { paramNames = append(paramNames, name) } @@ -879,7 +879,7 @@ func collectInputMethod(className string, funcNode *sitter.Node, source []byte, } func findReturnInputCall(body *sitter.Node, source []byte, imports *schema.ImportContext) *sitter.Node { - for _, child := range namedChildren(body) { + for _, child := range NamedChildren(body) { if child.Type() == "return_statement" { if child.NamedChildCount() > 0 { expr := child.NamedChild(0) @@ -895,7 +895,7 @@ func findReturnInputCall(body *sitter.Node, source []byte, imports *schema.Impor func resolveInputReference(node *sitter.Node, source []byte, registry *inputRegistry) (inputCallInfo, bool) { switch node.Type() { case "attribute": - text := content(node, source) + text := Content(node, source) info, ok := registry.Attributes[text] return info, ok @@ -904,7 +904,7 @@ func resolveInputReference(node *sitter.Node, source []byte, registry *inputRegi if funcNode == nil || funcNode.Type() != "attribute" { return inputCallInfo{}, false } - key := content(funcNode, source) + key := Content(funcNode, source) methodInfo, ok := registry.Methods[key] if !ok { return inputCallInfo{}, false @@ -920,12 +920,12 @@ func resolveInputReference(node *sitter.Node, source []byte, registry *inputRegi // Build param_name -> call-site value map argValues := make(map[string]*sitter.Node) positionalIdx := 0 - for _, arg := range namedChildren(args) { + for _, arg := range NamedChildren(args) { if arg.Type() == "keyword_argument" { nameNode := arg.ChildByFieldName("name") valNode := arg.ChildByFieldName("value") if nameNode != nil && valNode != nil { - argValues[content(nameNode, source)] = valNode + argValues[Content(nameNode, source)] = valNode } } else if positionalIdx < len(methodInfo.ParamNames) { argValues[methodInfo.ParamNames[positionalIdx]] = arg @@ -966,26 +966,26 @@ func resolveInputReference(node *sitter.Node, source []byte, registry *inputRegi func findTargetFunction(root *sitter.Node, source []byte, predictRef, methodName string) (*sitter.Node, error) { // First: look for a class with this name - for _, child := range namedChildren(root) { - classNode := unwrapClass(child) + for _, child := range NamedChildren(root) { + classNode := UnwrapClass(child) if classNode == nil { continue } nameNode := classNode.ChildByFieldName("name") - if nameNode != nil && content(nameNode, source) == predictRef { + if nameNode != nil && Content(nameNode, source) == predictRef { return findMethodInClass(classNode, source, predictRef, methodName) } } // Second: look for standalone function - for _, child := range namedChildren(root) { - funcNode := unwrapFunction(child) + for _, child := range NamedChildren(root) { + funcNode := UnwrapFunction(child) if funcNode == nil { continue } nameNode := funcNode.ChildByFieldName("name") if nameNode != nil { - name := content(nameNode, source) + name := Content(nameNode, source) if name == predictRef || name == methodName { return funcNode, nil } @@ -1001,13 +1001,13 @@ func findMethodInClass(classNode *sitter.Node, source []byte, className, methodN return nil, schema.WrapError(schema.ErrParse, fmt.Sprintf("class %s has no body", className), nil) } - for _, child := range namedChildren(body) { - funcNode := unwrapFunction(child) + for _, child := range NamedChildren(body) { + funcNode := UnwrapFunction(child) if funcNode == nil { continue } nameNode := funcNode.ChildByFieldName("name") - if nameNode != nil && content(nameNode, source) == methodName { + if nameNode != nil && Content(nameNode, source) == methodName { return funcNode, nil } } @@ -1020,9 +1020,9 @@ func findMethodInClass(classNode *sitter.Node, source []byte, className, methodN // --------------------------------------------------------------------------- func firstParamIsSelf(params *sitter.Node, source []byte) bool { - for _, child := range allChildren(params) { + for _, child := range AllChildren(params) { if child.Type() == "identifier" { - return content(child, source) == "self" + return Content(child, source) == "self" } } return false @@ -1041,11 +1041,11 @@ func extractInputs( order := 0 seenSelf := false - for _, child := range allChildren(paramsNode) { + for _, child := range AllChildren(paramsNode) { switch child.Type() { case "identifier": if !seenSelf && skipSelf { - name := content(child, source) + name := Content(child, source) if name == "self" { seenSelf = true continue @@ -1072,7 +1072,7 @@ func extractInputs( nameNode := child.ChildByFieldName("name") paramName := "" if nameNode != nil { - paramName = content(nameNode, source) + paramName = Content(nameNode, source) } return nil, schema.WrapError(schema.ErrMissingTypeAnnotation, fmt.Sprintf("parameter '%s' on %s has no type annotation", paramName, methodName), nil) } @@ -1091,7 +1091,7 @@ func parseTypedParameter(node *sitter.Node, source []byte, order int, methodName switch c.Type() { case "identifier": if name == "" { - name = content(c, source) + name = Content(c, source) } case "type": typeNode = c @@ -1133,7 +1133,7 @@ func parseTypedDefaultParameter( if nameNode == nil { return schema.InputField{}, schema.WrapError(schema.ErrParse, "typed_default_parameter has no name", nil) } - name := content(nameNode, source) + name := Content(nameNode, source) typeNode := node.ChildByFieldName("type") if typeNode == nil { @@ -1203,7 +1203,7 @@ func parseTypedDefaultParameter( } // Can't resolve — hard error - valText := content(valNode, source) + valText := Content(valNode, source) return schema.InputField{}, schema.WrapError(schema.ErrDefaultNotResolvable, fmt.Sprintf("parameter '%s': default `%s` cannot be statically resolved", name, valText), nil) } @@ -1229,17 +1229,17 @@ func parseTypeAnnotation(node *sitter.Node, source []byte) (schema.TypeAnnotatio switch n.Type() { case "identifier": - return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: content(n, source)}, nil + return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: Content(n, source)}, nil case "subscript": value := n.ChildByFieldName("value") if value == nil { return schema.TypeAnnotation{}, schema.WrapError(schema.ErrParse, "subscript has no value", nil) } - outer := content(value, source) + outer := Content(value, source) var args []schema.TypeAnnotation - for _, child := range namedChildren(n) { + for _, child := range NamedChildren(n) { // Skip the outer identifier (the value field) if child.StartByte() == value.StartByte() { continue @@ -1265,8 +1265,8 @@ func parseTypeAnnotation(node *sitter.Node, source []byte) (schema.TypeAnnotatio // Check that operator is | isUnion := false - for _, c := range allChildren(n) { - if !c.IsNamed() && content(c, source) == "|" { + for _, c := range AllChildren(n) { + if !c.IsNamed() && Content(c, source) == "|" { isUnion = true break } @@ -1303,10 +1303,10 @@ func parseTypeAnnotation(node *sitter.Node, source []byte) (schema.TypeAnnotatio return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: "None"}, nil case "attribute": - return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: content(n, source)}, nil + return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: Content(n, source)}, nil case "string", "concatenated_string": - text := content(n, source) + text := Content(n, source) inner := strings.TrimLeft(text, "\"'") inner = strings.TrimRight(inner, "\"'") if ann, ok := parseTypeFromString(inner); ok { @@ -1315,7 +1315,7 @@ func parseTypeAnnotation(node *sitter.Node, source []byte) (schema.TypeAnnotatio return schema.TypeAnnotation{}, errUnsupported(fmt.Sprintf("string annotation: %s", text)) default: - text := content(n, source) + text := Content(n, source) if ann, ok := parseTypeFromString(text); ok { return ann, nil } @@ -1339,7 +1339,7 @@ func isInputCall(node *sitter.Node, source []byte, imports *schema.ImportContext if funcNode == nil { return false } - name := content(funcNode, source) + name := Content(funcNode, source) if name == "Input" { return true } @@ -1357,7 +1357,7 @@ func parseInputCall(node *sitter.Node, source []byte, paramName string, scope mo return info, nil } - for _, child := range namedChildren(args) { + for _, child := range NamedChildren(args) { if child.Type() != "keyword_argument" { continue } @@ -1367,7 +1367,7 @@ func parseInputCall(node *sitter.Node, source []byte, paramName string, scope mo continue } - key := content(keyNode, source) + key := Content(keyNode, source) switch key { case "default": val, ok := resolveDefaultExpr(valNode, source, scope) @@ -1435,14 +1435,14 @@ func parseDefaultValue(node *sitter.Node, source []byte) (schema.DefaultValue, b case "false": return schema.DefaultValue{Kind: schema.DefaultBool, Bool: false}, true case "integer": - text := content(node, source) + text := Content(node, source) n, err := strconv.ParseInt(text, 0, 64) if err != nil { return schema.DefaultValue{}, false } return schema.DefaultValue{Kind: schema.DefaultInt, Int: n}, true case "float": - text := content(node, source) + text := Content(node, source) f, err := strconv.ParseFloat(text, 64) if err != nil { return schema.DefaultValue{}, false @@ -1473,7 +1473,7 @@ func parseDefaultValue(node *sitter.Node, source []byte) (schema.DefaultValue, b } return schema.DefaultValue{Kind: schema.DefaultSet, List: items}, true case "unary_operator": - text := strings.TrimSpace(content(node, source)) + text := strings.TrimSpace(Content(node, source)) if n, err := strconv.ParseInt(text, 0, 64); err == nil { return schema.DefaultValue{Kind: schema.DefaultInt, Int: n}, true } @@ -1483,7 +1483,7 @@ func parseDefaultValue(node *sitter.Node, source []byte) (schema.DefaultValue, b return schema.DefaultValue{}, false case "tuple": var items []schema.DefaultValue - for _, child := range namedChildren(node) { + for _, child := range NamedChildren(node) { if val, ok := parseDefaultValue(child, source); ok { items = append(items, val) } @@ -1494,7 +1494,7 @@ func parseDefaultValue(node *sitter.Node, source []byte) (schema.DefaultValue, b } func parseStringLiteral(node *sitter.Node, source []byte) (string, bool) { - text := content(node, source) + text := Content(node, source) if strings.HasPrefix(text, `"""`) || strings.HasPrefix(text, `'''`) { if len(text) >= 6 { return text[3 : len(text)-3], true @@ -1517,7 +1517,7 @@ func parseStringLiteral(node *sitter.Node, source []byte) (string, bool) { } func parseNumberLiteral(node *sitter.Node, source []byte) (float64, bool) { - text := strings.TrimSpace(content(node, source)) + text := strings.TrimSpace(Content(node, source)) f, err := strconv.ParseFloat(text, 64) if err != nil { return 0, false @@ -1532,7 +1532,7 @@ func parseBoolLiteral(node *sitter.Node, source []byte) (bool, bool) { case "false": return false, true } - text := content(node, source) + text := Content(node, source) switch text { case "True": return true, true @@ -1547,7 +1547,7 @@ func parseListLiteral(node *sitter.Node, source []byte) ([]schema.DefaultValue, return nil, false } var items []schema.DefaultValue - for _, child := range namedChildren(node) { + for _, child := range NamedChildren(node) { val, ok := parseDefaultValue(child, source) if !ok { return nil, false @@ -1562,7 +1562,7 @@ func parseDictLiteral(node *sitter.Node, source []byte) ([]schema.DefaultValue, return nil, nil, false } var keys, vals []schema.DefaultValue - for _, child := range namedChildren(node) { + for _, child := range NamedChildren(node) { if child.Type() == "pair" { keyNode := child.ChildByFieldName("key") valNode := child.ChildByFieldName("value") @@ -1585,7 +1585,7 @@ func parseSetLiteral(node *sitter.Node, source []byte) ([]schema.DefaultValue, b return nil, false } var items []schema.DefaultValue - for _, child := range namedChildren(node) { + for _, child := range NamedChildren(node) { val, ok := parseDefaultValue(child, source) if !ok { return nil, false From d2ba918f37386da9b6869c66ccf9fde2e10f6eeb Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 13:52:14 -0400 Subject: [PATCH 03/24] feat(doctor): add runner with check orchestration and fix flow --- pkg/doctor/doctor_test.go | 147 +++++++++++++++++++++++++++++++++ pkg/doctor/runner.go | 165 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 312 insertions(+) create mode 100644 pkg/doctor/doctor_test.go create mode 100644 pkg/doctor/runner.go diff --git a/pkg/doctor/doctor_test.go b/pkg/doctor/doctor_test.go new file mode 100644 index 0000000000..0e029d631f --- /dev/null +++ b/pkg/doctor/doctor_test.go @@ -0,0 +1,147 @@ +package doctor + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +// mockCheck is a test double for Check. +type mockCheck struct { + name string + group Group + description string + findings []Finding + checkErr error + fixErr error + fixCalled bool +} + +func (m *mockCheck) Name() string { return m.name } +func (m *mockCheck) Group() Group { return m.group } +func (m *mockCheck) Description() string { return m.description } + +func (m *mockCheck) Check(_ *CheckContext) ([]Finding, error) { + return m.findings, m.checkErr +} + +func (m *mockCheck) Fix(_ *CheckContext, _ []Finding) error { + m.fixCalled = true + return m.fixErr +} + +func TestRunCollectsFindings(t *testing.T) { + checks := []Check{ + &mockCheck{ + name: "passing-check", + group: GroupConfig, + findings: nil, + }, + &mockCheck{ + name: "failing-check", + group: GroupPython, + findings: []Finding{ + {Severity: SeverityError, Message: "something is wrong"}, + }, + }, + } + + result, err := Run(context.Background(), RunOptions{ProjectDir: t.TempDir()}, checks) + require.NoError(t, err) + require.Len(t, result.Results, 2) + require.Empty(t, result.Results[0].Findings) + require.Len(t, result.Results[1].Findings, 1) + require.Equal(t, "something is wrong", result.Results[1].Findings[0].Message) +} + +func TestRunCallsFixWhenEnabled(t *testing.T) { + check := &mockCheck{ + name: "fixable-check", + group: GroupPython, + findings: []Finding{ + {Severity: SeverityError, Message: "fixable issue"}, + }, + } + + _, err := Run(context.Background(), RunOptions{Fix: true, ProjectDir: t.TempDir()}, []Check{check}) + require.NoError(t, err) + require.True(t, check.fixCalled) +} + +func TestRunDoesNotCallFixWhenDisabled(t *testing.T) { + check := &mockCheck{ + name: "fixable-check", + group: GroupPython, + findings: []Finding{ + {Severity: SeverityError, Message: "fixable issue"}, + }, + } + + _, err := Run(context.Background(), RunOptions{Fix: false, ProjectDir: t.TempDir()}, []Check{check}) + require.NoError(t, err) + require.False(t, check.fixCalled) +} + +func TestRunDoesNotCallFixWhenNoFindings(t *testing.T) { + check := &mockCheck{ + name: "clean-check", + group: GroupConfig, + findings: nil, + } + + _, err := Run(context.Background(), RunOptions{Fix: true, ProjectDir: t.TempDir()}, []Check{check}) + require.NoError(t, err) + require.False(t, check.fixCalled) +} + +func TestRunHandlesCheckError(t *testing.T) { + checks := []Check{ + &mockCheck{ + name: "error-check", + group: GroupEnvironment, + checkErr: context.DeadlineExceeded, + }, + &mockCheck{ + name: "ok-check", + group: GroupConfig, + findings: nil, + }, + } + + result, err := Run(context.Background(), RunOptions{ProjectDir: t.TempDir()}, checks) + require.NoError(t, err) // Run itself doesn't fail; individual check errors are captured + require.Len(t, result.Results, 2) + require.Error(t, result.Results[0].Err) + require.NoError(t, result.Results[1].Err) +} + +func TestRunMarksFixedOnSuccess(t *testing.T) { + check := &mockCheck{ + name: "fixable", + group: GroupPython, + findings: []Finding{ + {Severity: SeverityError, Message: "broken"}, + }, + fixErr: nil, + } + + result, err := Run(context.Background(), RunOptions{Fix: true, ProjectDir: t.TempDir()}, []Check{check}) + require.NoError(t, err) + require.True(t, result.Results[0].Fixed) +} + +func TestRunMarksNotFixedOnErrNoAutoFix(t *testing.T) { + check := &mockCheck{ + name: "unfixable", + group: GroupConfig, + findings: []Finding{ + {Severity: SeverityWarning, Message: "deprecated"}, + }, + fixErr: ErrNoAutoFix, + } + + result, err := Run(context.Background(), RunOptions{Fix: true, ProjectDir: t.TempDir()}, []Check{check}) + require.NoError(t, err) + require.False(t, result.Results[0].Fixed) +} diff --git a/pkg/doctor/runner.go b/pkg/doctor/runner.go new file mode 100644 index 0000000000..b716318aa5 --- /dev/null +++ b/pkg/doctor/runner.go @@ -0,0 +1,165 @@ +package doctor + +import ( + "context" + "errors" + "os" + "os/exec" + "path/filepath" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" + + "github.com/replicate/cog/pkg/config" + schemaPython "github.com/replicate/cog/pkg/schema/python" +) + +// RunOptions configures a doctor run. +type RunOptions struct { + Fix bool + ProjectDir string +} + +// CheckResult holds the outcome of running a single check. +type CheckResult struct { + Check Check + Findings []Finding + Fixed bool // True if --fix was passed and Fix() succeeded + Err error // Non-nil if the check itself errored +} + +// Result holds the outcome of a full doctor run. +type Result struct { + Results []CheckResult +} + +// HasErrors returns true if any check produced error-severity findings. +func (r *Result) HasErrors() bool { + for _, cr := range r.Results { + for _, f := range cr.Findings { + if f.Severity == SeverityError && !cr.Fixed { + return true + } + } + } + return false +} + +// Run executes all checks and optionally applies fixes. +func Run(_ context.Context, opts RunOptions, checks []Check) (*Result, error) { + checkCtx, err := buildCheckContext(opts.ProjectDir) + if err != nil { + return nil, err + } + + result := &Result{} + + for _, check := range checks { + cr := CheckResult{Check: check} + + findings, err := check.Check(checkCtx) + if err != nil { + cr.Err = err + result.Results = append(result.Results, cr) + continue + } + + cr.Findings = findings + + if opts.Fix && len(findings) > 0 { + fixErr := check.Fix(checkCtx, findings) + if fixErr == nil { + cr.Fixed = true + } else if !errors.Is(fixErr, ErrNoAutoFix) { + cr.Err = fixErr + } + } + + result.Results = append(result.Results, cr) + } + + return result, nil +} + +// buildCheckContext constructs the shared context for all checks. +func buildCheckContext(projectDir string) (*CheckContext, error) { + ctx := &CheckContext{ + ProjectDir: projectDir, + PythonFiles: make(map[string]*ParsedFile), + } + + // Load cog.yaml + configPath := filepath.Join(projectDir, "cog.yaml") + configBytes, err := os.ReadFile(configPath) + if err == nil { + ctx.ConfigFile = configBytes + // Try to load and validate config + f, err := os.Open(configPath) + if err == nil { + defer f.Close() + loadResult, err := config.Load(f, projectDir) + if err == nil { + ctx.Config = loadResult.Config + } + } + } + + // Find python binary + if pythonPath, err := exec.LookPath("python3"); err == nil { + ctx.PythonPath = pythonPath + } else if pythonPath, err := exec.LookPath("python"); err == nil { + ctx.PythonPath = pythonPath + } + + // Pre-parse Python files referenced in config + if ctx.Config != nil { + parsePythonRef(ctx, ctx.Config.Predict) + parsePythonRef(ctx, ctx.Config.Train) + } + + return ctx, nil +} + +// parsePythonRef parses a predict/train reference like "predict.py:Predictor" +// and adds the parsed file to ctx.PythonFiles. +func parsePythonRef(ctx *CheckContext, ref string) { + if ref == "" { + return + } + parts := splitPredictRef(ref) + if parts[0] == "" { + return + } + + fullPath := filepath.Join(ctx.ProjectDir, parts[0]) + source, err := os.ReadFile(fullPath) + if err != nil { + return + } + + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + tree, err := parser.ParseCtx(context.Background(), nil, source) + if err != nil { + return + } + + imports := schemaPython.CollectImports(tree.RootNode(), source) + + ctx.PythonFiles[parts[0]] = &ParsedFile{ + Path: parts[0], + Source: source, + Tree: tree, + Imports: imports, + } +} + +// splitPredictRef splits "predict.py:Predictor" into ["predict.py", "Predictor"]. +func splitPredictRef(ref string) [2]string { + for i := len(ref) - 1; i >= 0; i-- { + if ref[i] == ':' { + return [2]string{ref[:i], ref[i+1:]} + } + } + return [2]string{ref, ""} +} From 4c7f252feb10a295422552349fbcdecbaf1c14c3 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 13:52:32 -0400 Subject: [PATCH 04/24] feat(doctor): add check registry --- pkg/doctor/registry.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 pkg/doctor/registry.go diff --git a/pkg/doctor/registry.go b/pkg/doctor/registry.go new file mode 100644 index 0000000000..0d209d8b7f --- /dev/null +++ b/pkg/doctor/registry.go @@ -0,0 +1,13 @@ +package doctor + +// AllChecks returns all registered doctor checks in execution order. +// To add a new check: implement the Check interface and add it here. +func AllChecks() []Check { + return []Check{ + // Config checks (added in subsequent tasks) + + // Python checks (added in subsequent tasks) + + // Environment checks (added in subsequent tasks) + } +} From dc23230242604fd19f370b04a1f8781c3f6c8df7 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 13:55:14 -0400 Subject: [PATCH 05/24] feat(doctor): add CLI command with output formatting --- pkg/cli/doctor.go | 188 ++++++++++++++++++++++++++++++++++++++++++++++ pkg/cli/root.go | 1 + 2 files changed, 189 insertions(+) create mode 100644 pkg/cli/doctor.go diff --git a/pkg/cli/doctor.go b/pkg/cli/doctor.go new file mode 100644 index 0000000000..755786bb35 --- /dev/null +++ b/pkg/cli/doctor.go @@ -0,0 +1,188 @@ +package cli + +import ( + "context" + "fmt" + + "github.com/spf13/cobra" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/doctor" + "github.com/replicate/cog/pkg/util/console" +) + +func newDoctorCommand() *cobra.Command { + var fix bool + + cmd := &cobra.Command{ + Use: "doctor", + Short: "Check your project for common issues and fix them", + Long: `Diagnose and fix common issues in your Cog project. + +By default, cog doctor reports problems without modifying any files. +Pass --fix to automatically apply safe fixes.`, + RunE: func(cmd *cobra.Command, args []string) error { + return runDoctor(cmd.Context(), fix) + }, + Args: cobra.NoArgs, + } + + addConfigFlag(cmd) + cmd.Flags().BoolVar(&fix, "fix", false, "Automatically apply fixes") + + return cmd +} + +func runDoctor(ctx context.Context, fix bool) error { + projectDir, err := config.GetProjectDir(configFilename) + if err != nil { + return err + } + + if fix { + console.Infof("Running cog doctor --fix...") + } else { + console.Infof("Running cog doctor...") + } + console.Info("") + + result, err := doctor.Run(ctx, doctor.RunOptions{ + Fix: fix, + ProjectDir: projectDir, + }, doctor.AllChecks()) + if err != nil { + return err + } + + printDoctorResults(result, fix) + + if result.HasErrors() { + return fmt.Errorf("doctor found errors") + } + + return nil +} + +func printDoctorResults(result *doctor.Result, fix bool) { + var currentGroup doctor.Group + errorCount := 0 + warningCount := 0 + fixedCount := 0 + + for _, cr := range result.Results { + // Print group header when group changes + if cr.Check.Group() != currentGroup { + currentGroup = cr.Check.Group() + console.Infof("%s", string(currentGroup)) + } + + // Check errored internally + if cr.Err != nil { + console.Errorf("%s: %v", cr.Check.Description(), cr.Err) + errorCount++ + continue + } + + // No findings — passed + if len(cr.Findings) == 0 { + console.Successf("%s", cr.Check.Description()) + continue + } + + // Has findings + if cr.Fixed { + console.Successf("Fixed: %s", cr.Check.Description()) + fixedCount += len(cr.Findings) + } else { + // Determine worst severity for the check header + worstSeverity := cr.Findings[0].Severity + for _, f := range cr.Findings[1:] { + if f.Severity < worstSeverity { + worstSeverity = f.Severity + } + } + switch worstSeverity { + case doctor.SeverityError: + console.Errorf("%s", cr.Check.Description()) + errorCount++ + case doctor.SeverityWarning: + console.Warnf("%s", cr.Check.Description()) + warningCount++ + default: + console.Infof("%s", cr.Check.Description()) + } + } + + // Print individual findings + for _, f := range cr.Findings { + location := "" + if f.File != "" { + if f.Line > 0 { + location = fmt.Sprintf("%s:%d — ", f.File, f.Line) + } else { + location = fmt.Sprintf("%s — ", f.File) + } + } + console.Infof(" %s%s", location, f.Message) + + if fix && !cr.Fixed && f.Remediation != "" { + console.Infof(" (no auto-fix available)") + } + } + } + + console.Info("") + + // Summary line + if fix && fixedCount > 0 { + msg := fmt.Sprintf("Fixed %d issue", fixedCount) + if fixedCount != 1 { + msg += "s" + } + if warningCount > 0 { + msg += fmt.Sprintf(". %d warning", warningCount) + if warningCount != 1 { + msg += "s" + } + msg += " remaining" + } + if errorCount > 0 { + msg += fmt.Sprintf(". %d unfixed error", errorCount) + if errorCount != 1 { + msg += "s" + } + } + console.Infof("%s.", msg) + } else if errorCount > 0 || warningCount > 0 { + parts := []string{} + if errorCount > 0 { + s := fmt.Sprintf("%d error", errorCount) + if errorCount != 1 { + s += "s" + } + parts = append(parts, s) + } + if warningCount > 0 { + s := fmt.Sprintf("%d warning", warningCount) + if warningCount != 1 { + s += "s" + } + parts = append(parts, s) + } + summary := "Found " + for i, p := range parts { + if i > 0 { + summary += ", " + } + summary += p + } + summary += "." + + if !fix && errorCount > 0 { + summary += ` Run "cog doctor --fix" to auto-fix.` + } + console.Infof("%s", summary) + } else { + console.Successf("no issues found") + } +} diff --git a/pkg/cli/root.go b/pkg/cli/root.go index 9184411ec4..c6ccc4087d 100644 --- a/pkg/cli/root.go +++ b/pkg/cli/root.go @@ -45,6 +45,7 @@ https://github.com/replicate/cog`, rootCmd.AddCommand( newBuildCommand(), newDebugCommand(), + newDoctorCommand(), newInitCommand(), newInspectCommand(), newLoginCommand(), From 96ee37e65146b1c0586dafc6ffeb152af9bbf862 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 13:56:00 -0400 Subject: [PATCH 06/24] feat(doctor): add config parse check --- pkg/doctor/check_config_parse.go | 65 ++++++++++++++++++++++++++++++++ pkg/doctor/check_config_test.go | 60 +++++++++++++++++++++++++++++ pkg/doctor/registry.go | 3 +- 3 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 pkg/doctor/check_config_parse.go create mode 100644 pkg/doctor/check_config_test.go diff --git a/pkg/doctor/check_config_parse.go b/pkg/doctor/check_config_parse.go new file mode 100644 index 0000000000..c11c859e2a --- /dev/null +++ b/pkg/doctor/check_config_parse.go @@ -0,0 +1,65 @@ +package doctor + +import ( + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/replicate/cog/pkg/config" +) + +// ConfigParseCheck verifies that cog.yaml exists and can be parsed as valid YAML. +type ConfigParseCheck struct{} + +func (c *ConfigParseCheck) Name() string { return "config-parse" } +func (c *ConfigParseCheck) Group() Group { return GroupConfig } +func (c *ConfigParseCheck) Description() string { return "Config parsing" } + +func (c *ConfigParseCheck) Check(ctx *CheckContext) ([]Finding, error) { + configPath := filepath.Join(ctx.ProjectDir, "cog.yaml") + + if _, err := os.Stat(configPath); os.IsNotExist(err) { + return []Finding{{ + Severity: SeverityError, + Message: "cog.yaml not found", + Remediation: `Run "cog init" to create a cog.yaml`, + File: "cog.yaml", + }}, nil + } + + f, err := os.Open(configPath) + if err != nil { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("cannot read cog.yaml: %v", err), + File: "cog.yaml", + }}, nil + } + defer f.Close() + + _, loadErr := config.Load(f, ctx.ProjectDir) + if loadErr != nil { + var parseErr *config.ParseError + if isParseError(loadErr, &parseErr) { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("cog.yaml has invalid YAML: %v", loadErr), + Remediation: "Fix the YAML syntax in cog.yaml", + File: "cog.yaml", + }}, nil + } + // Other errors (validation, schema) are handled by other checks + } + + return nil, nil +} + +func (c *ConfigParseCheck) Fix(_ *CheckContext, _ []Finding) error { + return ErrNoAutoFix +} + +// isParseError checks if the error chain contains a ParseError. +func isParseError(err error, target **config.ParseError) bool { + return errors.As(err, target) +} diff --git a/pkg/doctor/check_config_test.go b/pkg/doctor/check_config_test.go new file mode 100644 index 0000000000..7ff83d74ea --- /dev/null +++ b/pkg/doctor/check_config_test.go @@ -0,0 +1,60 @@ +package doctor + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConfigParseCheck_Valid(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + + ctx := &CheckContext{ProjectDir: dir} + ctx.ConfigFile, _ = os.ReadFile(filepath.Join(dir, "cog.yaml")) + + check := &ConfigParseCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestConfigParseCheck_InvalidYAML(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: [invalid yaml`) + + ctx := &CheckContext{ProjectDir: dir} + ctx.ConfigFile, _ = os.ReadFile(filepath.Join(dir, "cog.yaml")) + + check := &ConfigParseCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityError, findings[0].Severity) + require.Contains(t, findings[0].Message, "cog.yaml") +} + +func TestConfigParseCheck_MissingFile(t *testing.T) { + dir := t.TempDir() + + ctx := &CheckContext{ProjectDir: dir} + + check := &ConfigParseCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Contains(t, findings[0].Message, "cog.yaml not found") +} + +// writeFile is a test helper to create fixture files. +func writeFile(t *testing.T, dir, name, content string) { + t.Helper() + fullPath := filepath.Join(dir, name) + require.NoError(t, os.MkdirAll(filepath.Dir(fullPath), 0o755)) + require.NoError(t, os.WriteFile(fullPath, []byte(content), 0o644)) +} diff --git a/pkg/doctor/registry.go b/pkg/doctor/registry.go index 0d209d8b7f..f58a41b641 100644 --- a/pkg/doctor/registry.go +++ b/pkg/doctor/registry.go @@ -4,7 +4,8 @@ package doctor // To add a new check: implement the Check interface and add it here. func AllChecks() []Check { return []Check{ - // Config checks (added in subsequent tasks) + // Config checks + &ConfigParseCheck{}, // Python checks (added in subsequent tasks) From 7c5d1f1e4feb3b9ccb8a3b8ff47f2955cebfe393 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 13:56:52 -0400 Subject: [PATCH 07/24] feat(doctor): add deprecated config fields check --- pkg/doctor/check_config_deprecated.go | 48 +++++++++++++++++++++++++ pkg/doctor/check_config_test.go | 52 +++++++++++++++++++++++++++ pkg/doctor/registry.go | 1 + 3 files changed, 101 insertions(+) create mode 100644 pkg/doctor/check_config_deprecated.go diff --git a/pkg/doctor/check_config_deprecated.go b/pkg/doctor/check_config_deprecated.go new file mode 100644 index 0000000000..617fc6f71e --- /dev/null +++ b/pkg/doctor/check_config_deprecated.go @@ -0,0 +1,48 @@ +package doctor + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/replicate/cog/pkg/config" +) + +// ConfigDeprecatedFieldsCheck detects deprecated fields in cog.yaml. +type ConfigDeprecatedFieldsCheck struct{} + +func (c *ConfigDeprecatedFieldsCheck) Name() string { return "config-deprecated-fields" } +func (c *ConfigDeprecatedFieldsCheck) Group() Group { return GroupConfig } +func (c *ConfigDeprecatedFieldsCheck) Description() string { return "Deprecated fields" } + +func (c *ConfigDeprecatedFieldsCheck) Check(ctx *CheckContext) ([]Finding, error) { + configPath := filepath.Join(ctx.ProjectDir, "cog.yaml") + f, err := os.Open(configPath) + if err != nil { + return nil, nil // Config parse check handles missing file + } + defer f.Close() + + // We need to run validation to get deprecation warnings. + // Load does parse + validate + complete; we want just parse + validate. + loadResult, err := config.Load(f, ctx.ProjectDir) + if err != nil { + return nil, nil // Other config checks handle parse/validation errors + } + + var findings []Finding + for _, w := range loadResult.Warnings { + findings = append(findings, Finding{ + Severity: SeverityWarning, + Message: fmt.Sprintf("%q is deprecated: %s", w.Field, w.Message), + Remediation: fmt.Sprintf("Use %q instead", w.Replacement), + File: "cog.yaml", + }) + } + + return findings, nil +} + +func (c *ConfigDeprecatedFieldsCheck) Fix(_ *CheckContext, _ []Finding) error { + return ErrNoAutoFix +} diff --git a/pkg/doctor/check_config_test.go b/pkg/doctor/check_config_test.go index 7ff83d74ea..944835d3ae 100644 --- a/pkg/doctor/check_config_test.go +++ b/pkg/doctor/check_config_test.go @@ -51,6 +51,58 @@ func TestConfigParseCheck_MissingFile(t *testing.T) { require.Contains(t, findings[0].Message, "cog.yaml not found") } +func TestConfigDeprecatedFieldsCheck_Clean(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" + python_requirements: "requirements.txt" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "requirements.txt", "torch==2.0.0\n") + + ctx := &CheckContext{ProjectDir: dir} + check := &ConfigDeprecatedFieldsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestConfigDeprecatedFieldsCheck_PythonPackages(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" + python_packages: + - torch==2.0.0 +predict: "predict.py:Predictor" +`) + + ctx := &CheckContext{ProjectDir: dir} + check := &ConfigDeprecatedFieldsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityWarning, findings[0].Severity) + require.Contains(t, findings[0].Message, "python_packages") +} + +func TestConfigDeprecatedFieldsCheck_PreInstall(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" + pre_install: + - pip install something +predict: "predict.py:Predictor" +`) + + ctx := &CheckContext{ProjectDir: dir} + check := &ConfigDeprecatedFieldsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityWarning, findings[0].Severity) + require.Contains(t, findings[0].Message, "pre_install") +} + // writeFile is a test helper to create fixture files. func writeFile(t *testing.T, dir, name, content string) { t.Helper() diff --git a/pkg/doctor/registry.go b/pkg/doctor/registry.go index f58a41b641..c762f65977 100644 --- a/pkg/doctor/registry.go +++ b/pkg/doctor/registry.go @@ -6,6 +6,7 @@ func AllChecks() []Check { return []Check{ // Config checks &ConfigParseCheck{}, + &ConfigDeprecatedFieldsCheck{}, // Python checks (added in subsequent tasks) From 158968b83a38795aa43a5a44e719785e1da462d7 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 13:58:27 -0400 Subject: [PATCH 08/24] feat(doctor): add predict reference check --- pkg/doctor/check_config_predict_ref.go | 130 +++++++++++++++++++++++++ pkg/doctor/check_config_test.go | 97 ++++++++++++++++++ pkg/doctor/registry.go | 1 + 3 files changed, 228 insertions(+) create mode 100644 pkg/doctor/check_config_predict_ref.go diff --git a/pkg/doctor/check_config_predict_ref.go b/pkg/doctor/check_config_predict_ref.go new file mode 100644 index 0000000000..0de9cc6361 --- /dev/null +++ b/pkg/doctor/check_config_predict_ref.go @@ -0,0 +1,130 @@ +package doctor + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" + + schemaPython "github.com/replicate/cog/pkg/schema/python" +) + +// ConfigPredictRefCheck verifies that the predict field in cog.yaml +// points to a file and class that actually exist. +type ConfigPredictRefCheck struct{} + +func (c *ConfigPredictRefCheck) Name() string { return "config-predict-ref" } +func (c *ConfigPredictRefCheck) Group() Group { return GroupConfig } +func (c *ConfigPredictRefCheck) Description() string { return "Predict reference" } + +func (c *ConfigPredictRefCheck) Check(ctx *CheckContext) ([]Finding, error) { + // Get predict ref from config + predictRef := "" + if ctx.Config != nil { + predictRef = ctx.Config.Predict + } + if predictRef == "" { + return nil, nil // No predict field — nothing to check + } + + parts := splitPredictRef(predictRef) + pyFile := parts[0] + className := parts[1] + + if pyFile == "" || className == "" { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("predict reference %q must be in the form 'file.py:ClassName'", predictRef), + Remediation: `Set predict to "predict.py:Predictor" in cog.yaml`, + File: "cog.yaml", + }}, nil + } + + // Check file exists + fullPath := filepath.Join(ctx.ProjectDir, pyFile) + if _, err := os.Stat(fullPath); os.IsNotExist(err) { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("%s not found", pyFile), + Remediation: fmt.Sprintf("Create %s or update the predict field in cog.yaml", pyFile), + File: "cog.yaml", + }}, nil + } + + // Check class exists in file using tree-sitter + source, err := os.ReadFile(fullPath) + if err != nil { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("cannot read %s: %v", pyFile, err), + File: pyFile, + }}, nil + } + + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + tree, err := parser.ParseCtx(context.Background(), nil, source) + if err != nil { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("cannot parse %s: %v", pyFile, err), + File: pyFile, + }}, nil + } + + if !hasClassDefinition(tree.RootNode(), source, className) { + // List available classes to help the user + classes := listClassNames(tree.RootNode(), source) + msg := fmt.Sprintf("class %q not found in %s", className, pyFile) + if len(classes) > 0 { + msg += fmt.Sprintf("; found: %s", strings.Join(classes, ", ")) + } + return []Finding{{ + Severity: SeverityError, + Message: msg, + Remediation: fmt.Sprintf("Add class %s to %s or update the predict field in cog.yaml", className, pyFile), + File: pyFile, + }}, nil + } + + return nil, nil +} + +func (c *ConfigPredictRefCheck) Fix(_ *CheckContext, _ []Finding) error { + return ErrNoAutoFix +} + +// hasClassDefinition checks whether a class with the given name exists at module level. +func hasClassDefinition(root *sitter.Node, source []byte, name string) bool { + for _, child := range schemaPython.NamedChildren(root) { + classNode := schemaPython.UnwrapClass(child) + if classNode == nil { + continue + } + nameNode := classNode.ChildByFieldName("name") + if nameNode != nil && schemaPython.Content(nameNode, source) == name { + return true + } + } + return false +} + +// listClassNames returns the names of all top-level classes in the file. +func listClassNames(root *sitter.Node, source []byte) []string { + var names []string + for _, child := range schemaPython.NamedChildren(root) { + classNode := schemaPython.UnwrapClass(child) + if classNode == nil { + continue + } + nameNode := classNode.ChildByFieldName("name") + if nameNode != nil { + names = append(names, schemaPython.Content(nameNode, source)) + } + } + return names +} diff --git a/pkg/doctor/check_config_test.go b/pkg/doctor/check_config_test.go index 944835d3ae..62d8f5caf8 100644 --- a/pkg/doctor/check_config_test.go +++ b/pkg/doctor/check_config_test.go @@ -6,6 +6,8 @@ import ( "testing" "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/config" ) func TestConfigParseCheck_Valid(t *testing.T) { @@ -103,6 +105,101 @@ predict: "predict.py:Predictor" require.Contains(t, findings[0].Message, "pre_install") } +func TestConfigPredictRefCheck_Valid(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigPredictRefCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestConfigPredictRefCheck_MissingFile(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigPredictRefCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityError, findings[0].Severity) + require.Contains(t, findings[0].Message, "predict.py") + require.Contains(t, findings[0].Message, "not found") +} + +func TestConfigPredictRefCheck_MissingClass(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:DoesNotExist" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigPredictRefCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityError, findings[0].Severity) + require.Contains(t, findings[0].Message, "DoesNotExist") +} + +func TestConfigPredictRefCheck_NoPredictField(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigPredictRefCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) // No predict field is valid (some projects are train-only) +} + +// buildTestCheckContext creates a CheckContext by loading the cog.yaml in the given dir. +func buildTestCheckContext(t *testing.T, dir string) *CheckContext { + t.Helper() + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: make(map[string]*ParsedFile), + } + + configPath := filepath.Join(dir, "cog.yaml") + configBytes, err := os.ReadFile(configPath) + if err == nil { + ctx.ConfigFile = configBytes + f, err := os.Open(configPath) + if err == nil { + defer f.Close() + loadResult, err := config.Load(f, dir) + if err == nil { + ctx.Config = loadResult.Config + } + } + } + + return ctx +} + // writeFile is a test helper to create fixture files. func writeFile(t *testing.T, dir, name, content string) { t.Helper() diff --git a/pkg/doctor/registry.go b/pkg/doctor/registry.go index c762f65977..5921f2548c 100644 --- a/pkg/doctor/registry.go +++ b/pkg/doctor/registry.go @@ -7,6 +7,7 @@ func AllChecks() []Check { // Config checks &ConfigParseCheck{}, &ConfigDeprecatedFieldsCheck{}, + &ConfigPredictRefCheck{}, // Python checks (added in subsequent tasks) From f33aaff703ba26f81d710959f1901dc8023d0d09 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 14:02:33 -0400 Subject: [PATCH 09/24] feat(doctor): add pydantic BaseModel check with auto-fix --- pkg/doctor/check_python_pydantic_basemodel.go | 234 ++++++++++++++++++ pkg/doctor/check_python_test.go | 133 ++++++++++ pkg/doctor/registry.go | 3 +- 3 files changed, 369 insertions(+), 1 deletion(-) create mode 100644 pkg/doctor/check_python_pydantic_basemodel.go create mode 100644 pkg/doctor/check_python_test.go diff --git a/pkg/doctor/check_python_pydantic_basemodel.go b/pkg/doctor/check_python_pydantic_basemodel.go new file mode 100644 index 0000000000..64fa1f631d --- /dev/null +++ b/pkg/doctor/check_python_pydantic_basemodel.go @@ -0,0 +1,234 @@ +package doctor + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + sitter "github.com/smacker/go-tree-sitter" + + "github.com/replicate/cog/pkg/schema" + schemaPython "github.com/replicate/cog/pkg/schema/python" +) + +// PydanticBaseModelCheck detects output classes that inherit from pydantic.BaseModel +// with arbitrary_types_allowed=True instead of using cog.BaseModel. +type PydanticBaseModelCheck struct{} + +func (c *PydanticBaseModelCheck) Name() string { return "python-pydantic-basemodel" } +func (c *PydanticBaseModelCheck) Group() Group { return GroupPython } +func (c *PydanticBaseModelCheck) Description() string { return "Pydantic BaseModel workaround" } + +func (c *PydanticBaseModelCheck) Check(ctx *CheckContext) ([]Finding, error) { + var findings []Finding + + for _, pf := range ctx.PythonFiles { + root := pf.Tree.RootNode() + + for _, child := range schemaPython.NamedChildren(root) { + classNode := schemaPython.UnwrapClass(child) + if classNode == nil { + continue + } + + // Check if class inherits from pydantic.BaseModel (not cog.BaseModel) + if !inheritsPydanticBaseModel(classNode, pf.Source, pf.Imports) { + continue + } + + // Check if class has arbitrary_types_allowed=True + if !hasArbitraryTypesAllowed(classNode, pf.Source) { + continue + } + + nameNode := classNode.ChildByFieldName("name") + className := "" + line := 0 + if nameNode != nil { + className = schemaPython.Content(nameNode, pf.Source) + line = int(nameNode.StartPoint().Row) + 1 + } + + findings = append(findings, Finding{ + Severity: SeverityError, + Message: fmt.Sprintf("%s inherits from pydantic.BaseModel with arbitrary_types_allowed; use cog.BaseModel instead", className), + Remediation: "Replace pydantic.BaseModel with cog.BaseModel and remove ConfigDict(arbitrary_types_allowed=True)", + File: pf.Path, + Line: line, + }) + } + } + + return findings, nil +} + +func (c *PydanticBaseModelCheck) Fix(ctx *CheckContext, findings []Finding) error { + // Group findings by file + fileFindings := make(map[string][]Finding) + for _, f := range findings { + fileFindings[f.File] = append(fileFindings[f.File], f) + } + + for relPath := range fileFindings { + fullPath := filepath.Join(ctx.ProjectDir, relPath) + source, err := os.ReadFile(fullPath) + if err != nil { + return fmt.Errorf("reading %s: %w", relPath, err) + } + + fixed := fixPydanticBaseModel(string(source)) + + if err := os.WriteFile(fullPath, []byte(fixed), 0o644); err != nil { + return fmt.Errorf("writing %s: %w", relPath, err) + } + } + + return nil +} + +// inheritsPydanticBaseModel checks if a class inherits from pydantic.BaseModel +// (as opposed to cog.BaseModel or another BaseModel). +func inheritsPydanticBaseModel(classNode *sitter.Node, source []byte, imports *schema.ImportContext) bool { + supers := classNode.ChildByFieldName("superclasses") + if supers == nil { + return false + } + + for _, child := range schemaPython.AllChildren(supers) { + text := schemaPython.Content(child, source) + + switch child.Type() { + case "identifier": + if text == "BaseModel" { + // Check if BaseModel was imported from pydantic + if entry, ok := imports.Names.Get("BaseModel"); ok { + return entry.Module == "pydantic" + } + } + case "attribute": + if text == "pydantic.BaseModel" { + return true + } + } + } + return false +} + +// hasArbitraryTypesAllowed checks if a class body contains +// model_config = ConfigDict(arbitrary_types_allowed=True). +func hasArbitraryTypesAllowed(classNode *sitter.Node, source []byte) bool { + body := classNode.ChildByFieldName("body") + if body == nil { + return false + } + + for _, stmt := range schemaPython.NamedChildren(body) { + node := stmt + if stmt.Type() == "expression_statement" && stmt.NamedChildCount() == 1 { + node = stmt.NamedChild(0) + } + + if node.Type() != "assignment" { + continue + } + + left := node.ChildByFieldName("left") + if left == nil || schemaPython.Content(left, source) != "model_config" { + continue + } + + right := node.ChildByFieldName("right") + if right == nil { + continue + } + + text := schemaPython.Content(right, source) + if strings.Contains(text, "arbitrary_types_allowed") && strings.Contains(text, "True") { + return true + } + } + + return false +} + +// fixPydanticBaseModel rewrites Python source to replace pydantic.BaseModel with cog.BaseModel. +func fixPydanticBaseModel(source string) string { + lines := strings.Split(source, "\n") + var result []string + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + + // Remove "from pydantic import BaseModel" (and ConfigDict) + if strings.HasPrefix(trimmed, "from pydantic import") { + // Remove BaseModel and ConfigDict from the import + remaining := removePydanticImports(trimmed) + if remaining == "" { + continue // Drop the entire line + } + result = append(result, remaining) + continue + } + + // Remove model_config = ConfigDict(...) lines + if strings.Contains(trimmed, "model_config") && strings.Contains(trimmed, "ConfigDict") { + continue + } + + result = append(result, line) + } + + // Now add BaseModel to the cog import line + fixed := strings.Join(result, "\n") + fixed = addToCogImport(fixed, "BaseModel") + + return fixed +} + +// removePydanticImports removes BaseModel and ConfigDict from a pydantic import line. +// Returns "" if no imports remain. +func removePydanticImports(line string) string { + // Parse "from pydantic import X, Y, Z" + prefix := "from pydantic import " + if !strings.HasPrefix(strings.TrimSpace(line), prefix) { + return line + } + + importPart := strings.TrimSpace(line)[len(prefix):] + names := strings.Split(importPart, ",") + + var remaining []string + for _, name := range names { + name = strings.TrimSpace(name) + if name == "BaseModel" || name == "ConfigDict" { + continue + } + if name != "" { + remaining = append(remaining, name) + } + } + + if len(remaining) == 0 { + return "" + } + + return prefix + strings.Join(remaining, ", ") +} + +// addToCogImport adds a name to an existing "from cog import ..." line. +func addToCogImport(source string, name string) string { + lines := strings.Split(source, "\n") + for i, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "from cog import") { + if strings.Contains(trimmed, name) { + return source // Already imported + } + // Add the name at the end + lines[i] = trimmed + ", " + name + return strings.Join(lines, "\n") + } + } + return source +} diff --git a/pkg/doctor/check_python_test.go b/pkg/doctor/check_python_test.go new file mode 100644 index 0000000000..d17b7d884c --- /dev/null +++ b/pkg/doctor/check_python_test.go @@ -0,0 +1,133 @@ +package doctor + +import ( + "context" + "os" + "path/filepath" + "testing" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" + "github.com/stretchr/testify/require" + + schemaPython "github.com/replicate/cog/pkg/schema/python" +) + +func TestPydanticBaseModelCheck_Clean(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor, BaseModel, Path + +class Output(BaseModel): + audio: Path + +class Predictor(BasePredictor): + def predict(self, text: str) -> Output: + return Output(audio="a.wav") +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestPydanticBaseModelCheck_Detects(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor, Path +from pydantic import BaseModel, ConfigDict + +class VoiceCloningOutputs(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + audio: Path + spectrogram: Path + +class Predictor(BasePredictor): + def predict(self, text: str) -> VoiceCloningOutputs: + return VoiceCloningOutputs(audio="a.wav", spectrogram="s.png") +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityError, findings[0].Severity) + require.Contains(t, findings[0].Message, "VoiceCloningOutputs") + require.Contains(t, findings[0].Message, "pydantic.BaseModel") +} + +func TestPydanticBaseModelCheck_Fix(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor, Path +from pydantic import BaseModel, ConfigDict + +class VoiceCloningOutputs(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + audio: Path + spectrogram: Path + +class Predictor(BasePredictor): + def predict(self, text: str) -> VoiceCloningOutputs: + return VoiceCloningOutputs(audio="a.wav", spectrogram="s.png") +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + // Re-read and verify + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.Contains(t, string(content), "from cog import BasePredictor, Path, BaseModel") + require.NotContains(t, string(content), "from pydantic import BaseModel") + require.NotContains(t, string(content), "arbitrary_types_allowed") + require.NotContains(t, string(content), "model_config") + + // Re-parse and verify doctor passes + ctx.PythonFiles = parsePythonFiles(t, dir, "predict.py") + findings, err = check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +// parsePythonFiles is a test helper that parses Python files into ParsedFile structs. +func parsePythonFiles(t *testing.T, dir string, filenames ...string) map[string]*ParsedFile { + t.Helper() + files := make(map[string]*ParsedFile) + for _, name := range filenames { + source, err := os.ReadFile(filepath.Join(dir, name)) + require.NoError(t, err) + + sitterParser := sitter.NewParser() + sitterParser.SetLanguage(python.GetLanguage()) + tree, err := sitterParser.ParseCtx(context.Background(), nil, source) + require.NoError(t, err) + + imports := schemaPython.CollectImports(tree.RootNode(), source) + files[name] = &ParsedFile{ + Path: name, + Source: source, + Tree: tree, + Imports: imports, + } + } + return files +} diff --git a/pkg/doctor/registry.go b/pkg/doctor/registry.go index 5921f2548c..0ad41d7e34 100644 --- a/pkg/doctor/registry.go +++ b/pkg/doctor/registry.go @@ -9,7 +9,8 @@ func AllChecks() []Check { &ConfigDeprecatedFieldsCheck{}, &ConfigPredictRefCheck{}, - // Python checks (added in subsequent tasks) + // Python checks + &PydanticBaseModelCheck{}, // Environment checks (added in subsequent tasks) } From f2275f419df709a922d53259b8e903abc50383f0 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 14:04:18 -0400 Subject: [PATCH 10/24] feat(doctor): add deprecated imports check with auto-fix --- pkg/doctor/check_python_deprecated_imports.go | 186 ++++++++++++++++++ pkg/doctor/check_python_test.go | 77 ++++++++ pkg/doctor/registry.go | 1 + 3 files changed, 264 insertions(+) create mode 100644 pkg/doctor/check_python_deprecated_imports.go diff --git a/pkg/doctor/check_python_deprecated_imports.go b/pkg/doctor/check_python_deprecated_imports.go new file mode 100644 index 0000000000..c40e8f5d93 --- /dev/null +++ b/pkg/doctor/check_python_deprecated_imports.go @@ -0,0 +1,186 @@ +package doctor + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + sitter "github.com/smacker/go-tree-sitter" + + schemaPython "github.com/replicate/cog/pkg/schema/python" +) + +// deprecatedImport describes an import that was removed or moved. +type deprecatedImport struct { + Module string // e.g., "cog.types" + Name string // e.g., "ExperimentalFeatureWarning" + Message string // e.g., "removed in cog 0.17" +} + +// deprecatedImportsList is the list of known deprecated imports. +var deprecatedImportsList = []deprecatedImport{ + { + Module: "cog.types", + Name: "ExperimentalFeatureWarning", + Message: "ExperimentalFeatureWarning was removed in cog 0.17; current_scope().record_metric() is no longer experimental", + }, +} + +// DeprecatedImportsCheck detects imports that were removed or moved in recent cog versions. +type DeprecatedImportsCheck struct{} + +func (c *DeprecatedImportsCheck) Name() string { return "python-deprecated-imports" } +func (c *DeprecatedImportsCheck) Group() Group { return GroupPython } +func (c *DeprecatedImportsCheck) Description() string { return "Deprecated imports" } + +func (c *DeprecatedImportsCheck) Check(ctx *CheckContext) ([]Finding, error) { + var findings []Finding + + for _, pf := range ctx.PythonFiles { + root := pf.Tree.RootNode() + + for _, child := range schemaPython.NamedChildren(root) { + if child.Type() != "import_from_statement" { + continue + } + + moduleNode := child.ChildByFieldName("module_name") + if moduleNode == nil { + continue + } + module := schemaPython.Content(moduleNode, pf.Source) + + // Check each imported name against the deprecated list + for _, name := range extractImportedNames(child, pf.Source) { + for _, dep := range deprecatedImportsList { + if module == dep.Module && name == dep.Name { + line := int(child.StartPoint().Row) + 1 + findings = append(findings, Finding{ + Severity: SeverityError, + Message: dep.Message, + Remediation: fmt.Sprintf("Remove the import of %s from %s", dep.Name, dep.Module), + File: pf.Path, + Line: line, + }) + } + } + } + } + } + + return findings, nil +} + +func (c *DeprecatedImportsCheck) Fix(ctx *CheckContext, findings []Finding) error { + // Group findings by file + fileFindings := make(map[string][]Finding) + for _, f := range findings { + fileFindings[f.File] = append(fileFindings[f.File], f) + } + + for relPath, fileFindingsList := range fileFindings { + fullPath := filepath.Join(ctx.ProjectDir, relPath) + source, err := os.ReadFile(fullPath) + if err != nil { + return fmt.Errorf("reading %s: %w", relPath, err) + } + + // Collect deprecated names to remove + namesToRemove := make(map[string]map[string]bool) // module -> set of names + for _, f := range fileFindingsList { + for _, dep := range deprecatedImportsList { + if strings.Contains(f.Message, dep.Name) { + if namesToRemove[dep.Module] == nil { + namesToRemove[dep.Module] = make(map[string]bool) + } + namesToRemove[dep.Module][dep.Name] = true + } + } + } + + fixed := removeDeprecatedImportLines(string(source), namesToRemove) + + if err := os.WriteFile(fullPath, []byte(fixed), 0o644); err != nil { + return fmt.Errorf("writing %s: %w", relPath, err) + } + } + + return nil +} + +// extractImportedNames returns the names imported in a "from X import a, b, c" statement. +func extractImportedNames(importNode *sitter.Node, source []byte) []string { + moduleNode := importNode.ChildByFieldName("module_name") + var names []string + + for _, child := range schemaPython.AllChildren(importNode) { + switch child.Type() { + case "dotted_name": + if moduleNode != nil && child.StartByte() != moduleNode.StartByte() { + names = append(names, schemaPython.Content(child, source)) + } + case "aliased_import": + if origNode := child.ChildByFieldName("name"); origNode != nil { + names = append(names, schemaPython.Content(origNode, source)) + } + case "import_list": + for _, ic := range schemaPython.AllChildren(child) { + switch ic.Type() { + case "dotted_name": + names = append(names, schemaPython.Content(ic, source)) + case "aliased_import": + if origNode := ic.ChildByFieldName("name"); origNode != nil { + names = append(names, schemaPython.Content(origNode, source)) + } + } + } + } + } + + return names +} + +// removeDeprecatedImportLines removes specific names from import lines. +// If all names are removed, the entire import line is dropped. +func removeDeprecatedImportLines(source string, namesToRemove map[string]map[string]bool) string { + lines := strings.Split(source, "\n") + var result []string + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + + removed := false + for module, names := range namesToRemove { + prefix := "from " + module + " import " + if !strings.HasPrefix(trimmed, prefix) { + continue + } + + importPart := trimmed[len(prefix):] + importNames := strings.Split(importPart, ",") + + var remaining []string + for _, name := range importNames { + name = strings.TrimSpace(name) + if !names[name] { + remaining = append(remaining, name) + } + } + + if len(remaining) == 0 { + removed = true + } else { + result = append(result, prefix+strings.Join(remaining, ", ")) + removed = true + } + break + } + + if !removed { + result = append(result, line) + } + } + + return strings.Join(result, "\n") +} diff --git a/pkg/doctor/check_python_test.go b/pkg/doctor/check_python_test.go index d17b7d884c..da54191d26 100644 --- a/pkg/doctor/check_python_test.go +++ b/pkg/doctor/check_python_test.go @@ -108,6 +108,83 @@ class Predictor(BasePredictor): require.Empty(t, findings) } +func TestDeprecatedImportsCheck_Clean(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &DeprecatedImportsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestDeprecatedImportsCheck_Detects(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor +from cog.types import ExperimentalFeatureWarning + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &DeprecatedImportsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityError, findings[0].Severity) + require.Contains(t, findings[0].Message, "ExperimentalFeatureWarning") +} + +func TestDeprecatedImportsCheck_Fix(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor +from cog.types import ExperimentalFeatureWarning + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &DeprecatedImportsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.NotContains(t, string(content), "ExperimentalFeatureWarning") + + // Re-parse and verify clean + ctx.PythonFiles = parsePythonFiles(t, dir, "predict.py") + findings, err = check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + // parsePythonFiles is a test helper that parses Python files into ParsedFile structs. func parsePythonFiles(t *testing.T, dir string, filenames ...string) map[string]*ParsedFile { t.Helper() diff --git a/pkg/doctor/registry.go b/pkg/doctor/registry.go index 0ad41d7e34..3ec9dbc388 100644 --- a/pkg/doctor/registry.go +++ b/pkg/doctor/registry.go @@ -11,6 +11,7 @@ func AllChecks() []Check { // Python checks &PydanticBaseModelCheck{}, + &DeprecatedImportsCheck{}, // Environment checks (added in subsequent tasks) } From 5967098fbc9a8f87bf529418c43acc5963be19c4 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 14:06:47 -0400 Subject: [PATCH 11/24] feat(doctor): add Docker availability check --- pkg/doctor/check_env_docker.go | 29 +++++++++++++++++++++++++++++ pkg/doctor/check_env_test.go | 22 ++++++++++++++++++++++ pkg/doctor/registry.go | 3 ++- 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 pkg/doctor/check_env_docker.go create mode 100644 pkg/doctor/check_env_test.go diff --git a/pkg/doctor/check_env_docker.go b/pkg/doctor/check_env_docker.go new file mode 100644 index 0000000000..efe4dbfa48 --- /dev/null +++ b/pkg/doctor/check_env_docker.go @@ -0,0 +1,29 @@ +package doctor + +import ( + "fmt" + "os/exec" +) + +// DockerCheck verifies that Docker is installed and the daemon is reachable. +type DockerCheck struct{} + +func (c *DockerCheck) Name() string { return "env-docker" } +func (c *DockerCheck) Group() Group { return GroupEnvironment } +func (c *DockerCheck) Description() string { return "Docker" } + +func (c *DockerCheck) Check(_ *CheckContext) ([]Finding, error) { + if err := exec.Command("docker", "info").Run(); err != nil { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("Docker is not available: %v", err), + Remediation: "Install Docker (https://docs.docker.com/get-docker/) and ensure the daemon is running", + }}, nil + } + + return nil, nil +} + +func (c *DockerCheck) Fix(_ *CheckContext, _ []Finding) error { + return ErrNoAutoFix +} diff --git a/pkg/doctor/check_env_test.go b/pkg/doctor/check_env_test.go new file mode 100644 index 0000000000..83a583bbe6 --- /dev/null +++ b/pkg/doctor/check_env_test.go @@ -0,0 +1,22 @@ +package doctor + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDockerCheck_RunsWithoutError(t *testing.T) { + ctx := &CheckContext{ProjectDir: t.TempDir()} + check := &DockerCheck{} + // We don't assert findings — Docker may or may not be available in the test environment. + // We just verify the check doesn't panic or return an error. + _, err := check.Check(ctx) + require.NoError(t, err) +} + +func TestDockerCheck_FixReturnsNoAutoFix(t *testing.T) { + check := &DockerCheck{} + err := check.Fix(nil, nil) + require.ErrorIs(t, err, ErrNoAutoFix) +} diff --git a/pkg/doctor/registry.go b/pkg/doctor/registry.go index 3ec9dbc388..bfaf8bfa78 100644 --- a/pkg/doctor/registry.go +++ b/pkg/doctor/registry.go @@ -13,6 +13,7 @@ func AllChecks() []Check { &PydanticBaseModelCheck{}, &DeprecatedImportsCheck{}, - // Environment checks (added in subsequent tasks) + // Environment checks + &DockerCheck{}, } } From 7363db260e994541a9e4e5990a4f1840a78eecc4 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 14:07:57 -0400 Subject: [PATCH 12/24] feat(doctor): add Python version check --- pkg/doctor/check_env_python_version.go | 68 ++++++++++++++++++++++++++ pkg/doctor/check_env_test.go | 68 ++++++++++++++++++++++++++ pkg/doctor/registry.go | 1 + 3 files changed, 137 insertions(+) create mode 100644 pkg/doctor/check_env_python_version.go diff --git a/pkg/doctor/check_env_python_version.go b/pkg/doctor/check_env_python_version.go new file mode 100644 index 0000000000..6270dce2f7 --- /dev/null +++ b/pkg/doctor/check_env_python_version.go @@ -0,0 +1,68 @@ +package doctor + +import ( + "fmt" + "os/exec" + "strings" +) + +// PythonVersionCheck verifies that Python is available and that the local +// version is consistent with the version configured in cog.yaml. +type PythonVersionCheck struct{} + +func (c *PythonVersionCheck) Name() string { return "env-python-version" } +func (c *PythonVersionCheck) Group() Group { return GroupEnvironment } +func (c *PythonVersionCheck) Description() string { return "Python version" } + +func (c *PythonVersionCheck) Check(ctx *CheckContext) ([]Finding, error) { + if ctx.PythonPath == "" { + return []Finding{{ + Severity: SeverityWarning, + Message: "Python not found in PATH", + Remediation: "Install Python 3.10+ or ensure it is on your PATH", + }}, nil + } + + out, err := exec.Command(ctx.PythonPath, "--version").Output() + if err != nil { + return []Finding{{ + Severity: SeverityWarning, + Message: fmt.Sprintf("could not determine Python version: %v", err), + Remediation: "Ensure your Python installation is working correctly", + }}, nil + } + + // Output is "Python 3.12.1\n" + version := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(string(out)), "Python")) + localMajorMinor := majorMinor(version) + + // If cog.yaml specifies a Python version, compare + if ctx.Config != nil && ctx.Config.Build != nil && ctx.Config.Build.PythonVersion != "" { + configMajorMinor := majorMinor(ctx.Config.Build.PythonVersion) + if configMajorMinor != "" && localMajorMinor != "" && configMajorMinor != localMajorMinor { + return []Finding{{ + Severity: SeverityWarning, + Message: fmt.Sprintf( + "local Python is %s but cog.yaml specifies %s", + localMajorMinor, configMajorMinor, + ), + Remediation: "This is usually fine -- Docker builds use the configured version. Update cog.yaml or your local Python if needed.", + }}, nil + } + } + + return nil, nil +} + +func (c *PythonVersionCheck) Fix(_ *CheckContext, _ []Finding) error { + return ErrNoAutoFix +} + +// majorMinor extracts "3.12" from "3.12.1" or "3.12". +func majorMinor(version string) string { + parts := strings.SplitN(version, ".", 3) + if len(parts) < 2 { + return "" + } + return parts[0] + "." + parts[1] +} diff --git a/pkg/doctor/check_env_test.go b/pkg/doctor/check_env_test.go index 83a583bbe6..b67d25a3ac 100644 --- a/pkg/doctor/check_env_test.go +++ b/pkg/doctor/check_env_test.go @@ -4,6 +4,8 @@ import ( "testing" "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/config" ) func TestDockerCheck_RunsWithoutError(t *testing.T) { @@ -20,3 +22,69 @@ func TestDockerCheck_FixReturnsNoAutoFix(t *testing.T) { err := check.Fix(nil, nil) require.ErrorIs(t, err, ErrNoAutoFix) } + +func TestPythonVersionCheck_RunsWithoutError(t *testing.T) { + ctx := &CheckContext{ProjectDir: t.TempDir()} + check := &PythonVersionCheck{} + // Python may or may not be available; just ensure no panic or error. + _, err := check.Check(ctx) + require.NoError(t, err) +} + +func TestPythonVersionCheck_NoPython(t *testing.T) { + ctx := &CheckContext{ + ProjectDir: t.TempDir(), + PythonPath: "", // explicitly no Python + } + + check := &PythonVersionCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityWarning, findings[0].Severity) + require.Contains(t, findings[0].Message, "Python not found") +} + +func TestPythonVersionCheck_FixReturnsNoAutoFix(t *testing.T) { + check := &PythonVersionCheck{} + err := check.Fix(nil, nil) + require.ErrorIs(t, err, ErrNoAutoFix) +} + +func TestMajorMinor(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"3.12.1", "3.12"}, + {"3.12", "3.12"}, + {"3", ""}, + {"", ""}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + require.Equal(t, tt.want, majorMinor(tt.input)) + }) + } +} + +func TestPythonVersionCheck_VersionMismatch(t *testing.T) { + // This test verifies that when PythonPath is set but not a real binary, + // we get a warning. We can't easily fake the version output without a real binary. + ctx := &CheckContext{ + ProjectDir: t.TempDir(), + PythonPath: "/nonexistent/python3", + Config: &config.Config{ + Build: &config.Build{ + PythonVersion: "3.12", + }, + }, + } + + check := &PythonVersionCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityWarning, findings[0].Severity) + require.Contains(t, findings[0].Message, "could not determine Python version") +} diff --git a/pkg/doctor/registry.go b/pkg/doctor/registry.go index bfaf8bfa78..3111356ae6 100644 --- a/pkg/doctor/registry.go +++ b/pkg/doctor/registry.go @@ -15,5 +15,6 @@ func AllChecks() []Check { // Environment checks &DockerCheck{}, + &PythonVersionCheck{}, } } From 7c35c143be1de3e7d6ee7ee775a29561cd4c0d81 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 14:15:25 -0400 Subject: [PATCH 13/24] feat(doctor): add config schema validation check --- pkg/doctor/check_config_schema.go | 51 ++++++++++++++++++++++++++++ pkg/doctor/check_config_test.go | 56 +++++++++++++++++++++++++++++++ pkg/doctor/registry.go | 1 + 3 files changed, 108 insertions(+) create mode 100644 pkg/doctor/check_config_schema.go diff --git a/pkg/doctor/check_config_schema.go b/pkg/doctor/check_config_schema.go new file mode 100644 index 0000000000..e41a3f029e --- /dev/null +++ b/pkg/doctor/check_config_schema.go @@ -0,0 +1,51 @@ +package doctor + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/replicate/cog/pkg/config" +) + +// ConfigSchemaCheck validates cog.yaml against the configuration schema. +// Parse errors are handled by ConfigParseCheck; this check catches schema +// and validation errors (wrong types, invalid values, etc.). +type ConfigSchemaCheck struct{} + +func (c *ConfigSchemaCheck) Name() string { return "config-schema" } +func (c *ConfigSchemaCheck) Group() Group { return GroupConfig } +func (c *ConfigSchemaCheck) Description() string { return "Config schema" } + +func (c *ConfigSchemaCheck) Check(ctx *CheckContext) ([]Finding, error) { + configPath := filepath.Join(ctx.ProjectDir, "cog.yaml") + + f, err := os.Open(configPath) + if err != nil { + return nil, nil // ConfigParseCheck handles missing files + } + defer f.Close() + + _, loadErr := config.Load(f, ctx.ProjectDir) + if loadErr == nil { + return nil, nil // Valid config + } + + // If this is a parse error, skip — ConfigParseCheck handles it + var parseErr *config.ParseError + if isParseError(loadErr, &parseErr) { + return nil, nil + } + + // Any other error is a schema/validation error + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("cog.yaml validation failed: %v", loadErr), + Remediation: "Fix the configuration errors in cog.yaml", + File: "cog.yaml", + }}, nil +} + +func (c *ConfigSchemaCheck) Fix(_ *CheckContext, _ []Finding) error { + return ErrNoAutoFix +} diff --git a/pkg/doctor/check_config_test.go b/pkg/doctor/check_config_test.go index 62d8f5caf8..d530b4f5b7 100644 --- a/pkg/doctor/check_config_test.go +++ b/pkg/doctor/check_config_test.go @@ -200,6 +200,62 @@ func buildTestCheckContext(t *testing.T, dir string) *CheckContext { return ctx } +func TestConfigSchemaCheck_Valid(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := &CheckContext{ProjectDir: dir} + check := &ConfigSchemaCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestConfigSchemaCheck_InvalidSchema(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "2.7" +predict: "predict.py:Predictor" +`) + + ctx := &CheckContext{ProjectDir: dir} + check := &ConfigSchemaCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityError, findings[0].Severity) + require.Contains(t, findings[0].Message, "validation failed") +} + +func TestConfigSchemaCheck_ParseError(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: [invalid yaml`) + + ctx := &CheckContext{ProjectDir: dir} + check := &ConfigSchemaCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) // Parse errors are handled by ConfigParseCheck +} + +func TestConfigSchemaCheck_MissingFile(t *testing.T) { + dir := t.TempDir() + + ctx := &CheckContext{ProjectDir: dir} + check := &ConfigSchemaCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) // Missing file handled by ConfigParseCheck +} + // writeFile is a test helper to create fixture files. func writeFile(t *testing.T, dir, name, content string) { t.Helper() diff --git a/pkg/doctor/registry.go b/pkg/doctor/registry.go index 3111356ae6..9bb175d4af 100644 --- a/pkg/doctor/registry.go +++ b/pkg/doctor/registry.go @@ -6,6 +6,7 @@ func AllChecks() []Check { return []Check{ // Config checks &ConfigParseCheck{}, + &ConfigSchemaCheck{}, &ConfigDeprecatedFieldsCheck{}, &ConfigPredictRefCheck{}, From bdab4b16121839b47b1edb650c67a5d9e8fd66d4 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 14:16:36 -0400 Subject: [PATCH 14/24] feat(doctor): add missing type annotations check --- pkg/doctor/check_python_test.go | 64 ++++++++++ pkg/doctor/check_python_type_annotations.go | 123 ++++++++++++++++++++ pkg/doctor/registry.go | 1 + 3 files changed, 188 insertions(+) create mode 100644 pkg/doctor/check_python_type_annotations.go diff --git a/pkg/doctor/check_python_test.go b/pkg/doctor/check_python_test.go index da54191d26..942c172ecf 100644 --- a/pkg/doctor/check_python_test.go +++ b/pkg/doctor/check_python_test.go @@ -10,6 +10,7 @@ import ( "github.com/smacker/go-tree-sitter/python" "github.com/stretchr/testify/require" + "github.com/replicate/cog/pkg/config" schemaPython "github.com/replicate/cog/pkg/schema/python" ) @@ -185,6 +186,69 @@ class Predictor(BasePredictor): require.Empty(t, findings) } +func TestMissingTypeAnnotationsCheck_Clean(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + Config: &config.Config{Predict: "predict.py:Predictor"}, + } + + check := &MissingTypeAnnotationsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestMissingTypeAnnotationsCheck_MissingReturnType(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor + +class Predictor(BasePredictor): + def predict(self, text: str): + return text +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + Config: &config.Config{Predict: "predict.py:Predictor"}, + } + + check := &MissingTypeAnnotationsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityWarning, findings[0].Severity) + require.Contains(t, findings[0].Message, "predict()") + require.Contains(t, findings[0].Message, "return type annotation") +} + +func TestMissingTypeAnnotationsCheck_NoConfig(t *testing.T) { + ctx := &CheckContext{ + ProjectDir: t.TempDir(), + Config: nil, + } + + check := &MissingTypeAnnotationsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestMissingTypeAnnotationsCheck_FixReturnsNoAutoFix(t *testing.T) { + check := &MissingTypeAnnotationsCheck{} + err := check.Fix(nil, nil) + require.ErrorIs(t, err, ErrNoAutoFix) +} + // parsePythonFiles is a test helper that parses Python files into ParsedFile structs. func parsePythonFiles(t *testing.T, dir string, filenames ...string) map[string]*ParsedFile { t.Helper() diff --git a/pkg/doctor/check_python_type_annotations.go b/pkg/doctor/check_python_type_annotations.go new file mode 100644 index 0000000000..48ecec56e5 --- /dev/null +++ b/pkg/doctor/check_python_type_annotations.go @@ -0,0 +1,123 @@ +package doctor + +import ( + "fmt" + + sitter "github.com/smacker/go-tree-sitter" + + schemaPython "github.com/replicate/cog/pkg/schema/python" +) + +// MissingTypeAnnotationsCheck detects predict/train methods that are +// missing return type annotations. +type MissingTypeAnnotationsCheck struct{} + +func (c *MissingTypeAnnotationsCheck) Name() string { return "python-missing-type-annotations" } +func (c *MissingTypeAnnotationsCheck) Group() Group { return GroupPython } +func (c *MissingTypeAnnotationsCheck) Description() string { return "Type annotations" } + +func (c *MissingTypeAnnotationsCheck) Check(ctx *CheckContext) ([]Finding, error) { + if ctx.Config == nil { + return nil, nil + } + + var findings []Finding + + if ctx.Config.Predict != "" { + f := checkMethodAnnotations(ctx, ctx.Config.Predict, "predict") + findings = append(findings, f...) + } + + if ctx.Config.Train != "" { + f := checkMethodAnnotations(ctx, ctx.Config.Train, "train") + findings = append(findings, f...) + } + + return findings, nil +} + +func (c *MissingTypeAnnotationsCheck) Fix(_ *CheckContext, _ []Finding) error { + return ErrNoAutoFix +} + +// checkMethodAnnotations checks that the given method has a return type annotation. +func checkMethodAnnotations(ctx *CheckContext, ref string, methodName string) []Finding { + parts := splitPredictRef(ref) + fileName := parts[0] + className := parts[1] + + if fileName == "" || className == "" { + return nil // Invalid ref — handled by ConfigPredictRefCheck + } + + pf, ok := ctx.PythonFiles[fileName] + if !ok { + return nil // File not parsed — handled by ConfigPredictRefCheck + } + + root := pf.Tree.RootNode() + + // Find the class + classNode := findClass(root, pf.Source, className) + if classNode == nil { + return nil // Class not found — handled by ConfigPredictRefCheck + } + + // Find the method inside the class + funcNode := findMethod(classNode, pf.Source, methodName) + if funcNode == nil { + return nil // Method not found — could be a separate check later + } + + // Check for return type annotation + if funcNode.ChildByFieldName("return_type") == nil { + line := int(funcNode.StartPoint().Row) + 1 + return []Finding{{ + Severity: SeverityWarning, + Message: fmt.Sprintf( + "%s.%s() is missing a return type annotation", + className, methodName, + ), + Remediation: fmt.Sprintf("Add a return type annotation: def %s(self, ...) -> YourType:", methodName), + File: fileName, + Line: line, + }} + } + + return nil +} + +// findClass locates a top-level class by name. +func findClass(root *sitter.Node, source []byte, name string) *sitter.Node { + for _, child := range schemaPython.NamedChildren(root) { + classNode := schemaPython.UnwrapClass(child) + if classNode == nil { + continue + } + nameNode := classNode.ChildByFieldName("name") + if nameNode != nil && schemaPython.Content(nameNode, source) == name { + return classNode + } + } + return nil +} + +// findMethod locates a method by name inside a class body. +func findMethod(classNode *sitter.Node, source []byte, name string) *sitter.Node { + body := classNode.ChildByFieldName("body") + if body == nil { + return nil + } + + for _, child := range schemaPython.NamedChildren(body) { + funcNode := schemaPython.UnwrapFunction(child) + if funcNode == nil { + continue + } + nameNode := funcNode.ChildByFieldName("name") + if nameNode != nil && schemaPython.Content(nameNode, source) == name { + return funcNode + } + } + return nil +} diff --git a/pkg/doctor/registry.go b/pkg/doctor/registry.go index 9bb175d4af..5a2873138b 100644 --- a/pkg/doctor/registry.go +++ b/pkg/doctor/registry.go @@ -13,6 +13,7 @@ func AllChecks() []Check { // Python checks &PydanticBaseModelCheck{}, &DeprecatedImportsCheck{}, + &MissingTypeAnnotationsCheck{}, // Environment checks &DockerCheck{}, From f88981823fda928b7fc75fcf6e077fc7c0ada60f Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 14:22:33 -0400 Subject: [PATCH 15/24] fix(doctor): address lint issues in printDoctorResults Replace if-else chain with switch statement (gocritic) and use strings.Join instead of string concatenation in a loop (modernize). --- pkg/cli/doctor.go | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/pkg/cli/doctor.go b/pkg/cli/doctor.go index 755786bb35..8ed4fdda39 100644 --- a/pkg/cli/doctor.go +++ b/pkg/cli/doctor.go @@ -3,6 +3,7 @@ package cli import ( "context" "fmt" + "strings" "github.com/spf13/cobra" @@ -134,7 +135,8 @@ func printDoctorResults(result *doctor.Result, fix bool) { console.Info("") // Summary line - if fix && fixedCount > 0 { + switch { + case fix && fixedCount > 0: msg := fmt.Sprintf("Fixed %d issue", fixedCount) if fixedCount != 1 { msg += "s" @@ -153,8 +155,8 @@ func printDoctorResults(result *doctor.Result, fix bool) { } } console.Infof("%s.", msg) - } else if errorCount > 0 || warningCount > 0 { - parts := []string{} + case errorCount > 0 || warningCount > 0: + var parts []string if errorCount > 0 { s := fmt.Sprintf("%d error", errorCount) if errorCount != 1 { @@ -169,20 +171,13 @@ func printDoctorResults(result *doctor.Result, fix bool) { } parts = append(parts, s) } - summary := "Found " - for i, p := range parts { - if i > 0 { - summary += ", " - } - summary += p - } - summary += "." + summary := "Found " + strings.Join(parts, ", ") + "." if !fix && errorCount > 0 { summary += ` Run "cog doctor --fix" to auto-fix.` } console.Infof("%s", summary) - } else { + default: console.Successf("no issues found") } } From 3a9676082a7dc326faf5f124cd37dd1a00005881 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 14:22:41 -0400 Subject: [PATCH 16/24] test(doctor): add integration tests for cog doctor command Add 8 txtar integration tests covering: - Clean project (no issues found, exit 0) - Pydantic BaseModel detection (exit 1) - Pydantic BaseModel auto-fix with --fix - Deprecated imports detection (exit 1) - Deprecated imports auto-fix with --fix - Missing predict class reference (exit 1) - Exit code 1 when predict file is missing - Deprecated config fields as warnings (exit 0) --- .../tests/doctor_clean_project.txtar | 17 ++++++++ .../tests/doctor_deprecated_fields.txtar | 21 ++++++++++ .../tests/doctor_deprecated_imports.txtar | 23 +++++++++++ .../tests/doctor_exit_code.txtar | 10 +++++ .../tests/doctor_fix_deprecated_imports.txtar | 35 ++++++++++++++++ .../tests/doctor_fix_pydantic.txtar | 41 +++++++++++++++++++ .../tests/doctor_missing_predict_ref.txtar | 19 +++++++++ .../tests/doctor_pydantic_basemodel.txtar | 26 ++++++++++++ 8 files changed, 192 insertions(+) create mode 100644 integration-tests/tests/doctor_clean_project.txtar create mode 100644 integration-tests/tests/doctor_deprecated_fields.txtar create mode 100644 integration-tests/tests/doctor_deprecated_imports.txtar create mode 100644 integration-tests/tests/doctor_exit_code.txtar create mode 100644 integration-tests/tests/doctor_fix_deprecated_imports.txtar create mode 100644 integration-tests/tests/doctor_fix_pydantic.txtar create mode 100644 integration-tests/tests/doctor_missing_predict_ref.txtar create mode 100644 integration-tests/tests/doctor_pydantic_basemodel.txtar diff --git a/integration-tests/tests/doctor_clean_project.txtar b/integration-tests/tests/doctor_clean_project.txtar new file mode 100644 index 0000000000..76479483fe --- /dev/null +++ b/integration-tests/tests/doctor_clean_project.txtar @@ -0,0 +1,17 @@ +# Test that cog doctor succeeds on a valid project with no issues. + +cog doctor +stderr 'no issues found' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return "hello " + text diff --git a/integration-tests/tests/doctor_deprecated_fields.txtar b/integration-tests/tests/doctor_deprecated_fields.txtar new file mode 100644 index 0000000000..7a6c384880 --- /dev/null +++ b/integration-tests/tests/doctor_deprecated_fields.txtar @@ -0,0 +1,21 @@ +# Test that cog doctor reports deprecated config fields as warnings. +# Warnings do not cause a non-zero exit code; only errors do. + +cog doctor +stderr 'python_packages' +stderr 'deprecated' + +-- cog.yaml -- +build: + python_version: "3.12" + python_packages: + - torch==2.0.0 +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return "hello " + text diff --git a/integration-tests/tests/doctor_deprecated_imports.txtar b/integration-tests/tests/doctor_deprecated_imports.txtar new file mode 100644 index 0000000000..ac07785290 --- /dev/null +++ b/integration-tests/tests/doctor_deprecated_imports.txtar @@ -0,0 +1,23 @@ +# Test that cog doctor detects deprecated imports. +# ExperimentalFeatureWarning was removed in cog 0.17. + +! cog doctor +stderr 'ExperimentalFeatureWarning' +stderr 'removed' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +import warnings +from cog import BasePredictor +from cog.types import ExperimentalFeatureWarning + +warnings.filterwarnings("ignore", category=ExperimentalFeatureWarning) + + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return "hello " + text diff --git a/integration-tests/tests/doctor_exit_code.txtar b/integration-tests/tests/doctor_exit_code.txtar new file mode 100644 index 0000000000..c294a6039a --- /dev/null +++ b/integration-tests/tests/doctor_exit_code.txtar @@ -0,0 +1,10 @@ +# Test that cog doctor exits with code 1 when predict file is missing. + +! cog doctor +stderr 'predict.py' +stderr 'not found' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" diff --git a/integration-tests/tests/doctor_fix_deprecated_imports.txtar b/integration-tests/tests/doctor_fix_deprecated_imports.txtar new file mode 100644 index 0000000000..4e652c7a80 --- /dev/null +++ b/integration-tests/tests/doctor_fix_deprecated_imports.txtar @@ -0,0 +1,35 @@ +# Test that cog doctor --fix removes deprecated imports. + +# First, doctor should detect the issue +! cog doctor +stderr 'ExperimentalFeatureWarning' + +# Fix the issue +cog doctor --fix +stderr 'Fixed' + +# Verify the import was removed from the file +exec cat predict.py +! stdout 'ExperimentalFeatureWarning' +! stdout 'cog.types' + +# Re-running doctor should now pass +cog doctor +stderr 'no issues found' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +import warnings +from cog import BasePredictor +from cog.types import ExperimentalFeatureWarning + +warnings.filterwarnings("ignore", category=ExperimentalFeatureWarning) + + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return "hello " + text diff --git a/integration-tests/tests/doctor_fix_pydantic.txtar b/integration-tests/tests/doctor_fix_pydantic.txtar new file mode 100644 index 0000000000..198d80bb21 --- /dev/null +++ b/integration-tests/tests/doctor_fix_pydantic.txtar @@ -0,0 +1,41 @@ +# Test that cog doctor --fix rewrites pydantic.BaseModel to cog.BaseModel. +# After fix, re-running cog doctor should pass. + +# First, doctor should detect the issue +! cog doctor +stderr 'pydantic.BaseModel' + +# Fix the issue +cog doctor --fix +stderr 'Fixed' + +# Verify the file was modified: pydantic.BaseModel replaced with cog.BaseModel +exec cat predict.py +stdout 'from cog import BasePredictor, Path, BaseModel' +! stdout 'from pydantic import BaseModel' +! stdout 'ConfigDict' +! stdout 'arbitrary_types_allowed' + +# Re-running doctor should now pass +cog doctor +stderr 'no issues found' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor, Path +from pydantic import BaseModel, ConfigDict + + +class VoiceCloningOutputs(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + audio: Path + spectrogram: Path + + +class Predictor(BasePredictor): + def predict(self, text: str) -> VoiceCloningOutputs: + return VoiceCloningOutputs(audio="a.wav", spectrogram="s.png") diff --git a/integration-tests/tests/doctor_missing_predict_ref.txtar b/integration-tests/tests/doctor_missing_predict_ref.txtar new file mode 100644 index 0000000000..0aad732b24 --- /dev/null +++ b/integration-tests/tests/doctor_missing_predict_ref.txtar @@ -0,0 +1,19 @@ +# Test that cog doctor detects when predict ref points to a nonexistent class. +# cog.yaml references DoesNotExist but only Predictor exists in predict.py. + +! cog doctor +stderr 'DoesNotExist' +stderr 'not found' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:DoesNotExist" + +-- predict.py -- +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return "hello " + text diff --git a/integration-tests/tests/doctor_pydantic_basemodel.txtar b/integration-tests/tests/doctor_pydantic_basemodel.txtar new file mode 100644 index 0000000000..74ff03cf3c --- /dev/null +++ b/integration-tests/tests/doctor_pydantic_basemodel.txtar @@ -0,0 +1,26 @@ +# Test that cog doctor detects pydantic.BaseModel with arbitrary_types_allowed. +# This is a common workaround that should use cog.BaseModel instead. + +! cog doctor +stderr 'pydantic.BaseModel' +stderr 'cog.BaseModel' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor, Path +from pydantic import BaseModel, ConfigDict + + +class VoiceCloningOutputs(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + audio: Path + spectrogram: Path + + +class Predictor(BasePredictor): + def predict(self, text: str) -> VoiceCloningOutputs: + return VoiceCloningOutputs(audio="a.wav", spectrogram="s.png") From 43b831e20af0f0be84d449f645f330fd122dc3f1 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 17:29:31 -0400 Subject: [PATCH 17/24] =?UTF-8?q?fix(doctor):=20address=20review=20finding?= =?UTF-8?q?s=20=E2=80=94=20fix=20correctness=20bugs,=20consolidate=20confi?= =?UTF-8?q?g=20loading,=20improve=20fix=20robustness?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Critical fixes: - HasErrors() now accounts for check internal errors (exit code 1) - hasArbitraryTypesAllowed uses tree-sitter AST instead of string matching (no false positives) - addToCogImport handles missing 'from cog import' line - Config loading consolidated onto CheckContext (eliminates 4x redundant Load calls) High priority fixes: - Respect --file flag instead of hardcoding cog.yaml - Handle 'import pydantic' module-level style in fix - Deprecated imports fix uses tree-sitter re-scan instead of fragile string matching - ConfigPredictRefCheck reuses cached Python files from ctx.PythonFiles - model_config fix preserves other ConfigDict kwargs Medium fixes: - Docker/Python exec commands have 5-second timeout - Multi-line parenthesized imports handled by fix - File permissions preserved (not hardcoded to 0o644) - printDoctorResults counts per-finding consistently --- pkg/cli/doctor.go | 17 +- pkg/doctor/check_config_deprecated.go | 26 +-- pkg/doctor/check_config_parse.go | 33 ++-- pkg/doctor/check_config_predict_ref.go | 50 +++--- pkg/doctor/check_config_schema.go | 25 ++- pkg/doctor/check_config_test.go | 43 +++-- pkg/doctor/check_env_docker.go | 7 +- pkg/doctor/check_env_python_version.go | 7 +- pkg/doctor/check_python_deprecated_imports.go | 116 ++++++++++++-- pkg/doctor/check_python_pydantic_basemodel.go | 135 ++++++++++++++-- pkg/doctor/check_python_test.go | 150 ++++++++++++++++++ pkg/doctor/doctor.go | 13 +- pkg/doctor/doctor_test.go | 10 ++ pkg/doctor/runner.go | 42 +++-- 14 files changed, 525 insertions(+), 149 deletions(-) diff --git a/pkg/cli/doctor.go b/pkg/cli/doctor.go index 8ed4fdda39..9191e6a4e3 100644 --- a/pkg/cli/doctor.go +++ b/pkg/cli/doctor.go @@ -48,8 +48,9 @@ func runDoctor(ctx context.Context, fix bool) error { console.Info("") result, err := doctor.Run(ctx, doctor.RunOptions{ - Fix: fix, - ProjectDir: projectDir, + Fix: fix, + ProjectDir: projectDir, + ConfigFilename: configFilename, }, doctor.AllChecks()) if err != nil { return err @@ -105,13 +106,21 @@ func printDoctorResults(result *doctor.Result, fix bool) { switch worstSeverity { case doctor.SeverityError: console.Errorf("%s", cr.Check.Description()) - errorCount++ case doctor.SeverityWarning: console.Warnf("%s", cr.Check.Description()) - warningCount++ default: console.Infof("%s", cr.Check.Description()) } + + // Count per-finding for consistent totals + for _, f := range cr.Findings { + switch f.Severity { + case doctor.SeverityError: + errorCount++ + case doctor.SeverityWarning: + warningCount++ + } + } } // Print individual findings diff --git a/pkg/doctor/check_config_deprecated.go b/pkg/doctor/check_config_deprecated.go index 617fc6f71e..7eeda77cac 100644 --- a/pkg/doctor/check_config_deprecated.go +++ b/pkg/doctor/check_config_deprecated.go @@ -2,10 +2,6 @@ package doctor import ( "fmt" - "os" - "path/filepath" - - "github.com/replicate/cog/pkg/config" ) // ConfigDeprecatedFieldsCheck detects deprecated fields in cog.yaml. @@ -16,27 +12,21 @@ func (c *ConfigDeprecatedFieldsCheck) Group() Group { return GroupConfig func (c *ConfigDeprecatedFieldsCheck) Description() string { return "Deprecated fields" } func (c *ConfigDeprecatedFieldsCheck) Check(ctx *CheckContext) ([]Finding, error) { - configPath := filepath.Join(ctx.ProjectDir, "cog.yaml") - f, err := os.Open(configPath) - if err != nil { - return nil, nil // Config parse check handles missing file - } - defer f.Close() - - // We need to run validation to get deprecation warnings. - // Load does parse + validate + complete; we want just parse + validate. - loadResult, err := config.Load(f, ctx.ProjectDir) - if err != nil { - return nil, nil // Other config checks handle parse/validation errors + // No config loaded successfully — other checks handle parse/validation errors. + // Note: warnings are only available when Load succeeds because they come from + // ValidateConfigFile which uses an unexported type. If the config has validation + // errors, deprecation warnings cannot be surfaced through Load. + if ctx.LoadResult == nil { + return nil, nil } var findings []Finding - for _, w := range loadResult.Warnings { + for _, w := range ctx.LoadResult.Warnings { findings = append(findings, Finding{ Severity: SeverityWarning, Message: fmt.Sprintf("%q is deprecated: %s", w.Field, w.Message), Remediation: fmt.Sprintf("Use %q instead", w.Replacement), - File: "cog.yaml", + File: ctx.ConfigFilename, }) } diff --git a/pkg/doctor/check_config_parse.go b/pkg/doctor/check_config_parse.go index c11c859e2a..971ca9bd9e 100644 --- a/pkg/doctor/check_config_parse.go +++ b/pkg/doctor/check_config_parse.go @@ -3,8 +3,6 @@ package doctor import ( "errors" "fmt" - "os" - "path/filepath" "github.com/replicate/cog/pkg/config" ) @@ -17,36 +15,25 @@ func (c *ConfigParseCheck) Group() Group { return GroupConfig } func (c *ConfigParseCheck) Description() string { return "Config parsing" } func (c *ConfigParseCheck) Check(ctx *CheckContext) ([]Finding, error) { - configPath := filepath.Join(ctx.ProjectDir, "cog.yaml") - - if _, err := os.Stat(configPath); os.IsNotExist(err) { + // Config file not found on disk + if ctx.ConfigFile == nil { return []Finding{{ Severity: SeverityError, - Message: "cog.yaml not found", + Message: fmt.Sprintf("%s not found", ctx.ConfigFilename), Remediation: `Run "cog init" to create a cog.yaml`, - File: "cog.yaml", - }}, nil - } - - f, err := os.Open(configPath) - if err != nil { - return []Finding{{ - Severity: SeverityError, - Message: fmt.Sprintf("cannot read cog.yaml: %v", err), - File: "cog.yaml", + File: ctx.ConfigFilename, }}, nil } - defer f.Close() - _, loadErr := config.Load(f, ctx.ProjectDir) - if loadErr != nil { + // Check for parse errors from the single Load call in buildCheckContext + if ctx.LoadErr != nil { var parseErr *config.ParseError - if isParseError(loadErr, &parseErr) { + if isParseError(ctx.LoadErr, &parseErr) { return []Finding{{ Severity: SeverityError, - Message: fmt.Sprintf("cog.yaml has invalid YAML: %v", loadErr), - Remediation: "Fix the YAML syntax in cog.yaml", - File: "cog.yaml", + Message: fmt.Sprintf("%s has invalid YAML: %v", ctx.ConfigFilename, ctx.LoadErr), + Remediation: fmt.Sprintf("Fix the YAML syntax in %s", ctx.ConfigFilename), + File: ctx.ConfigFilename, }}, nil } // Other errors (validation, schema) are handled by other checks diff --git a/pkg/doctor/check_config_predict_ref.go b/pkg/doctor/check_config_predict_ref.go index 0de9cc6361..0ead3a9273 100644 --- a/pkg/doctor/check_config_predict_ref.go +++ b/pkg/doctor/check_config_predict_ref.go @@ -55,30 +55,40 @@ func (c *ConfigPredictRefCheck) Check(ctx *CheckContext) ([]Finding, error) { }}, nil } - // Check class exists in file using tree-sitter - source, err := os.ReadFile(fullPath) - if err != nil { - return []Finding{{ - Severity: SeverityError, - Message: fmt.Sprintf("cannot read %s: %v", pyFile, err), - File: pyFile, - }}, nil - } + // Use cached parse tree if available, otherwise parse from disk + var rootNode *sitter.Node + var source []byte + + if pf, ok := ctx.PythonFiles[pyFile]; ok { + rootNode = pf.Tree.RootNode() + source = pf.Source + } else { + var err error + source, err = os.ReadFile(fullPath) + if err != nil { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("cannot read %s: %v", pyFile, err), + File: pyFile, + }}, nil + } - parser := sitter.NewParser() - parser.SetLanguage(python.GetLanguage()) - tree, err := parser.ParseCtx(context.Background(), nil, source) - if err != nil { - return []Finding{{ - Severity: SeverityError, - Message: fmt.Sprintf("cannot parse %s: %v", pyFile, err), - File: pyFile, - }}, nil + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + tree, err := parser.ParseCtx(context.Background(), nil, source) + if err != nil { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("cannot parse %s: %v", pyFile, err), + File: pyFile, + }}, nil + } + rootNode = tree.RootNode() } - if !hasClassDefinition(tree.RootNode(), source, className) { + if !hasClassDefinition(rootNode, source, className) { // List available classes to help the user - classes := listClassNames(tree.RootNode(), source) + classes := listClassNames(rootNode, source) msg := fmt.Sprintf("class %q not found in %s", className, pyFile) if len(classes) > 0 { msg += fmt.Sprintf("; found: %s", strings.Join(classes, ", ")) diff --git a/pkg/doctor/check_config_schema.go b/pkg/doctor/check_config_schema.go index e41a3f029e..8c5bb41a52 100644 --- a/pkg/doctor/check_config_schema.go +++ b/pkg/doctor/check_config_schema.go @@ -2,8 +2,6 @@ package doctor import ( "fmt" - "os" - "path/filepath" "github.com/replicate/cog/pkg/config" ) @@ -18,31 +16,28 @@ func (c *ConfigSchemaCheck) Group() Group { return GroupConfig } func (c *ConfigSchemaCheck) Description() string { return "Config schema" } func (c *ConfigSchemaCheck) Check(ctx *CheckContext) ([]Finding, error) { - configPath := filepath.Join(ctx.ProjectDir, "cog.yaml") - - f, err := os.Open(configPath) - if err != nil { - return nil, nil // ConfigParseCheck handles missing files + // No config file on disk — ConfigParseCheck handles this + if ctx.ConfigFile == nil { + return nil, nil } - defer f.Close() - _, loadErr := config.Load(f, ctx.ProjectDir) - if loadErr == nil { - return nil, nil // Valid config + // No load error means valid config + if ctx.LoadErr == nil { + return nil, nil } // If this is a parse error, skip — ConfigParseCheck handles it var parseErr *config.ParseError - if isParseError(loadErr, &parseErr) { + if isParseError(ctx.LoadErr, &parseErr) { return nil, nil } // Any other error is a schema/validation error return []Finding{{ Severity: SeverityError, - Message: fmt.Sprintf("cog.yaml validation failed: %v", loadErr), - Remediation: "Fix the configuration errors in cog.yaml", - File: "cog.yaml", + Message: fmt.Sprintf("%s validation failed: %v", ctx.ConfigFilename, ctx.LoadErr), + Remediation: fmt.Sprintf("Fix the configuration errors in %s", ctx.ConfigFilename), + File: ctx.ConfigFilename, }}, nil } diff --git a/pkg/doctor/check_config_test.go b/pkg/doctor/check_config_test.go index d530b4f5b7..31a563626a 100644 --- a/pkg/doctor/check_config_test.go +++ b/pkg/doctor/check_config_test.go @@ -1,6 +1,7 @@ package doctor import ( + "bytes" "os" "path/filepath" "testing" @@ -17,9 +18,7 @@ func TestConfigParseCheck_Valid(t *testing.T) { predict: "predict.py:Predictor" `) - ctx := &CheckContext{ProjectDir: dir} - ctx.ConfigFile, _ = os.ReadFile(filepath.Join(dir, "cog.yaml")) - + ctx := buildTestCheckContext(t, dir) check := &ConfigParseCheck{} findings, err := check.Check(ctx) require.NoError(t, err) @@ -30,9 +29,7 @@ func TestConfigParseCheck_InvalidYAML(t *testing.T) { dir := t.TempDir() writeFile(t, dir, "cog.yaml", `build: [invalid yaml`) - ctx := &CheckContext{ProjectDir: dir} - ctx.ConfigFile, _ = os.ReadFile(filepath.Join(dir, "cog.yaml")) - + ctx := buildTestCheckContext(t, dir) check := &ConfigParseCheck{} findings, err := check.Check(ctx) require.NoError(t, err) @@ -44,8 +41,7 @@ func TestConfigParseCheck_InvalidYAML(t *testing.T) { func TestConfigParseCheck_MissingFile(t *testing.T) { dir := t.TempDir() - ctx := &CheckContext{ProjectDir: dir} - + ctx := buildTestCheckContext(t, dir) check := &ConfigParseCheck{} findings, err := check.Check(ctx) require.NoError(t, err) @@ -62,7 +58,7 @@ predict: "predict.py:Predictor" `) writeFile(t, dir, "requirements.txt", "torch==2.0.0\n") - ctx := &CheckContext{ProjectDir: dir} + ctx := buildTestCheckContext(t, dir) check := &ConfigDeprecatedFieldsCheck{} findings, err := check.Check(ctx) require.NoError(t, err) @@ -78,7 +74,7 @@ func TestConfigDeprecatedFieldsCheck_PythonPackages(t *testing.T) { predict: "predict.py:Predictor" `) - ctx := &CheckContext{ProjectDir: dir} + ctx := buildTestCheckContext(t, dir) check := &ConfigDeprecatedFieldsCheck{} findings, err := check.Check(ctx) require.NoError(t, err) @@ -96,7 +92,7 @@ func TestConfigDeprecatedFieldsCheck_PreInstall(t *testing.T) { predict: "predict.py:Predictor" `) - ctx := &CheckContext{ProjectDir: dir} + ctx := buildTestCheckContext(t, dir) check := &ConfigDeprecatedFieldsCheck{} findings, err := check.Check(ctx) require.NoError(t, err) @@ -179,21 +175,20 @@ func TestConfigPredictRefCheck_NoPredictField(t *testing.T) { func buildTestCheckContext(t *testing.T, dir string) *CheckContext { t.Helper() ctx := &CheckContext{ - ProjectDir: dir, - PythonFiles: make(map[string]*ParsedFile), + ProjectDir: dir, + ConfigFilename: "cog.yaml", + PythonFiles: make(map[string]*ParsedFile), } configPath := filepath.Join(dir, "cog.yaml") configBytes, err := os.ReadFile(configPath) if err == nil { ctx.ConfigFile = configBytes - f, err := os.Open(configPath) - if err == nil { - defer f.Close() - loadResult, err := config.Load(f, dir) - if err == nil { - ctx.Config = loadResult.Config - } + loadResult, loadErr := config.Load(bytes.NewReader(configBytes), dir) + ctx.LoadErr = loadErr + if loadResult != nil { + ctx.LoadResult = loadResult + ctx.Config = loadResult.Config } } @@ -212,7 +207,7 @@ class Predictor(BasePredictor): return text `) - ctx := &CheckContext{ProjectDir: dir} + ctx := buildTestCheckContext(t, dir) check := &ConfigSchemaCheck{} findings, err := check.Check(ctx) require.NoError(t, err) @@ -226,7 +221,7 @@ func TestConfigSchemaCheck_InvalidSchema(t *testing.T) { predict: "predict.py:Predictor" `) - ctx := &CheckContext{ProjectDir: dir} + ctx := buildTestCheckContext(t, dir) check := &ConfigSchemaCheck{} findings, err := check.Check(ctx) require.NoError(t, err) @@ -239,7 +234,7 @@ func TestConfigSchemaCheck_ParseError(t *testing.T) { dir := t.TempDir() writeFile(t, dir, "cog.yaml", `build: [invalid yaml`) - ctx := &CheckContext{ProjectDir: dir} + ctx := buildTestCheckContext(t, dir) check := &ConfigSchemaCheck{} findings, err := check.Check(ctx) require.NoError(t, err) @@ -249,7 +244,7 @@ func TestConfigSchemaCheck_ParseError(t *testing.T) { func TestConfigSchemaCheck_MissingFile(t *testing.T) { dir := t.TempDir() - ctx := &CheckContext{ProjectDir: dir} + ctx := buildTestCheckContext(t, dir) check := &ConfigSchemaCheck{} findings, err := check.Check(ctx) require.NoError(t, err) diff --git a/pkg/doctor/check_env_docker.go b/pkg/doctor/check_env_docker.go index efe4dbfa48..10580d7939 100644 --- a/pkg/doctor/check_env_docker.go +++ b/pkg/doctor/check_env_docker.go @@ -1,8 +1,10 @@ package doctor import ( + "context" "fmt" "os/exec" + "time" ) // DockerCheck verifies that Docker is installed and the daemon is reachable. @@ -13,7 +15,10 @@ func (c *DockerCheck) Group() Group { return GroupEnvironment } func (c *DockerCheck) Description() string { return "Docker" } func (c *DockerCheck) Check(_ *CheckContext) ([]Finding, error) { - if err := exec.Command("docker", "info").Run(); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := exec.CommandContext(ctx, "docker", "info").Run(); err != nil { return []Finding{{ Severity: SeverityError, Message: fmt.Sprintf("Docker is not available: %v", err), diff --git a/pkg/doctor/check_env_python_version.go b/pkg/doctor/check_env_python_version.go index 6270dce2f7..a4f8241e75 100644 --- a/pkg/doctor/check_env_python_version.go +++ b/pkg/doctor/check_env_python_version.go @@ -1,9 +1,11 @@ package doctor import ( + "context" "fmt" "os/exec" "strings" + "time" ) // PythonVersionCheck verifies that Python is available and that the local @@ -23,7 +25,10 @@ func (c *PythonVersionCheck) Check(ctx *CheckContext) ([]Finding, error) { }}, nil } - out, err := exec.Command(ctx.PythonPath, "--version").Output() + execCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + out, err := exec.CommandContext(execCtx, ctx.PythonPath, "--version").Output() if err != nil { return []Finding{{ Severity: SeverityWarning, diff --git a/pkg/doctor/check_python_deprecated_imports.go b/pkg/doctor/check_python_deprecated_imports.go index c40e8f5d93..e5b66357d6 100644 --- a/pkg/doctor/check_python_deprecated_imports.go +++ b/pkg/doctor/check_python_deprecated_imports.go @@ -74,34 +74,55 @@ func (c *DeprecatedImportsCheck) Check(ctx *CheckContext) ([]Finding, error) { func (c *DeprecatedImportsCheck) Fix(ctx *CheckContext, findings []Finding) error { // Group findings by file - fileFindings := make(map[string][]Finding) + affectedFiles := make(map[string]bool) for _, f := range findings { - fileFindings[f.File] = append(fileFindings[f.File], f) + affectedFiles[f.File] = true } - for relPath, fileFindingsList := range fileFindings { + for relPath := range affectedFiles { fullPath := filepath.Join(ctx.ProjectDir, relPath) + info, err := os.Stat(fullPath) + if err != nil { + return fmt.Errorf("stat %s: %w", relPath, err) + } + source, err := os.ReadFile(fullPath) if err != nil { return fmt.Errorf("reading %s: %w", relPath, err) } - // Collect deprecated names to remove + // Re-scan the file using tree-sitter to find deprecated imports directly, + // rather than relying on fragile string matching against finding messages. + pf, ok := ctx.PythonFiles[relPath] + if !ok { + continue + } namesToRemove := make(map[string]map[string]bool) // module -> set of names - for _, f := range fileFindingsList { - for _, dep := range deprecatedImportsList { - if strings.Contains(f.Message, dep.Name) { - if namesToRemove[dep.Module] == nil { - namesToRemove[dep.Module] = make(map[string]bool) + root := pf.Tree.RootNode() + for _, child := range schemaPython.NamedChildren(root) { + if child.Type() != "import_from_statement" { + continue + } + moduleNode := child.ChildByFieldName("module_name") + if moduleNode == nil { + continue + } + module := schemaPython.Content(moduleNode, pf.Source) + for _, name := range extractImportedNames(child, pf.Source) { + for _, dep := range deprecatedImportsList { + if module == dep.Module && name == dep.Name { + if namesToRemove[dep.Module] == nil { + namesToRemove[dep.Module] = make(map[string]bool) + } + namesToRemove[dep.Module][dep.Name] = true } - namesToRemove[dep.Module][dep.Name] = true } } } fixed := removeDeprecatedImportLines(string(source), namesToRemove) - if err := os.WriteFile(fullPath, []byte(fixed), 0o644); err != nil { + if err := os.WriteFile(fullPath, []byte(fixed), info.Mode()); err != nil { return fmt.Errorf("writing %s: %w", relPath, err) } } @@ -143,13 +164,48 @@ func extractImportedNames(importNode *sitter.Node, source []byte) []string { // removeDeprecatedImportLines removes specific names from import lines. // If all names are removed, the entire import line is dropped. +// Handles both single-line and multi-line parenthesized imports. func removeDeprecatedImportLines(source string, namesToRemove map[string]map[string]bool) string { lines := strings.Split(source, "\n") var result []string + // Track multi-line import state + inMultilineImport := false + var multilineModule string + var multilineNames []string + for _, line := range lines { trimmed := strings.TrimSpace(line) + // Handle multi-line imports + if inMultilineImport { + // Collect names from continuation lines + cleaned := strings.TrimSpace(strings.TrimSuffix(strings.TrimSpace(trimmed), ")")) + if cleaned != "" { + for n := range strings.SplitSeq(cleaned, ",") { + n = strings.TrimSpace(n) + if n != "" { + multilineNames = append(multilineNames, n) + } + } + } + if strings.Contains(trimmed, ")") { + inMultilineImport = false + // Now filter the collected names + names := namesToRemove[multilineModule] + var remaining []string + for _, name := range multilineNames { + if !names[name] { + remaining = append(remaining, name) + } + } + if len(remaining) > 0 { + result = append(result, "from "+multilineModule+" import "+strings.Join(remaining, ", ")) + } + } + continue + } + removed := false for module, names := range namesToRemove { prefix := "from " + module + " import " @@ -158,6 +214,44 @@ func removeDeprecatedImportLines(source string, namesToRemove map[string]map[str } importPart := trimmed[len(prefix):] + + // Check for multi-line parenthesized import + if strings.HasPrefix(strings.TrimSpace(importPart), "(") { + inner := strings.TrimSpace(importPart)[1:] // strip leading "(" + if strings.Contains(inner, ")") { + // Single-line parenthesized: from X import (A, B) + inner = strings.TrimSuffix(strings.TrimSpace(inner), ")") + importNames := strings.Split(inner, ",") + var remaining []string + for _, name := range importNames { + name = strings.TrimSpace(name) + if name != "" && !names[name] { + remaining = append(remaining, name) + } + } + if len(remaining) > 0 { + result = append(result, prefix+strings.Join(remaining, ", ")) + } + removed = true + } else { + // Start of multi-line import + inMultilineImport = true + multilineModule = module + multilineNames = nil + // Collect any names on the first line after "(" + if inner != "" { + for n := range strings.SplitSeq(inner, ",") { + n = strings.TrimSpace(n) + if n != "" { + multilineNames = append(multilineNames, n) + } + } + } + removed = true + } + break + } + importNames := strings.Split(importPart, ",") var remaining []string diff --git a/pkg/doctor/check_python_pydantic_basemodel.go b/pkg/doctor/check_python_pydantic_basemodel.go index 64fa1f631d..4d7908f1db 100644 --- a/pkg/doctor/check_python_pydantic_basemodel.go +++ b/pkg/doctor/check_python_pydantic_basemodel.go @@ -72,6 +72,11 @@ func (c *PydanticBaseModelCheck) Fix(ctx *CheckContext, findings []Finding) erro for relPath := range fileFindings { fullPath := filepath.Join(ctx.ProjectDir, relPath) + info, err := os.Stat(fullPath) + if err != nil { + return fmt.Errorf("stat %s: %w", relPath, err) + } + source, err := os.ReadFile(fullPath) if err != nil { return fmt.Errorf("reading %s: %w", relPath, err) @@ -79,7 +84,7 @@ func (c *PydanticBaseModelCheck) Fix(ctx *CheckContext, findings []Finding) erro fixed := fixPydanticBaseModel(string(source)) - if err := os.WriteFile(fullPath, []byte(fixed), 0o644); err != nil { + if err := os.WriteFile(fullPath, []byte(fixed), info.Mode()); err != nil { return fmt.Errorf("writing %s: %w", relPath, err) } } @@ -117,6 +122,8 @@ func inheritsPydanticBaseModel(classNode *sitter.Node, source []byte, imports *s // hasArbitraryTypesAllowed checks if a class body contains // model_config = ConfigDict(arbitrary_types_allowed=True). +// Uses tree-sitter to properly parse keyword arguments, avoiding false positives +// on arbitrary_types_allowed=False. func hasArbitraryTypesAllowed(classNode *sitter.Node, source []byte) bool { body := classNode.ChildByFieldName("body") if body == nil { @@ -139,13 +146,26 @@ func hasArbitraryTypesAllowed(classNode *sitter.Node, source []byte) bool { } right := node.ChildByFieldName("right") - if right == nil { + if right == nil || right.Type() != "call" { continue } - text := schemaPython.Content(right, source) - if strings.Contains(text, "arbitrary_types_allowed") && strings.Contains(text, "True") { - return true + // Walk keyword arguments of the call + args := right.ChildByFieldName("arguments") + if args == nil { + continue + } + for _, arg := range schemaPython.NamedChildren(args) { + if arg.Type() != "keyword_argument" { + continue + } + key := arg.ChildByFieldName("name") + val := arg.ChildByFieldName("value") + if key != nil && val != nil && + schemaPython.Content(key, source) == "arbitrary_types_allowed" && + schemaPython.Content(val, source) == "True" { + return true + } } } @@ -156,13 +176,34 @@ func hasArbitraryTypesAllowed(classNode *sitter.Node, source []byte) bool { func fixPydanticBaseModel(source string) string { lines := strings.Split(source, "\n") var result []string + inPydanticImport := false for _, line := range lines { trimmed := strings.TrimSpace(line) - // Remove "from pydantic import BaseModel" (and ConfigDict) + // Handle multi-line "from pydantic import (" style + if strings.HasPrefix(trimmed, "from pydantic import (") { + inPydanticImport = true + // Check if this single-line also closes: from pydantic import (BaseModel, ConfigDict) + if strings.Contains(trimmed, ")") { + inPydanticImport = false + remaining := removePydanticImportsMultiline(trimmed) + if remaining != "" { + result = append(result, remaining) + } + } + continue + } + if inPydanticImport { + if strings.Contains(trimmed, ")") { + inPydanticImport = false + } + // Skip all lines in the parenthesized pydantic import + continue + } + + // Remove single-line "from pydantic import BaseModel" (and ConfigDict) if strings.HasPrefix(trimmed, "from pydantic import") { - // Remove BaseModel and ConfigDict from the import remaining := removePydanticImports(trimmed) if remaining == "" { continue // Drop the entire line @@ -171,9 +212,32 @@ func fixPydanticBaseModel(source string) string { continue } - // Remove model_config = ConfigDict(...) lines + // Handle "import pydantic" style + if trimmed == "import pydantic" { + continue // Drop the line + } + + // Handle model_config = ConfigDict(...) lines -- only remove arbitrary_types_allowed=True if strings.Contains(trimmed, "model_config") && strings.Contains(trimmed, "ConfigDict") { - continue + fixed := removeKeywordArg(line, "arbitrary_types_allowed=True") + if fixed != line { + // Successfully removed the argument; check if ConfigDict is now empty + if isEmptyConfigDict(fixed) { + continue // Drop the line entirely + } + result = append(result, fixed) + continue + } + } + + // Replace pydantic.BaseModel in class definitions + if strings.Contains(line, "pydantic.BaseModel") { + line = strings.ReplaceAll(line, "pydantic.BaseModel", "BaseModel") + } + + // Replace pydantic.ConfigDict references + if strings.Contains(line, "pydantic.ConfigDict") { + line = strings.ReplaceAll(line, "pydantic.ConfigDict", "ConfigDict") } result = append(result, line) @@ -186,6 +250,52 @@ func fixPydanticBaseModel(source string) string { return fixed } +// removeKeywordArg removes a specific keyword argument from a line containing a function call. +// Handles "arg, ", ", arg", and standalone "arg". +func removeKeywordArg(line string, arg string) string { + result := strings.Replace(line, arg+", ", "", 1) + if result == line { + result = strings.Replace(line, ", "+arg, "", 1) + } + if result == line { + result = strings.Replace(line, arg, "", 1) + } + return result +} + +// isEmptyConfigDict checks if a line contains an empty ConfigDict() call. +func isEmptyConfigDict(line string) bool { + trimmed := strings.TrimSpace(line) + return strings.Contains(trimmed, "ConfigDict()") || strings.Contains(trimmed, "ConfigDict( )") +} + +// removePydanticImportsMultiline handles "from pydantic import (X, Y, Z)" on a single line. +func removePydanticImportsMultiline(line string) string { + // Extract contents between "from pydantic import (" and ")" + start := strings.Index(line, "(") + end := strings.LastIndex(line, ")") + if start == -1 || end == -1 || start >= end { + return "" + } + importPart := line[start+1 : end] + names := strings.Split(importPart, ",") + + var remaining []string + for _, name := range names { + name = strings.TrimSpace(name) + if name == "BaseModel" || name == "ConfigDict" || name == "" { + continue + } + remaining = append(remaining, name) + } + + if len(remaining) == 0 { + return "" + } + + return "from pydantic import " + strings.Join(remaining, ", ") +} + // removePydanticImports removes BaseModel and ConfigDict from a pydantic import line. // Returns "" if no imports remain. func removePydanticImports(line string) string { @@ -217,6 +327,7 @@ func removePydanticImports(line string) string { } // addToCogImport adds a name to an existing "from cog import ..." line. +// If no "from cog import" line exists, inserts one at the top of the file. func addToCogImport(source string, name string) string { lines := strings.Split(source, "\n") for i, line := range lines { @@ -226,9 +337,11 @@ func addToCogImport(source string, name string) string { return source // Already imported } // Add the name at the end - lines[i] = trimmed + ", " + name + lines[i] = line + ", " + name return strings.Join(lines, "\n") } } - return source + // No existing cog import found -- add one at the top + newImport := "from cog import " + name + return newImport + "\n" + source } diff --git a/pkg/doctor/check_python_test.go b/pkg/doctor/check_python_test.go index 942c172ecf..04a1106493 100644 --- a/pkg/doctor/check_python_test.go +++ b/pkg/doctor/check_python_test.go @@ -109,6 +109,156 @@ class Predictor(BasePredictor): require.Empty(t, findings) } +func TestPydanticBaseModelCheck_NoFalsePositive(t *testing.T) { + // arbitrary_types_allowed=False should NOT trigger the check + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor, Path +from pydantic import BaseModel, ConfigDict + +class Output(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=False, validate_default=True) + audio: Path + +class Predictor(BasePredictor): + def predict(self, text: str) -> Output: + return Output(audio="a.wav") +`) + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestPydanticBaseModelCheck_Fix_NoCogImport(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from pydantic import BaseModel, ConfigDict + +class Output(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + name: str +`) + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.Contains(t, string(content), "from cog import BaseModel") + require.NotContains(t, string(content), "from pydantic import BaseModel") +} + +func TestPydanticBaseModelCheck_Fix_ImportPydanticStyle(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor, Path +import pydantic + +class Output(pydantic.BaseModel): + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + audio: Path + +class Predictor(BasePredictor): + def predict(self, text: str) -> Output: + return Output(audio="a.wav") +`) + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.Contains(t, string(content), "from cog import BasePredictor, Path, BaseModel") + require.NotContains(t, string(content), "import pydantic") + require.NotContains(t, string(content), "pydantic.BaseModel") + require.NotContains(t, string(content), "pydantic.ConfigDict") +} + +func TestPydanticBaseModelCheck_Fix_PreservesOtherConfigDict(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor, Path +from pydantic import BaseModel, ConfigDict + +class Output(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True) + audio: Path + +class Predictor(BasePredictor): + def predict(self, text: str) -> Output: + return Output(audio="a.wav") +`) + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.Contains(t, string(content), "validate_default=True") + require.NotContains(t, string(content), "arbitrary_types_allowed") + require.Contains(t, string(content), "model_config") +} + +func TestPydanticBaseModelCheck_Fix_MultilineImport(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor, Path +from pydantic import ( + BaseModel, + ConfigDict, +) + +class Output(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + audio: Path + +class Predictor(BasePredictor): + def predict(self, text: str) -> Output: + return Output(audio="a.wav") +`) + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.Contains(t, string(content), "from cog import BasePredictor, Path, BaseModel") + require.NotContains(t, string(content), "from pydantic import") + require.NotContains(t, string(content), "arbitrary_types_allowed") +} + func TestDeprecatedImportsCheck_Clean(t *testing.T) { dir := t.TempDir() writeFile(t, dir, "predict.py", `from cog import BasePredictor diff --git a/pkg/doctor/doctor.go b/pkg/doctor/doctor.go index 27467a323a..15d0a6e695 100644 --- a/pkg/doctor/doctor.go +++ b/pkg/doctor/doctor.go @@ -73,9 +73,12 @@ type ParsedFile struct { // CheckContext provides checks with access to project state. // Built once by the runner and passed to every check. type CheckContext struct { - ProjectDir string - Config *config.Config // Parsed cog.yaml (nil if parsing failed) - ConfigFile []byte // Raw cog.yaml bytes (available even if parsing failed) - PythonFiles map[string]*ParsedFile // Pre-parsed Python files keyed by relative path - PythonPath string // Path to python binary (empty if not found) + ProjectDir string + ConfigFilename string // Config filename (e.g. "cog.yaml") + Config *config.Config // Parsed cog.yaml (nil if parsing failed) + ConfigFile []byte // Raw cog.yaml bytes (available even if parsing failed) + LoadResult *config.LoadResult // Non-nil if config loaded successfully + LoadErr error // Non-nil if config loading failed + PythonFiles map[string]*ParsedFile // Pre-parsed Python files keyed by relative path + PythonPath string // Path to python binary (empty if not found) } diff --git a/pkg/doctor/doctor_test.go b/pkg/doctor/doctor_test.go index 0e029d631f..15a8d53a12 100644 --- a/pkg/doctor/doctor_test.go +++ b/pkg/doctor/doctor_test.go @@ -2,6 +2,7 @@ package doctor import ( "context" + "errors" "testing" "github.com/stretchr/testify/require" @@ -145,3 +146,12 @@ func TestRunMarksNotFixedOnErrNoAutoFix(t *testing.T) { require.NoError(t, err) require.False(t, result.Results[0].Fixed) } + +func TestHasErrorsWithCheckError(t *testing.T) { + result := &Result{ + Results: []CheckResult{ + {Err: errors.New("check failed")}, + }, + } + require.True(t, result.HasErrors()) +} diff --git a/pkg/doctor/runner.go b/pkg/doctor/runner.go index b716318aa5..fa01a9fd6c 100644 --- a/pkg/doctor/runner.go +++ b/pkg/doctor/runner.go @@ -1,6 +1,7 @@ package doctor import ( + "bytes" "context" "errors" "os" @@ -16,8 +17,9 @@ import ( // RunOptions configures a doctor run. type RunOptions struct { - Fix bool - ProjectDir string + Fix bool + ProjectDir string + ConfigFilename string // Config filename (defaults to "cog.yaml" if empty) } // CheckResult holds the outcome of running a single check. @@ -33,9 +35,13 @@ type Result struct { Results []CheckResult } -// HasErrors returns true if any check produced error-severity findings. +// HasErrors returns true if any check produced error-severity findings +// or if any check itself errored. func (r *Result) HasErrors() bool { for _, cr := range r.Results { + if cr.Err != nil { + return true + } for _, f := range cr.Findings { if f.Severity == SeverityError && !cr.Fixed { return true @@ -47,7 +53,12 @@ func (r *Result) HasErrors() bool { // Run executes all checks and optionally applies fixes. func Run(_ context.Context, opts RunOptions, checks []Check) (*Result, error) { - checkCtx, err := buildCheckContext(opts.ProjectDir) + configFilename := opts.ConfigFilename + if configFilename == "" { + configFilename = "cog.yaml" + } + + checkCtx, err := buildCheckContext(opts.ProjectDir, configFilename) if err != nil { return nil, err } @@ -82,25 +93,24 @@ func Run(_ context.Context, opts RunOptions, checks []Check) (*Result, error) { } // buildCheckContext constructs the shared context for all checks. -func buildCheckContext(projectDir string) (*CheckContext, error) { +func buildCheckContext(projectDir string, configFilename string) (*CheckContext, error) { ctx := &CheckContext{ - ProjectDir: projectDir, - PythonFiles: make(map[string]*ParsedFile), + ProjectDir: projectDir, + ConfigFilename: configFilename, + PythonFiles: make(map[string]*ParsedFile), } // Load cog.yaml - configPath := filepath.Join(projectDir, "cog.yaml") + configPath := filepath.Join(projectDir, configFilename) configBytes, err := os.ReadFile(configPath) if err == nil { ctx.ConfigFile = configBytes - // Try to load and validate config - f, err := os.Open(configPath) - if err == nil { - defer f.Close() - loadResult, err := config.Load(f, projectDir) - if err == nil { - ctx.Config = loadResult.Config - } + // Load and validate config once — checks use ctx.LoadResult / ctx.LoadErr + loadResult, loadErr := config.Load(bytes.NewReader(configBytes), projectDir) + ctx.LoadErr = loadErr + if loadResult != nil { + ctx.LoadResult = loadResult + ctx.Config = loadResult.Config } } From bf45c61d1b1c05d2678c0141aa22ba5b7b98c9fd Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 17:32:18 -0400 Subject: [PATCH 18/24] chore: update cli docs Signed-off-by: Mark Phelps --- docs/cli.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/cli.md b/docs/cli.md index 04446b2e29..13869bde7e 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -66,6 +66,24 @@ cog build [flags] --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` +## `cog doctor` + +Diagnose and fix common issues in your Cog project. + +By default, cog doctor reports problems without modifying any files. +Pass --fix to automatically apply safe fixes. + +``` +cog doctor [flags] +``` + +**Options** + +``` + -f, --file string The name of the config file. (default "cog.yaml") + --fix Automatically apply fixes + -h, --help help for doctor +``` ## `cog exec` Execute a command inside a Docker environment defined by cog.yaml. From 98fcd5ad0bb7c49c2630c01f21871101b84d05a1 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 9 Apr 2026 17:35:13 -0400 Subject: [PATCH 19/24] chore: update llms docs Signed-off-by: Mark Phelps --- docs/llms.txt | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/llms.txt b/docs/llms.txt index 73975c6beb..01a0d7d42c 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -262,6 +262,24 @@ cog build [flags] --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` +## `cog doctor` + +Diagnose and fix common issues in your Cog project. + +By default, cog doctor reports problems without modifying any files. +Pass --fix to automatically apply safe fixes. + +``` +cog doctor [flags] +``` + +**Options** + +``` + -f, --file string The name of the config file. (default "cog.yaml") + --fix Automatically apply fixes + -h, --help help for doctor +``` ## `cog exec` Execute a command inside a Docker environment defined by cog.yaml. From ca418848fd8f33a28dc8bab9b548256b2996854c Mon Sep 17 00:00:00 2001 From: Mark Phelps <209477+markphelps@users.noreply.github.com> Date: Mon, 13 Apr 2026 10:43:06 -0400 Subject: [PATCH 20/24] Update pkg/doctor/runner.go Co-authored-by: ask-bonk[bot] <249159057+ask-bonk[bot]@users.noreply.github.com> Signed-off-by: Mark Phelps <209477+markphelps@users.noreply.github.com> --- pkg/doctor/runner.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/doctor/runner.go b/pkg/doctor/runner.go index fa01a9fd6c..388c52a75a 100644 --- a/pkg/doctor/runner.go +++ b/pkg/doctor/runner.go @@ -149,7 +149,7 @@ func parsePythonRef(ctx *CheckContext, ref string) { parser := sitter.NewParser() parser.SetLanguage(python.GetLanguage()) - tree, err := parser.ParseCtx(context.Background(), nil, source) + tree, err := parser.ParseCtx(ctx, nil, source) if err != nil { return } From af0d2ece814a74d523dd24d938d25f9c824cae14 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Mon, 13 Apr 2026 11:55:14 -0400 Subject: [PATCH 21/24] chore: fix context passing Signed-off-by: Mark Phelps --- pkg/doctor/doctor.go | 2 ++ pkg/doctor/runner.go | 39 ++++++++++++++++++++------------------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/pkg/doctor/doctor.go b/pkg/doctor/doctor.go index 15d0a6e695..76cce45792 100644 --- a/pkg/doctor/doctor.go +++ b/pkg/doctor/doctor.go @@ -1,6 +1,7 @@ package doctor import ( + "context" "errors" sitter "github.com/smacker/go-tree-sitter" @@ -73,6 +74,7 @@ type ParsedFile struct { // CheckContext provides checks with access to project state. // Built once by the runner and passed to every check. type CheckContext struct { + ctx context.Context ProjectDir string ConfigFilename string // Config filename (e.g. "cog.yaml") Config *config.Config // Parsed cog.yaml (nil if parsing failed) diff --git a/pkg/doctor/runner.go b/pkg/doctor/runner.go index 388c52a75a..4a9193c8cd 100644 --- a/pkg/doctor/runner.go +++ b/pkg/doctor/runner.go @@ -52,13 +52,13 @@ func (r *Result) HasErrors() bool { } // Run executes all checks and optionally applies fixes. -func Run(_ context.Context, opts RunOptions, checks []Check) (*Result, error) { +func Run(ctx context.Context, opts RunOptions, checks []Check) (*Result, error) { configFilename := opts.ConfigFilename if configFilename == "" { configFilename = "cog.yaml" } - checkCtx, err := buildCheckContext(opts.ProjectDir, configFilename) + checkCtx, err := buildCheckContext(ctx, opts.ProjectDir, configFilename) if err != nil { return nil, err } @@ -93,8 +93,9 @@ func Run(_ context.Context, opts RunOptions, checks []Check) (*Result, error) { } // buildCheckContext constructs the shared context for all checks. -func buildCheckContext(projectDir string, configFilename string) (*CheckContext, error) { - ctx := &CheckContext{ +func buildCheckContext(ctx context.Context, projectDir string, configFilename string) (*CheckContext, error) { + ctxt := &CheckContext{ + ctx: ctx, ProjectDir: projectDir, ConfigFilename: configFilename, PythonFiles: make(map[string]*ParsedFile), @@ -104,35 +105,35 @@ func buildCheckContext(projectDir string, configFilename string) (*CheckContext, configPath := filepath.Join(projectDir, configFilename) configBytes, err := os.ReadFile(configPath) if err == nil { - ctx.ConfigFile = configBytes - // Load and validate config once — checks use ctx.LoadResult / ctx.LoadErr + ctxt.ConfigFile = configBytes + // Load and validate config once — checks use ctxt.LoadResult / ctxt.LoadErr loadResult, loadErr := config.Load(bytes.NewReader(configBytes), projectDir) - ctx.LoadErr = loadErr + ctxt.LoadErr = loadErr if loadResult != nil { - ctx.LoadResult = loadResult - ctx.Config = loadResult.Config + ctxt.LoadResult = loadResult + ctxt.Config = loadResult.Config } } // Find python binary if pythonPath, err := exec.LookPath("python3"); err == nil { - ctx.PythonPath = pythonPath + ctxt.PythonPath = pythonPath } else if pythonPath, err := exec.LookPath("python"); err == nil { - ctx.PythonPath = pythonPath + ctxt.PythonPath = pythonPath } // Pre-parse Python files referenced in config - if ctx.Config != nil { - parsePythonRef(ctx, ctx.Config.Predict) - parsePythonRef(ctx, ctx.Config.Train) + if ctxt.Config != nil { + parsePythonRef(ctxt, ctxt.Config.Predict) + parsePythonRef(ctxt, ctxt.Config.Train) } - return ctx, nil + return ctxt, nil } // parsePythonRef parses a predict/train reference like "predict.py:Predictor" // and adds the parsed file to ctx.PythonFiles. -func parsePythonRef(ctx *CheckContext, ref string) { +func parsePythonRef(ctxt *CheckContext, ref string) { if ref == "" { return } @@ -141,7 +142,7 @@ func parsePythonRef(ctx *CheckContext, ref string) { return } - fullPath := filepath.Join(ctx.ProjectDir, parts[0]) + fullPath := filepath.Join(ctxt.ProjectDir, parts[0]) source, err := os.ReadFile(fullPath) if err != nil { return @@ -149,14 +150,14 @@ func parsePythonRef(ctx *CheckContext, ref string) { parser := sitter.NewParser() parser.SetLanguage(python.GetLanguage()) - tree, err := parser.ParseCtx(ctx, nil, source) + tree, err := parser.ParseCtx(ctxt.ctx, nil, source) if err != nil { return } imports := schemaPython.CollectImports(tree.RootNode(), source) - ctx.PythonFiles[parts[0]] = &ParsedFile{ + ctxt.PythonFiles[parts[0]] = &ParsedFile{ Path: parts[0], Source: source, Tree: tree, From 5b5008d051847f2401c3727b98eadf7331af714d Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Mon, 13 Apr 2026 13:33:48 -0400 Subject: [PATCH 22/24] fix: remove deprecated name references and orphaned imports in doctor --fix The deprecated imports fix was only removing the import line itself (e.g. 'from cog.types import ExperimentalFeatureWarning') but leaving behind statements that reference the removed symbol (e.g. 'warnings.filterwarnings("ignore", category=ExperimentalFeatureWarning)'). Use tree-sitter AST traversal to: 1. Find and remove statements referencing deprecated names by walking identifier nodes recursively 2. Detect and remove orphaned 'import X' statements where the module is no longer referenced anywhere in the file --- pkg/doctor/check_python_deprecated_imports.go | 233 ++++++++++++++++-- pkg/doctor/check_python_test.go | 76 ++++++ 2 files changed, 285 insertions(+), 24 deletions(-) diff --git a/pkg/doctor/check_python_deprecated_imports.go b/pkg/doctor/check_python_deprecated_imports.go index e5b66357d6..dc6eaf8a9b 100644 --- a/pkg/doctor/check_python_deprecated_imports.go +++ b/pkg/doctor/check_python_deprecated_imports.go @@ -1,12 +1,15 @@ package doctor import ( + "context" "fmt" "os" "path/filepath" + "sort" "strings" sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" schemaPython "github.com/replicate/cog/pkg/schema/python" ) @@ -91,43 +94,225 @@ func (c *DeprecatedImportsCheck) Fix(ctx *CheckContext, findings []Finding) erro return fmt.Errorf("reading %s: %w", relPath, err) } - // Re-scan the file using tree-sitter to find deprecated imports directly, - // rather than relying on fragile string matching against finding messages. pf, ok := ctx.PythonFiles[relPath] if !ok { continue } - namesToRemove := make(map[string]map[string]bool) // module -> set of names - root := pf.Tree.RootNode() - for _, child := range schemaPython.NamedChildren(root) { - if child.Type() != "import_from_statement" { - continue - } - moduleNode := child.ChildByFieldName("module_name") - if moduleNode == nil { - continue - } - module := schemaPython.Content(moduleNode, pf.Source) - for _, name := range extractImportedNames(child, pf.Source) { - for _, dep := range deprecatedImportsList { - if module == dep.Module && name == dep.Name { - if namesToRemove[dep.Module] == nil { - namesToRemove[dep.Module] = make(map[string]bool) - } - namesToRemove[dep.Module][dep.Name] = true + + fixed := removeDeprecatedImportsAST(source, pf.Tree) + + if err := os.WriteFile(fullPath, []byte(fixed), info.Mode()); err != nil { + return fmt.Errorf("writing %s: %w", relPath, err) + } + } + + return nil +} + +// byteRange represents a range of bytes to remove from source, corresponding +// to a full line (including its trailing newline). +type byteRange struct { + start uint32 + end uint32 +} + +// removeDeprecatedImportsAST uses tree-sitter to identify and remove: +// 1. import_from_statement nodes that import deprecated names +// 2. expression_statement nodes that reference those deprecated names +// 3. orphaned "import X" statements where X is no longer used +func removeDeprecatedImportsAST(source []byte, tree *sitter.Tree) string { + root := tree.RootNode() + + // Step 1: Walk the AST to find which deprecated names are present in this file. + deprecatedNames := make(map[string]bool) + namesToRemove := make(map[string]map[string]bool) // module -> set of names + + for _, child := range schemaPython.NamedChildren(root) { + if child.Type() != "import_from_statement" { + continue + } + moduleNode := child.ChildByFieldName("module_name") + if moduleNode == nil { + continue + } + module := schemaPython.Content(moduleNode, source) + + for _, name := range extractImportedNames(child, source) { + for _, dep := range deprecatedImportsList { + if module == dep.Module && name == dep.Name { + deprecatedNames[dep.Name] = true + if namesToRemove[dep.Module] == nil { + namesToRemove[dep.Module] = make(map[string]bool) } + namesToRemove[dep.Module][dep.Name] = true } } } + } - fixed := removeDeprecatedImportLines(string(source), namesToRemove) + if len(deprecatedNames) == 0 { + return string(source) + } - if err := os.WriteFile(fullPath, []byte(fixed), info.Mode()); err != nil { - return fmt.Errorf("writing %s: %w", relPath, err) + // Step 2: Remove deprecated import lines/names (handles partial imports). + fixed := removeDeprecatedImportLines(string(source), namesToRemove) + + // Step 3: Re-parse and use tree-sitter to find statements referencing + // the deprecated names, then remove them by byte range. + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + newTree, err := parser.ParseCtx(context.Background(), nil, []byte(fixed)) + if err != nil { + return fixed + } + + newSource := []byte(fixed) + newRoot := newTree.RootNode() + var removals []byteRange + for _, child := range schemaPython.NamedChildren(newRoot) { + if child.Type() == "import_from_statement" || child.Type() == "import_statement" { + continue + } + if nodeReferencesAny(child, newSource, deprecatedNames) { + removals = append(removals, nodeLineRange(child, newSource)) } } - return nil + fixed = applyRemovals(newSource, removals) + + // Step 4: Remove orphaned "import X" statements via AST. + fixed = removeOrphanedImportsAST(fixed) + + return fixed +} + +// nodeReferencesAny walks a tree-sitter node recursively and returns true if +// any identifier node matches one of the given names. +func nodeReferencesAny(node *sitter.Node, source []byte, names map[string]bool) bool { + if node.Type() == "identifier" { + return names[schemaPython.Content(node, source)] + } + for _, child := range schemaPython.AllChildren(node) { + if nodeReferencesAny(child, source, names) { + return true + } + } + return false +} + +// nodeLineRange returns a byte range covering the full line(s) of a node, +// including the trailing newline. +func nodeLineRange(node *sitter.Node, source []byte) byteRange { + start := node.StartByte() + end := node.EndByte() + + // Extend start to beginning of line + for start > 0 && source[start-1] != '\n' { + start-- + } + // Extend end past trailing newline + if int(end) < len(source) && source[end] == '\n' { + end++ + } + + return byteRange{start: start, end: end} +} + +// applyRemovals removes all byte ranges from source, handling overlaps. +// Ranges are sorted descending by start so earlier indices remain valid. +func applyRemovals(source []byte, ranges []byteRange) string { + if len(ranges) == 0 { + return string(source) + } + + // Sort descending by start so we can remove from back to front + sort.Slice(ranges, func(i, j int) bool { + return ranges[i].start > ranges[j].start + }) + + result := make([]byte, len(source)) + copy(result, source) + + for _, r := range ranges { + if int(r.start) >= len(result) { + continue + } + end := min(int(r.end), len(result)) + result = append(result[:r.start], result[end:]...) + } + + return string(result) +} + +// removeOrphanedImportsAST re-parses source and removes "import X" statements +// where X is no longer referenced anywhere else in the file. +func removeOrphanedImportsAST(source string) string { + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + tree, err := parser.ParseCtx(context.Background(), nil, []byte(source)) + if err != nil { + return source + } + + src := []byte(source) + root := tree.RootNode() + var removals []byteRange + + for _, child := range schemaPython.NamedChildren(root) { + if child.Type() != "import_statement" { + continue + } + + // Get the module name being imported (e.g. "warnings" from "import warnings") + var moduleName string + for _, c := range schemaPython.NamedChildren(child) { + if c.Type() == "dotted_name" { + moduleName = schemaPython.Content(c, src) + break + } + } + if moduleName == "" { + continue + } + + // Check if this module is referenced elsewhere (outside import statements) + used := false + for _, stmt := range schemaPython.NamedChildren(root) { + if stmt.Type() == "import_statement" || stmt.Type() == "import_from_statement" { + continue + } + if nodeReferencesModule(stmt, src, moduleName) { + used = true + break + } + } + + if !used { + removals = append(removals, nodeLineRange(child, src)) + } + } + + return applyRemovals(src, removals) +} + +// nodeReferencesModule checks if a node contains an attribute access on the +// given module (e.g. "warnings.filterwarnings") or a bare identifier matching it. +func nodeReferencesModule(node *sitter.Node, source []byte, moduleName string) bool { + if node.Type() == "attribute" { + obj := node.ChildByFieldName("object") + if obj != nil && obj.Type() == "identifier" && schemaPython.Content(obj, source) == moduleName { + return true + } + } + if node.Type() == "identifier" && schemaPython.Content(node, source) == moduleName { + return true + } + for _, child := range schemaPython.AllChildren(node) { + if nodeReferencesModule(child, source, moduleName) { + return true + } + } + return false } // extractImportedNames returns the names imported in a "from X import a, b, c" statement. diff --git a/pkg/doctor/check_python_test.go b/pkg/doctor/check_python_test.go index 04a1106493..f507b0e4e8 100644 --- a/pkg/doctor/check_python_test.go +++ b/pkg/doctor/check_python_test.go @@ -336,6 +336,82 @@ class Predictor(BasePredictor): require.Empty(t, findings) } +func TestDeprecatedImportsCheck_Fix_WithWarningsFilterwarnings(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `import warnings +from cog import BasePredictor +from cog.types import ExperimentalFeatureWarning + +warnings.filterwarnings("ignore", category=ExperimentalFeatureWarning) + + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return "hello " + text +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &DeprecatedImportsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.NotContains(t, string(content), "ExperimentalFeatureWarning") + require.NotContains(t, string(content), "cog.types") + require.NotContains(t, string(content), "import warnings") + + // Re-parse and verify clean + ctx.PythonFiles = parsePythonFiles(t, dir, "predict.py") + findings, err = check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestDeprecatedImportsCheck_Fix_WarningsImportPreserved(t *testing.T) { + // When warnings module is still used elsewhere, import should be preserved + dir := t.TempDir() + writeFile(t, dir, "predict.py", `import warnings +from cog import BasePredictor +from cog.types import ExperimentalFeatureWarning + +warnings.filterwarnings("ignore", category=ExperimentalFeatureWarning) +warnings.warn("something else") + + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return "hello " + text +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &DeprecatedImportsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.NotContains(t, string(content), "ExperimentalFeatureWarning") + require.Contains(t, string(content), "import warnings") + require.Contains(t, string(content), "warnings.warn") +} + func TestMissingTypeAnnotationsCheck_Clean(t *testing.T) { dir := t.TempDir() writeFile(t, dir, "predict.py", `from cog import BasePredictor From 2401425111dcccf93d558d427656216c46a9816b Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Mon, 13 Apr 2026 14:01:17 -0400 Subject: [PATCH 23/24] chore: pass context Signed-off-by: Mark Phelps --- pkg/doctor/check_python_deprecated_imports.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pkg/doctor/check_python_deprecated_imports.go b/pkg/doctor/check_python_deprecated_imports.go index dc6eaf8a9b..a3f1afc3a6 100644 --- a/pkg/doctor/check_python_deprecated_imports.go +++ b/pkg/doctor/check_python_deprecated_imports.go @@ -99,7 +99,7 @@ func (c *DeprecatedImportsCheck) Fix(ctx *CheckContext, findings []Finding) erro continue } - fixed := removeDeprecatedImportsAST(source, pf.Tree) + fixed := removeDeprecatedImportsAST(ctx.ctx, source, pf.Tree) if err := os.WriteFile(fullPath, []byte(fixed), info.Mode()); err != nil { return fmt.Errorf("writing %s: %w", relPath, err) @@ -120,7 +120,7 @@ type byteRange struct { // 1. import_from_statement nodes that import deprecated names // 2. expression_statement nodes that reference those deprecated names // 3. orphaned "import X" statements where X is no longer used -func removeDeprecatedImportsAST(source []byte, tree *sitter.Tree) string { +func removeDeprecatedImportsAST(ctx context.Context, source []byte, tree *sitter.Tree) string { root := tree.RootNode() // Step 1: Walk the AST to find which deprecated names are present in this file. @@ -161,7 +161,7 @@ func removeDeprecatedImportsAST(source []byte, tree *sitter.Tree) string { // the deprecated names, then remove them by byte range. parser := sitter.NewParser() parser.SetLanguage(python.GetLanguage()) - newTree, err := parser.ParseCtx(context.Background(), nil, []byte(fixed)) + newTree, err := parser.ParseCtx(ctx, nil, []byte(fixed)) if err != nil { return fixed } @@ -181,7 +181,7 @@ func removeDeprecatedImportsAST(source []byte, tree *sitter.Tree) string { fixed = applyRemovals(newSource, removals) // Step 4: Remove orphaned "import X" statements via AST. - fixed = removeOrphanedImportsAST(fixed) + fixed = removeOrphanedImportsAST(ctx, fixed) return fixed } @@ -246,10 +246,10 @@ func applyRemovals(source []byte, ranges []byteRange) string { // removeOrphanedImportsAST re-parses source and removes "import X" statements // where X is no longer referenced anywhere else in the file. -func removeOrphanedImportsAST(source string) string { +func removeOrphanedImportsAST(ctx context.Context, source string) string { parser := sitter.NewParser() parser.SetLanguage(python.GetLanguage()) - tree, err := parser.ParseCtx(context.Background(), nil, []byte(source)) + tree, err := parser.ParseCtx(ctx, nil, []byte(source)) if err != nil { return source } From ea9766882d4abd572cac89212fec129ccf953707 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Mon, 13 Apr 2026 14:10:53 -0400 Subject: [PATCH 24/24] fix: pass context.Background() in deprecated imports fix tests The CheckContext.ctx field was nil in tests that exercise DeprecatedImportsCheck.Fix(), causing a nil pointer dereference when tree-sitter's ParseCtx was called with a nil context. --- pkg/doctor/check_python_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkg/doctor/check_python_test.go b/pkg/doctor/check_python_test.go index f507b0e4e8..b8f92e8388 100644 --- a/pkg/doctor/check_python_test.go +++ b/pkg/doctor/check_python_test.go @@ -313,6 +313,7 @@ class Predictor(BasePredictor): `) ctx := &CheckContext{ + ctx: context.Background(), ProjectDir: dir, PythonFiles: parsePythonFiles(t, dir, "predict.py"), } @@ -351,6 +352,7 @@ class Predictor(BasePredictor): `) ctx := &CheckContext{ + ctx: context.Background(), ProjectDir: dir, PythonFiles: parsePythonFiles(t, dir, "predict.py"), } @@ -393,6 +395,7 @@ class Predictor(BasePredictor): `) ctx := &CheckContext{ + ctx: context.Background(), ProjectDir: dir, PythonFiles: parsePythonFiles(t, dir, "predict.py"), }