diff --git a/docs/cli.md b/docs/cli.md index 04446b2e29..13869bde7e 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -66,6 +66,24 @@ cog build [flags] --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` +## `cog doctor` + +Diagnose and fix common issues in your Cog project. + +By default, cog doctor reports problems without modifying any files. +Pass --fix to automatically apply safe fixes. + +``` +cog doctor [flags] +``` + +**Options** + +``` + -f, --file string The name of the config file. (default "cog.yaml") + --fix Automatically apply fixes + -h, --help help for doctor +``` ## `cog exec` Execute a command inside a Docker environment defined by cog.yaml. diff --git a/docs/llms.txt b/docs/llms.txt index 73975c6beb..01a0d7d42c 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -262,6 +262,24 @@ cog build [flags] --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` +## `cog doctor` + +Diagnose and fix common issues in your Cog project. + +By default, cog doctor reports problems without modifying any files. +Pass --fix to automatically apply safe fixes. + +``` +cog doctor [flags] +``` + +**Options** + +``` + -f, --file string The name of the config file. (default "cog.yaml") + --fix Automatically apply fixes + -h, --help help for doctor +``` ## `cog exec` Execute a command inside a Docker environment defined by cog.yaml. diff --git a/integration-tests/tests/doctor_clean_project.txtar b/integration-tests/tests/doctor_clean_project.txtar new file mode 100644 index 0000000000..76479483fe --- /dev/null +++ b/integration-tests/tests/doctor_clean_project.txtar @@ -0,0 +1,17 @@ +# Test that cog doctor succeeds on a valid project with no issues. + +cog doctor +stderr 'no issues found' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return "hello " + text diff --git a/integration-tests/tests/doctor_deprecated_fields.txtar b/integration-tests/tests/doctor_deprecated_fields.txtar new file mode 100644 index 0000000000..7a6c384880 --- /dev/null +++ b/integration-tests/tests/doctor_deprecated_fields.txtar @@ -0,0 +1,21 @@ +# Test that cog doctor reports deprecated config fields as warnings. +# Warnings do not cause a non-zero exit code; only errors do. + +cog doctor +stderr 'python_packages' +stderr 'deprecated' + +-- cog.yaml -- +build: + python_version: "3.12" + python_packages: + - torch==2.0.0 +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return "hello " + text diff --git a/integration-tests/tests/doctor_deprecated_imports.txtar b/integration-tests/tests/doctor_deprecated_imports.txtar new file mode 100644 index 0000000000..ac07785290 --- /dev/null +++ b/integration-tests/tests/doctor_deprecated_imports.txtar @@ -0,0 +1,23 @@ +# Test that cog doctor detects deprecated imports. +# ExperimentalFeatureWarning was removed in cog 0.17. + +! cog doctor +stderr 'ExperimentalFeatureWarning' +stderr 'removed' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +import warnings +from cog import BasePredictor +from cog.types import ExperimentalFeatureWarning + +warnings.filterwarnings("ignore", category=ExperimentalFeatureWarning) + + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return "hello " + text diff --git a/integration-tests/tests/doctor_exit_code.txtar b/integration-tests/tests/doctor_exit_code.txtar new file mode 100644 index 0000000000..c294a6039a --- /dev/null +++ b/integration-tests/tests/doctor_exit_code.txtar @@ -0,0 +1,10 @@ +# Test that cog doctor exits with code 1 when predict file is missing. + +! cog doctor +stderr 'predict.py' +stderr 'not found' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" diff --git a/integration-tests/tests/doctor_fix_deprecated_imports.txtar b/integration-tests/tests/doctor_fix_deprecated_imports.txtar new file mode 100644 index 0000000000..4e652c7a80 --- /dev/null +++ b/integration-tests/tests/doctor_fix_deprecated_imports.txtar @@ -0,0 +1,35 @@ +# Test that cog doctor --fix removes deprecated imports. + +# First, doctor should detect the issue +! cog doctor +stderr 'ExperimentalFeatureWarning' + +# Fix the issue +cog doctor --fix +stderr 'Fixed' + +# Verify the import was removed from the file +exec cat predict.py +! stdout 'ExperimentalFeatureWarning' +! stdout 'cog.types' + +# Re-running doctor should now pass +cog doctor +stderr 'no issues found' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +import warnings +from cog import BasePredictor +from cog.types import ExperimentalFeatureWarning + +warnings.filterwarnings("ignore", category=ExperimentalFeatureWarning) + + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return "hello " + text diff --git a/integration-tests/tests/doctor_fix_pydantic.txtar b/integration-tests/tests/doctor_fix_pydantic.txtar new file mode 100644 index 0000000000..198d80bb21 --- /dev/null +++ b/integration-tests/tests/doctor_fix_pydantic.txtar @@ -0,0 +1,41 @@ +# Test that cog doctor --fix rewrites pydantic.BaseModel to cog.BaseModel. +# After fix, re-running cog doctor should pass. + +# First, doctor should detect the issue +! cog doctor +stderr 'pydantic.BaseModel' + +# Fix the issue +cog doctor --fix +stderr 'Fixed' + +# Verify the file was modified: pydantic.BaseModel replaced with cog.BaseModel +exec cat predict.py +stdout 'from cog import BasePredictor, Path, BaseModel' +! stdout 'from pydantic import BaseModel' +! stdout 'ConfigDict' +! stdout 'arbitrary_types_allowed' + +# Re-running doctor should now pass +cog doctor +stderr 'no issues found' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor, Path +from pydantic import BaseModel, ConfigDict + + +class VoiceCloningOutputs(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + audio: Path + spectrogram: Path + + +class Predictor(BasePredictor): + def predict(self, text: str) -> VoiceCloningOutputs: + return VoiceCloningOutputs(audio="a.wav", spectrogram="s.png") diff --git a/integration-tests/tests/doctor_missing_predict_ref.txtar b/integration-tests/tests/doctor_missing_predict_ref.txtar new file mode 100644 index 0000000000..0aad732b24 --- /dev/null +++ b/integration-tests/tests/doctor_missing_predict_ref.txtar @@ -0,0 +1,19 @@ +# Test that cog doctor detects when predict ref points to a nonexistent class. +# cog.yaml references DoesNotExist but only Predictor exists in predict.py. + +! cog doctor +stderr 'DoesNotExist' +stderr 'not found' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:DoesNotExist" + +-- predict.py -- +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return "hello " + text diff --git a/integration-tests/tests/doctor_pydantic_basemodel.txtar b/integration-tests/tests/doctor_pydantic_basemodel.txtar new file mode 100644 index 0000000000..74ff03cf3c --- /dev/null +++ b/integration-tests/tests/doctor_pydantic_basemodel.txtar @@ -0,0 +1,26 @@ +# Test that cog doctor detects pydantic.BaseModel with arbitrary_types_allowed. +# This is a common workaround that should use cog.BaseModel instead. + +! cog doctor +stderr 'pydantic.BaseModel' +stderr 'cog.BaseModel' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor, Path +from pydantic import BaseModel, ConfigDict + + +class VoiceCloningOutputs(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + audio: Path + spectrogram: Path + + +class Predictor(BasePredictor): + def predict(self, text: str) -> VoiceCloningOutputs: + return VoiceCloningOutputs(audio="a.wav", spectrogram="s.png") diff --git a/pkg/cli/doctor.go b/pkg/cli/doctor.go new file mode 100644 index 0000000000..9191e6a4e3 --- /dev/null +++ b/pkg/cli/doctor.go @@ -0,0 +1,192 @@ +package cli + +import ( + "context" + "fmt" + "strings" + + "github.com/spf13/cobra" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/doctor" + "github.com/replicate/cog/pkg/util/console" +) + +func newDoctorCommand() *cobra.Command { + var fix bool + + cmd := &cobra.Command{ + Use: "doctor", + Short: "Check your project for common issues and fix them", + Long: `Diagnose and fix common issues in your Cog project. + +By default, cog doctor reports problems without modifying any files. +Pass --fix to automatically apply safe fixes.`, + RunE: func(cmd *cobra.Command, args []string) error { + return runDoctor(cmd.Context(), fix) + }, + Args: cobra.NoArgs, + } + + addConfigFlag(cmd) + cmd.Flags().BoolVar(&fix, "fix", false, "Automatically apply fixes") + + return cmd +} + +func runDoctor(ctx context.Context, fix bool) error { + projectDir, err := config.GetProjectDir(configFilename) + if err != nil { + return err + } + + if fix { + console.Infof("Running cog doctor --fix...") + } else { + console.Infof("Running cog doctor...") + } + console.Info("") + + result, err := doctor.Run(ctx, doctor.RunOptions{ + Fix: fix, + ProjectDir: projectDir, + ConfigFilename: configFilename, + }, doctor.AllChecks()) + if err != nil { + return err + } + + printDoctorResults(result, fix) + + if result.HasErrors() { + return fmt.Errorf("doctor found errors") + } + + return nil +} + +func printDoctorResults(result *doctor.Result, fix bool) { + var currentGroup doctor.Group + errorCount := 0 + warningCount := 0 + fixedCount := 0 + + for _, cr := range result.Results { + // Print group header when group changes + if cr.Check.Group() != currentGroup { + currentGroup = cr.Check.Group() + console.Infof("%s", string(currentGroup)) + } + + // Check errored internally + if cr.Err != nil { + console.Errorf("%s: %v", cr.Check.Description(), cr.Err) + errorCount++ + continue + } + + // No findings — passed + if len(cr.Findings) == 0 { + console.Successf("%s", cr.Check.Description()) + continue + } + + // Has findings + if cr.Fixed { + console.Successf("Fixed: %s", cr.Check.Description()) + fixedCount += len(cr.Findings) + } else { + // Determine worst severity for the check header + worstSeverity := cr.Findings[0].Severity + for _, f := range cr.Findings[1:] { + if f.Severity < worstSeverity { + worstSeverity = f.Severity + } + } + switch worstSeverity { + case doctor.SeverityError: + console.Errorf("%s", cr.Check.Description()) + case doctor.SeverityWarning: + console.Warnf("%s", cr.Check.Description()) + default: + console.Infof("%s", cr.Check.Description()) + } + + // Count per-finding for consistent totals + for _, f := range cr.Findings { + switch f.Severity { + case doctor.SeverityError: + errorCount++ + case doctor.SeverityWarning: + warningCount++ + } + } + } + + // Print individual findings + for _, f := range cr.Findings { + location := "" + if f.File != "" { + if f.Line > 0 { + location = fmt.Sprintf("%s:%d — ", f.File, f.Line) + } else { + location = fmt.Sprintf("%s — ", f.File) + } + } + console.Infof(" %s%s", location, f.Message) + + if fix && !cr.Fixed && f.Remediation != "" { + console.Infof(" (no auto-fix available)") + } + } + } + + console.Info("") + + // Summary line + switch { + case fix && fixedCount > 0: + msg := fmt.Sprintf("Fixed %d issue", fixedCount) + if fixedCount != 1 { + msg += "s" + } + if warningCount > 0 { + msg += fmt.Sprintf(". %d warning", warningCount) + if warningCount != 1 { + msg += "s" + } + msg += " remaining" + } + if errorCount > 0 { + msg += fmt.Sprintf(". %d unfixed error", errorCount) + if errorCount != 1 { + msg += "s" + } + } + console.Infof("%s.", msg) + case errorCount > 0 || warningCount > 0: + var parts []string + if errorCount > 0 { + s := fmt.Sprintf("%d error", errorCount) + if errorCount != 1 { + s += "s" + } + parts = append(parts, s) + } + if warningCount > 0 { + s := fmt.Sprintf("%d warning", warningCount) + if warningCount != 1 { + s += "s" + } + parts = append(parts, s) + } + summary := "Found " + strings.Join(parts, ", ") + "." + + if !fix && errorCount > 0 { + summary += ` Run "cog doctor --fix" to auto-fix.` + } + console.Infof("%s", summary) + default: + console.Successf("no issues found") + } +} diff --git a/pkg/cli/root.go b/pkg/cli/root.go index 9184411ec4..c6ccc4087d 100644 --- a/pkg/cli/root.go +++ b/pkg/cli/root.go @@ -45,6 +45,7 @@ https://github.com/replicate/cog`, rootCmd.AddCommand( newBuildCommand(), newDebugCommand(), + newDoctorCommand(), newInitCommand(), newInspectCommand(), newLoginCommand(), diff --git a/pkg/doctor/check_config_deprecated.go b/pkg/doctor/check_config_deprecated.go new file mode 100644 index 0000000000..7eeda77cac --- /dev/null +++ b/pkg/doctor/check_config_deprecated.go @@ -0,0 +1,38 @@ +package doctor + +import ( + "fmt" +) + +// ConfigDeprecatedFieldsCheck detects deprecated fields in cog.yaml. +type ConfigDeprecatedFieldsCheck struct{} + +func (c *ConfigDeprecatedFieldsCheck) Name() string { return "config-deprecated-fields" } +func (c *ConfigDeprecatedFieldsCheck) Group() Group { return GroupConfig } +func (c *ConfigDeprecatedFieldsCheck) Description() string { return "Deprecated fields" } + +func (c *ConfigDeprecatedFieldsCheck) Check(ctx *CheckContext) ([]Finding, error) { + // No config loaded successfully — other checks handle parse/validation errors. + // Note: warnings are only available when Load succeeds because they come from + // ValidateConfigFile which uses an unexported type. If the config has validation + // errors, deprecation warnings cannot be surfaced through Load. + if ctx.LoadResult == nil { + return nil, nil + } + + var findings []Finding + for _, w := range ctx.LoadResult.Warnings { + findings = append(findings, Finding{ + Severity: SeverityWarning, + Message: fmt.Sprintf("%q is deprecated: %s", w.Field, w.Message), + Remediation: fmt.Sprintf("Use %q instead", w.Replacement), + File: ctx.ConfigFilename, + }) + } + + return findings, nil +} + +func (c *ConfigDeprecatedFieldsCheck) Fix(_ *CheckContext, _ []Finding) error { + return ErrNoAutoFix +} diff --git a/pkg/doctor/check_config_parse.go b/pkg/doctor/check_config_parse.go new file mode 100644 index 0000000000..971ca9bd9e --- /dev/null +++ b/pkg/doctor/check_config_parse.go @@ -0,0 +1,52 @@ +package doctor + +import ( + "errors" + "fmt" + + "github.com/replicate/cog/pkg/config" +) + +// ConfigParseCheck verifies that cog.yaml exists and can be parsed as valid YAML. +type ConfigParseCheck struct{} + +func (c *ConfigParseCheck) Name() string { return "config-parse" } +func (c *ConfigParseCheck) Group() Group { return GroupConfig } +func (c *ConfigParseCheck) Description() string { return "Config parsing" } + +func (c *ConfigParseCheck) Check(ctx *CheckContext) ([]Finding, error) { + // Config file not found on disk + if ctx.ConfigFile == nil { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("%s not found", ctx.ConfigFilename), + Remediation: `Run "cog init" to create a cog.yaml`, + File: ctx.ConfigFilename, + }}, nil + } + + // Check for parse errors from the single Load call in buildCheckContext + if ctx.LoadErr != nil { + var parseErr *config.ParseError + if isParseError(ctx.LoadErr, &parseErr) { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("%s has invalid YAML: %v", ctx.ConfigFilename, ctx.LoadErr), + Remediation: fmt.Sprintf("Fix the YAML syntax in %s", ctx.ConfigFilename), + File: ctx.ConfigFilename, + }}, nil + } + // Other errors (validation, schema) are handled by other checks + } + + return nil, nil +} + +func (c *ConfigParseCheck) Fix(_ *CheckContext, _ []Finding) error { + return ErrNoAutoFix +} + +// isParseError checks if the error chain contains a ParseError. +func isParseError(err error, target **config.ParseError) bool { + return errors.As(err, target) +} diff --git a/pkg/doctor/check_config_predict_ref.go b/pkg/doctor/check_config_predict_ref.go new file mode 100644 index 0000000000..0ead3a9273 --- /dev/null +++ b/pkg/doctor/check_config_predict_ref.go @@ -0,0 +1,140 @@ +package doctor + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" + + schemaPython "github.com/replicate/cog/pkg/schema/python" +) + +// ConfigPredictRefCheck verifies that the predict field in cog.yaml +// points to a file and class that actually exist. +type ConfigPredictRefCheck struct{} + +func (c *ConfigPredictRefCheck) Name() string { return "config-predict-ref" } +func (c *ConfigPredictRefCheck) Group() Group { return GroupConfig } +func (c *ConfigPredictRefCheck) Description() string { return "Predict reference" } + +func (c *ConfigPredictRefCheck) Check(ctx *CheckContext) ([]Finding, error) { + // Get predict ref from config + predictRef := "" + if ctx.Config != nil { + predictRef = ctx.Config.Predict + } + if predictRef == "" { + return nil, nil // No predict field — nothing to check + } + + parts := splitPredictRef(predictRef) + pyFile := parts[0] + className := parts[1] + + if pyFile == "" || className == "" { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("predict reference %q must be in the form 'file.py:ClassName'", predictRef), + Remediation: `Set predict to "predict.py:Predictor" in cog.yaml`, + File: "cog.yaml", + }}, nil + } + + // Check file exists + fullPath := filepath.Join(ctx.ProjectDir, pyFile) + if _, err := os.Stat(fullPath); os.IsNotExist(err) { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("%s not found", pyFile), + Remediation: fmt.Sprintf("Create %s or update the predict field in cog.yaml", pyFile), + File: "cog.yaml", + }}, nil + } + + // Use cached parse tree if available, otherwise parse from disk + var rootNode *sitter.Node + var source []byte + + if pf, ok := ctx.PythonFiles[pyFile]; ok { + rootNode = pf.Tree.RootNode() + source = pf.Source + } else { + var err error + source, err = os.ReadFile(fullPath) + if err != nil { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("cannot read %s: %v", pyFile, err), + File: pyFile, + }}, nil + } + + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + tree, err := parser.ParseCtx(context.Background(), nil, source) + if err != nil { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("cannot parse %s: %v", pyFile, err), + File: pyFile, + }}, nil + } + rootNode = tree.RootNode() + } + + if !hasClassDefinition(rootNode, source, className) { + // List available classes to help the user + classes := listClassNames(rootNode, source) + msg := fmt.Sprintf("class %q not found in %s", className, pyFile) + if len(classes) > 0 { + msg += fmt.Sprintf("; found: %s", strings.Join(classes, ", ")) + } + return []Finding{{ + Severity: SeverityError, + Message: msg, + Remediation: fmt.Sprintf("Add class %s to %s or update the predict field in cog.yaml", className, pyFile), + File: pyFile, + }}, nil + } + + return nil, nil +} + +func (c *ConfigPredictRefCheck) Fix(_ *CheckContext, _ []Finding) error { + return ErrNoAutoFix +} + +// hasClassDefinition checks whether a class with the given name exists at module level. +func hasClassDefinition(root *sitter.Node, source []byte, name string) bool { + for _, child := range schemaPython.NamedChildren(root) { + classNode := schemaPython.UnwrapClass(child) + if classNode == nil { + continue + } + nameNode := classNode.ChildByFieldName("name") + if nameNode != nil && schemaPython.Content(nameNode, source) == name { + return true + } + } + return false +} + +// listClassNames returns the names of all top-level classes in the file. +func listClassNames(root *sitter.Node, source []byte) []string { + var names []string + for _, child := range schemaPython.NamedChildren(root) { + classNode := schemaPython.UnwrapClass(child) + if classNode == nil { + continue + } + nameNode := classNode.ChildByFieldName("name") + if nameNode != nil { + names = append(names, schemaPython.Content(nameNode, source)) + } + } + return names +} diff --git a/pkg/doctor/check_config_schema.go b/pkg/doctor/check_config_schema.go new file mode 100644 index 0000000000..8c5bb41a52 --- /dev/null +++ b/pkg/doctor/check_config_schema.go @@ -0,0 +1,46 @@ +package doctor + +import ( + "fmt" + + "github.com/replicate/cog/pkg/config" +) + +// ConfigSchemaCheck validates cog.yaml against the configuration schema. +// Parse errors are handled by ConfigParseCheck; this check catches schema +// and validation errors (wrong types, invalid values, etc.). +type ConfigSchemaCheck struct{} + +func (c *ConfigSchemaCheck) Name() string { return "config-schema" } +func (c *ConfigSchemaCheck) Group() Group { return GroupConfig } +func (c *ConfigSchemaCheck) Description() string { return "Config schema" } + +func (c *ConfigSchemaCheck) Check(ctx *CheckContext) ([]Finding, error) { + // No config file on disk — ConfigParseCheck handles this + if ctx.ConfigFile == nil { + return nil, nil + } + + // No load error means valid config + if ctx.LoadErr == nil { + return nil, nil + } + + // If this is a parse error, skip — ConfigParseCheck handles it + var parseErr *config.ParseError + if isParseError(ctx.LoadErr, &parseErr) { + return nil, nil + } + + // Any other error is a schema/validation error + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("%s validation failed: %v", ctx.ConfigFilename, ctx.LoadErr), + Remediation: fmt.Sprintf("Fix the configuration errors in %s", ctx.ConfigFilename), + File: ctx.ConfigFilename, + }}, nil +} + +func (c *ConfigSchemaCheck) Fix(_ *CheckContext, _ []Finding) error { + return ErrNoAutoFix +} diff --git a/pkg/doctor/check_config_test.go b/pkg/doctor/check_config_test.go new file mode 100644 index 0000000000..31a563626a --- /dev/null +++ b/pkg/doctor/check_config_test.go @@ -0,0 +1,260 @@ +package doctor + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/config" +) + +func TestConfigParseCheck_Valid(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigParseCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestConfigParseCheck_InvalidYAML(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: [invalid yaml`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigParseCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityError, findings[0].Severity) + require.Contains(t, findings[0].Message, "cog.yaml") +} + +func TestConfigParseCheck_MissingFile(t *testing.T) { + dir := t.TempDir() + + ctx := buildTestCheckContext(t, dir) + check := &ConfigParseCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Contains(t, findings[0].Message, "cog.yaml not found") +} + +func TestConfigDeprecatedFieldsCheck_Clean(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" + python_requirements: "requirements.txt" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "requirements.txt", "torch==2.0.0\n") + + ctx := buildTestCheckContext(t, dir) + check := &ConfigDeprecatedFieldsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestConfigDeprecatedFieldsCheck_PythonPackages(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" + python_packages: + - torch==2.0.0 +predict: "predict.py:Predictor" +`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigDeprecatedFieldsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityWarning, findings[0].Severity) + require.Contains(t, findings[0].Message, "python_packages") +} + +func TestConfigDeprecatedFieldsCheck_PreInstall(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" + pre_install: + - pip install something +predict: "predict.py:Predictor" +`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigDeprecatedFieldsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityWarning, findings[0].Severity) + require.Contains(t, findings[0].Message, "pre_install") +} + +func TestConfigPredictRefCheck_Valid(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigPredictRefCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestConfigPredictRefCheck_MissingFile(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigPredictRefCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityError, findings[0].Severity) + require.Contains(t, findings[0].Message, "predict.py") + require.Contains(t, findings[0].Message, "not found") +} + +func TestConfigPredictRefCheck_MissingClass(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:DoesNotExist" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigPredictRefCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityError, findings[0].Severity) + require.Contains(t, findings[0].Message, "DoesNotExist") +} + +func TestConfigPredictRefCheck_NoPredictField(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigPredictRefCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) // No predict field is valid (some projects are train-only) +} + +// buildTestCheckContext creates a CheckContext by loading the cog.yaml in the given dir. +func buildTestCheckContext(t *testing.T, dir string) *CheckContext { + t.Helper() + ctx := &CheckContext{ + ProjectDir: dir, + ConfigFilename: "cog.yaml", + PythonFiles: make(map[string]*ParsedFile), + } + + configPath := filepath.Join(dir, "cog.yaml") + configBytes, err := os.ReadFile(configPath) + if err == nil { + ctx.ConfigFile = configBytes + loadResult, loadErr := config.Load(bytes.NewReader(configBytes), dir) + ctx.LoadErr = loadErr + if loadResult != nil { + ctx.LoadResult = loadResult + ctx.Config = loadResult.Config + } + } + + return ctx +} + +func TestConfigSchemaCheck_Valid(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "3.12" +predict: "predict.py:Predictor" +`) + writeFile(t, dir, "predict.py", `from cog import BasePredictor +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigSchemaCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestConfigSchemaCheck_InvalidSchema(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: + python_version: "2.7" +predict: "predict.py:Predictor" +`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigSchemaCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityError, findings[0].Severity) + require.Contains(t, findings[0].Message, "validation failed") +} + +func TestConfigSchemaCheck_ParseError(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "cog.yaml", `build: [invalid yaml`) + + ctx := buildTestCheckContext(t, dir) + check := &ConfigSchemaCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) // Parse errors are handled by ConfigParseCheck +} + +func TestConfigSchemaCheck_MissingFile(t *testing.T) { + dir := t.TempDir() + + ctx := buildTestCheckContext(t, dir) + check := &ConfigSchemaCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) // Missing file handled by ConfigParseCheck +} + +// writeFile is a test helper to create fixture files. +func writeFile(t *testing.T, dir, name, content string) { + t.Helper() + fullPath := filepath.Join(dir, name) + require.NoError(t, os.MkdirAll(filepath.Dir(fullPath), 0o755)) + require.NoError(t, os.WriteFile(fullPath, []byte(content), 0o644)) +} diff --git a/pkg/doctor/check_env_docker.go b/pkg/doctor/check_env_docker.go new file mode 100644 index 0000000000..10580d7939 --- /dev/null +++ b/pkg/doctor/check_env_docker.go @@ -0,0 +1,34 @@ +package doctor + +import ( + "context" + "fmt" + "os/exec" + "time" +) + +// DockerCheck verifies that Docker is installed and the daemon is reachable. +type DockerCheck struct{} + +func (c *DockerCheck) Name() string { return "env-docker" } +func (c *DockerCheck) Group() Group { return GroupEnvironment } +func (c *DockerCheck) Description() string { return "Docker" } + +func (c *DockerCheck) Check(_ *CheckContext) ([]Finding, error) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := exec.CommandContext(ctx, "docker", "info").Run(); err != nil { + return []Finding{{ + Severity: SeverityError, + Message: fmt.Sprintf("Docker is not available: %v", err), + Remediation: "Install Docker (https://docs.docker.com/get-docker/) and ensure the daemon is running", + }}, nil + } + + return nil, nil +} + +func (c *DockerCheck) Fix(_ *CheckContext, _ []Finding) error { + return ErrNoAutoFix +} diff --git a/pkg/doctor/check_env_python_version.go b/pkg/doctor/check_env_python_version.go new file mode 100644 index 0000000000..a4f8241e75 --- /dev/null +++ b/pkg/doctor/check_env_python_version.go @@ -0,0 +1,73 @@ +package doctor + +import ( + "context" + "fmt" + "os/exec" + "strings" + "time" +) + +// PythonVersionCheck verifies that Python is available and that the local +// version is consistent with the version configured in cog.yaml. +type PythonVersionCheck struct{} + +func (c *PythonVersionCheck) Name() string { return "env-python-version" } +func (c *PythonVersionCheck) Group() Group { return GroupEnvironment } +func (c *PythonVersionCheck) Description() string { return "Python version" } + +func (c *PythonVersionCheck) Check(ctx *CheckContext) ([]Finding, error) { + if ctx.PythonPath == "" { + return []Finding{{ + Severity: SeverityWarning, + Message: "Python not found in PATH", + Remediation: "Install Python 3.10+ or ensure it is on your PATH", + }}, nil + } + + execCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + out, err := exec.CommandContext(execCtx, ctx.PythonPath, "--version").Output() + if err != nil { + return []Finding{{ + Severity: SeverityWarning, + Message: fmt.Sprintf("could not determine Python version: %v", err), + Remediation: "Ensure your Python installation is working correctly", + }}, nil + } + + // Output is "Python 3.12.1\n" + version := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(string(out)), "Python")) + localMajorMinor := majorMinor(version) + + // If cog.yaml specifies a Python version, compare + if ctx.Config != nil && ctx.Config.Build != nil && ctx.Config.Build.PythonVersion != "" { + configMajorMinor := majorMinor(ctx.Config.Build.PythonVersion) + if configMajorMinor != "" && localMajorMinor != "" && configMajorMinor != localMajorMinor { + return []Finding{{ + Severity: SeverityWarning, + Message: fmt.Sprintf( + "local Python is %s but cog.yaml specifies %s", + localMajorMinor, configMajorMinor, + ), + Remediation: "This is usually fine -- Docker builds use the configured version. Update cog.yaml or your local Python if needed.", + }}, nil + } + } + + return nil, nil +} + +func (c *PythonVersionCheck) Fix(_ *CheckContext, _ []Finding) error { + return ErrNoAutoFix +} + +// majorMinor extracts "3.12" from "3.12.1" or "3.12". +func majorMinor(version string) string { + parts := strings.SplitN(version, ".", 3) + if len(parts) < 2 { + return "" + } + return parts[0] + "." + parts[1] +} diff --git a/pkg/doctor/check_env_test.go b/pkg/doctor/check_env_test.go new file mode 100644 index 0000000000..b67d25a3ac --- /dev/null +++ b/pkg/doctor/check_env_test.go @@ -0,0 +1,90 @@ +package doctor + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/config" +) + +func TestDockerCheck_RunsWithoutError(t *testing.T) { + ctx := &CheckContext{ProjectDir: t.TempDir()} + check := &DockerCheck{} + // We don't assert findings — Docker may or may not be available in the test environment. + // We just verify the check doesn't panic or return an error. + _, err := check.Check(ctx) + require.NoError(t, err) +} + +func TestDockerCheck_FixReturnsNoAutoFix(t *testing.T) { + check := &DockerCheck{} + err := check.Fix(nil, nil) + require.ErrorIs(t, err, ErrNoAutoFix) +} + +func TestPythonVersionCheck_RunsWithoutError(t *testing.T) { + ctx := &CheckContext{ProjectDir: t.TempDir()} + check := &PythonVersionCheck{} + // Python may or may not be available; just ensure no panic or error. + _, err := check.Check(ctx) + require.NoError(t, err) +} + +func TestPythonVersionCheck_NoPython(t *testing.T) { + ctx := &CheckContext{ + ProjectDir: t.TempDir(), + PythonPath: "", // explicitly no Python + } + + check := &PythonVersionCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityWarning, findings[0].Severity) + require.Contains(t, findings[0].Message, "Python not found") +} + +func TestPythonVersionCheck_FixReturnsNoAutoFix(t *testing.T) { + check := &PythonVersionCheck{} + err := check.Fix(nil, nil) + require.ErrorIs(t, err, ErrNoAutoFix) +} + +func TestMajorMinor(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"3.12.1", "3.12"}, + {"3.12", "3.12"}, + {"3", ""}, + {"", ""}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + require.Equal(t, tt.want, majorMinor(tt.input)) + }) + } +} + +func TestPythonVersionCheck_VersionMismatch(t *testing.T) { + // This test verifies that when PythonPath is set but not a real binary, + // we get a warning. We can't easily fake the version output without a real binary. + ctx := &CheckContext{ + ProjectDir: t.TempDir(), + PythonPath: "/nonexistent/python3", + Config: &config.Config{ + Build: &config.Build{ + PythonVersion: "3.12", + }, + }, + } + + check := &PythonVersionCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityWarning, findings[0].Severity) + require.Contains(t, findings[0].Message, "could not determine Python version") +} diff --git a/pkg/doctor/check_python_deprecated_imports.go b/pkg/doctor/check_python_deprecated_imports.go new file mode 100644 index 0000000000..a3f1afc3a6 --- /dev/null +++ b/pkg/doctor/check_python_deprecated_imports.go @@ -0,0 +1,465 @@ +package doctor + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" + + schemaPython "github.com/replicate/cog/pkg/schema/python" +) + +// deprecatedImport describes an import that was removed or moved. +type deprecatedImport struct { + Module string // e.g., "cog.types" + Name string // e.g., "ExperimentalFeatureWarning" + Message string // e.g., "removed in cog 0.17" +} + +// deprecatedImportsList is the list of known deprecated imports. +var deprecatedImportsList = []deprecatedImport{ + { + Module: "cog.types", + Name: "ExperimentalFeatureWarning", + Message: "ExperimentalFeatureWarning was removed in cog 0.17; current_scope().record_metric() is no longer experimental", + }, +} + +// DeprecatedImportsCheck detects imports that were removed or moved in recent cog versions. +type DeprecatedImportsCheck struct{} + +func (c *DeprecatedImportsCheck) Name() string { return "python-deprecated-imports" } +func (c *DeprecatedImportsCheck) Group() Group { return GroupPython } +func (c *DeprecatedImportsCheck) Description() string { return "Deprecated imports" } + +func (c *DeprecatedImportsCheck) Check(ctx *CheckContext) ([]Finding, error) { + var findings []Finding + + for _, pf := range ctx.PythonFiles { + root := pf.Tree.RootNode() + + for _, child := range schemaPython.NamedChildren(root) { + if child.Type() != "import_from_statement" { + continue + } + + moduleNode := child.ChildByFieldName("module_name") + if moduleNode == nil { + continue + } + module := schemaPython.Content(moduleNode, pf.Source) + + // Check each imported name against the deprecated list + for _, name := range extractImportedNames(child, pf.Source) { + for _, dep := range deprecatedImportsList { + if module == dep.Module && name == dep.Name { + line := int(child.StartPoint().Row) + 1 + findings = append(findings, Finding{ + Severity: SeverityError, + Message: dep.Message, + Remediation: fmt.Sprintf("Remove the import of %s from %s", dep.Name, dep.Module), + File: pf.Path, + Line: line, + }) + } + } + } + } + } + + return findings, nil +} + +func (c *DeprecatedImportsCheck) Fix(ctx *CheckContext, findings []Finding) error { + // Group findings by file + affectedFiles := make(map[string]bool) + for _, f := range findings { + affectedFiles[f.File] = true + } + + for relPath := range affectedFiles { + fullPath := filepath.Join(ctx.ProjectDir, relPath) + info, err := os.Stat(fullPath) + if err != nil { + return fmt.Errorf("stat %s: %w", relPath, err) + } + + source, err := os.ReadFile(fullPath) + if err != nil { + return fmt.Errorf("reading %s: %w", relPath, err) + } + + pf, ok := ctx.PythonFiles[relPath] + if !ok { + continue + } + + fixed := removeDeprecatedImportsAST(ctx.ctx, source, pf.Tree) + + if err := os.WriteFile(fullPath, []byte(fixed), info.Mode()); err != nil { + return fmt.Errorf("writing %s: %w", relPath, err) + } + } + + return nil +} + +// byteRange represents a range of bytes to remove from source, corresponding +// to a full line (including its trailing newline). +type byteRange struct { + start uint32 + end uint32 +} + +// removeDeprecatedImportsAST uses tree-sitter to identify and remove: +// 1. import_from_statement nodes that import deprecated names +// 2. expression_statement nodes that reference those deprecated names +// 3. orphaned "import X" statements where X is no longer used +func removeDeprecatedImportsAST(ctx context.Context, source []byte, tree *sitter.Tree) string { + root := tree.RootNode() + + // Step 1: Walk the AST to find which deprecated names are present in this file. + deprecatedNames := make(map[string]bool) + namesToRemove := make(map[string]map[string]bool) // module -> set of names + + for _, child := range schemaPython.NamedChildren(root) { + if child.Type() != "import_from_statement" { + continue + } + moduleNode := child.ChildByFieldName("module_name") + if moduleNode == nil { + continue + } + module := schemaPython.Content(moduleNode, source) + + for _, name := range extractImportedNames(child, source) { + for _, dep := range deprecatedImportsList { + if module == dep.Module && name == dep.Name { + deprecatedNames[dep.Name] = true + if namesToRemove[dep.Module] == nil { + namesToRemove[dep.Module] = make(map[string]bool) + } + namesToRemove[dep.Module][dep.Name] = true + } + } + } + } + + if len(deprecatedNames) == 0 { + return string(source) + } + + // Step 2: Remove deprecated import lines/names (handles partial imports). + fixed := removeDeprecatedImportLines(string(source), namesToRemove) + + // Step 3: Re-parse and use tree-sitter to find statements referencing + // the deprecated names, then remove them by byte range. + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + newTree, err := parser.ParseCtx(ctx, nil, []byte(fixed)) + if err != nil { + return fixed + } + + newSource := []byte(fixed) + newRoot := newTree.RootNode() + var removals []byteRange + for _, child := range schemaPython.NamedChildren(newRoot) { + if child.Type() == "import_from_statement" || child.Type() == "import_statement" { + continue + } + if nodeReferencesAny(child, newSource, deprecatedNames) { + removals = append(removals, nodeLineRange(child, newSource)) + } + } + + fixed = applyRemovals(newSource, removals) + + // Step 4: Remove orphaned "import X" statements via AST. + fixed = removeOrphanedImportsAST(ctx, fixed) + + return fixed +} + +// nodeReferencesAny walks a tree-sitter node recursively and returns true if +// any identifier node matches one of the given names. +func nodeReferencesAny(node *sitter.Node, source []byte, names map[string]bool) bool { + if node.Type() == "identifier" { + return names[schemaPython.Content(node, source)] + } + for _, child := range schemaPython.AllChildren(node) { + if nodeReferencesAny(child, source, names) { + return true + } + } + return false +} + +// nodeLineRange returns a byte range covering the full line(s) of a node, +// including the trailing newline. +func nodeLineRange(node *sitter.Node, source []byte) byteRange { + start := node.StartByte() + end := node.EndByte() + + // Extend start to beginning of line + for start > 0 && source[start-1] != '\n' { + start-- + } + // Extend end past trailing newline + if int(end) < len(source) && source[end] == '\n' { + end++ + } + + return byteRange{start: start, end: end} +} + +// applyRemovals removes all byte ranges from source, handling overlaps. +// Ranges are sorted descending by start so earlier indices remain valid. +func applyRemovals(source []byte, ranges []byteRange) string { + if len(ranges) == 0 { + return string(source) + } + + // Sort descending by start so we can remove from back to front + sort.Slice(ranges, func(i, j int) bool { + return ranges[i].start > ranges[j].start + }) + + result := make([]byte, len(source)) + copy(result, source) + + for _, r := range ranges { + if int(r.start) >= len(result) { + continue + } + end := min(int(r.end), len(result)) + result = append(result[:r.start], result[end:]...) + } + + return string(result) +} + +// removeOrphanedImportsAST re-parses source and removes "import X" statements +// where X is no longer referenced anywhere else in the file. +func removeOrphanedImportsAST(ctx context.Context, source string) string { + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + tree, err := parser.ParseCtx(ctx, nil, []byte(source)) + if err != nil { + return source + } + + src := []byte(source) + root := tree.RootNode() + var removals []byteRange + + for _, child := range schemaPython.NamedChildren(root) { + if child.Type() != "import_statement" { + continue + } + + // Get the module name being imported (e.g. "warnings" from "import warnings") + var moduleName string + for _, c := range schemaPython.NamedChildren(child) { + if c.Type() == "dotted_name" { + moduleName = schemaPython.Content(c, src) + break + } + } + if moduleName == "" { + continue + } + + // Check if this module is referenced elsewhere (outside import statements) + used := false + for _, stmt := range schemaPython.NamedChildren(root) { + if stmt.Type() == "import_statement" || stmt.Type() == "import_from_statement" { + continue + } + if nodeReferencesModule(stmt, src, moduleName) { + used = true + break + } + } + + if !used { + removals = append(removals, nodeLineRange(child, src)) + } + } + + return applyRemovals(src, removals) +} + +// nodeReferencesModule checks if a node contains an attribute access on the +// given module (e.g. "warnings.filterwarnings") or a bare identifier matching it. +func nodeReferencesModule(node *sitter.Node, source []byte, moduleName string) bool { + if node.Type() == "attribute" { + obj := node.ChildByFieldName("object") + if obj != nil && obj.Type() == "identifier" && schemaPython.Content(obj, source) == moduleName { + return true + } + } + if node.Type() == "identifier" && schemaPython.Content(node, source) == moduleName { + return true + } + for _, child := range schemaPython.AllChildren(node) { + if nodeReferencesModule(child, source, moduleName) { + return true + } + } + return false +} + +// extractImportedNames returns the names imported in a "from X import a, b, c" statement. +func extractImportedNames(importNode *sitter.Node, source []byte) []string { + moduleNode := importNode.ChildByFieldName("module_name") + var names []string + + for _, child := range schemaPython.AllChildren(importNode) { + switch child.Type() { + case "dotted_name": + if moduleNode != nil && child.StartByte() != moduleNode.StartByte() { + names = append(names, schemaPython.Content(child, source)) + } + case "aliased_import": + if origNode := child.ChildByFieldName("name"); origNode != nil { + names = append(names, schemaPython.Content(origNode, source)) + } + case "import_list": + for _, ic := range schemaPython.AllChildren(child) { + switch ic.Type() { + case "dotted_name": + names = append(names, schemaPython.Content(ic, source)) + case "aliased_import": + if origNode := ic.ChildByFieldName("name"); origNode != nil { + names = append(names, schemaPython.Content(origNode, source)) + } + } + } + } + } + + return names +} + +// removeDeprecatedImportLines removes specific names from import lines. +// If all names are removed, the entire import line is dropped. +// Handles both single-line and multi-line parenthesized imports. +func removeDeprecatedImportLines(source string, namesToRemove map[string]map[string]bool) string { + lines := strings.Split(source, "\n") + var result []string + + // Track multi-line import state + inMultilineImport := false + var multilineModule string + var multilineNames []string + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + + // Handle multi-line imports + if inMultilineImport { + // Collect names from continuation lines + cleaned := strings.TrimSpace(strings.TrimSuffix(strings.TrimSpace(trimmed), ")")) + if cleaned != "" { + for n := range strings.SplitSeq(cleaned, ",") { + n = strings.TrimSpace(n) + if n != "" { + multilineNames = append(multilineNames, n) + } + } + } + if strings.Contains(trimmed, ")") { + inMultilineImport = false + // Now filter the collected names + names := namesToRemove[multilineModule] + var remaining []string + for _, name := range multilineNames { + if !names[name] { + remaining = append(remaining, name) + } + } + if len(remaining) > 0 { + result = append(result, "from "+multilineModule+" import "+strings.Join(remaining, ", ")) + } + } + continue + } + + removed := false + for module, names := range namesToRemove { + prefix := "from " + module + " import " + if !strings.HasPrefix(trimmed, prefix) { + continue + } + + importPart := trimmed[len(prefix):] + + // Check for multi-line parenthesized import + if strings.HasPrefix(strings.TrimSpace(importPart), "(") { + inner := strings.TrimSpace(importPart)[1:] // strip leading "(" + if strings.Contains(inner, ")") { + // Single-line parenthesized: from X import (A, B) + inner = strings.TrimSuffix(strings.TrimSpace(inner), ")") + importNames := strings.Split(inner, ",") + var remaining []string + for _, name := range importNames { + name = strings.TrimSpace(name) + if name != "" && !names[name] { + remaining = append(remaining, name) + } + } + if len(remaining) > 0 { + result = append(result, prefix+strings.Join(remaining, ", ")) + } + removed = true + } else { + // Start of multi-line import + inMultilineImport = true + multilineModule = module + multilineNames = nil + // Collect any names on the first line after "(" + if inner != "" { + for n := range strings.SplitSeq(inner, ",") { + n = strings.TrimSpace(n) + if n != "" { + multilineNames = append(multilineNames, n) + } + } + } + removed = true + } + break + } + + importNames := strings.Split(importPart, ",") + + var remaining []string + for _, name := range importNames { + name = strings.TrimSpace(name) + if !names[name] { + remaining = append(remaining, name) + } + } + + if len(remaining) == 0 { + removed = true + } else { + result = append(result, prefix+strings.Join(remaining, ", ")) + removed = true + } + break + } + + if !removed { + result = append(result, line) + } + } + + return strings.Join(result, "\n") +} diff --git a/pkg/doctor/check_python_pydantic_basemodel.go b/pkg/doctor/check_python_pydantic_basemodel.go new file mode 100644 index 0000000000..4d7908f1db --- /dev/null +++ b/pkg/doctor/check_python_pydantic_basemodel.go @@ -0,0 +1,347 @@ +package doctor + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + sitter "github.com/smacker/go-tree-sitter" + + "github.com/replicate/cog/pkg/schema" + schemaPython "github.com/replicate/cog/pkg/schema/python" +) + +// PydanticBaseModelCheck detects output classes that inherit from pydantic.BaseModel +// with arbitrary_types_allowed=True instead of using cog.BaseModel. +type PydanticBaseModelCheck struct{} + +func (c *PydanticBaseModelCheck) Name() string { return "python-pydantic-basemodel" } +func (c *PydanticBaseModelCheck) Group() Group { return GroupPython } +func (c *PydanticBaseModelCheck) Description() string { return "Pydantic BaseModel workaround" } + +func (c *PydanticBaseModelCheck) Check(ctx *CheckContext) ([]Finding, error) { + var findings []Finding + + for _, pf := range ctx.PythonFiles { + root := pf.Tree.RootNode() + + for _, child := range schemaPython.NamedChildren(root) { + classNode := schemaPython.UnwrapClass(child) + if classNode == nil { + continue + } + + // Check if class inherits from pydantic.BaseModel (not cog.BaseModel) + if !inheritsPydanticBaseModel(classNode, pf.Source, pf.Imports) { + continue + } + + // Check if class has arbitrary_types_allowed=True + if !hasArbitraryTypesAllowed(classNode, pf.Source) { + continue + } + + nameNode := classNode.ChildByFieldName("name") + className := "" + line := 0 + if nameNode != nil { + className = schemaPython.Content(nameNode, pf.Source) + line = int(nameNode.StartPoint().Row) + 1 + } + + findings = append(findings, Finding{ + Severity: SeverityError, + Message: fmt.Sprintf("%s inherits from pydantic.BaseModel with arbitrary_types_allowed; use cog.BaseModel instead", className), + Remediation: "Replace pydantic.BaseModel with cog.BaseModel and remove ConfigDict(arbitrary_types_allowed=True)", + File: pf.Path, + Line: line, + }) + } + } + + return findings, nil +} + +func (c *PydanticBaseModelCheck) Fix(ctx *CheckContext, findings []Finding) error { + // Group findings by file + fileFindings := make(map[string][]Finding) + for _, f := range findings { + fileFindings[f.File] = append(fileFindings[f.File], f) + } + + for relPath := range fileFindings { + fullPath := filepath.Join(ctx.ProjectDir, relPath) + info, err := os.Stat(fullPath) + if err != nil { + return fmt.Errorf("stat %s: %w", relPath, err) + } + + source, err := os.ReadFile(fullPath) + if err != nil { + return fmt.Errorf("reading %s: %w", relPath, err) + } + + fixed := fixPydanticBaseModel(string(source)) + + if err := os.WriteFile(fullPath, []byte(fixed), info.Mode()); err != nil { + return fmt.Errorf("writing %s: %w", relPath, err) + } + } + + return nil +} + +// inheritsPydanticBaseModel checks if a class inherits from pydantic.BaseModel +// (as opposed to cog.BaseModel or another BaseModel). +func inheritsPydanticBaseModel(classNode *sitter.Node, source []byte, imports *schema.ImportContext) bool { + supers := classNode.ChildByFieldName("superclasses") + if supers == nil { + return false + } + + for _, child := range schemaPython.AllChildren(supers) { + text := schemaPython.Content(child, source) + + switch child.Type() { + case "identifier": + if text == "BaseModel" { + // Check if BaseModel was imported from pydantic + if entry, ok := imports.Names.Get("BaseModel"); ok { + return entry.Module == "pydantic" + } + } + case "attribute": + if text == "pydantic.BaseModel" { + return true + } + } + } + return false +} + +// hasArbitraryTypesAllowed checks if a class body contains +// model_config = ConfigDict(arbitrary_types_allowed=True). +// Uses tree-sitter to properly parse keyword arguments, avoiding false positives +// on arbitrary_types_allowed=False. +func hasArbitraryTypesAllowed(classNode *sitter.Node, source []byte) bool { + body := classNode.ChildByFieldName("body") + if body == nil { + return false + } + + for _, stmt := range schemaPython.NamedChildren(body) { + node := stmt + if stmt.Type() == "expression_statement" && stmt.NamedChildCount() == 1 { + node = stmt.NamedChild(0) + } + + if node.Type() != "assignment" { + continue + } + + left := node.ChildByFieldName("left") + if left == nil || schemaPython.Content(left, source) != "model_config" { + continue + } + + right := node.ChildByFieldName("right") + if right == nil || right.Type() != "call" { + continue + } + + // Walk keyword arguments of the call + args := right.ChildByFieldName("arguments") + if args == nil { + continue + } + for _, arg := range schemaPython.NamedChildren(args) { + if arg.Type() != "keyword_argument" { + continue + } + key := arg.ChildByFieldName("name") + val := arg.ChildByFieldName("value") + if key != nil && val != nil && + schemaPython.Content(key, source) == "arbitrary_types_allowed" && + schemaPython.Content(val, source) == "True" { + return true + } + } + } + + return false +} + +// fixPydanticBaseModel rewrites Python source to replace pydantic.BaseModel with cog.BaseModel. +func fixPydanticBaseModel(source string) string { + lines := strings.Split(source, "\n") + var result []string + inPydanticImport := false + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + + // Handle multi-line "from pydantic import (" style + if strings.HasPrefix(trimmed, "from pydantic import (") { + inPydanticImport = true + // Check if this single-line also closes: from pydantic import (BaseModel, ConfigDict) + if strings.Contains(trimmed, ")") { + inPydanticImport = false + remaining := removePydanticImportsMultiline(trimmed) + if remaining != "" { + result = append(result, remaining) + } + } + continue + } + if inPydanticImport { + if strings.Contains(trimmed, ")") { + inPydanticImport = false + } + // Skip all lines in the parenthesized pydantic import + continue + } + + // Remove single-line "from pydantic import BaseModel" (and ConfigDict) + if strings.HasPrefix(trimmed, "from pydantic import") { + remaining := removePydanticImports(trimmed) + if remaining == "" { + continue // Drop the entire line + } + result = append(result, remaining) + continue + } + + // Handle "import pydantic" style + if trimmed == "import pydantic" { + continue // Drop the line + } + + // Handle model_config = ConfigDict(...) lines -- only remove arbitrary_types_allowed=True + if strings.Contains(trimmed, "model_config") && strings.Contains(trimmed, "ConfigDict") { + fixed := removeKeywordArg(line, "arbitrary_types_allowed=True") + if fixed != line { + // Successfully removed the argument; check if ConfigDict is now empty + if isEmptyConfigDict(fixed) { + continue // Drop the line entirely + } + result = append(result, fixed) + continue + } + } + + // Replace pydantic.BaseModel in class definitions + if strings.Contains(line, "pydantic.BaseModel") { + line = strings.ReplaceAll(line, "pydantic.BaseModel", "BaseModel") + } + + // Replace pydantic.ConfigDict references + if strings.Contains(line, "pydantic.ConfigDict") { + line = strings.ReplaceAll(line, "pydantic.ConfigDict", "ConfigDict") + } + + result = append(result, line) + } + + // Now add BaseModel to the cog import line + fixed := strings.Join(result, "\n") + fixed = addToCogImport(fixed, "BaseModel") + + return fixed +} + +// removeKeywordArg removes a specific keyword argument from a line containing a function call. +// Handles "arg, ", ", arg", and standalone "arg". +func removeKeywordArg(line string, arg string) string { + result := strings.Replace(line, arg+", ", "", 1) + if result == line { + result = strings.Replace(line, ", "+arg, "", 1) + } + if result == line { + result = strings.Replace(line, arg, "", 1) + } + return result +} + +// isEmptyConfigDict checks if a line contains an empty ConfigDict() call. +func isEmptyConfigDict(line string) bool { + trimmed := strings.TrimSpace(line) + return strings.Contains(trimmed, "ConfigDict()") || strings.Contains(trimmed, "ConfigDict( )") +} + +// removePydanticImportsMultiline handles "from pydantic import (X, Y, Z)" on a single line. +func removePydanticImportsMultiline(line string) string { + // Extract contents between "from pydantic import (" and ")" + start := strings.Index(line, "(") + end := strings.LastIndex(line, ")") + if start == -1 || end == -1 || start >= end { + return "" + } + importPart := line[start+1 : end] + names := strings.Split(importPart, ",") + + var remaining []string + for _, name := range names { + name = strings.TrimSpace(name) + if name == "BaseModel" || name == "ConfigDict" || name == "" { + continue + } + remaining = append(remaining, name) + } + + if len(remaining) == 0 { + return "" + } + + return "from pydantic import " + strings.Join(remaining, ", ") +} + +// removePydanticImports removes BaseModel and ConfigDict from a pydantic import line. +// Returns "" if no imports remain. +func removePydanticImports(line string) string { + // Parse "from pydantic import X, Y, Z" + prefix := "from pydantic import " + if !strings.HasPrefix(strings.TrimSpace(line), prefix) { + return line + } + + importPart := strings.TrimSpace(line)[len(prefix):] + names := strings.Split(importPart, ",") + + var remaining []string + for _, name := range names { + name = strings.TrimSpace(name) + if name == "BaseModel" || name == "ConfigDict" { + continue + } + if name != "" { + remaining = append(remaining, name) + } + } + + if len(remaining) == 0 { + return "" + } + + return prefix + strings.Join(remaining, ", ") +} + +// addToCogImport adds a name to an existing "from cog import ..." line. +// If no "from cog import" line exists, inserts one at the top of the file. +func addToCogImport(source string, name string) string { + lines := strings.Split(source, "\n") + for i, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "from cog import") { + if strings.Contains(trimmed, name) { + return source // Already imported + } + // Add the name at the end + lines[i] = line + ", " + name + return strings.Join(lines, "\n") + } + } + // No existing cog import found -- add one at the top + newImport := "from cog import " + name + return newImport + "\n" + source +} diff --git a/pkg/doctor/check_python_test.go b/pkg/doctor/check_python_test.go new file mode 100644 index 0000000000..b8f92e8388 --- /dev/null +++ b/pkg/doctor/check_python_test.go @@ -0,0 +1,503 @@ +package doctor + +import ( + "context" + "os" + "path/filepath" + "testing" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/config" + schemaPython "github.com/replicate/cog/pkg/schema/python" +) + +func TestPydanticBaseModelCheck_Clean(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor, BaseModel, Path + +class Output(BaseModel): + audio: Path + +class Predictor(BasePredictor): + def predict(self, text: str) -> Output: + return Output(audio="a.wav") +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestPydanticBaseModelCheck_Detects(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor, Path +from pydantic import BaseModel, ConfigDict + +class VoiceCloningOutputs(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + audio: Path + spectrogram: Path + +class Predictor(BasePredictor): + def predict(self, text: str) -> VoiceCloningOutputs: + return VoiceCloningOutputs(audio="a.wav", spectrogram="s.png") +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityError, findings[0].Severity) + require.Contains(t, findings[0].Message, "VoiceCloningOutputs") + require.Contains(t, findings[0].Message, "pydantic.BaseModel") +} + +func TestPydanticBaseModelCheck_Fix(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor, Path +from pydantic import BaseModel, ConfigDict + +class VoiceCloningOutputs(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + audio: Path + spectrogram: Path + +class Predictor(BasePredictor): + def predict(self, text: str) -> VoiceCloningOutputs: + return VoiceCloningOutputs(audio="a.wav", spectrogram="s.png") +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + // Re-read and verify + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.Contains(t, string(content), "from cog import BasePredictor, Path, BaseModel") + require.NotContains(t, string(content), "from pydantic import BaseModel") + require.NotContains(t, string(content), "arbitrary_types_allowed") + require.NotContains(t, string(content), "model_config") + + // Re-parse and verify doctor passes + ctx.PythonFiles = parsePythonFiles(t, dir, "predict.py") + findings, err = check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestPydanticBaseModelCheck_NoFalsePositive(t *testing.T) { + // arbitrary_types_allowed=False should NOT trigger the check + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor, Path +from pydantic import BaseModel, ConfigDict + +class Output(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=False, validate_default=True) + audio: Path + +class Predictor(BasePredictor): + def predict(self, text: str) -> Output: + return Output(audio="a.wav") +`) + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestPydanticBaseModelCheck_Fix_NoCogImport(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from pydantic import BaseModel, ConfigDict + +class Output(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + name: str +`) + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.Contains(t, string(content), "from cog import BaseModel") + require.NotContains(t, string(content), "from pydantic import BaseModel") +} + +func TestPydanticBaseModelCheck_Fix_ImportPydanticStyle(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor, Path +import pydantic + +class Output(pydantic.BaseModel): + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + audio: Path + +class Predictor(BasePredictor): + def predict(self, text: str) -> Output: + return Output(audio="a.wav") +`) + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.Contains(t, string(content), "from cog import BasePredictor, Path, BaseModel") + require.NotContains(t, string(content), "import pydantic") + require.NotContains(t, string(content), "pydantic.BaseModel") + require.NotContains(t, string(content), "pydantic.ConfigDict") +} + +func TestPydanticBaseModelCheck_Fix_PreservesOtherConfigDict(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor, Path +from pydantic import BaseModel, ConfigDict + +class Output(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True) + audio: Path + +class Predictor(BasePredictor): + def predict(self, text: str) -> Output: + return Output(audio="a.wav") +`) + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.Contains(t, string(content), "validate_default=True") + require.NotContains(t, string(content), "arbitrary_types_allowed") + require.Contains(t, string(content), "model_config") +} + +func TestPydanticBaseModelCheck_Fix_MultilineImport(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor, Path +from pydantic import ( + BaseModel, + ConfigDict, +) + +class Output(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + audio: Path + +class Predictor(BasePredictor): + def predict(self, text: str) -> Output: + return Output(audio="a.wav") +`) + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + check := &PydanticBaseModelCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.Contains(t, string(content), "from cog import BasePredictor, Path, BaseModel") + require.NotContains(t, string(content), "from pydantic import") + require.NotContains(t, string(content), "arbitrary_types_allowed") +} + +func TestDeprecatedImportsCheck_Clean(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &DeprecatedImportsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestDeprecatedImportsCheck_Detects(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor +from cog.types import ExperimentalFeatureWarning + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &DeprecatedImportsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityError, findings[0].Severity) + require.Contains(t, findings[0].Message, "ExperimentalFeatureWarning") +} + +func TestDeprecatedImportsCheck_Fix(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor +from cog.types import ExperimentalFeatureWarning + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := &CheckContext{ + ctx: context.Background(), + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &DeprecatedImportsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.NotContains(t, string(content), "ExperimentalFeatureWarning") + + // Re-parse and verify clean + ctx.PythonFiles = parsePythonFiles(t, dir, "predict.py") + findings, err = check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestDeprecatedImportsCheck_Fix_WithWarningsFilterwarnings(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `import warnings +from cog import BasePredictor +from cog.types import ExperimentalFeatureWarning + +warnings.filterwarnings("ignore", category=ExperimentalFeatureWarning) + + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return "hello " + text +`) + + ctx := &CheckContext{ + ctx: context.Background(), + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &DeprecatedImportsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.NotContains(t, string(content), "ExperimentalFeatureWarning") + require.NotContains(t, string(content), "cog.types") + require.NotContains(t, string(content), "import warnings") + + // Re-parse and verify clean + ctx.PythonFiles = parsePythonFiles(t, dir, "predict.py") + findings, err = check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestDeprecatedImportsCheck_Fix_WarningsImportPreserved(t *testing.T) { + // When warnings module is still used elsewhere, import should be preserved + dir := t.TempDir() + writeFile(t, dir, "predict.py", `import warnings +from cog import BasePredictor +from cog.types import ExperimentalFeatureWarning + +warnings.filterwarnings("ignore", category=ExperimentalFeatureWarning) +warnings.warn("something else") + + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return "hello " + text +`) + + ctx := &CheckContext{ + ctx: context.Background(), + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + } + + check := &DeprecatedImportsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + + err = check.Fix(ctx, findings) + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(dir, "predict.py")) + require.NoError(t, err) + require.NotContains(t, string(content), "ExperimentalFeatureWarning") + require.Contains(t, string(content), "import warnings") + require.Contains(t, string(content), "warnings.warn") +} + +func TestMissingTypeAnnotationsCheck_Clean(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor + +class Predictor(BasePredictor): + def predict(self, text: str) -> str: + return text +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + Config: &config.Config{Predict: "predict.py:Predictor"}, + } + + check := &MissingTypeAnnotationsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestMissingTypeAnnotationsCheck_MissingReturnType(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "predict.py", `from cog import BasePredictor + +class Predictor(BasePredictor): + def predict(self, text: str): + return text +`) + + ctx := &CheckContext{ + ProjectDir: dir, + PythonFiles: parsePythonFiles(t, dir, "predict.py"), + Config: &config.Config{Predict: "predict.py:Predictor"}, + } + + check := &MissingTypeAnnotationsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Len(t, findings, 1) + require.Equal(t, SeverityWarning, findings[0].Severity) + require.Contains(t, findings[0].Message, "predict()") + require.Contains(t, findings[0].Message, "return type annotation") +} + +func TestMissingTypeAnnotationsCheck_NoConfig(t *testing.T) { + ctx := &CheckContext{ + ProjectDir: t.TempDir(), + Config: nil, + } + + check := &MissingTypeAnnotationsCheck{} + findings, err := check.Check(ctx) + require.NoError(t, err) + require.Empty(t, findings) +} + +func TestMissingTypeAnnotationsCheck_FixReturnsNoAutoFix(t *testing.T) { + check := &MissingTypeAnnotationsCheck{} + err := check.Fix(nil, nil) + require.ErrorIs(t, err, ErrNoAutoFix) +} + +// parsePythonFiles is a test helper that parses Python files into ParsedFile structs. +func parsePythonFiles(t *testing.T, dir string, filenames ...string) map[string]*ParsedFile { + t.Helper() + files := make(map[string]*ParsedFile) + for _, name := range filenames { + source, err := os.ReadFile(filepath.Join(dir, name)) + require.NoError(t, err) + + sitterParser := sitter.NewParser() + sitterParser.SetLanguage(python.GetLanguage()) + tree, err := sitterParser.ParseCtx(context.Background(), nil, source) + require.NoError(t, err) + + imports := schemaPython.CollectImports(tree.RootNode(), source) + files[name] = &ParsedFile{ + Path: name, + Source: source, + Tree: tree, + Imports: imports, + } + } + return files +} diff --git a/pkg/doctor/check_python_type_annotations.go b/pkg/doctor/check_python_type_annotations.go new file mode 100644 index 0000000000..48ecec56e5 --- /dev/null +++ b/pkg/doctor/check_python_type_annotations.go @@ -0,0 +1,123 @@ +package doctor + +import ( + "fmt" + + sitter "github.com/smacker/go-tree-sitter" + + schemaPython "github.com/replicate/cog/pkg/schema/python" +) + +// MissingTypeAnnotationsCheck detects predict/train methods that are +// missing return type annotations. +type MissingTypeAnnotationsCheck struct{} + +func (c *MissingTypeAnnotationsCheck) Name() string { return "python-missing-type-annotations" } +func (c *MissingTypeAnnotationsCheck) Group() Group { return GroupPython } +func (c *MissingTypeAnnotationsCheck) Description() string { return "Type annotations" } + +func (c *MissingTypeAnnotationsCheck) Check(ctx *CheckContext) ([]Finding, error) { + if ctx.Config == nil { + return nil, nil + } + + var findings []Finding + + if ctx.Config.Predict != "" { + f := checkMethodAnnotations(ctx, ctx.Config.Predict, "predict") + findings = append(findings, f...) + } + + if ctx.Config.Train != "" { + f := checkMethodAnnotations(ctx, ctx.Config.Train, "train") + findings = append(findings, f...) + } + + return findings, nil +} + +func (c *MissingTypeAnnotationsCheck) Fix(_ *CheckContext, _ []Finding) error { + return ErrNoAutoFix +} + +// checkMethodAnnotations checks that the given method has a return type annotation. +func checkMethodAnnotations(ctx *CheckContext, ref string, methodName string) []Finding { + parts := splitPredictRef(ref) + fileName := parts[0] + className := parts[1] + + if fileName == "" || className == "" { + return nil // Invalid ref — handled by ConfigPredictRefCheck + } + + pf, ok := ctx.PythonFiles[fileName] + if !ok { + return nil // File not parsed — handled by ConfigPredictRefCheck + } + + root := pf.Tree.RootNode() + + // Find the class + classNode := findClass(root, pf.Source, className) + if classNode == nil { + return nil // Class not found — handled by ConfigPredictRefCheck + } + + // Find the method inside the class + funcNode := findMethod(classNode, pf.Source, methodName) + if funcNode == nil { + return nil // Method not found — could be a separate check later + } + + // Check for return type annotation + if funcNode.ChildByFieldName("return_type") == nil { + line := int(funcNode.StartPoint().Row) + 1 + return []Finding{{ + Severity: SeverityWarning, + Message: fmt.Sprintf( + "%s.%s() is missing a return type annotation", + className, methodName, + ), + Remediation: fmt.Sprintf("Add a return type annotation: def %s(self, ...) -> YourType:", methodName), + File: fileName, + Line: line, + }} + } + + return nil +} + +// findClass locates a top-level class by name. +func findClass(root *sitter.Node, source []byte, name string) *sitter.Node { + for _, child := range schemaPython.NamedChildren(root) { + classNode := schemaPython.UnwrapClass(child) + if classNode == nil { + continue + } + nameNode := classNode.ChildByFieldName("name") + if nameNode != nil && schemaPython.Content(nameNode, source) == name { + return classNode + } + } + return nil +} + +// findMethod locates a method by name inside a class body. +func findMethod(classNode *sitter.Node, source []byte, name string) *sitter.Node { + body := classNode.ChildByFieldName("body") + if body == nil { + return nil + } + + for _, child := range schemaPython.NamedChildren(body) { + funcNode := schemaPython.UnwrapFunction(child) + if funcNode == nil { + continue + } + nameNode := funcNode.ChildByFieldName("name") + if nameNode != nil && schemaPython.Content(nameNode, source) == name { + return funcNode + } + } + return nil +} diff --git a/pkg/doctor/doctor.go b/pkg/doctor/doctor.go new file mode 100644 index 0000000000..76cce45792 --- /dev/null +++ b/pkg/doctor/doctor.go @@ -0,0 +1,86 @@ +package doctor + +import ( + "context" + "errors" + + sitter "github.com/smacker/go-tree-sitter" + + "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/schema" +) + +// ErrNoAutoFix is returned by Fix() for detect-only checks. +var ErrNoAutoFix = errors.New("no auto-fix available for this check") + +// Severity of a finding. +type Severity int + +const ( + SeverityError Severity = iota // Must fix -- will cause build/predict failures + SeverityWarning // Should fix -- deprecated patterns, future breakage + SeverityInfo // Informational -- suggestions, best practices +) + +// String returns the human-readable name of the severity level. +func (s Severity) String() string { + switch s { + case SeverityError: + return "error" + case SeverityWarning: + return "warning" + case SeverityInfo: + return "info" + default: + return "unknown" + } +} + +// Group categorizes checks for display purposes. +type Group string + +const ( + GroupConfig Group = "Config" + GroupPython Group = "Python" + GroupEnvironment Group = "Environment" +) + +// Finding represents a single problem detected by a check. +type Finding struct { + Severity Severity + Message string // What's wrong + Remediation string // How to fix it + File string // Optional: file path where the issue was found + Line int // Optional: line number (1-indexed, 0 means unknown) +} + +// Check is the interface every doctor rule implements. +type Check interface { + Name() string + Group() Group + Description() string + Check(ctx *CheckContext) ([]Finding, error) + Fix(ctx *CheckContext, findings []Finding) error +} + +// ParsedFile holds tree-sitter parse results for a Python file. +type ParsedFile struct { + Path string // Relative path from project root + Source []byte // Raw file contents + Tree *sitter.Tree // Tree-sitter parse tree + Imports *schema.ImportContext // Collected imports +} + +// CheckContext provides checks with access to project state. +// Built once by the runner and passed to every check. +type CheckContext struct { + ctx context.Context + ProjectDir string + ConfigFilename string // Config filename (e.g. "cog.yaml") + Config *config.Config // Parsed cog.yaml (nil if parsing failed) + ConfigFile []byte // Raw cog.yaml bytes (available even if parsing failed) + LoadResult *config.LoadResult // Non-nil if config loaded successfully + LoadErr error // Non-nil if config loading failed + PythonFiles map[string]*ParsedFile // Pre-parsed Python files keyed by relative path + PythonPath string // Path to python binary (empty if not found) +} diff --git a/pkg/doctor/doctor_test.go b/pkg/doctor/doctor_test.go new file mode 100644 index 0000000000..15a8d53a12 --- /dev/null +++ b/pkg/doctor/doctor_test.go @@ -0,0 +1,157 @@ +package doctor + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +// mockCheck is a test double for Check. +type mockCheck struct { + name string + group Group + description string + findings []Finding + checkErr error + fixErr error + fixCalled bool +} + +func (m *mockCheck) Name() string { return m.name } +func (m *mockCheck) Group() Group { return m.group } +func (m *mockCheck) Description() string { return m.description } + +func (m *mockCheck) Check(_ *CheckContext) ([]Finding, error) { + return m.findings, m.checkErr +} + +func (m *mockCheck) Fix(_ *CheckContext, _ []Finding) error { + m.fixCalled = true + return m.fixErr +} + +func TestRunCollectsFindings(t *testing.T) { + checks := []Check{ + &mockCheck{ + name: "passing-check", + group: GroupConfig, + findings: nil, + }, + &mockCheck{ + name: "failing-check", + group: GroupPython, + findings: []Finding{ + {Severity: SeverityError, Message: "something is wrong"}, + }, + }, + } + + result, err := Run(context.Background(), RunOptions{ProjectDir: t.TempDir()}, checks) + require.NoError(t, err) + require.Len(t, result.Results, 2) + require.Empty(t, result.Results[0].Findings) + require.Len(t, result.Results[1].Findings, 1) + require.Equal(t, "something is wrong", result.Results[1].Findings[0].Message) +} + +func TestRunCallsFixWhenEnabled(t *testing.T) { + check := &mockCheck{ + name: "fixable-check", + group: GroupPython, + findings: []Finding{ + {Severity: SeverityError, Message: "fixable issue"}, + }, + } + + _, err := Run(context.Background(), RunOptions{Fix: true, ProjectDir: t.TempDir()}, []Check{check}) + require.NoError(t, err) + require.True(t, check.fixCalled) +} + +func TestRunDoesNotCallFixWhenDisabled(t *testing.T) { + check := &mockCheck{ + name: "fixable-check", + group: GroupPython, + findings: []Finding{ + {Severity: SeverityError, Message: "fixable issue"}, + }, + } + + _, err := Run(context.Background(), RunOptions{Fix: false, ProjectDir: t.TempDir()}, []Check{check}) + require.NoError(t, err) + require.False(t, check.fixCalled) +} + +func TestRunDoesNotCallFixWhenNoFindings(t *testing.T) { + check := &mockCheck{ + name: "clean-check", + group: GroupConfig, + findings: nil, + } + + _, err := Run(context.Background(), RunOptions{Fix: true, ProjectDir: t.TempDir()}, []Check{check}) + require.NoError(t, err) + require.False(t, check.fixCalled) +} + +func TestRunHandlesCheckError(t *testing.T) { + checks := []Check{ + &mockCheck{ + name: "error-check", + group: GroupEnvironment, + checkErr: context.DeadlineExceeded, + }, + &mockCheck{ + name: "ok-check", + group: GroupConfig, + findings: nil, + }, + } + + result, err := Run(context.Background(), RunOptions{ProjectDir: t.TempDir()}, checks) + require.NoError(t, err) // Run itself doesn't fail; individual check errors are captured + require.Len(t, result.Results, 2) + require.Error(t, result.Results[0].Err) + require.NoError(t, result.Results[1].Err) +} + +func TestRunMarksFixedOnSuccess(t *testing.T) { + check := &mockCheck{ + name: "fixable", + group: GroupPython, + findings: []Finding{ + {Severity: SeverityError, Message: "broken"}, + }, + fixErr: nil, + } + + result, err := Run(context.Background(), RunOptions{Fix: true, ProjectDir: t.TempDir()}, []Check{check}) + require.NoError(t, err) + require.True(t, result.Results[0].Fixed) +} + +func TestRunMarksNotFixedOnErrNoAutoFix(t *testing.T) { + check := &mockCheck{ + name: "unfixable", + group: GroupConfig, + findings: []Finding{ + {Severity: SeverityWarning, Message: "deprecated"}, + }, + fixErr: ErrNoAutoFix, + } + + result, err := Run(context.Background(), RunOptions{Fix: true, ProjectDir: t.TempDir()}, []Check{check}) + require.NoError(t, err) + require.False(t, result.Results[0].Fixed) +} + +func TestHasErrorsWithCheckError(t *testing.T) { + result := &Result{ + Results: []CheckResult{ + {Err: errors.New("check failed")}, + }, + } + require.True(t, result.HasErrors()) +} diff --git a/pkg/doctor/registry.go b/pkg/doctor/registry.go new file mode 100644 index 0000000000..5a2873138b --- /dev/null +++ b/pkg/doctor/registry.go @@ -0,0 +1,22 @@ +package doctor + +// AllChecks returns all registered doctor checks in execution order. +// To add a new check: implement the Check interface and add it here. +func AllChecks() []Check { + return []Check{ + // Config checks + &ConfigParseCheck{}, + &ConfigSchemaCheck{}, + &ConfigDeprecatedFieldsCheck{}, + &ConfigPredictRefCheck{}, + + // Python checks + &PydanticBaseModelCheck{}, + &DeprecatedImportsCheck{}, + &MissingTypeAnnotationsCheck{}, + + // Environment checks + &DockerCheck{}, + &PythonVersionCheck{}, + } +} diff --git a/pkg/doctor/runner.go b/pkg/doctor/runner.go new file mode 100644 index 0000000000..4a9193c8cd --- /dev/null +++ b/pkg/doctor/runner.go @@ -0,0 +1,176 @@ +package doctor + +import ( + "bytes" + "context" + "errors" + "os" + "os/exec" + "path/filepath" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" + + "github.com/replicate/cog/pkg/config" + schemaPython "github.com/replicate/cog/pkg/schema/python" +) + +// RunOptions configures a doctor run. +type RunOptions struct { + Fix bool + ProjectDir string + ConfigFilename string // Config filename (defaults to "cog.yaml" if empty) +} + +// CheckResult holds the outcome of running a single check. +type CheckResult struct { + Check Check + Findings []Finding + Fixed bool // True if --fix was passed and Fix() succeeded + Err error // Non-nil if the check itself errored +} + +// Result holds the outcome of a full doctor run. +type Result struct { + Results []CheckResult +} + +// HasErrors returns true if any check produced error-severity findings +// or if any check itself errored. +func (r *Result) HasErrors() bool { + for _, cr := range r.Results { + if cr.Err != nil { + return true + } + for _, f := range cr.Findings { + if f.Severity == SeverityError && !cr.Fixed { + return true + } + } + } + return false +} + +// Run executes all checks and optionally applies fixes. +func Run(ctx context.Context, opts RunOptions, checks []Check) (*Result, error) { + configFilename := opts.ConfigFilename + if configFilename == "" { + configFilename = "cog.yaml" + } + + checkCtx, err := buildCheckContext(ctx, opts.ProjectDir, configFilename) + if err != nil { + return nil, err + } + + result := &Result{} + + for _, check := range checks { + cr := CheckResult{Check: check} + + findings, err := check.Check(checkCtx) + if err != nil { + cr.Err = err + result.Results = append(result.Results, cr) + continue + } + + cr.Findings = findings + + if opts.Fix && len(findings) > 0 { + fixErr := check.Fix(checkCtx, findings) + if fixErr == nil { + cr.Fixed = true + } else if !errors.Is(fixErr, ErrNoAutoFix) { + cr.Err = fixErr + } + } + + result.Results = append(result.Results, cr) + } + + return result, nil +} + +// buildCheckContext constructs the shared context for all checks. +func buildCheckContext(ctx context.Context, projectDir string, configFilename string) (*CheckContext, error) { + ctxt := &CheckContext{ + ctx: ctx, + ProjectDir: projectDir, + ConfigFilename: configFilename, + PythonFiles: make(map[string]*ParsedFile), + } + + // Load cog.yaml + configPath := filepath.Join(projectDir, configFilename) + configBytes, err := os.ReadFile(configPath) + if err == nil { + ctxt.ConfigFile = configBytes + // Load and validate config once — checks use ctxt.LoadResult / ctxt.LoadErr + loadResult, loadErr := config.Load(bytes.NewReader(configBytes), projectDir) + ctxt.LoadErr = loadErr + if loadResult != nil { + ctxt.LoadResult = loadResult + ctxt.Config = loadResult.Config + } + } + + // Find python binary + if pythonPath, err := exec.LookPath("python3"); err == nil { + ctxt.PythonPath = pythonPath + } else if pythonPath, err := exec.LookPath("python"); err == nil { + ctxt.PythonPath = pythonPath + } + + // Pre-parse Python files referenced in config + if ctxt.Config != nil { + parsePythonRef(ctxt, ctxt.Config.Predict) + parsePythonRef(ctxt, ctxt.Config.Train) + } + + return ctxt, nil +} + +// parsePythonRef parses a predict/train reference like "predict.py:Predictor" +// and adds the parsed file to ctx.PythonFiles. +func parsePythonRef(ctxt *CheckContext, ref string) { + if ref == "" { + return + } + parts := splitPredictRef(ref) + if parts[0] == "" { + return + } + + fullPath := filepath.Join(ctxt.ProjectDir, parts[0]) + source, err := os.ReadFile(fullPath) + if err != nil { + return + } + + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + tree, err := parser.ParseCtx(ctxt.ctx, nil, source) + if err != nil { + return + } + + imports := schemaPython.CollectImports(tree.RootNode(), source) + + ctxt.PythonFiles[parts[0]] = &ParsedFile{ + Path: parts[0], + Source: source, + Tree: tree, + Imports: imports, + } +} + +// splitPredictRef splits "predict.py:Predictor" into ["predict.py", "Predictor"]. +func splitPredictRef(ref string) [2]string { + for i := len(ref) - 1; i >= 0; i-- { + if ref[i] == ':' { + return [2]string{ref[:i], ref[i+1:]} + } + } + return [2]string{ref, ""} +} diff --git a/pkg/schema/python/parser.go b/pkg/schema/python/parser.go index 253de24d48..a425956909 100644 --- a/pkg/schema/python/parser.go +++ b/pkg/schema/python/parser.go @@ -38,7 +38,7 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi root := tree.RootNode() // 1. Collect imports - imports := collectImports(root, source) + imports := CollectImports(root, source) // 2. Collect module-level variable assignments moduleScope := collectModuleScope(root, source) @@ -101,8 +101,8 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi // Helpers // --------------------------------------------------------------------------- -// namedChildren returns all named children of a node. -func namedChildren(n *sitter.Node) []*sitter.Node { +// NamedChildren returns all named children of a node. +func NamedChildren(n *sitter.Node) []*sitter.Node { count := int(n.NamedChildCount()) result := make([]*sitter.Node, 0, count) for i := range count { @@ -111,8 +111,8 @@ func namedChildren(n *sitter.Node) []*sitter.Node { return result } -// allChildren returns all children (named and anonymous) of a node. -func allChildren(n *sitter.Node) []*sitter.Node { +// AllChildren returns all children (named and anonymous) of a node. +func AllChildren(n *sitter.Node) []*sitter.Node { count := int(n.ChildCount()) result := make([]*sitter.Node, 0, count) for i := range count { @@ -121,8 +121,8 @@ func allChildren(n *sitter.Node) []*sitter.Node { return result } -// content returns the source text for a node. -func content(n *sitter.Node, source []byte) string { +// Content returns the source text for a node. +func Content(n *sitter.Node, source []byte) string { return n.Content(source) } @@ -130,10 +130,10 @@ func content(n *sitter.Node, source []byte) string { // Import collection // --------------------------------------------------------------------------- -func collectImports(root *sitter.Node, source []byte) *schema.ImportContext { +func CollectImports(root *sitter.Node, source []byte) *schema.ImportContext { ctx := schema.NewImportContext() - for _, child := range namedChildren(root) { + for _, child := range NamedChildren(root) { if child.Type() == "import_from_statement" { parseImportFrom(child, source, ctx) } @@ -157,15 +157,15 @@ func parseImportFrom(node *sitter.Node, source []byte, ctx *schema.ImportContext if moduleNode == nil { return } - module := content(moduleNode, source) + module := Content(moduleNode, source) - for _, child := range allChildren(node) { + for _, child := range AllChildren(node) { switch child.Type() { case "dotted_name": // Single import: `from X import name` // Skip if this is the module_name itself if child.StartByte() != moduleNode.StartByte() { - name := content(child, source) + name := Content(child, source) ctx.Names.Set(name, schema.ImportEntry{Module: module, Original: name}) } case "aliased_import": @@ -174,29 +174,29 @@ func parseImportFrom(node *sitter.Node, source []byte, ctx *schema.ImportContext aliasNode := child.ChildByFieldName("alias") orig := "" if origNode != nil { - orig = content(origNode, source) + orig = Content(origNode, source) } alias := orig if aliasNode != nil { - alias = content(aliasNode, source) + alias = Content(aliasNode, source) } ctx.Names.Set(alias, schema.ImportEntry{Module: module, Original: orig}) case "import_list": - for _, importChild := range allChildren(child) { + for _, importChild := range AllChildren(child) { switch importChild.Type() { case "dotted_name": - name := content(importChild, source) + name := Content(importChild, source) ctx.Names.Set(name, schema.ImportEntry{Module: module, Original: name}) case "aliased_import": origNode := importChild.ChildByFieldName("name") aliasNode := importChild.ChildByFieldName("alias") orig := "" if origNode != nil { - orig = content(origNode, source) + orig = Content(origNode, source) } alias := orig if aliasNode != nil { - alias = content(aliasNode, source) + alias = Content(aliasNode, source) } ctx.Names.Set(alias, schema.ImportEntry{Module: module, Original: orig}) } @@ -213,7 +213,7 @@ type moduleScope map[string]schema.DefaultValue func collectModuleScope(root *sitter.Node, source []byte) moduleScope { scope := make(moduleScope) - for _, child := range namedChildren(root) { + for _, child := range NamedChildren(root) { var assign *sitter.Node if child.Type() == "expression_statement" { if child.NamedChildCount() == 1 { @@ -233,7 +233,7 @@ func collectModuleScope(root *sitter.Node, source []byte) moduleScope { if left == nil || left.Type() != "identifier" { continue } - name := content(left, source) + name := Content(left, source) right := assign.ChildByFieldName("right") if right == nil { @@ -253,7 +253,7 @@ func resolveDefaultExpr(node *sitter.Node, source []byte, scope moduleScope) (sc return val, true } if node.Type() == "identifier" { - name := content(node, source) + name := Content(node, source) if val, ok := scope[name]; ok { return val, true } @@ -268,7 +268,7 @@ func resolveChoicesExpr(node *sitter.Node, source []byte, scope moduleScope) ([] return parseListLiteral(node, source) case "identifier": - name := content(node, source) + name := Content(node, source) val, ok := scope[name] if !ok { return nil, false @@ -284,8 +284,8 @@ func resolveChoicesExpr(node *sitter.Node, source []byte, scope moduleScope) ([] case "binary_operator": // Only handle + (list concatenation) hasPlus := false - for _, c := range allChildren(node) { - if !c.IsNamed() && content(c, source) == "+" { + for _, c := range AllChildren(node) { + if !c.IsNamed() && Content(c, source) == "+" { hasPlus = true break } @@ -314,7 +314,7 @@ func resolveChoicesExpr(node *sitter.Node, source []byte, scope moduleScope) ([] // resolveChoicesCall resolves list(X.keys()) or list(X.values()). func resolveChoicesCall(node *sitter.Node, source []byte, scope moduleScope) ([]schema.DefaultValue, bool) { funcNode := node.ChildByFieldName("function") - if funcNode == nil || content(funcNode, source) != "list" { + if funcNode == nil || Content(funcNode, source) != "list" { return nil, false } @@ -325,7 +325,7 @@ func resolveChoicesCall(node *sitter.Node, source []byte, scope moduleScope) ([] // Find the single positional argument var arg *sitter.Node - for _, c := range namedChildren(args) { + for _, c := range NamedChildren(args) { arg = c break } @@ -344,8 +344,8 @@ func resolveChoicesCall(node *sitter.Node, source []byte, scope moduleScope) ([] return nil, false } - varName := content(obj, source) - methodName := content(attr, source) + varName := Content(obj, source) + methodName := Content(attr, source) dictVal, ok := scope[varName] if !ok || dictVal.Kind != schema.DefaultDict { @@ -368,8 +368,8 @@ func resolveChoicesCall(node *sitter.Node, source []byte, scope moduleScope) ([] func collectModelClasses(root *sitter.Node, source []byte, imports *schema.ImportContext) schema.ModelClassMap { models := schema.NewOrderedMap[string, []schema.ModelField]() - for _, child := range namedChildren(root) { - classNode := unwrapClass(child) + for _, child := range NamedChildren(root) { + classNode := UnwrapClass(child) if classNode == nil { continue } @@ -378,9 +378,9 @@ func collectModelClasses(root *sitter.Node, source []byte, imports *schema.Impor if nameNode == nil { continue } - className := content(nameNode, source) + className := Content(nameNode, source) - if !inheritsFromBaseModel(classNode, source, imports) { + if !InheritsFromBaseModel(classNode, source, imports) { continue } @@ -449,7 +449,7 @@ func resolveExternalModels(sourceDir string, imports *schema.ImportContext, mode return } - fileImports := collectImports(tree.RootNode(), source) + fileImports := CollectImports(tree.RootNode(), source) fileModels := collectModelClasses(tree.RootNode(), source, fileImports) // Merge discovered models into the caller's map. @@ -518,12 +518,12 @@ func isKnownExternalModule(module string) bool { return false } -func unwrapClass(node *sitter.Node) *sitter.Node { +func UnwrapClass(node *sitter.Node) *sitter.Node { if node.Type() == "class_definition" { return node } if node.Type() == "decorated_definition" { - for _, c := range namedChildren(node) { + for _, c := range NamedChildren(node) { if c.Type() == "class_definition" { return c } @@ -532,12 +532,12 @@ func unwrapClass(node *sitter.Node) *sitter.Node { return nil } -func unwrapFunction(node *sitter.Node) *sitter.Node { +func UnwrapFunction(node *sitter.Node) *sitter.Node { if node.Type() == "function_definition" { return node } if node.Type() == "decorated_definition" { - for _, c := range namedChildren(node) { + for _, c := range NamedChildren(node) { if c.Type() == "function_definition" { return c } @@ -546,21 +546,21 @@ func unwrapFunction(node *sitter.Node) *sitter.Node { return nil } -func inheritsFromBaseModel(classNode *sitter.Node, source []byte, imports *schema.ImportContext) bool { +func InheritsFromBaseModel(classNode *sitter.Node, source []byte, imports *schema.ImportContext) bool { supers := classNode.ChildByFieldName("superclasses") if supers == nil { return false } - for _, child := range allChildren(supers) { + for _, child := range AllChildren(supers) { switch child.Type() { case "identifier": - name := content(child, source) + name := Content(child, source) if imports.IsBaseModel(name) || name == "BaseModel" { return true } case "attribute": // Handle dotted access: pydantic.BaseModel, cog.BaseModel - text := content(child, source) + text := Content(child, source) if strings.HasSuffix(text, ".BaseModel") { return true } @@ -576,7 +576,7 @@ func extractClassAnnotations(classNode *sitter.Node, source []byte) []schema.Mod } var fields []schema.ModelField - for _, child := range namedChildren(body) { + for _, child := range NamedChildren(body) { node := child if child.Type() == "expression_statement" && child.NamedChildCount() == 1 { node = child.NamedChild(0) @@ -603,7 +603,7 @@ func parseAnnotatedAssignment(node *sitter.Node, source []byte) (schema.ModelFie return schema.ModelField{}, false } - name := content(left, source) + name := Content(left, source) typeAnn, err := parseTypeAnnotation(typeNode, source) if err != nil { return schema.ModelField{}, false @@ -620,7 +620,7 @@ func parseAnnotatedAssignment(node *sitter.Node, source []byte) (schema.ModelFie } func parseBareAnnotation(node *sitter.Node, source []byte) (schema.ModelField, bool) { - text := strings.TrimSpace(content(node, source)) + text := strings.TrimSpace(Content(node, source)) parts := strings.SplitN(text, ":", 2) if len(parts) != 2 { return schema.ModelField{}, false @@ -763,8 +763,8 @@ func newInputRegistry() *inputRegistry { func collectInputRegistry(root *sitter.Node, source []byte, imports *schema.ImportContext, scope moduleScope) *inputRegistry { registry := newInputRegistry() - for _, child := range namedChildren(root) { - classNode := unwrapClass(child) + for _, child := range NamedChildren(root) { + classNode := UnwrapClass(child) if classNode == nil { continue } @@ -772,14 +772,14 @@ func collectInputRegistry(root *sitter.Node, source []byte, imports *schema.Impo if nameNode == nil { continue } - className := content(nameNode, source) + className := Content(nameNode, source) body := classNode.ChildByFieldName("body") if body == nil { continue } - for _, stmt := range namedChildren(body) { + for _, stmt := range NamedChildren(body) { inner := stmt if stmt.Type() == "expression_statement" && stmt.NamedChildCount() == 1 { inner = stmt.NamedChild(0) @@ -789,7 +789,7 @@ func collectInputRegistry(root *sitter.Node, source []byte, imports *schema.Impo collectInputAttribute(className, inner, source, imports, scope, registry) } - if funcNode := unwrapFunction(inner); funcNode != nil { + if funcNode := UnwrapFunction(inner); funcNode != nil { collectInputMethod(className, funcNode, source, imports, scope, registry) } } @@ -803,7 +803,7 @@ func collectInputAttribute(className string, assignment *sitter.Node, source []b if left == nil || left.Type() != "identifier" { return } - attrName := content(left, source) + attrName := Content(left, source) right := assignment.ChildByFieldName("right") if right == nil || !isInputCall(right, source, imports) { @@ -823,7 +823,7 @@ func collectInputMethod(className string, funcNode *sitter.Node, source []byte, if nameNode == nil { return } - methodName := content(nameNode, source) + methodName := Content(nameNode, source) params := funcNode.ChildByFieldName("parameters") if params == nil { @@ -831,10 +831,10 @@ func collectInputMethod(className string, funcNode *sitter.Node, source []byte, } var paramNames []string - for _, param := range allChildren(params) { + for _, param := range AllChildren(params) { switch param.Type() { case "identifier": - name := content(param, source) + name := Content(param, source) if name != "self" && name != "cls" { paramNames = append(paramNames, name) } @@ -843,7 +843,7 @@ func collectInputMethod(className string, funcNode *sitter.Node, source []byte, for j := 0; j < int(param.NamedChildCount()); j++ { c := param.NamedChild(j) if c.Type() == "identifier" { - name := content(c, source) + name := Content(c, source) if name != "self" && name != "cls" { paramNames = append(paramNames, name) } @@ -852,7 +852,7 @@ func collectInputMethod(className string, funcNode *sitter.Node, source []byte, } case "typed_default_parameter", "default_parameter": if n := param.ChildByFieldName("name"); n != nil { - name := content(n, source) + name := Content(n, source) if name != "self" && name != "cls" { paramNames = append(paramNames, name) } @@ -879,7 +879,7 @@ func collectInputMethod(className string, funcNode *sitter.Node, source []byte, } func findReturnInputCall(body *sitter.Node, source []byte, imports *schema.ImportContext) *sitter.Node { - for _, child := range namedChildren(body) { + for _, child := range NamedChildren(body) { if child.Type() == "return_statement" { if child.NamedChildCount() > 0 { expr := child.NamedChild(0) @@ -895,7 +895,7 @@ func findReturnInputCall(body *sitter.Node, source []byte, imports *schema.Impor func resolveInputReference(node *sitter.Node, source []byte, registry *inputRegistry) (inputCallInfo, bool) { switch node.Type() { case "attribute": - text := content(node, source) + text := Content(node, source) info, ok := registry.Attributes[text] return info, ok @@ -904,7 +904,7 @@ func resolveInputReference(node *sitter.Node, source []byte, registry *inputRegi if funcNode == nil || funcNode.Type() != "attribute" { return inputCallInfo{}, false } - key := content(funcNode, source) + key := Content(funcNode, source) methodInfo, ok := registry.Methods[key] if !ok { return inputCallInfo{}, false @@ -920,12 +920,12 @@ func resolveInputReference(node *sitter.Node, source []byte, registry *inputRegi // Build param_name -> call-site value map argValues := make(map[string]*sitter.Node) positionalIdx := 0 - for _, arg := range namedChildren(args) { + for _, arg := range NamedChildren(args) { if arg.Type() == "keyword_argument" { nameNode := arg.ChildByFieldName("name") valNode := arg.ChildByFieldName("value") if nameNode != nil && valNode != nil { - argValues[content(nameNode, source)] = valNode + argValues[Content(nameNode, source)] = valNode } } else if positionalIdx < len(methodInfo.ParamNames) { argValues[methodInfo.ParamNames[positionalIdx]] = arg @@ -966,26 +966,26 @@ func resolveInputReference(node *sitter.Node, source []byte, registry *inputRegi func findTargetFunction(root *sitter.Node, source []byte, predictRef, methodName string) (*sitter.Node, error) { // First: look for a class with this name - for _, child := range namedChildren(root) { - classNode := unwrapClass(child) + for _, child := range NamedChildren(root) { + classNode := UnwrapClass(child) if classNode == nil { continue } nameNode := classNode.ChildByFieldName("name") - if nameNode != nil && content(nameNode, source) == predictRef { + if nameNode != nil && Content(nameNode, source) == predictRef { return findMethodInClass(classNode, source, predictRef, methodName) } } // Second: look for standalone function - for _, child := range namedChildren(root) { - funcNode := unwrapFunction(child) + for _, child := range NamedChildren(root) { + funcNode := UnwrapFunction(child) if funcNode == nil { continue } nameNode := funcNode.ChildByFieldName("name") if nameNode != nil { - name := content(nameNode, source) + name := Content(nameNode, source) if name == predictRef || name == methodName { return funcNode, nil } @@ -1001,13 +1001,13 @@ func findMethodInClass(classNode *sitter.Node, source []byte, className, methodN return nil, schema.WrapError(schema.ErrParse, fmt.Sprintf("class %s has no body", className), nil) } - for _, child := range namedChildren(body) { - funcNode := unwrapFunction(child) + for _, child := range NamedChildren(body) { + funcNode := UnwrapFunction(child) if funcNode == nil { continue } nameNode := funcNode.ChildByFieldName("name") - if nameNode != nil && content(nameNode, source) == methodName { + if nameNode != nil && Content(nameNode, source) == methodName { return funcNode, nil } } @@ -1020,9 +1020,9 @@ func findMethodInClass(classNode *sitter.Node, source []byte, className, methodN // --------------------------------------------------------------------------- func firstParamIsSelf(params *sitter.Node, source []byte) bool { - for _, child := range allChildren(params) { + for _, child := range AllChildren(params) { if child.Type() == "identifier" { - return content(child, source) == "self" + return Content(child, source) == "self" } } return false @@ -1041,11 +1041,11 @@ func extractInputs( order := 0 seenSelf := false - for _, child := range allChildren(paramsNode) { + for _, child := range AllChildren(paramsNode) { switch child.Type() { case "identifier": if !seenSelf && skipSelf { - name := content(child, source) + name := Content(child, source) if name == "self" { seenSelf = true continue @@ -1072,7 +1072,7 @@ func extractInputs( nameNode := child.ChildByFieldName("name") paramName := "" if nameNode != nil { - paramName = content(nameNode, source) + paramName = Content(nameNode, source) } return nil, schema.WrapError(schema.ErrMissingTypeAnnotation, fmt.Sprintf("parameter '%s' on %s has no type annotation", paramName, methodName), nil) } @@ -1091,7 +1091,7 @@ func parseTypedParameter(node *sitter.Node, source []byte, order int, methodName switch c.Type() { case "identifier": if name == "" { - name = content(c, source) + name = Content(c, source) } case "type": typeNode = c @@ -1133,7 +1133,7 @@ func parseTypedDefaultParameter( if nameNode == nil { return schema.InputField{}, schema.WrapError(schema.ErrParse, "typed_default_parameter has no name", nil) } - name := content(nameNode, source) + name := Content(nameNode, source) typeNode := node.ChildByFieldName("type") if typeNode == nil { @@ -1203,7 +1203,7 @@ func parseTypedDefaultParameter( } // Can't resolve — hard error - valText := content(valNode, source) + valText := Content(valNode, source) return schema.InputField{}, schema.WrapError(schema.ErrDefaultNotResolvable, fmt.Sprintf("parameter '%s': default `%s` cannot be statically resolved", name, valText), nil) } @@ -1229,17 +1229,17 @@ func parseTypeAnnotation(node *sitter.Node, source []byte) (schema.TypeAnnotatio switch n.Type() { case "identifier": - return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: content(n, source)}, nil + return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: Content(n, source)}, nil case "subscript": value := n.ChildByFieldName("value") if value == nil { return schema.TypeAnnotation{}, schema.WrapError(schema.ErrParse, "subscript has no value", nil) } - outer := content(value, source) + outer := Content(value, source) var args []schema.TypeAnnotation - for _, child := range namedChildren(n) { + for _, child := range NamedChildren(n) { // Skip the outer identifier (the value field) if child.StartByte() == value.StartByte() { continue @@ -1265,8 +1265,8 @@ func parseTypeAnnotation(node *sitter.Node, source []byte) (schema.TypeAnnotatio // Check that operator is | isUnion := false - for _, c := range allChildren(n) { - if !c.IsNamed() && content(c, source) == "|" { + for _, c := range AllChildren(n) { + if !c.IsNamed() && Content(c, source) == "|" { isUnion = true break } @@ -1303,10 +1303,10 @@ func parseTypeAnnotation(node *sitter.Node, source []byte) (schema.TypeAnnotatio return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: "None"}, nil case "attribute": - return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: content(n, source)}, nil + return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: Content(n, source)}, nil case "string", "concatenated_string": - text := content(n, source) + text := Content(n, source) inner := strings.TrimLeft(text, "\"'") inner = strings.TrimRight(inner, "\"'") if ann, ok := parseTypeFromString(inner); ok { @@ -1315,7 +1315,7 @@ func parseTypeAnnotation(node *sitter.Node, source []byte) (schema.TypeAnnotatio return schema.TypeAnnotation{}, errUnsupported(fmt.Sprintf("string annotation: %s", text)) default: - text := content(n, source) + text := Content(n, source) if ann, ok := parseTypeFromString(text); ok { return ann, nil } @@ -1339,7 +1339,7 @@ func isInputCall(node *sitter.Node, source []byte, imports *schema.ImportContext if funcNode == nil { return false } - name := content(funcNode, source) + name := Content(funcNode, source) if name == "Input" { return true } @@ -1357,7 +1357,7 @@ func parseInputCall(node *sitter.Node, source []byte, paramName string, scope mo return info, nil } - for _, child := range namedChildren(args) { + for _, child := range NamedChildren(args) { if child.Type() != "keyword_argument" { continue } @@ -1367,7 +1367,7 @@ func parseInputCall(node *sitter.Node, source []byte, paramName string, scope mo continue } - key := content(keyNode, source) + key := Content(keyNode, source) switch key { case "default": val, ok := resolveDefaultExpr(valNode, source, scope) @@ -1435,14 +1435,14 @@ func parseDefaultValue(node *sitter.Node, source []byte) (schema.DefaultValue, b case "false": return schema.DefaultValue{Kind: schema.DefaultBool, Bool: false}, true case "integer": - text := content(node, source) + text := Content(node, source) n, err := strconv.ParseInt(text, 0, 64) if err != nil { return schema.DefaultValue{}, false } return schema.DefaultValue{Kind: schema.DefaultInt, Int: n}, true case "float": - text := content(node, source) + text := Content(node, source) f, err := strconv.ParseFloat(text, 64) if err != nil { return schema.DefaultValue{}, false @@ -1473,7 +1473,7 @@ func parseDefaultValue(node *sitter.Node, source []byte) (schema.DefaultValue, b } return schema.DefaultValue{Kind: schema.DefaultSet, List: items}, true case "unary_operator": - text := strings.TrimSpace(content(node, source)) + text := strings.TrimSpace(Content(node, source)) if n, err := strconv.ParseInt(text, 0, 64); err == nil { return schema.DefaultValue{Kind: schema.DefaultInt, Int: n}, true } @@ -1483,7 +1483,7 @@ func parseDefaultValue(node *sitter.Node, source []byte) (schema.DefaultValue, b return schema.DefaultValue{}, false case "tuple": var items []schema.DefaultValue - for _, child := range namedChildren(node) { + for _, child := range NamedChildren(node) { if val, ok := parseDefaultValue(child, source); ok { items = append(items, val) } @@ -1494,7 +1494,7 @@ func parseDefaultValue(node *sitter.Node, source []byte) (schema.DefaultValue, b } func parseStringLiteral(node *sitter.Node, source []byte) (string, bool) { - text := content(node, source) + text := Content(node, source) if strings.HasPrefix(text, `"""`) || strings.HasPrefix(text, `'''`) { if len(text) >= 6 { return text[3 : len(text)-3], true @@ -1517,7 +1517,7 @@ func parseStringLiteral(node *sitter.Node, source []byte) (string, bool) { } func parseNumberLiteral(node *sitter.Node, source []byte) (float64, bool) { - text := strings.TrimSpace(content(node, source)) + text := strings.TrimSpace(Content(node, source)) f, err := strconv.ParseFloat(text, 64) if err != nil { return 0, false @@ -1532,7 +1532,7 @@ func parseBoolLiteral(node *sitter.Node, source []byte) (bool, bool) { case "false": return false, true } - text := content(node, source) + text := Content(node, source) switch text { case "True": return true, true @@ -1547,7 +1547,7 @@ func parseListLiteral(node *sitter.Node, source []byte) ([]schema.DefaultValue, return nil, false } var items []schema.DefaultValue - for _, child := range namedChildren(node) { + for _, child := range NamedChildren(node) { val, ok := parseDefaultValue(child, source) if !ok { return nil, false @@ -1562,7 +1562,7 @@ func parseDictLiteral(node *sitter.Node, source []byte) ([]schema.DefaultValue, return nil, nil, false } var keys, vals []schema.DefaultValue - for _, child := range namedChildren(node) { + for _, child := range NamedChildren(node) { if child.Type() == "pair" { keyNode := child.ChildByFieldName("key") valNode := child.ChildByFieldName("value") @@ -1585,7 +1585,7 @@ func parseSetLiteral(node *sitter.Node, source []byte) ([]schema.DefaultValue, b return nil, false } var items []schema.DefaultValue - for _, child := range namedChildren(node) { + for _, child := range NamedChildren(node) { val, ok := parseDefaultValue(child, source) if !ok { return nil, false