diff --git a/internal/cmd/install.go b/internal/cmd/install.go index 2b0baad..591733a 100644 --- a/internal/cmd/install.go +++ b/internal/cmd/install.go @@ -1,17 +1,228 @@ package cmd import ( + "fmt" + "os" + "strings" + + "github.com/ArmisSecurity/armis-cli/internal/install" "github.com/spf13/cobra" ) var installCmd = &cobra.Command{ - Use: "install", - Short: "Install Armis integrations", - Long: `Install Armis integrations for development tools.`, - Example: ` # Install the Claude Code MCP plugin - armis-cli install claude`, + Use: "install [editor...]", + Short: "Install the Armis security scanner MCP server", + Long: `Download and install the Armis AppSec MCP server for your coding tools. + +With no arguments, installs the plugin and registers it in all detected editors. +Specify one or more editor names to target specific tools. + +Supported editors: + claude Claude Code (uses plugin system) + vscode VS Code / GitHub Copilot + copilot Alias for vscode + cursor Cursor + windsurf Windsurf (Codeium) + zed Zed + cline Cline (VS Code extension) + amazonq Amazon Q Developer + continue Continue + antigravity Antigravity + gemini Gemini CLI + +Not auto-configurable (manual setup required): + jetbrains JetBrains IDEs (per-project .jb-mcp.json) + devin Devin (cloud-based, configure via web UI) + aider Aider (no MCP support)`, + Example: ` # Install to all detected editors + armis-cli install + + # Install to specific editors + armis-cli install vscode cursor + + # Install to Claude Code only + armis-cli install claude + + # Check installed version + armis-cli install --version`, + RunE: runInstall, } func init() { rootCmd.AddCommand(installCmd) + installCmd.Flags().Bool("version", false, "Print the installed plugin version and exit") +} + +func runInstall(cmd *cobra.Command, args []string) error { + showVersion, err := cmd.Flags().GetBool("version") + if err != nil { + return fmt.Errorf("reading --version flag: %w", err) + } + + if showVersion { + return showInstalledVersions() + } + + if len(args) == 0 { + return installAll() + } + + return installTargets(args) +} + +func showInstalledVersions() error { + ei := install.NewEditorInstaller() + v := ei.GetInstalledVersion() + + ci := install.NewClaudeInstaller() + cv := ci.GetInstalledVersion() + + if v == "" && cv == "" { + return fmt.Errorf("Armis AppSec MCP server is not installed — run: armis-cli install") //nolint:staticcheck // proper noun + } + + if cv != "" { + fmt.Fprintf(os.Stderr, "Claude Code plugin: v%s\n", cv) + } + if v != "" { + fmt.Fprintf(os.Stderr, "MCP server: v%s\n", v) + } + return nil +} + +func installAll() error { + ei := install.NewEditorInstaller() + + fmt.Fprintln(os.Stderr, "Downloading Armis AppSec MCP server...") + if err := ei.FetchPlugin(); err != nil { + return fmt.Errorf("download failed: %w", err) + } + fmt.Fprintf(os.Stderr, "MCP server v%s downloaded.\n\n", ei.InstalledVersion()) + + detected := install.DetectedEditors() + var registered []string + var failed []string + + for _, e := range detected { + if err := e.Register(ei.PluginDir()); err != nil { + fmt.Fprintf(os.Stderr, " ✗ %s: %v\n", e.Name, err) + failed = append(failed, e.Name) + } else { + fmt.Fprintf(os.Stderr, " ✓ %s\n", e.Name) + registered = append(registered, e.Name) + } + } + + ci := install.NewClaudeInstaller() + if err := ci.Install(); err != nil { + fmt.Fprintf(os.Stderr, " ✗ Claude Code: %v\n", err) + failed = append(failed, "Claude Code") + } else { + fmt.Fprintf(os.Stderr, " ✓ Claude Code\n") + registered = append(registered, "Claude Code") + } + + fmt.Fprintln(os.Stderr, "") + + if len(registered) > 0 { + fmt.Fprintf(os.Stderr, "Registered in: %s\n", strings.Join(registered, ", ")) + } + if len(failed) > 0 { + fmt.Fprintf(os.Stderr, "Failed: %s\n", strings.Join(failed, ", ")) + } + if len(detected) == 0 && len(registered) <= 1 { + fmt.Fprintln(os.Stderr, "No additional editors detected. Use 'armis-cli install ' to target a specific tool.") + } + + printCredentialStatus(ei) + return nil +} + +func installTargets(targets []string) error { + hasClaude := false + var editorIDs []install.EditorID + + for _, name := range targets { + switch name { + case "claude": + hasClaude = true + case "copilot": + editorIDs = append(editorIDs, install.EditorVSCode) + case "jetbrains": + fmt.Fprintln(os.Stderr, "JetBrains: MCP servers are configured per-project.") + fmt.Fprintln(os.Stderr, "After installing, copy .jb-mcp.json to your project root.") + fmt.Fprintln(os.Stderr, "Run: armis-cli install --jetbrains-project /path/to/project") + fmt.Fprintln(os.Stderr, "") + case "devin": + fmt.Fprintln(os.Stderr, "Devin: MCP servers are configured via the Devin web UI.") + fmt.Fprintln(os.Stderr, "See: Settings → MCP Servers in your Devin dashboard.") + fmt.Fprintln(os.Stderr, "") + case "aider": + fmt.Fprintln(os.Stderr, "Aider does not support MCP servers.") + fmt.Fprintln(os.Stderr, "") + default: + id := install.EditorID(name) + if _, ok := install.EditorByID(id); !ok { + return fmt.Errorf("unknown editor %q — run 'armis-cli install --help' for supported editors", name) + } + editorIDs = append(editorIDs, id) + } + } + + needsSharedPlugin := len(editorIDs) > 0 + var ei *install.EditorInstaller + + if needsSharedPlugin { + ei = install.NewEditorInstaller() + fmt.Fprintln(os.Stderr, "Downloading Armis AppSec MCP server...") + if err := ei.FetchPlugin(); err != nil { + return fmt.Errorf("download failed: %w", err) + } + fmt.Fprintf(os.Stderr, "MCP server v%s downloaded.\n\n", ei.InstalledVersion()) + + for _, id := range editorIDs { + e, _ := install.EditorByID(id) + if err := e.Register(ei.PluginDir()); err != nil { + fmt.Fprintf(os.Stderr, " ✗ %s: %v\n", e.Name, err) + } else { + fmt.Fprintf(os.Stderr, " ✓ %s\n", e.Name) + } + } + fmt.Fprintln(os.Stderr, "") + printCredentialStatus(ei) + } + + if hasClaude { + ci := install.NewClaudeInstaller() + fmt.Fprintln(os.Stderr, "Installing Armis AppSec plugin for Claude Code...") + if err := ci.Install(); err != nil { + return fmt.Errorf("Claude Code installation failed: %w", err) //nolint:staticcheck // proper noun + } + fmt.Fprintf(os.Stderr, " ✓ Claude Code v%s\n", ci.InstalledVersion()) + fmt.Fprintln(os.Stderr, "") + + if ci.HasExistingEnv() { + fmt.Fprintln(os.Stderr, "Credentials configured. Restart Claude Code to pick up the updated plugin.") + } else { + fmt.Fprintln(os.Stderr, "Next steps:") + fmt.Fprintf(os.Stderr, " 1. Set your credentials in %s:\n", ci.EnvFilePath()) + fmt.Fprintln(os.Stderr, " ARMIS_CLIENT_ID=") + fmt.Fprintln(os.Stderr, " ARMIS_CLIENT_SECRET=") + fmt.Fprintln(os.Stderr, " 2. Restart Claude Code") + } + } + + return nil +} + +func printCredentialStatus(ei *install.EditorInstaller) { + if ei.HasExistingEnv() { + fmt.Fprintln(os.Stderr, "Credentials configured. Restart your editors to use the MCP server.") + } else { + fmt.Fprintln(os.Stderr, "Next steps:") + fmt.Fprintf(os.Stderr, " 1. Set your credentials in %s:\n", ei.EnvFilePath()) + fmt.Fprintln(os.Stderr, " ARMIS_CLIENT_ID=") + fmt.Fprintln(os.Stderr, " ARMIS_CLIENT_SECRET=") + fmt.Fprintln(os.Stderr, " 2. Restart your editors") + } } diff --git a/internal/cmd/install_claude.go b/internal/cmd/install_claude.go deleted file mode 100644 index d625da6..0000000 --- a/internal/cmd/install_claude.go +++ /dev/null @@ -1,82 +0,0 @@ -package cmd - -import ( - "fmt" - "os" - - "github.com/ArmisSecurity/armis-cli/internal/cli" - "github.com/ArmisSecurity/armis-cli/internal/install" - "github.com/spf13/cobra" -) - -var installClaudeCmd = &cobra.Command{ - Use: "claude", - Short: "Install the Armis security scanner plugin for Claude Code", - Long: `Download and install the Armis AppSec MCP plugin for Claude Code. - -The plugin adds AI-powered vulnerability scanning directly into Claude Code: - - scan_code: Scan code snippets for vulnerabilities - - scan_file: Scan files on disk - - scan_diff: Scan git changes before committing - -After installation, set your credentials in the plugin's .env file -and restart Claude Code. - -Source: https://github.com/ArmisSecurity/armis-appsec-mcp`, - Example: ` # Install the Claude Code plugin - armis-cli install claude - - # Check the installed plugin version - armis-cli install claude --version`, - RunE: runInstallClaude, -} - -func init() { - installCmd.AddCommand(installClaudeCmd) - installClaudeCmd.Flags().Bool("version", false, "Print the installed plugin version and exit") -} - -func runInstallClaude(cmd *cobra.Command, args []string) error { - installer := install.NewClaudeInstaller() - - showVersion, err := cmd.Flags().GetBool("version") - if err != nil { - return fmt.Errorf("reading --version flag: %w", err) - } - if showVersion { - v := installer.GetInstalledVersion() - if v == "" { - return fmt.Errorf("Armis AppSec plugin is not installed — run: armis-cli install claude") //nolint:staticcheck // proper noun - } - fmt.Fprintf(os.Stderr, "Armis AppSec plugin v%s\n", v) - return nil - } - - fmt.Fprintln(os.Stderr, "Installing Armis AppSec plugin for Claude Code...") - - if err := installer.Install(); err != nil { - return fmt.Errorf("installation failed: %w", err) - } - - fmt.Fprintln(os.Stderr, "") - fmt.Fprintf(os.Stderr, "Plugin v%s installed successfully!\n", installer.InstalledVersion()) - fmt.Fprintln(os.Stderr, "") - - if !installer.HasExistingEnv() { - home, err := os.UserHomeDir() - if err != nil { - return fmt.Errorf("determining home directory: %w", err) - } - envPath := home + "/.claude/plugins/cache/armis-appsec-mcp/armis-appsec/latest/.env" - fmt.Fprintln(os.Stderr, "Next steps:") - fmt.Fprintf(os.Stderr, " 1. Set your credentials in %s:\n", envPath) - fmt.Fprintln(os.Stderr, " ARMIS_CLIENT_ID=") - fmt.Fprintln(os.Stderr, " ARMIS_CLIENT_SECRET=") - fmt.Fprintln(os.Stderr, " 2. Restart Claude Code") - } else { - cli.PrintWarning("Existing .env file preserved — credentials were not overwritten.") - fmt.Fprintln(os.Stderr, "Restart Claude Code to pick up the updated plugin.") - } - - return nil -} diff --git a/internal/install/claude.go b/internal/install/claude.go index 0263827..089bd3c 100644 --- a/internal/install/claude.go +++ b/internal/install/claude.go @@ -1,89 +1,48 @@ -// Package install provides installation logic for Armis integrations. package install import ( - "archive/tar" - "compress/gzip" "encoding/json" "fmt" - "io" - "net/http" - "net/url" "os" - "os/exec" "path/filepath" - "runtime" - "strings" "time" ) -const githubAPIHost = "api.github.com" - const ( - pluginRepo = "ArmisSecurity/armis-appsec-mcp" - marketplaceName = "armis-appsec-mcp" - pluginName = "armis-appsec" - releasesURL = "https://api.github.com/repos/" + pluginRepo + "/releases/latest" - downloadTimeout = 60 * time.Second - maxArchiveBytes = 50 * 1024 * 1024 // 50 MB safety limit - maxExtractedSize = 100 * 1024 * 1024 // 100 MB total extracted size - maxFileSize = 10 * 1024 * 1024 // 10 MB per file - maxArchiveEntries = 10000 // max tar entries to prevent resource exhaustion + marketplaceName = "armis-appsec-mcp" + pluginName = "armis-appsec" ) -// githubRelease is the minimal structure from the GitHub releases API. -type githubRelease struct { - TagName string `json:"tag_name"` - TarballURL string `json:"tarball_url"` -} - // ClaudeInstaller installs the Armis AppSec MCP plugin for Claude Code. type ClaudeInstaller struct { - claudeDir string - httpClient *http.Client - releasesURL string - installedVersion string - skipURLValidation bool // testing only: skip GitHub URL enforcement + claudeDir string + plugin *PluginInstaller } // NewClaudeInstaller creates an installer with the default Claude directory. func NewClaudeInstaller() *ClaudeInstaller { home, _ := os.UserHomeDir() return &ClaudeInstaller{ - claudeDir: filepath.Join(home, ".claude"), - httpClient: &http.Client{Timeout: downloadTimeout}, - releasesURL: releasesURL, + claudeDir: filepath.Join(home, ".claude"), + plugin: newPluginInstaller(), } } // InstalledVersion returns the version that was installed (available after Install). func (ci *ClaudeInstaller) InstalledVersion() string { - return ci.installedVersion + return ci.plugin.InstalledVersion() } -// Install downloads and installs the MCP plugin. +// Install downloads and installs the MCP plugin for Claude Code. func (ci *ClaudeInstaller) Install() error { if _, err := os.Stat(ci.claudeDir); os.IsNotExist(err) { return fmt.Errorf("Claude Code directory not found at %s — is Claude Code installed?", ci.claudeDir) //nolint:staticcheck // proper noun } - release, err := ci.fetchLatestRelease() - if err != nil { - return fmt.Errorf("failed to fetch latest release: %w", err) - } - ci.installedVersion = strings.TrimPrefix(release.TagName, "v") - pluginDir := ci.pluginCacheDir() - if err := os.MkdirAll(pluginDir, 0o750); err != nil { - return fmt.Errorf("failed to create plugin directory: %w", err) - } - - if err := ci.downloadAndExtract(release.TarballURL, pluginDir); err != nil { - return fmt.Errorf("failed to download plugin: %w", err) - } - if err := ci.createVenv(pluginDir); err != nil { - return fmt.Errorf("failed to set up Python environment: %w", err) + if err := ci.plugin.FetchAndInstall(pluginDir); err != nil { + return err } if err := ci.registerMarketplace(pluginDir); err != nil { @@ -98,16 +57,21 @@ func (ci *ClaudeInstaller) Install() error { return fmt.Errorf("failed to enable plugin: %w", err) } + writeEnvFromEnvironment(ci.EnvFilePath()) + return nil } -// pluginCacheDir returns the install target directory. func (ci *ClaudeInstaller) pluginCacheDir() string { return filepath.Join(ci.claudeDir, "plugins", "cache", marketplaceName, pluginName, "latest") } +// EnvFilePath returns the path to the plugin's .env file. +func (ci *ClaudeInstaller) EnvFilePath() string { + return filepath.Join(ci.pluginCacheDir(), ".env") +} + // GetInstalledVersion reads the installed plugin version from the registry. -// Returns empty string if the plugin is not installed. func (ci *ClaudeInstaller) GetInstalledVersion() string { instFile := filepath.Join(ci.claudeDir, "plugins", "installed_plugins.json") b, err := os.ReadFile(filepath.Clean(instFile)) @@ -140,203 +104,10 @@ func (ci *ClaudeInstaller) GetInstalledVersion() string { // HasExistingEnv checks whether credentials are already configured. func (ci *ClaudeInstaller) HasExistingEnv() bool { - envPath := filepath.Join(ci.pluginCacheDir(), ".env") - _, err := os.Stat(envPath) + _, err := os.Stat(ci.EnvFilePath()) return err == nil } -func (ci *ClaudeInstaller) fetchLatestRelease() (*githubRelease, error) { - if !ci.skipURLValidation { - if err := validateGitHubURL(ci.releasesURL); err != nil { - return nil, fmt.Errorf("invalid releases URL: %w", err) - } - } - - req, err := http.NewRequest("GET", ci.releasesURL, nil) //nolint:gosec // URL validated by validateGitHubURL above - if err != nil { - return nil, fmt.Errorf("creating request: %w", err) - } - req.Header.Set("Accept", "application/vnd.github.v3+json") - - resp, err := ci.httpClient.Do(req) //nolint:gosec // URL validated by validateGitHubURL above - if err != nil { - return nil, fmt.Errorf("querying GitHub releases: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("GitHub API returned HTTP %d — is there a published release?", resp.StatusCode) - } - - body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) - if err != nil { - return nil, fmt.Errorf("reading response: %w", err) - } - - var release githubRelease - if err := json.Unmarshal(body, &release); err != nil { - return nil, fmt.Errorf("parsing release: %w", err) - } - - if release.TagName == "" || release.TarballURL == "" { - return nil, fmt.Errorf("release is missing tag or tarball URL") - } - - return &release, nil -} - -func (ci *ClaudeInstaller) downloadAndExtract(tarballURL, destDir string) error { - if !ci.skipURLValidation { - if err := validateGitHubURL(tarballURL); err != nil { - return fmt.Errorf("invalid tarball URL: %w", err) - } - } - - req, err := http.NewRequest("GET", tarballURL, nil) //nolint:gosec // URL validated by validateGitHubURL above - if err != nil { - return fmt.Errorf("creating request: %w", err) - } - req.Header.Set("Accept", "application/vnd.github+json") - - resp, err := ci.httpClient.Do(req) //nolint:gosec // URL validated by validateGitHubURL above - if err != nil { - return fmt.Errorf("downloading archive: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("GitHub API returned HTTP %d", resp.StatusCode) - } - - reader := io.LimitReader(resp.Body, maxArchiveBytes) - gz, err := gzip.NewReader(reader) - if err != nil { - return fmt.Errorf("decompressing archive: %w", err) - } - defer func() { _ = gz.Close() }() - - tr := tar.NewReader(gz) - var totalExtracted int64 - var entryCount int - var prefix string - - for { - header, err := tr.Next() - if err == io.EOF { - break - } - if err != nil { - return fmt.Errorf("reading archive: %w", err) - } - - entryCount++ - if entryCount > maxArchiveEntries { - return fmt.Errorf("archive exceeds %d entry limit", maxArchiveEntries) - } - - if header.Typeflag == tar.TypeXGlobalHeader || header.Typeflag == tar.TypeXHeader { - continue - } - - // GitHub tarballs have a top-level directory like "org-repo-sha/" - // Strip it to extract files directly into destDir. - if prefix == "" { - parts := strings.SplitN(header.Name, "/", 2) - if len(parts) > 0 { - prefix = parts[0] + "/" - } - } - - name := strings.TrimPrefix(header.Name, prefix) - if name == "" || name == "." { - continue - } - - // CWE-22: reject any entry containing path traversal sequences before cleaning - if strings.Contains(name, "..") { - continue - } - - clean := filepath.Clean(filepath.FromSlash(name)) - if filepath.IsAbs(clean) { - continue - } - - target := filepath.Join(destDir, clean) - absTarget, err := filepath.Abs(target) - if err != nil { - continue - } - absDestDir, err := filepath.Abs(destDir) - if err != nil { - continue - } - if !strings.HasPrefix(absTarget, absDestDir+string(os.PathSeparator)) && absTarget != absDestDir { - continue - } - - switch header.Typeflag { - case tar.TypeDir: - if err := os.MkdirAll(absTarget, 0o750); err != nil { - return fmt.Errorf("creating directory %s: %w", name, err) - } - case tar.TypeReg: - if header.Size > maxFileSize { - continue - } - totalExtracted += header.Size - if totalExtracted > maxExtractedSize { - return fmt.Errorf("extracted archive exceeds %d MB safety limit", maxExtractedSize/1024/1024) - } - if err := os.MkdirAll(filepath.Dir(absTarget), 0o750); err != nil { - return fmt.Errorf("creating parent directory: %w", err) - } - perm := os.FileMode(0o644) - if header.Mode&0o100 != 0 { - perm = 0o750 - } - if err := extractFile(absTarget, tr, perm); err != nil { - return fmt.Errorf("writing file %s: %w", name, err) - } - } - } - - if prefix == "" { - return fmt.Errorf("archive appears to be empty") - } - - return nil -} - -func (ci *ClaudeInstaller) createVenv(pluginDir string) error { - python := findPython() - if python == "" { - return fmt.Errorf("Python 3.11+ is required but not found in PATH") //nolint:staticcheck // proper noun - } - - venvDir := filepath.Join(pluginDir, ".venv") - venvCmd := exec.Command(python, "-m", "venv", venvDir) //nolint:gosec // python validated by findPython allowlist - venvCmd.Stdout = os.Stderr - venvCmd.Stderr = os.Stderr - if err := venvCmd.Run(); err != nil { - return fmt.Errorf("creating venv: %w", err) - } - - pip := filepath.Join(venvDir, "bin", "pip") - if runtime.GOOS == "windows" { - pip = filepath.Join(venvDir, "Scripts", "pip.exe") - } - reqsFile := filepath.Join(pluginDir, "requirements.txt") - pipCmd := exec.Command(pip, "install", "-q", "-r", reqsFile) //nolint:gosec // pip path derived from our own venv - pipCmd.Stdout = os.Stderr - pipCmd.Stderr = os.Stderr - if err := pipCmd.Run(); err != nil { - return fmt.Errorf("installing dependencies: %w", err) - } - - return nil -} - func (ci *ClaudeInstaller) registerMarketplace(pluginDir string) error { mktsFile := filepath.Join(ci.claudeDir, "plugins", "known_marketplaces.json") data := make(map[string]interface{}) @@ -372,7 +143,7 @@ func (ci *ClaudeInstaller) registerPlugin(pluginDir string) error { map[string]interface{}{ "scope": "user", "installPath": pluginDir, - "version": ci.installedVersion, + "version": ci.plugin.InstalledVersion(), "installedAt": now, "lastUpdated": now, }, @@ -399,62 +170,3 @@ func (ci *ClaudeInstaller) enablePlugin() error { return writeJSON(settingsFile, data) } - -func validateGitHubURL(rawURL string) error { - u, err := url.Parse(rawURL) - if err != nil { - return fmt.Errorf("malformed URL: %w", err) - } - if u.Scheme != "https" { - return fmt.Errorf("URL scheme must be https, got %q", u.Scheme) - } - if u.Host != githubAPIHost { - return fmt.Errorf("URL host must be %s, got %q", githubAPIHost, u.Host) - } - return nil -} - -func extractFile(target string, r io.Reader, perm os.FileMode) error { - f, err := os.OpenFile(filepath.Clean(target), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, perm) //nolint:gosec // target validated by caller - if err != nil { - return err - } - if _, err := io.Copy(f, io.LimitReader(r, maxFileSize)); err != nil { - _ = f.Close() - return err - } - return f.Close() -} - -func writeJSON(path string, data interface{}) error { - if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil { - return err - } - b, err := json.MarshalIndent(data, "", " ") - if err != nil { - return err - } - return os.WriteFile(filepath.Clean(path), append(b, '\n'), 0o600) -} - -func findPython() string { - for _, name := range []string{"python3", "python"} { - resolved, err := exec.LookPath(name) - if err != nil { - continue - } - // CWE-426: resolve symlinks and verify the path is absolute - resolved, err = filepath.EvalSymlinks(resolved) - if err != nil || !filepath.IsAbs(resolved) { - continue - } - out, err := exec.Command(resolved, "-c", "import sys; print(sys.version_info >= (3, 11))").Output() //nolint:gosec // resolved path validated above - if err != nil { - continue - } - if strings.TrimSpace(string(out)) == "True" { - return resolved - } - } - return "" -} diff --git a/internal/install/claude_test.go b/internal/install/claude_test.go index 1b89ec5..3386bb4 100644 --- a/internal/install/claude_test.go +++ b/internal/install/claude_test.go @@ -1,11 +1,7 @@ package install import ( - "archive/tar" - "compress/gzip" "encoding/json" - "net/http" - "net/http/httptest" "os" "path/filepath" "testing" @@ -18,14 +14,14 @@ func TestNewClaudeInstaller(t *testing.T) { if ci.claudeDir == "" { t.Fatal("claudeDir should not be empty") } - if ci.httpClient == nil { - t.Fatal("httpClient should not be nil") + if ci.plugin == nil { + t.Fatal("plugin should not be nil") } } func TestPluginCacheDir(t *testing.T) { base := filepath.Join("home", "test", ".claude") - ci := &ClaudeInstaller{claudeDir: base} + ci := &ClaudeInstaller{claudeDir: base, plugin: newPluginInstaller()} got := ci.pluginCacheDir() want := filepath.Join(base, "plugins", "cache", "armis-appsec-mcp", "armis-appsec", "latest") if got != want { @@ -35,7 +31,7 @@ func TestPluginCacheDir(t *testing.T) { func TestHasExistingEnv(t *testing.T) { dir := t.TempDir() - ci := &ClaudeInstaller{claudeDir: dir} + ci := &ClaudeInstaller{claudeDir: dir, plugin: newPluginInstaller()} if ci.HasExistingEnv() { t.Error("HasExistingEnv() should return false when .env doesn't exist") @@ -53,85 +49,12 @@ func TestHasExistingEnv(t *testing.T) { } } -func TestFetchLatestRelease(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - _, _ = w.Write([]byte(`{"tag_name":"v1.2.3","tarball_url":"https://api.github.com/repos/test/tarball/v1.2.3"}`)) - })) - defer server.Close() - - ci := &ClaudeInstaller{ - httpClient: server.Client(), - releasesURL: server.URL, - skipURLValidation: true, - } - - release, err := ci.fetchLatestRelease() - if err != nil { - t.Fatalf("fetchLatestRelease() error: %v", err) - } - if release.TagName != "v1.2.3" { - t.Errorf("TagName = %q, want %q", release.TagName, "v1.2.3") - } - if release.TarballURL == "" { - t.Error("TarballURL should not be empty") - } -} - -func TestFetchLatestRelease_NoRelease(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - })) - defer server.Close() - - ci := &ClaudeInstaller{ - httpClient: server.Client(), - releasesURL: server.URL, - skipURLValidation: true, - } - - _, err := ci.fetchLatestRelease() - if err == nil { - t.Fatal("expected error for 404 response") - } -} - -func TestDownloadAndExtract(t *testing.T) { - tarball := createTestTarball(t) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/gzip") - _, _ = w.Write(tarball) - })) - defer server.Close() - - ci := &ClaudeInstaller{ - claudeDir: t.TempDir(), - httpClient: server.Client(), - skipURLValidation: true, - } - - destDir := filepath.Join(ci.claudeDir, "extract") - if err := os.MkdirAll(destDir, 0o750); err != nil { - t.Fatal(err) - } - - if err := ci.downloadAndExtract(server.URL, destDir); err != nil { - t.Fatalf("downloadAndExtract() error: %v", err) - } - - if _, err := os.Stat(filepath.Join(destDir, "server.py")); err != nil { - t.Error("server.py not extracted") - } - if _, err := os.Stat(filepath.Join(destDir, "requirements.txt")); err != nil { - t.Error("requirements.txt not extracted") - } -} - func TestInstalledVersion(t *testing.T) { - ci := &ClaudeInstaller{} + ci := &ClaudeInstaller{plugin: newPluginInstaller()} if v := ci.InstalledVersion(); v != "" { t.Errorf("InstalledVersion() = %q, want empty", v) } - ci.installedVersion = testVersion + ci.plugin.installedVersion = testVersion if v := ci.InstalledVersion(); v != testVersion { t.Errorf("InstalledVersion() = %q, want %q", v, testVersion) } @@ -139,7 +62,7 @@ func TestInstalledVersion(t *testing.T) { func TestRegisterMarketplace(t *testing.T) { dir := t.TempDir() - ci := &ClaudeInstaller{claudeDir: dir} + ci := &ClaudeInstaller{claudeDir: dir, plugin: newPluginInstaller()} pluginsDir := filepath.Join(dir, "plugins") if err := os.MkdirAll(pluginsDir, 0o750); err != nil { t.Fatal(err) @@ -166,7 +89,9 @@ func TestRegisterMarketplace(t *testing.T) { func TestRegisterPlugin(t *testing.T) { dir := t.TempDir() - ci := &ClaudeInstaller{claudeDir: dir, installedVersion: testVersion} + pi := newPluginInstaller() + pi.installedVersion = testVersion + ci := &ClaudeInstaller{claudeDir: dir, plugin: pi} pluginsDir := filepath.Join(dir, "plugins") if err := os.MkdirAll(pluginsDir, 0o750); err != nil { t.Fatal(err) @@ -204,13 +129,14 @@ func TestRegisterPlugin(t *testing.T) { func TestGetInstalledVersion(t *testing.T) { dir := t.TempDir() - ci := &ClaudeInstaller{claudeDir: dir} + pi := newPluginInstaller() + ci := &ClaudeInstaller{claudeDir: dir, plugin: pi} if v := ci.GetInstalledVersion(); v != "" { t.Errorf("GetInstalledVersion() = %q, want empty for missing file", v) } - ci.installedVersion = "2.1.0" + pi.installedVersion = "2.1.0" pluginsDir := filepath.Join(dir, "plugins") if err := os.MkdirAll(pluginsDir, 0o750); err != nil { t.Fatal(err) @@ -227,7 +153,7 @@ func TestGetInstalledVersion(t *testing.T) { func TestEnablePlugin(t *testing.T) { dir := t.TempDir() - ci := &ClaudeInstaller{claudeDir: dir} + ci := &ClaudeInstaller{claudeDir: dir, plugin: newPluginInstaller()} if err := ci.enablePlugin(); err != nil { t.Fatalf("enablePlugin() error: %v", err) @@ -255,7 +181,7 @@ func TestEnablePlugin(t *testing.T) { func TestEnablePluginPreservesExistingSettings(t *testing.T) { dir := t.TempDir() - ci := &ClaudeInstaller{claudeDir: dir} + ci := &ClaudeInstaller{claudeDir: dir, plugin: newPluginInstaller()} existing := map[string]interface{}{ "permissions": map[string]interface{}{"allow": []string{"Bash"}}, @@ -282,7 +208,6 @@ func TestEnablePluginPreservesExistingSettings(t *testing.T) { t.Fatal(err) } - // Verify existing settings preserved if result["permissions"] == nil { t.Error("existing permissions key was lost") } @@ -297,16 +222,10 @@ func TestEnablePluginPreservesExistingSettings(t *testing.T) { } } -func TestFindPython(t *testing.T) { - // This test just verifies findPython doesn't panic. - // On CI without Python 3.11+, it may return "". - _ = findPython() -} - func TestInstallMissingClaudeDir(t *testing.T) { ci := &ClaudeInstaller{ - claudeDir: "/nonexistent/path/.claude", - httpClient: http.DefaultClient, + claudeDir: "/nonexistent/path/.claude", + plugin: newPluginInstaller(), } err := ci.Install() if err == nil { @@ -329,110 +248,3 @@ func searchString(s, substr string) bool { } return false } - -func TestDownloadAndExtractFlattensPrefix(t *testing.T) { - tarball := createTestTarball(t, true) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/gzip") - _, _ = w.Write(tarball) - })) - defer server.Close() - - ci := &ClaudeInstaller{ - claudeDir: t.TempDir(), - httpClient: server.Client(), - skipURLValidation: true, - } - - destDir := filepath.Join(ci.claudeDir, "extract") - if err := os.MkdirAll(destDir, 0o750); err != nil { - t.Fatal(err) - } - - if err := ci.downloadAndExtract(server.URL, destDir); err != nil { - t.Fatalf("downloadAndExtract() error: %v", err) - } - - wantFiles := []string{"server.py", "requirements.txt"} - for _, f := range wantFiles { - if _, err := os.Stat(filepath.Join(destDir, f)); err != nil { - t.Errorf("expected file %q not found in extracted directory", f) - } - } -} - -// createTestTarball creates a gzipped tarball matching GitHub's format: -// top-level directory prefix like "org-repo-sha/" with files inside. -// If withPaxHeader is true, includes a pax_global_header like real GitHub tarballs. -func createTestTarball(t *testing.T, withPaxHeader ...bool) []byte { - t.Helper() - var buf []byte - - tmpFile := filepath.Join(t.TempDir(), "test.tar.gz") - f, err := os.Create(filepath.Clean(tmpFile)) - if err != nil { - t.Fatal(err) - } - gw := gzip.NewWriter(f) - tw := tar.NewWriter(gw) - - if len(withPaxHeader) > 0 && withPaxHeader[0] { - if err := tw.WriteHeader(&tar.Header{ - Typeflag: tar.TypeXGlobalHeader, - Name: "pax_global_header", - Size: 0, - }); err != nil { - t.Fatal(err) - } - } - - writeEntry := func(hdr *tar.Header, data []byte) { - t.Helper() - if err := tw.WriteHeader(hdr); err != nil { - t.Fatal(err) - } - if len(data) > 0 { - if _, err := tw.Write(data); err != nil { - t.Fatal(err) - } - } - } - - writeEntry(&tar.Header{ - Name: "ArmisSecurity-armis-appsec-mcp-abc1234/", - Typeflag: tar.TypeDir, - Mode: 0o755, - }, nil) - - content := []byte("print('hello')\n") - writeEntry(&tar.Header{ - Name: "ArmisSecurity-armis-appsec-mcp-abc1234/server.py", - Typeflag: tar.TypeReg, - Mode: 0o644, - Size: int64(len(content)), - }, content) - - reqs := []byte("mcp[cli]==1.25.0\nhttpx==0.28.1\n") - writeEntry(&tar.Header{ - Name: "ArmisSecurity-armis-appsec-mcp-abc1234/requirements.txt", - Typeflag: tar.TypeReg, - Mode: 0o644, - Size: int64(len(reqs)), - }, reqs) - - if err := tw.Close(); err != nil { - t.Fatal(err) - } - if err := gw.Close(); err != nil { - t.Fatal(err) - } - if err := f.Close(); err != nil { - t.Fatal(err) - } - - buf, err = os.ReadFile(filepath.Clean(tmpFile)) - if err != nil { - t.Fatal(err) - } - return buf -} diff --git a/internal/install/editors.go b/internal/install/editors.go new file mode 100644 index 0000000..f940cff --- /dev/null +++ b/internal/install/editors.go @@ -0,0 +1,306 @@ +package install + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "runtime" +) + +const mcpServerName = "armis-appsec" + +// EditorID identifies a supported editor. +type EditorID string + +const ( + EditorVSCode EditorID = "vscode" + EditorCursor EditorID = "cursor" + EditorWindsurf EditorID = "windsurf" + EditorZed EditorID = "zed" + EditorCline EditorID = "cline" + EditorAmazonQ EditorID = "amazonq" + EditorContinue EditorID = "continue" + EditorAntigravity EditorID = "antigravity" + EditorGemini EditorID = "gemini" +) + +// Editor represents a code editor with MCP server support. +type Editor struct { + ID EditorID + Name string +} + +// AllEditors lists every editor that can be auto-configured. +var AllEditors = []Editor{ + {EditorVSCode, "VS Code"}, + {EditorCursor, "Cursor"}, + {EditorWindsurf, "Windsurf"}, + {EditorZed, "Zed"}, + {EditorCline, "Cline"}, + {EditorAmazonQ, "Amazon Q"}, + {EditorContinue, "Continue"}, + {EditorAntigravity, "Antigravity"}, + {EditorGemini, "Gemini CLI"}, +} + +// EditorByID returns the editor with the given ID. +func EditorByID(id EditorID) (Editor, bool) { + for _, e := range AllEditors { + if e.ID == id { + return e, true + } + } + return Editor{}, false +} + +// configPathOverrides lets tests inject custom config paths. +var configPathOverrides map[EditorID]string + +// ConfigPath returns the MCP config file path for this editor on the current OS. +func (e Editor) ConfigPath() string { + if configPathOverrides != nil { + if p, ok := configPathOverrides[e.ID]; ok { + return p + } + } + return defaultConfigPath(e.ID) +} + +// IsDetected checks whether the editor appears to be installed by looking +// for the parent directory of its config file. +func (e Editor) IsDetected() bool { + p := e.ConfigPath() + if p == "" { + return false + } + _, err := os.Stat(filepath.Dir(p)) + return err == nil +} + +// Register adds the Armis MCP server to this editor's configuration. +func (e Editor) Register(pluginDir string) error { + configFile := e.ConfigPath() + if configFile == "" { + return fmt.Errorf("%s is not supported on this platform", e.Name) + } + return registerEditor(e.ID, pluginDir, configFile) +} + +// DetectedEditors returns editors that appear to be installed on this system. +func DetectedEditors() []Editor { + var detected []Editor + for _, e := range AllEditors { + if e.IsDetected() { + detected = append(detected, e) + } + } + return detected +} + +// EditorInstaller downloads the plugin once and registers it across editors. +type EditorInstaller struct { + pluginDir string + plugin *PluginInstaller +} + +// NewEditorInstaller creates an installer using the shared plugin directory (~/.armis/plugins/armis-appsec-mcp). +func NewEditorInstaller() *EditorInstaller { + home, _ := os.UserHomeDir() + return &EditorInstaller{ + pluginDir: filepath.Join(home, ".armis", "plugins", "armis-appsec-mcp"), + plugin: newPluginInstaller(), + } +} + +// InstalledVersion returns the version that was installed (available after FetchPlugin). +func (ei *EditorInstaller) InstalledVersion() string { return ei.plugin.InstalledVersion() } + +// PluginDir returns the shared plugin installation directory. +func (ei *EditorInstaller) PluginDir() string { return ei.pluginDir } + +// EnvFilePath returns the path to the shared .env credentials file. +func (ei *EditorInstaller) EnvFilePath() string { return filepath.Join(ei.pluginDir, ".env") } + +// HasExistingEnv checks whether credentials are already configured. +func (ei *EditorInstaller) HasExistingEnv() bool { + _, err := os.Stat(ei.EnvFilePath()) + return err == nil +} + +// FetchPlugin downloads and sets up the plugin (venv + deps), writes credentials +// from the environment, and records the installed version. +func (ei *EditorInstaller) FetchPlugin() error { + if err := ei.plugin.FetchAndInstall(ei.pluginDir); err != nil { + return err + } + writeEnvFromEnvironment(ei.EnvFilePath()) + versionFile := filepath.Join(ei.pluginDir, ".installed-version") + _ = os.WriteFile(filepath.Clean(versionFile), []byte(ei.plugin.InstalledVersion()), 0o600) + return nil +} + +// GetInstalledVersion reads the version from the shared plugin directory. +func (ei *EditorInstaller) GetInstalledVersion() string { + versionFile := filepath.Join(ei.pluginDir, ".installed-version") + v, err := os.ReadFile(filepath.Clean(versionFile)) + if err != nil { + return "" + } + return string(v) +} + +// RegisterJetBrains writes a .jb-mcp.json file at the given path. +func RegisterJetBrains(pluginDir, configFile string) error { + return registerMCPServersFormat(pluginDir, configFile) +} + +// --- Config path resolution --- + +func defaultConfigPath(id EditorID) string { + switch id { + case EditorVSCode: + return appSupportPath("Code", "User", "mcp.json") + case EditorCursor: + return homeDir(".cursor", "mcp.json") + case EditorWindsurf: + return homeDir(".codeium", "windsurf", "mcp_config.json") + case EditorContinue: + return homeDir(".continue", "mcpServers", "armis-appsec.json") + case EditorZed: + if runtime.GOOS == osWindows { + return "" + } + return appSupportPath("Zed", "settings.json") + case EditorCline: + return appSupportPath("Code", "User", "globalStorage", + "saoudrizwan.claude-dev", "settings", "cline_mcp_settings.json") + case EditorAmazonQ: + return homeDir(".aws", "amazonq", "mcp.json") + case EditorAntigravity: + return homeDir(".gemini", "antigravity", "mcp_config.json") + case EditorGemini: + return homeDir(".gemini", "settings.json") + } + return "" +} + +func homeDir(parts ...string) string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return filepath.Join(append([]string{home}, parts...)...) +} + +func appSupportPath(parts ...string) string { + var base string + switch runtime.GOOS { + case "darwin": + home, err := os.UserHomeDir() + if err != nil { + return "" + } + base = filepath.Join(home, "Library", "Application Support") + case "linux": + base = os.Getenv("XDG_CONFIG_HOME") + if base == "" { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + base = filepath.Join(home, ".config") + } + case osWindows: + base = os.Getenv("APPDATA") + if base == "" { + return "" + } + default: + return "" + } + return filepath.Join(append([]string{base}, parts...)...) +} + +// --- Registration --- + +func registerEditor(id EditorID, pluginDir, configFile string) error { + switch id { + case EditorVSCode: + return registerVSCodeFormat(pluginDir, configFile) + case EditorZed: + return registerZedFormat(pluginDir, configFile) + default: + // Cursor, Windsurf, Cline, Amazon Q, Continue, Antigravity all use mcpServers map + return registerMCPServersFormat(pluginDir, configFile) + } +} + +// registerMCPServersFormat handles {"mcpServers": {"name": {command, args}}}. +// Used by Cursor, Windsurf, Cline, Amazon Q, JetBrains. +func registerMCPServersFormat(pluginDir, configFile string) error { + data := readJSONFileAsMap(configFile) + + servers, ok := data["mcpServers"].(map[string]interface{}) + if !ok { + servers = make(map[string]interface{}) + } + servers[mcpServerName] = stdServerEntry(pluginDir) + data["mcpServers"] = servers + + return writeJSON(configFile, data) +} + +// registerVSCodeFormat handles {"servers": {"name": {type, command, args, envFile}}}. +func registerVSCodeFormat(pluginDir, configFile string) error { + data := readJSONFileAsMap(configFile) + + servers, ok := data["servers"].(map[string]interface{}) + if !ok { + servers = make(map[string]interface{}) + } + servers[mcpServerName] = map[string]interface{}{ + "type": "stdio", + "command": venvPython(pluginDir), + "args": []string{filepath.Join(pluginDir, "server.py")}, + "envFile": filepath.Join(pluginDir, ".env"), + } + data["servers"] = servers + + return writeJSON(configFile, data) +} + +// registerZedFormat handles {"context_servers": {"name": {command: {path, args}}}}. +func registerZedFormat(pluginDir, configFile string) error { + data := readJSONFileAsMap(configFile) + + servers, ok := data["context_servers"].(map[string]interface{}) + if !ok { + servers = make(map[string]interface{}) + } + servers[mcpServerName] = map[string]interface{}{ + "command": map[string]interface{}{ + "path": venvPython(pluginDir), + "args": []string{filepath.Join(pluginDir, "server.py")}, + }, + "settings": map[string]interface{}{}, + } + data["context_servers"] = servers + + return writeJSON(configFile, data) +} + +func stdServerEntry(pluginDir string) map[string]interface{} { + return map[string]interface{}{ + "command": venvPython(pluginDir), + "args": []string{filepath.Join(pluginDir, "server.py")}, + } +} + +func readJSONFileAsMap(path string) map[string]interface{} { + data := make(map[string]interface{}) + if b, err := os.ReadFile(filepath.Clean(path)); err == nil { + _ = json.Unmarshal(b, &data) + } + return data +} diff --git a/internal/install/editors_test.go b/internal/install/editors_test.go new file mode 100644 index 0000000..a6399c5 --- /dev/null +++ b/internal/install/editors_test.go @@ -0,0 +1,332 @@ +package install + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestEditorByID(t *testing.T) { + e, ok := EditorByID(EditorVSCode) + if !ok { + t.Fatal("EditorByID(EditorVSCode) not found") + } + if e.Name != "VS Code" { + t.Errorf("Name = %q, want %q", e.Name, "VS Code") + } + + _, ok = EditorByID("nonexistent") + if ok { + t.Error("EditorByID(nonexistent) should return false") + } +} + +func TestEditorConfigPath(t *testing.T) { + for _, e := range AllEditors { + p := e.ConfigPath() + if p == "" { + t.Logf("skipping %s (not supported on this OS)", e.Name) + continue + } + if !filepath.IsAbs(p) { + t.Errorf("%s config path %q is not absolute", e.Name, p) + } + } +} + +func TestEditorConfigPathOverride(t *testing.T) { + configPathOverrides = map[EditorID]string{ + EditorCursor: "/tmp/test-cursor-mcp.json", + } + defer func() { configPathOverrides = nil }() + + e, _ := EditorByID(EditorCursor) + if got := e.ConfigPath(); got != "/tmp/test-cursor-mcp.json" { + t.Errorf("ConfigPath() = %q, want override path", got) + } +} + +func TestEditorIsDetected(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, "mcp.json") + + configPathOverrides = map[EditorID]string{ + EditorCursor: configFile, + } + defer func() { configPathOverrides = nil }() + + e, _ := EditorByID(EditorCursor) + if !e.IsDetected() { + t.Error("IsDetected() should return true when parent dir exists") + } + + configPathOverrides[EditorCursor] = "/nonexistent/dir/mcp.json" + if e.IsDetected() { + t.Error("IsDetected() should return false when parent dir missing") + } +} + +func TestDetectedEditors(t *testing.T) { + dir := t.TempDir() + configPathOverrides = make(map[EditorID]string) + for _, e := range AllEditors { + configPathOverrides[e.ID] = "/nonexistent/" + string(e.ID) + "/mcp.json" + } + cursorFile := filepath.Join(dir, "cursor-mcp.json") + configPathOverrides[EditorCursor] = cursorFile + defer func() { configPathOverrides = nil }() + + detected := DetectedEditors() + if len(detected) != 1 { + t.Fatalf("DetectedEditors() = %d editors, want 1", len(detected)) + } + if detected[0].ID != EditorCursor { + t.Errorf("detected[0].ID = %q, want %q", detected[0].ID, EditorCursor) + } +} + +func TestRegisterMCPServersFormat(t *testing.T) { + editors := []EditorID{EditorCursor, EditorWindsurf, EditorCline, EditorAmazonQ, EditorAntigravity, EditorContinue, EditorGemini} + for _, id := range editors { + t.Run(string(id), func(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, "mcp.json") + pluginDir := filepath.Join(dir, "plugin") + + configPathOverrides = map[EditorID]string{id: configFile} + defer func() { configPathOverrides = nil }() + + e, _ := EditorByID(id) + if err := e.Register(pluginDir); err != nil { + t.Fatalf("Register() error: %v", err) + } + + var data map[string]interface{} + b, _ := os.ReadFile(filepath.Clean(configFile)) + if err := json.Unmarshal(b, &data); err != nil { + t.Fatal(err) + } + + servers, ok := data["mcpServers"].(map[string]interface{}) + if !ok { + t.Fatal("mcpServers key missing") + } + server, ok := servers[mcpServerName].(map[string]interface{}) + if !ok { + t.Fatal("armis-appsec server not registered") + } + if server["command"] != venvPython(pluginDir) { + t.Errorf("command = %q, want %q", server["command"], venvPython(pluginDir)) + } + }) + } +} + +func TestRegisterVSCodeFormat(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, "mcp.json") + pluginDir := filepath.Join(dir, "plugin") + + configPathOverrides = map[EditorID]string{EditorVSCode: configFile} + defer func() { configPathOverrides = nil }() + + e, _ := EditorByID(EditorVSCode) + if err := e.Register(pluginDir); err != nil { + t.Fatalf("Register() error: %v", err) + } + + var data map[string]interface{} + b, _ := os.ReadFile(filepath.Clean(configFile)) + if err := json.Unmarshal(b, &data); err != nil { + t.Fatal(err) + } + + servers, ok := data["servers"].(map[string]interface{}) + if !ok { + t.Fatal("servers key missing") + } + server, ok := servers[mcpServerName].(map[string]interface{}) + if !ok { + t.Fatal("armis-appsec server not registered") + } + if server["type"] != "stdio" { + t.Errorf("type = %q, want %q", server["type"], "stdio") + } + if server["envFile"] == nil { + t.Error("envFile should be set for VS Code") + } +} + +func TestRegisterZedFormat(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, "settings.json") + pluginDir := filepath.Join(dir, "plugin") + + configPathOverrides = map[EditorID]string{EditorZed: configFile} + defer func() { configPathOverrides = nil }() + + e, _ := EditorByID(EditorZed) + if err := e.Register(pluginDir); err != nil { + t.Fatalf("Register() error: %v", err) + } + + var data map[string]interface{} + b, _ := os.ReadFile(filepath.Clean(configFile)) + if err := json.Unmarshal(b, &data); err != nil { + t.Fatal(err) + } + + ctxServers, ok := data["context_servers"].(map[string]interface{}) + if !ok { + t.Fatal("context_servers key missing") + } + server, ok := ctxServers[mcpServerName].(map[string]interface{}) + if !ok { + t.Fatal("armis-appsec server not registered") + } + cmd, ok := server["command"].(map[string]interface{}) + if !ok { + t.Fatal("command object missing") + } + if cmd["path"] != venvPython(pluginDir) { + t.Errorf("command.path = %q, want %q", cmd["path"], venvPython(pluginDir)) + } +} + +func TestRegisterContinueCreatesDirectoryFile(t *testing.T) { + dir := t.TempDir() + mcpServersDir := filepath.Join(dir, "mcpServers") + if err := os.MkdirAll(mcpServersDir, 0o750); err != nil { + t.Fatal(err) + } + configFile := filepath.Join(mcpServersDir, "armis-appsec.json") + pluginDir := filepath.Join(dir, "plugin") + + configPathOverrides = map[EditorID]string{EditorContinue: configFile} + defer func() { configPathOverrides = nil }() + + e, _ := EditorByID(EditorContinue) + if err := e.Register(pluginDir); err != nil { + t.Fatalf("Register() error: %v", err) + } + + var data map[string]interface{} + b, _ := os.ReadFile(filepath.Clean(configFile)) + if err := json.Unmarshal(b, &data); err != nil { + t.Fatal(err) + } + + servers, ok := data["mcpServers"].(map[string]interface{}) + if !ok { + t.Fatal("mcpServers key missing") + } + server, ok := servers[mcpServerName].(map[string]interface{}) + if !ok { + t.Fatal("armis-appsec server not registered") + } + if server["command"] != venvPython(pluginDir) { + t.Errorf("command = %q, want %q", server["command"], venvPython(pluginDir)) + } +} + +func TestRegisterPreservesExistingConfig(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, "mcp.json") + pluginDir := filepath.Join(dir, "plugin") + + existing := map[string]interface{}{ + "mcpServers": map[string]interface{}{ + "other-server": map[string]interface{}{ + "command": "node", + "args": []string{"server.js"}, + }, + }, + } + b, _ := json.MarshalIndent(existing, "", " ") + _ = os.WriteFile(configFile, b, 0o600) + + configPathOverrides = map[EditorID]string{EditorCursor: configFile} + defer func() { configPathOverrides = nil }() + + e, _ := EditorByID(EditorCursor) + if err := e.Register(pluginDir); err != nil { + t.Fatalf("Register() error: %v", err) + } + + var data map[string]interface{} + b, _ = os.ReadFile(filepath.Clean(configFile)) + _ = json.Unmarshal(b, &data) + + servers := data["mcpServers"].(map[string]interface{}) + if _, ok := servers["other-server"]; !ok { + t.Error("existing server was lost") + } + if _, ok := servers[mcpServerName]; !ok { + t.Error("armis-appsec not registered") + } +} + +func TestRegisterJetBrains(t *testing.T) { + dir := t.TempDir() + configFile := filepath.Join(dir, ".jb-mcp.json") + pluginDir := filepath.Join(dir, "plugin") + + if err := RegisterJetBrains(pluginDir, configFile); err != nil { + t.Fatalf("RegisterJetBrains() error: %v", err) + } + + var data map[string]interface{} + b, _ := os.ReadFile(filepath.Clean(configFile)) + if err := json.Unmarshal(b, &data); err != nil { + t.Fatal(err) + } + + servers, ok := data["mcpServers"].(map[string]interface{}) + if !ok { + t.Fatal("mcpServers key missing") + } + if _, ok := servers[mcpServerName]; !ok { + t.Fatal("armis-appsec server not registered") + } +} + +func TestEditorInstallerFields(t *testing.T) { + dir := t.TempDir() + ei := &EditorInstaller{pluginDir: dir, plugin: newPluginInstaller()} + if ei.PluginDir() == "" { + t.Error("PluginDir() should not be empty") + } + if ei.EnvFilePath() == "" { + t.Error("EnvFilePath() should not be empty") + } + if ei.InstalledVersion() != "" { + t.Error("InstalledVersion() should be empty before install") + } + if v := ei.GetInstalledVersion(); v != "" { + t.Errorf("GetInstalledVersion() = %q, want empty", v) + } +} + +func TestNewEditorInstaller(t *testing.T) { + ei := NewEditorInstaller() + if ei.PluginDir() == "" { + t.Error("PluginDir() should not be empty") + } +} + +func TestEditorInstallerHasExistingEnv(t *testing.T) { + dir := t.TempDir() + ei := &EditorInstaller{pluginDir: dir, plugin: newPluginInstaller()} + + if ei.HasExistingEnv() { + t.Error("HasExistingEnv() should return false when .env doesn't exist") + } + + if err := os.WriteFile(filepath.Join(dir, ".env"), []byte("test"), 0o600); err != nil { + t.Fatal(err) + } + if !ei.HasExistingEnv() { + t.Error("HasExistingEnv() should return true when .env exists") + } +} diff --git a/internal/install/plugin.go b/internal/install/plugin.go new file mode 100644 index 0000000..fc05934 --- /dev/null +++ b/internal/install/plugin.go @@ -0,0 +1,406 @@ +// Package install provides installation logic for Armis integrations. +package install + +import ( + "archive/tar" + "compress/gzip" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" +) + +const githubAPIHost = "api.github.com" + +const osWindows = "windows" + +const ( + pluginRepo = "ArmisSecurity/armis-appsec-mcp" + releasesURL = "https://api.github.com/repos/" + pluginRepo + "/releases/latest" + downloadTimeout = 60 * time.Second + maxArchiveBytes = 50 * 1024 * 1024 // 50 MB safety limit + maxExtractedSize = 100 * 1024 * 1024 // 100 MB total extracted size + maxFileSize = 10 * 1024 * 1024 // 10 MB per file + maxArchiveEntries = 10000 // max tar entries to prevent resource exhaustion +) + +type githubRelease struct { + TagName string `json:"tag_name"` + TarballURL string `json:"tarball_url"` +} + +// PluginInstaller handles downloading and setting up the Armis AppSec MCP plugin. +type PluginInstaller struct { + httpClient *http.Client + releasesURL string + installedVersion string + skipURLValidation bool +} + +func newPluginInstaller() *PluginInstaller { + return &PluginInstaller{ + httpClient: &http.Client{Timeout: downloadTimeout}, + releasesURL: releasesURL, + } +} + +// InstalledVersion returns the version that was installed (available after FetchAndInstall). +func (pi *PluginInstaller) InstalledVersion() string { + return pi.installedVersion +} + +// FetchAndInstall downloads the latest release and sets up the plugin in destDir. +func (pi *PluginInstaller) FetchAndInstall(destDir string) error { + release, err := pi.fetchLatestRelease() + if err != nil { + return fmt.Errorf("failed to fetch latest release: %w", err) + } + pi.installedVersion = strings.TrimPrefix(release.TagName, "v") + + if err := os.MkdirAll(destDir, 0o750); err != nil { + return fmt.Errorf("failed to create plugin directory: %w", err) + } + + if err := pi.downloadAndExtract(release.TarballURL, destDir); err != nil { + return fmt.Errorf("failed to download plugin: %w", err) + } + + if err := pi.createVenv(destDir); err != nil { + return fmt.Errorf("failed to set up Python environment: %w", err) + } + + if err := writeHelperScript(destDir); err != nil { + return fmt.Errorf("failed to write helper script: %w", err) + } + + return nil +} + +func (pi *PluginInstaller) fetchLatestRelease() (*githubRelease, error) { + if !pi.skipURLValidation { + if err := validateGitHubURL(pi.releasesURL); err != nil { + return nil, fmt.Errorf("invalid releases URL: %w", err) + } + } + + req, err := http.NewRequest("GET", pi.releasesURL, nil) //nolint:gosec // URL validated by validateGitHubURL above + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Accept", "application/vnd.github.v3+json") + + resp, err := pi.httpClient.Do(req) //nolint:gosec // URL validated by validateGitHubURL above + if err != nil { + return nil, fmt.Errorf("querying GitHub releases: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GitHub API returned HTTP %d — is there a published release?", resp.StatusCode) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + if err != nil { + return nil, fmt.Errorf("reading response: %w", err) + } + + var release githubRelease + if err := json.Unmarshal(body, &release); err != nil { + return nil, fmt.Errorf("parsing release: %w", err) + } + + if release.TagName == "" || release.TarballURL == "" { + return nil, fmt.Errorf("release is missing tag or tarball URL") + } + + return &release, nil +} + +func (pi *PluginInstaller) downloadAndExtract(tarballURL, destDir string) error { + if !pi.skipURLValidation { + if err := validateGitHubURL(tarballURL); err != nil { + return fmt.Errorf("invalid tarball URL: %w", err) + } + } + + req, err := http.NewRequest("GET", tarballURL, nil) //nolint:gosec // URL validated by validateGitHubURL above + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Accept", "application/vnd.github+json") + + resp, err := pi.httpClient.Do(req) //nolint:gosec // URL validated by validateGitHubURL above + if err != nil { + return fmt.Errorf("downloading archive: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("GitHub API returned HTTP %d", resp.StatusCode) + } + + reader := io.LimitReader(resp.Body, maxArchiveBytes) + gz, err := gzip.NewReader(reader) + if err != nil { + return fmt.Errorf("decompressing archive: %w", err) + } + defer func() { _ = gz.Close() }() + + tr := tar.NewReader(gz) + var totalExtracted int64 + var entryCount int + var prefix string + + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("reading archive: %w", err) + } + + entryCount++ + if entryCount > maxArchiveEntries { + return fmt.Errorf("archive exceeds %d entry limit", maxArchiveEntries) + } + + if header.Typeflag == tar.TypeXGlobalHeader || header.Typeflag == tar.TypeXHeader { + continue + } + + if prefix == "" { + parts := strings.SplitN(header.Name, "/", 2) + if len(parts) > 0 { + prefix = parts[0] + "/" + } + } + + name := strings.TrimPrefix(header.Name, prefix) + if name == "" || name == "." { + continue + } + + clean := filepath.Clean(filepath.FromSlash(name)) + if filepath.IsAbs(clean) || clean == ".." || strings.HasPrefix(clean, ".."+string(filepath.Separator)) { + continue + } + + target := filepath.Join(destDir, clean) + absTarget, err := filepath.Abs(target) + if err != nil { + continue + } + absDestDir, err := filepath.Abs(destDir) + if err != nil { + continue + } + if !strings.HasPrefix(absTarget, absDestDir+string(os.PathSeparator)) && absTarget != absDestDir { + continue + } + + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(absTarget, 0o750); err != nil { + return fmt.Errorf("creating directory %s: %w", name, err) + } + case tar.TypeReg: + if header.Size > maxFileSize { + continue + } + totalExtracted += header.Size + if totalExtracted > maxExtractedSize { + return fmt.Errorf("extracted archive exceeds %d MB safety limit", maxExtractedSize/1024/1024) + } + if err := os.MkdirAll(filepath.Dir(absTarget), 0o750); err != nil { + return fmt.Errorf("creating parent directory: %w", err) + } + perm := os.FileMode(0o644) + if header.Mode&0o100 != 0 { + perm = 0o750 + } + if err := extractFile(absTarget, tr, perm); err != nil { + return fmt.Errorf("writing file %s: %w", name, err) + } + } + } + + if prefix == "" { + return fmt.Errorf("archive appears to be empty") + } + + return nil +} + +func (pi *PluginInstaller) createVenv(pluginDir string) error { + python := findPython() + if python == "" { + return fmt.Errorf("Python 3.11+ is required but not found in PATH") //nolint:staticcheck // proper noun + } + + venvDir := filepath.Join(pluginDir, ".venv") + venvCmd := exec.Command(python, "-m", "venv", venvDir) //nolint:gosec // python validated by findPython allowlist + venvCmd.Stdout = os.Stderr + venvCmd.Stderr = os.Stderr + if err := venvCmd.Run(); err != nil { + return fmt.Errorf("creating venv: %w", err) + } + + pip := filepath.Join(venvDir, "bin", "pip") + if runtime.GOOS == osWindows { + pip = filepath.Join(venvDir, "Scripts", "pip.exe") + } + reqsFile := filepath.Join(pluginDir, "requirements.txt") + pipCmd := exec.Command(pip, "install", "-q", "-r", reqsFile) //nolint:gosec // pip path derived from our own venv + pipCmd.Stdout = os.Stderr + pipCmd.Stderr = os.Stderr + if err := pipCmd.Run(); err != nil { + return fmt.Errorf("installing dependencies: %w", err) + } + + return nil +} + +func validateGitHubURL(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("malformed URL: %w", err) + } + if u.Scheme != "https" { + return fmt.Errorf("URL scheme must be https, got %q", u.Scheme) + } + if u.Host != githubAPIHost { + return fmt.Errorf("URL host must be %s, got %q", githubAPIHost, u.Host) + } + return nil +} + +func extractFile(target string, r io.Reader, perm os.FileMode) error { + f, err := os.OpenFile(filepath.Clean(target), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, perm) //nolint:gosec // target validated by caller + if err != nil { + return err + } + if _, err := io.Copy(f, io.LimitReader(r, maxFileSize)); err != nil { + _ = f.Close() + return err + } + return f.Close() +} + +func writeJSON(path string, data interface{}) error { + if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil { + return err + } + b, err := json.MarshalIndent(data, "", " ") + if err != nil { + return err + } + return os.WriteFile(filepath.Clean(path), append(b, '\n'), 0o600) +} + +func findPython() string { + for _, name := range []string{"python3", "python"} { + resolved, err := exec.LookPath(name) + if err != nil { + continue + } + resolved, err = filepath.EvalSymlinks(resolved) + if err != nil || !filepath.IsAbs(resolved) { + continue + } + out, err := exec.Command(resolved, "-c", "import sys; print(sys.version_info >= (3, 11))").Output() //nolint:gosec // resolved path validated above + if err != nil { + continue + } + if strings.TrimSpace(string(out)) == "True" { + return resolved + } + } + return "" +} + +// writeEnvFromEnvironment writes ARMIS_CLIENT_ID and ARMIS_CLIENT_SECRET to a .env +// file if both are set in the current process environment. Returns true if the file +// was written. Skips writing if the file already exists (to preserve user edits). +func writeEnvFromEnvironment(envPath string) bool { + if _, err := os.Stat(envPath); err == nil { + return false + } + + clientID := os.Getenv("ARMIS_CLIENT_ID") + clientSecret := os.Getenv("ARMIS_CLIENT_SECRET") + if clientID == "" || clientSecret == "" { + return false + } + + content := fmt.Sprintf("ARMIS_CLIENT_ID=%s\nARMIS_CLIENT_SECRET=%s\n", clientID, clientSecret) + if err := os.MkdirAll(filepath.Dir(envPath), 0o750); err != nil { + return false + } + if err := os.WriteFile(filepath.Clean(envPath), []byte(content), 0o600); err != nil { // #nosec G703 - envPath is constructed from pluginDir + ".env" + return false + } + return true +} + +// writeHelperScript writes a standalone scan script that editors without +// native MCP support (e.g. Copilot) can invoke to scan files directly. +func writeHelperScript(pluginDir string) error { + scriptPath := filepath.Join(pluginDir, "scan_file.py") + script := `#!/usr/bin/env python3 +"""Scan a file for security vulnerabilities using the Armis AppSec scanner. + +Usage: python3 scan_file.py + +This script calls the same scanning engine as the MCP server but can be +invoked directly from editors that cannot call MCP tools natively. +""" +import os +import sys + +_plugin_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, _plugin_dir) + +from dotenv import load_dotenv + +_env_file = os.path.join(_plugin_dir, ".env") +if os.path.isfile(_env_file): + load_dotenv(_env_file, override=False) + +from auth import init_auth +from scanner_core import call_appsec_api, format_findings, parse_findings + +if len(sys.argv) != 2: + print("Usage: python3 scan_file.py ", file=sys.stderr) + sys.exit(1) + +file_path = os.path.abspath(sys.argv[1]) +if not os.path.isfile(file_path): + print(f"Error: {file_path} not found", file=sys.stderr) + sys.exit(1) + +init_auth() +with open(file_path) as f: + code = f.read() + +raw = call_appsec_api(code) +findings = parse_findings(raw) +print(format_findings(findings, os.path.basename(file_path))) +` + return os.WriteFile(filepath.Clean(scriptPath), []byte(script), 0o750) // #nosec G306 - script needs execute permission +} + +// venvPython returns the path to the Python interpreter inside a venv. +func venvPython(pluginDir string) string { + if runtime.GOOS == osWindows { + return filepath.Join(pluginDir, ".venv", "Scripts", "python.exe") + } + return filepath.Join(pluginDir, ".venv", "bin", "python") +} diff --git a/internal/install/plugin_test.go b/internal/install/plugin_test.go new file mode 100644 index 0000000..1a753c8 --- /dev/null +++ b/internal/install/plugin_test.go @@ -0,0 +1,381 @@ +package install + +import ( + "archive/tar" + "compress/gzip" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestFetchLatestRelease(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(`{"tag_name":"v1.2.3","tarball_url":"https://api.github.com/repos/test/tarball/v1.2.3"}`)) + })) + defer server.Close() + + pi := &PluginInstaller{ + httpClient: server.Client(), + releasesURL: server.URL, + skipURLValidation: true, + } + + release, err := pi.fetchLatestRelease() + if err != nil { + t.Fatalf("fetchLatestRelease() error: %v", err) + } + if release.TagName != "v1.2.3" { + t.Errorf("TagName = %q, want %q", release.TagName, "v1.2.3") + } + if release.TarballURL == "" { + t.Error("TarballURL should not be empty") + } +} + +func TestFetchLatestRelease_NoRelease(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + pi := &PluginInstaller{ + httpClient: server.Client(), + releasesURL: server.URL, + skipURLValidation: true, + } + + _, err := pi.fetchLatestRelease() + if err == nil { + t.Fatal("expected error for 404 response") + } +} + +func TestDownloadAndExtract(t *testing.T) { + tarball := createTestTarball(t) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/gzip") + _, _ = w.Write(tarball) + })) + defer server.Close() + + pi := &PluginInstaller{ + httpClient: server.Client(), + skipURLValidation: true, + } + + destDir := filepath.Join(t.TempDir(), "extract") + if err := os.MkdirAll(destDir, 0o750); err != nil { + t.Fatal(err) + } + + if err := pi.downloadAndExtract(server.URL, destDir); err != nil { + t.Fatalf("downloadAndExtract() error: %v", err) + } + + if _, err := os.Stat(filepath.Join(destDir, "server.py")); err != nil { + t.Error("server.py not extracted") + } + if _, err := os.Stat(filepath.Join(destDir, "requirements.txt")); err != nil { + t.Error("requirements.txt not extracted") + } +} + +func TestDownloadAndExtractFlattensPrefix(t *testing.T) { + tarball := createTestTarball(t, true) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/gzip") + _, _ = w.Write(tarball) + })) + defer server.Close() + + pi := &PluginInstaller{ + httpClient: server.Client(), + skipURLValidation: true, + } + + destDir := filepath.Join(t.TempDir(), "extract") + if err := os.MkdirAll(destDir, 0o750); err != nil { + t.Fatal(err) + } + + if err := pi.downloadAndExtract(server.URL, destDir); err != nil { + t.Fatalf("downloadAndExtract() error: %v", err) + } + + for _, f := range []string{"server.py", "requirements.txt"} { + if _, err := os.Stat(filepath.Join(destDir, f)); err != nil { + t.Errorf("expected file %q not found in extracted directory", f) + } + } +} + +func TestDownloadAndExtractRejectsTraversal(t *testing.T) { + tarball := createTraversalTarball(t) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/gzip") + _, _ = w.Write(tarball) + })) + defer server.Close() + + pi := &PluginInstaller{ + httpClient: server.Client(), + skipURLValidation: true, + } + + destDir := filepath.Join(t.TempDir(), "extract") + if err := os.MkdirAll(destDir, 0o750); err != nil { + t.Fatal(err) + } + + if err := pi.downloadAndExtract(server.URL, destDir); err != nil { + t.Fatalf("downloadAndExtract() error: %v", err) + } + + // The legitimate file should be extracted + if _, err := os.Stat(filepath.Join(destDir, "safe.txt")); err != nil { + t.Error("safe.txt should be extracted") + } + + // Traversal files must not escape destDir + if _, err := os.Stat(filepath.Join(destDir, "..", "escaped.txt")); err == nil { + t.Error("path traversal file escaped destDir") + } +} + +func TestPluginInstalledVersion(t *testing.T) { + pi := newPluginInstaller() + if v := pi.InstalledVersion(); v != "" { + t.Errorf("InstalledVersion() = %q, want empty", v) + } + pi.installedVersion = testVersion + if v := pi.InstalledVersion(); v != testVersion { + t.Errorf("InstalledVersion() = %q, want %q", v, testVersion) + } +} + +func TestFindPython(t *testing.T) { + _ = findPython() +} + +func TestValidateGitHubURL(t *testing.T) { + tests := []struct { + name string + url string + wantErr bool + }{ + {"valid", "https://api.github.com/repos/test/releases/latest", false}, + {"http scheme", "http://api.github.com/repos/test", true}, + {"wrong host", "https://evil.com/repos/test", true}, + {"malformed", "://bad", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateGitHubURL(tt.url) + if (err != nil) != tt.wantErr { + t.Errorf("validateGitHubURL(%q) error = %v, wantErr %v", tt.url, err, tt.wantErr) + } + }) + } +} + +func TestWriteEnvFromEnvironment(t *testing.T) { + dir := t.TempDir() + envPath := filepath.Join(dir, ".env") + + t.Run("writes env when both vars set", func(t *testing.T) { + t.Setenv("ARMIS_CLIENT_ID", "test-id") + t.Setenv("ARMIS_CLIENT_SECRET", "test-secret") + + if !writeEnvFromEnvironment(envPath) { + t.Fatal("writeEnvFromEnvironment() returned false, want true") + } + + b, err := os.ReadFile(filepath.Clean(envPath)) + if err != nil { + t.Fatal(err) + } + content := string(b) + if !searchString(content, "ARMIS_CLIENT_ID=test-id") { + t.Error("missing ARMIS_CLIENT_ID") + } + if !searchString(content, "ARMIS_CLIENT_SECRET=test-secret") { + t.Error("missing ARMIS_CLIENT_SECRET") + } + + if runtime.GOOS != "windows" { + info, _ := os.Stat(envPath) + if perm := info.Mode().Perm(); perm != 0o600 { + t.Errorf("file permissions = %o, want 600", perm) + } + } + }) + + t.Run("skips when file exists", func(t *testing.T) { + t.Setenv("ARMIS_CLIENT_ID", "new-id") + t.Setenv("ARMIS_CLIENT_SECRET", "new-secret") + + if writeEnvFromEnvironment(envPath) { + t.Error("writeEnvFromEnvironment() returned true for existing file") + } + + b, _ := os.ReadFile(filepath.Clean(envPath)) + if searchString(string(b), "new-id") { + t.Error("existing file was overwritten") + } + }) + + t.Run("skips when vars missing", func(t *testing.T) { + freshPath := filepath.Join(t.TempDir(), ".env") + t.Setenv("ARMIS_CLIENT_ID", "") + t.Setenv("ARMIS_CLIENT_SECRET", "") + + if writeEnvFromEnvironment(freshPath) { + t.Error("writeEnvFromEnvironment() returned true with empty vars") + } + if _, err := os.Stat(freshPath); err == nil { + t.Error("file should not exist when vars are empty") + } + }) + + t.Run("skips when only one var set", func(t *testing.T) { + freshPath := filepath.Join(t.TempDir(), ".env") + t.Setenv("ARMIS_CLIENT_ID", "test-id") + t.Setenv("ARMIS_CLIENT_SECRET", "") + + if writeEnvFromEnvironment(freshPath) { + t.Error("writeEnvFromEnvironment() returned true with only client ID") + } + }) +} + +// createTestTarball creates a gzipped tarball matching GitHub's format. +func createTestTarball(t *testing.T, withPaxHeader ...bool) []byte { + t.Helper() + + tmpFile := filepath.Join(t.TempDir(), "test.tar.gz") + f, err := os.Create(filepath.Clean(tmpFile)) + if err != nil { + t.Fatal(err) + } + gw := gzip.NewWriter(f) + tw := tar.NewWriter(gw) + + if len(withPaxHeader) > 0 && withPaxHeader[0] { + if err := tw.WriteHeader(&tar.Header{ + Typeflag: tar.TypeXGlobalHeader, + Name: "pax_global_header", + Size: 0, + }); err != nil { + t.Fatal(err) + } + } + + writeEntry := func(hdr *tar.Header, data []byte) { + t.Helper() + if err := tw.WriteHeader(hdr); err != nil { + t.Fatal(err) + } + if len(data) > 0 { + if _, err := tw.Write(data); err != nil { + t.Fatal(err) + } + } + } + + writeEntry(&tar.Header{ + Name: "ArmisSecurity-armis-appsec-mcp-abc1234/", + Typeflag: tar.TypeDir, + Mode: 0o755, + }, nil) + + content := []byte("print('hello')\n") + writeEntry(&tar.Header{ + Name: "ArmisSecurity-armis-appsec-mcp-abc1234/server.py", + Typeflag: tar.TypeReg, + Mode: 0o644, + Size: int64(len(content)), + }, content) + + reqs := []byte("mcp[cli]==1.25.0\nhttpx==0.28.1\n") + writeEntry(&tar.Header{ + Name: "ArmisSecurity-armis-appsec-mcp-abc1234/requirements.txt", + Typeflag: tar.TypeReg, + Mode: 0o644, + Size: int64(len(reqs)), + }, reqs) + + if err := tw.Close(); err != nil { + t.Fatal(err) + } + if err := gw.Close(); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + + buf, err := os.ReadFile(filepath.Clean(tmpFile)) + if err != nil { + t.Fatal(err) + } + return buf +} + +func createTraversalTarball(t *testing.T) []byte { + t.Helper() + + tmpFile := filepath.Join(t.TempDir(), "traversal.tar.gz") + f, err := os.Create(filepath.Clean(tmpFile)) + if err != nil { + t.Fatal(err) + } + gw := gzip.NewWriter(f) + tw := tar.NewWriter(gw) + + writeEntry := func(hdr *tar.Header, data []byte) { + t.Helper() + if err := tw.WriteHeader(hdr); err != nil { + t.Fatal(err) + } + if len(data) > 0 { + if _, err := tw.Write(data); err != nil { + t.Fatal(err) + } + } + } + + prefix := "repo-abc1234/" + + writeEntry(&tar.Header{ + Name: prefix, Typeflag: tar.TypeDir, Mode: 0o755, + }, nil) + + safe := []byte("safe content\n") + writeEntry(&tar.Header{ + Name: prefix + "safe.txt", Typeflag: tar.TypeReg, Mode: 0o644, Size: int64(len(safe)), + }, safe) + + evil := []byte("escaped\n") + writeEntry(&tar.Header{ + Name: prefix + "../../escaped.txt", Typeflag: tar.TypeReg, Mode: 0o644, Size: int64(len(evil)), + }, evil) + + if err := tw.Close(); err != nil { + t.Fatal(err) + } + if err := gw.Close(); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + + buf, err := os.ReadFile(filepath.Clean(tmpFile)) + if err != nil { + t.Fatal(err) + } + return buf +}