diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 8651d406..de4f4cb5 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -58,6 +58,7 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/workspacegroups" "github.com/brevdev/brev-cli/pkg/cmd/writeconnectionevent" "github.com/brevdev/brev-cli/pkg/config" + "github.com/brevdev/brev-cli/pkg/entity" "github.com/brevdev/brev-cli/pkg/featureflag" "github.com/brevdev/brev-cli/pkg/files" "github.com/brevdev/brev-cli/pkg/remoteversion" @@ -257,8 +258,20 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin cmds.SetUsageTemplate(usageTemplate) // In-memory auth for external node commands — never touches credentials.json. - memAuthStore := store.NewMemoryAuthStore() - memAuthenticator := auth.StandardLogin("", "", nil) + // Pre-fill the cached email so the user sees a confirmation prompt instead of + // having to type it from scratch every time. + cachedEmail, _ := fsStore.GetCachedEmail() + memAuthenticator := auth.StandardLogin("", cachedEmail, nil) + if cachedEmail != "" { + if kas, ok := memAuthenticator.(auth.KasAuthenticator); ok { + kas.ShouldPromptEmail = true + memAuthenticator = kas + } + } + memAuthStore := &emailCachingAuthStore{ + MemoryAuthStore: store.NewMemoryAuthStore(), + fileStore: fsStore, + } memLoginAuth := auth.NewLoginAuth(memAuthStore, memAuthenticator) memLoginAuth.WithShouldLogin(func() (bool, error) { return true, nil }) @@ -555,4 +568,22 @@ var ( _ store.Auth = auth.NoLoginAuth{} _ auth.AuthStore = store.FileStore{} _ auth.AuthStore = &store.MemoryAuthStore{} + _ auth.AuthStore = &emailCachingAuthStore{} ) + +// emailCachingAuthStore wraps MemoryAuthStore and persists the login email +// to ~/.brev/cached-email after each successful authentication. +type emailCachingAuthStore struct { + *store.MemoryAuthStore + fileStore *store.FileStore +} + +func (e *emailCachingAuthStore) SaveAuthTokens(tokens entity.AuthTokens) error { + if err := e.MemoryAuthStore.SaveAuthTokens(tokens); err != nil { + return breverrors.WrapAndTrace(err) + } + if email := auth.GetEmailFromToken(tokens.AccessToken); email != "" { + _ = e.fileStore.SaveCachedEmail(email) + } + return nil +} diff --git a/pkg/cmd/cmd_test.go b/pkg/cmd/cmd_test.go new file mode 100644 index 00000000..c9289a4c --- /dev/null +++ b/pkg/cmd/cmd_test.go @@ -0,0 +1,75 @@ +package cmd + +import ( + "encoding/base64" + "encoding/json" + "testing" + + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/store" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeJWT builds an unsigned JWT with the given claims (header.payload.signature). +func fakeJWT(t *testing.T, claims map[string]interface{}) string { + t.Helper() + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + payload, err := json.Marshal(claims) + require.NoError(t, err) + return header + "." + base64.RawURLEncoding.EncodeToString(payload) + "." +} + +func newTestFileStore(t *testing.T) *store.FileStore { + t.Helper() + fs := afero.NewMemMapFs() + err := fs.MkdirAll("/home/testuser/.brev", 0o755) + require.NoError(t, err) + return store.NewBasicStore().WithFileSystem(fs).WithUserHomeDirGetter( + func() (string, error) { return "/home/testuser", nil }, + ) +} + +func TestEmailCachingAuthStore_SaveCachesEmail(t *testing.T) { + fs := newTestFileStore(t) + s := &emailCachingAuthStore{ + MemoryAuthStore: store.NewMemoryAuthStore(), + fileStore: fs, + } + + token := fakeJWT(t, map[string]interface{}{"email": "user@example.com"}) + err := s.SaveAuthTokens(entity.AuthTokens{AccessToken: token}) + require.NoError(t, err) + + cached, err := fs.GetCachedEmail() + require.NoError(t, err) + assert.Equal(t, "user@example.com", cached) +} + +func TestEmailCachingAuthStore_NoEmailInToken(t *testing.T) { + fs := newTestFileStore(t) + s := &emailCachingAuthStore{ + MemoryAuthStore: store.NewMemoryAuthStore(), + fileStore: fs, + } + + token := fakeJWT(t, map[string]interface{}{"sub": "12345"}) + err := s.SaveAuthTokens(entity.AuthTokens{AccessToken: token}) + require.NoError(t, err) + + cached, err := fs.GetCachedEmail() + require.NoError(t, err) + assert.Equal(t, "", cached) +} + +func TestEmailCachingAuthStore_EmptyAccessToken(t *testing.T) { + fs := newTestFileStore(t) + s := &emailCachingAuthStore{ + MemoryAuthStore: store.NewMemoryAuthStore(), + fileStore: fs, + } + + err := s.SaveAuthTokens(entity.AuthTokens{AccessToken: ""}) + require.Error(t, err) +} diff --git a/pkg/cmd/deregister/deregister_test.go b/pkg/cmd/deregister/deregister_test.go index 64870429..b36ac993 100644 --- a/pkg/cmd/deregister/deregister_test.go +++ b/pkg/cmd/deregister/deregister_test.go @@ -91,8 +91,9 @@ type mockNetBirdManager struct { err error } -func (m *mockNetBirdManager) Install() error { return m.err } -func (m *mockNetBirdManager) Uninstall() error { m.called = true; return m.err } +func (m *mockNetBirdManager) Install() error { return m.err } +func (m *mockNetBirdManager) Uninstall() error { m.called = true; return m.err } +func (m *mockNetBirdManager) EnsureRunning() error { return m.err } type mockNodeClientFactory struct { serverURL string diff --git a/pkg/cmd/gpucreate/gpucreate.go b/pkg/cmd/gpucreate/gpucreate.go index b58c7b94..777e236e 100644 --- a/pkg/cmd/gpucreate/gpucreate.go +++ b/pkg/cmd/gpucreate/gpucreate.go @@ -18,6 +18,7 @@ import ( "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/featureflag" + "github.com/brevdev/brev-cli/pkg/names" "github.com/brevdev/brev-cli/pkg/store" "github.com/brevdev/brev-cli/pkg/terminal" "github.com/spf13/cobra" @@ -194,8 +195,8 @@ func NewCmdGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore) *cobra } } - if name == "" { - return breverrors.NewValidationError("name is required (as argument or --name flag)") + if err := names.ValidateNodeName(name); err != nil { + return breverrors.WrapAndTrace(err) } if count < 1 { diff --git a/pkg/cmd/register/device_registration_store.go b/pkg/cmd/register/device_registration_store.go index 86dba36b..8a5243c6 100644 --- a/pkg/cmd/register/device_registration_store.go +++ b/pkg/cmd/register/device_registration_store.go @@ -85,6 +85,9 @@ func (s *FileRegistrationStore) Load() (*DeviceRegistration, error) { if err := files.ReadJSON(files.AppFs, path, ®); err != nil { return nil, breverrors.WrapAndTrace(err) } + if reg.ExternalNodeID == "" || reg.OrgID == "" { + return nil, breverrors.New("malformed registration") + } return ®, nil } @@ -125,6 +128,9 @@ func sudoWriteFile(path string, data []byte) error { if err := cmd.Run(); err != nil { return fmt.Errorf("sudo tee %s failed: %w", path, err) } + if err := exec.Command("sudo", "chmod", "644", path).Run(); err != nil { //nolint:gosec // fixed base path + return fmt.Errorf("sudo chmod %s failed: %w", path, err) + } return nil } diff --git a/pkg/cmd/register/device_registration_store_test.go b/pkg/cmd/register/device_registration_store_test.go index 256acd4a..50834da3 100644 --- a/pkg/cmd/register/device_registration_store_test.go +++ b/pkg/cmd/register/device_registration_store_test.go @@ -144,6 +144,48 @@ func Test_LoadRegistration_FailsWhenMissing(t *testing.T) { } } +func Test_LoadRegistration_RejectsMissingExternalNodeID(t *testing.T) { + cleanup := setupTestFs(t) + defer cleanup() + + store := NewFileRegistrationStore() + + reg := &DeviceRegistration{ + ExternalNodeID: "", + DisplayName: "Test", + OrgID: "org_xyz", + } + if err := store.Save(reg); err != nil { + t.Fatalf("Save failed: %v", err) + } + + _, err := store.Load() + if err == nil { + t.Fatal("expected error loading registration with empty ExternalNodeID") + } +} + +func Test_LoadRegistration_RejectsMissingOrgID(t *testing.T) { + cleanup := setupTestFs(t) + defer cleanup() + + store := NewFileRegistrationStore() + + reg := &DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "Test", + OrgID: "", + } + if err := store.Save(reg); err != nil { + t.Fatalf("Save failed: %v", err) + } + + _, err := store.Load() + if err == nil { + t.Fatal("expected error loading registration with empty OrgID") + } +} + func Test_DeleteRegistration_FailsWhenMissing(t *testing.T) { cleanup := setupTestFs(t) defer cleanup() diff --git a/pkg/cmd/register/providers.go b/pkg/cmd/register/providers.go index cabfa1c4..a3c27bd5 100644 --- a/pkg/cmd/register/providers.go +++ b/pkg/cmd/register/providers.go @@ -1,7 +1,10 @@ package register import ( + "fmt" + "os/exec" "runtime" + "strings" nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" @@ -38,6 +41,33 @@ type Netbird struct{} func (Netbird) Install() error { return InstallNetbird() } func (Netbird) Uninstall() error { return UninstallNetbird() } +// EnsureRunning checks if the netbird systemd service is active and attempts +// to start it if it is not. It also checks the netbird peer connection status +// and runs "netbird up" if the peer is disconnected. +func (Netbird) EnsureRunning() error { + out, err := exec.Command("systemctl", "is-active", "netbird").Output() //nolint:gosec // fixed service name + if err != nil || strings.TrimSpace(string(out)) != "active" { + if startErr := exec.Command("sudo", "systemctl", "start", "netbird").Run(); startErr != nil { //nolint:gosec // fixed service name + return fmt.Errorf("failed to start Brev tunnel service: %w", startErr) + } + } + + statusOut, err := exec.Command("netbird", "status").Output() //nolint:gosec // fixed command + if err != nil { + // Service is running, just can't confirm peer status. + return nil + } + + if netbirdManagementConnected(string(statusOut)) { + return nil + } + + if upErr := exec.Command("sudo", "netbird", "up").Run(); upErr != nil { //nolint:gosec // fixed command + return fmt.Errorf("failed to reconnect Brev tunnel: %w", upErr) + } + return nil +} + // ShellSetupRunner runs setup scripts via shell. type ShellSetupRunner struct{} diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 4e4cf0db..a1573960 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "os" - "os/exec" "os/user" "strings" "time" @@ -18,6 +17,7 @@ import ( "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/externalnode" + "github.com/brevdev/brev-cli/pkg/names" "github.com/brevdev/brev-cli/pkg/terminal" "github.com/spf13/cobra" @@ -27,6 +27,7 @@ import ( type RegisterStore interface { GetCurrentUser() (*entity.User, error) GetActiveOrganizationOrDefault() (*entity.Organization, error) + GetOrganizationsByName(name string) ([]entity.Organization, error) GetAccessToken() (string, error) } @@ -41,10 +42,14 @@ func (r OSFileReader) ReadFile(path string) ([]byte, error) { return data, nil } -// NetBirdManager installs and uninstalls the NetBird network agent. +// NetBirdManager installs, uninstalls, and monitors the NetBird network agent. type NetBirdManager interface { Install() error Uninstall() error + // EnsureRunning checks whether the NetBird service is active and + // connected, starting or reconnecting it if needed. Returns nil when + // the tunnel is healthy. + EnsureRunning() error } // SetupRunner runs a setup script on the local machine. @@ -87,6 +92,8 @@ This command sets up network connectivity and registers this machine with Brev.` ) func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { + var orgFlag string + cmd := &cobra.Command{ Annotations: map[string]string{"configuration": ""}, Use: "register [name]", @@ -100,14 +107,16 @@ func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { if len(args) > 0 { name = args[0] } - return runRegister(cmd.Context(), t, store, name, defaultRegisterDeps()) + return runRegister(cmd.Context(), t, store, name, orgFlag, defaultRegisterDeps()) }, } + cmd.Flags().StringVarP(&orgFlag, "org", "o", "", "organization name (overrides active org)") + return cmd } -func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, deps registerDeps) error { //nolint:funlen // registration flow +func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, orgName string, deps registerDeps) error { //nolint:funlen // registration flow if !deps.platform.IsCompatible() { return breverrors.New("brev register is only supported on Linux") } @@ -120,15 +129,15 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam return checkExistingRegistration(ctx, t, s, name, deps) } - if name == "" { - return fmt.Errorf("please provide a name for this device\n\nUsage: brev register \nExample: brev register \"my-DGX-Spark\"") + if err := names.ValidateNodeName(name); err != nil { + return breverrors.WrapAndTrace(err) } brevUser, err := s.GetCurrentUser() if err != nil { return breverrors.WrapAndTrace(err) } - org, err := getOrgToRegisterFor(s) + org, err := getOrgToRegisterFor(s, orgName) if err != nil { return breverrors.WrapAndTrace(err) } @@ -215,7 +224,21 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam return nil } -func getOrgToRegisterFor(s RegisterStore) (*entity.Organization, error) { +func getOrgToRegisterFor(s RegisterStore, orgName string) (*entity.Organization, error) { + if orgName != "" { + orgs, err := s.GetOrganizationsByName(orgName) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if len(orgs) == 0 { + return nil, fmt.Errorf("no organization found with name %q", orgName) + } + if len(orgs) > 1 { + return nil, fmt.Errorf("multiple organizations found with name %q", orgName) + } + return &orgs[0], nil + } + org, err := s.GetActiveOrganizationOrDefault() if err != nil { return nil, breverrors.WrapAndTrace(err) @@ -223,7 +246,6 @@ func getOrgToRegisterFor(s RegisterStore) (*entity.Organization, error) { if org == nil { return nil, fmt.Errorf("no organization found; please create or join an organization first") } - return org, nil } @@ -273,7 +295,9 @@ func checkExistingRegistration(ctx context.Context, t *terminal.Terminal, s Regi // Check local netbird service and start it if down. t.Vprint(" Checking local Brev tunnel...") - if ensureNetbirdRunning(t) { + if err := deps.netbird.EnsureRunning(); err != nil { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: %v", err))) + } else { t.Vprint(t.Green(" Brev tunnel is running.")) } @@ -282,41 +306,6 @@ func checkExistingRegistration(ctx context.Context, t *terminal.Terminal, s Regi return nil } -// ensureNetbirdRunning checks if the netbird systemd service is active and -// attempts to start it if it is not. It also checks the netbird peer -// connection status and runs "netbird up" if the peer is disconnected. -// Returns true if the service is running and connected after the check. -func ensureNetbirdRunning(t *terminal.Terminal) bool { - out, err := exec.Command("systemctl", "is-active", "netbird").Output() //nolint:gosec // fixed service name - if err != nil || strings.TrimSpace(string(out)) != "active" { - t.Vprintf(" %s\n", t.Yellow("Brev tunnel service is not running. Attempting to start...")) - if startErr := exec.Command("sudo", "systemctl", "start", "netbird").Run(); startErr != nil { //nolint:gosec // fixed service name - t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to start Brev tunnel service: %v", startErr))) - return false - } - t.Vprint(t.Green(" Brev tunnel service started.")) - } - - // Service is running — now check peer connection status. - statusOut, err := exec.Command("netbird", "status").Output() //nolint:gosec // fixed command - if err != nil { - t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: could not check Brev tunnel status: %v", err))) - return true // service is running, just can't confirm peer status - } - - if netbirdManagementConnected(string(statusOut)) { - return true - } - - t.Vprintf(" %s\n", t.Yellow("Brev tunnel peer is disconnected. Reconnecting...")) - if upErr := exec.Command("sudo", "netbird", "up").Run(); upErr != nil { //nolint:gosec // fixed command - t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to reconnect Brev tunnel: %v", upErr))) - return false - } - t.Vprint(t.Green(" Brev tunnel reconnected.")) - return true -} - // netbirdManagementConnected parses "netbird status" output and returns true // when the Management line reports "Connected". func netbirdManagementConnected(statusOutput string) bool { diff --git a/pkg/cmd/register/register_test.go b/pkg/cmd/register/register_test.go index b031fea4..885f1481 100644 --- a/pkg/cmd/register/register_test.go +++ b/pkg/cmd/register/register_test.go @@ -20,6 +20,7 @@ import ( type mockRegisterStore struct { user *entity.User org *entity.Organization + orgs []entity.Organization token string err error } @@ -35,6 +36,16 @@ func (m *mockRegisterStore) GetActiveOrganizationOrDefault() (*entity.Organizati return m.org, nil } +func (m *mockRegisterStore) GetOrganizationsByName(name string) ([]entity.Organization, error) { + var matched []entity.Organization + for _, o := range m.orgs { + if o.Name == name { + matched = append(matched, o) + } + } + return matched, nil +} + func (m *mockRegisterStore) GetAccessToken() (string, error) { return m.token, nil } // mockRegistrationStore satisfies RegistrationStore for orchestration tests. @@ -75,8 +86,9 @@ func (m mockConfirmer) ConfirmYesNo(_ string) bool { return m.confirm } type mockNetBirdManager struct{ err error } -func (m mockNetBirdManager) Install() error { return m.err } -func (m mockNetBirdManager) Uninstall() error { return m.err } +func (m mockNetBirdManager) Install() error { return m.err } +func (m mockNetBirdManager) Uninstall() error { return m.err } +func (m mockNetBirdManager) EnsureRunning() error { return m.err } type mockSetupRunner struct { called bool @@ -144,7 +156,7 @@ func Test_runRegister_HappyPath(t *testing.T) { if req.GetOrganizationId() != "org_123" { t.Errorf("unexpected org: %s", req.GetOrganizationId()) } - if req.GetName() != "My Spark" { + if req.GetName() != "my-spark" { t.Errorf("unexpected name: %s", req.GetName()) } return &nodev1.AddNodeResponse{ @@ -172,7 +184,7 @@ func Test_runRegister_HappyPath(t *testing.T) { defer ClearTestSSHPort() term := terminal.New() - err := runRegister(context.Background(), term, store, "My Spark", deps) + err := runRegister(context.Background(), term, store, "my-spark", "", deps) if err != nil { t.Fatalf("runRegister failed: %v", err) } @@ -193,8 +205,8 @@ func Test_runRegister_HappyPath(t *testing.T) { if reg.ExternalNodeID != "unode_abc" { t.Errorf("expected ExternalNodeID unode_abc, got %s", reg.ExternalNodeID) } - if reg.DisplayName != "My Spark" { - t.Errorf("expected display name 'My Spark', got %s", reg.DisplayName) + if reg.DisplayName != "my-spark" { + t.Errorf("expected display name 'my-spark', got %s", reg.DisplayName) } if reg.OrgID != "org_123" { t.Errorf("expected org org_123, got %s", reg.OrgID) @@ -223,7 +235,7 @@ func Test_runRegister_UserCancels(t *testing.T) { deps.prompter = mockConfirmer{confirm: false} term := terminal.New() - err := runRegister(context.Background(), term, store, "My Spark", deps) + err := runRegister(context.Background(), term, store, "my-spark", "", deps) if err != nil { t.Fatalf("expected nil error on cancel, got: %v", err) } @@ -311,7 +323,7 @@ func Test_runRegister_AlreadyRegistered(t *testing.T) { term := terminal.New() // Pass the same name as the existing registration so we go through // the checkExistingRegistration path (not the different-name path). - err := runRegister(context.Background(), term, store, "Existing", deps) + err := runRegister(context.Background(), term, store, "Existing", "", deps) if err != nil { t.Fatalf("expected nil error, got: %v", err) } @@ -339,12 +351,90 @@ func Test_runRegister_NoOrganization(t *testing.T) { defer server.Close() term := terminal.New() - err := runRegister(context.Background(), term, store, "My Spark", deps) + err := runRegister(context.Background(), term, store, "my-spark", "", deps) if err == nil { t.Fatal("expected error when no org exists") } } +func Test_runRegister_WithOrgFlag(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_default", Name: "DefaultOrg"}, + orgs: []entity.Organization{ + {ID: "org_456", Name: "SpecificOrg"}, + }, + token: "tok", + } + + var capturedOrgID string + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + capturedOrgID = req.GetOrganizationId() + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: req.GetOrganizationId(), + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + }, + }, nil + }, + } + + setupRunner := &mockSetupRunner{} + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + deps.setupRunner = setupRunner + + SetTestSSHPort(22) + defer ClearTestSSHPort() + + term := terminal.New() + err := runRegister(context.Background(), term, store, "my-spark", "SpecificOrg", deps) + if err != nil { + t.Fatalf("runRegister with --org failed: %v", err) + } + + if capturedOrgID != "org_456" { + t.Errorf("expected org_456, got %s", capturedOrgID) + } + + reg, err := regStore.Load() + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if reg.OrgID != "org_456" { + t.Errorf("expected registration org org_456, got %s", reg.OrgID) + } +} + +func Test_runRegister_WithOrgFlag_NotFound(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_default", Name: "DefaultOrg"}, + orgs: []entity.Organization{}, + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runRegister(context.Background(), term, store, "my-spark", "NonexistentOrg", deps) + if err == nil { + t.Fatal("expected error when org not found") + } + if !strings.Contains(err.Error(), "no organization found") { + t.Errorf("expected 'no organization found' error, got: %v", err) + } +} + func Test_runRegister_AddNodeFails(t *testing.T) { regStore := &mockRegistrationStore{} @@ -365,7 +455,7 @@ func Test_runRegister_AddNodeFails(t *testing.T) { defer server.Close() term := terminal.New() - err := runRegister(context.Background(), term, store, "My Spark", deps) + err := runRegister(context.Background(), term, store, "my-spark", "", deps) if err == nil { t.Fatal("expected error when AddNode fails") } @@ -415,7 +505,7 @@ func Test_runRegister_NoSetupCommand(t *testing.T) { defer ClearTestSSHPort() term := terminal.New() - err := runRegister(context.Background(), term, store, "My Spark", deps) + err := runRegister(context.Background(), term, store, "my-spark", "", deps) if err != nil { t.Fatalf("runRegister failed: %v", err) } @@ -548,7 +638,7 @@ func Test_runRegister_GrantSSH_retries_on_connection_error_then_succeeds(t *test defer ClearTestSSHPort() term := terminal.New() - err := runRegister(context.Background(), term, store, "My Spark", deps) + err := runRegister(context.Background(), term, store, "my-spark", "", deps) if err != nil { t.Fatalf("runRegister failed: %v", err) } @@ -597,7 +687,7 @@ func Test_runRegister_GrantSSH_no_retry_on_permanent_error(t *testing.T) { defer ClearTestSSHPort() term := terminal.New() - err := runRegister(context.Background(), term, store, "My Spark", deps) + err := runRegister(context.Background(), term, store, "my-spark", "", deps) if err != nil { t.Fatalf("runRegister should not fail the overall flow when SSH grant fails: %v", err) } @@ -607,6 +697,71 @@ func Test_runRegister_GrantSSH_no_retry_on_permanent_error(t *testing.T) { } } +func Test_runRegister_NameValidation(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + errSubstr string + }{ + {"Valid", "my-dgx-spark", false, ""}, + {"WithDots", "node.local.1", false, ""}, + {"WithUnderscore", "my_node", false, ""}, + {"Spaces", "My Spark", true, "letters, digits"}, + {"ShellInjection", "$(whoami)", true, "letters, digits"}, + {"PathTraversal", "../etc/passwd", true, "letters, digits"}, + {"Backticks", "`rm -rf`", true, "letters, digits"}, + {"Semicolon", "a;rm -rf /", true, "letters, digits"}, + {"LeadingHyphen", "-node", true, "start with"}, + {"LeadingDot", ".hidden", true, "start with"}, + {"TooLong", strings.Repeat("a", 64), true, "63 characters"}, + {"Empty", "", true, "name is required"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + regStore := &mockRegistrationStore{} + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + token: "tok", + } + + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + }, + }, nil + }, + } + + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + SetTestSSHPort(22) + defer ClearTestSSHPort() + + term := terminal.New() + err := runRegister(context.Background(), term, store, tt.input, "", deps) + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.errSubstr) { + t.Errorf("expected error containing %q, got: %v", tt.errSubstr, err) + } + } else if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + func Test_runRegister_NoNameNotRegistered(t *testing.T) { regStore := &mockRegistrationStore{} @@ -621,12 +776,12 @@ func Test_runRegister_NoNameNotRegistered(t *testing.T) { defer server.Close() term := terminal.New() - err := runRegister(context.Background(), term, store, "", deps) + err := runRegister(context.Background(), term, store, "", "", deps) if err == nil { t.Fatal("expected error when no name provided and not registered") } - if !strings.Contains(err.Error(), "please provide a name") { - t.Errorf("expected 'please provide a name' error, got: %v", err) + if !strings.Contains(err.Error(), "name is required") { + t.Errorf("expected 'name is required' error, got: %v", err) } } @@ -662,7 +817,7 @@ func Test_runRegister_NoNameAlreadyRegistered(t *testing.T) { defer server.Close() term := terminal.New() - err := runRegister(context.Background(), term, store, "", deps) + err := runRegister(context.Background(), term, store, "", "", deps) if err != nil { t.Fatalf("expected nil error when already registered with no name, got: %v", err) } diff --git a/pkg/names/validate.go b/pkg/names/validate.go new file mode 100644 index 00000000..b107aab6 --- /dev/null +++ b/pkg/names/validate.go @@ -0,0 +1,27 @@ +package names + +import ( + "fmt" + "regexp" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" +) + +var validNameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]*$`) + +const maxNameLen = 63 + +func ValidateNodeName(name string) error { + if name == "" { + return breverrors.NewValidationError("name is required") + } + if len(name) > maxNameLen { + return breverrors.NewValidationError( + fmt.Sprintf("name must be %d characters or fewer (got %d)", maxNameLen, len(name))) + } + if !validNameRe.MatchString(name) { + return breverrors.NewValidationError( + "name must start with a letter or digit and contain only letters, digits, hyphens, underscores, and dots") + } + return nil +} diff --git a/pkg/names/validate_test.go b/pkg/names/validate_test.go new file mode 100644 index 00000000..2a34814c --- /dev/null +++ b/pkg/names/validate_test.go @@ -0,0 +1,48 @@ +package names + +import ( + "strings" + "testing" +) + +func TestValidateNodeName(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + errSubstr string + }{ + {"Valid", "my-dgx-spark", false, ""}, + {"WithDots", "node.local.1", false, ""}, + {"WithUnderscore", "my_node", false, ""}, + {"SingleChar", "a", false, ""}, + {"MaxLength", strings.Repeat("a", 63), false, ""}, + {"Spaces", "My Spark", true, "letters, digits"}, + {"ShellInjection", "$(whoami)", true, "letters, digits"}, + {"PathTraversal", "../etc/passwd", true, "letters, digits"}, + {"Backticks", "`rm -rf`", true, "letters, digits"}, + {"Semicolon", "a;rm -rf /", true, "letters, digits"}, + {"Pipe", "a|cat", true, "letters, digits"}, + {"Ampersand", "a&bg", true, "letters, digits"}, + {"LeadingHyphen", "-node", true, "start with"}, + {"LeadingDot", ".hidden", true, "start with"}, + {"TooLong", strings.Repeat("a", 64), true, "63 characters"}, + {"Empty", "", true, "name is required"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateNodeName(tt.input) + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.errSubstr) { + t.Errorf("expected error containing %q, got: %v", tt.errSubstr, err) + } + } else if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} diff --git a/pkg/store/email_cache.go b/pkg/store/email_cache.go new file mode 100644 index 00000000..2c9bba4c --- /dev/null +++ b/pkg/store/email_cache.go @@ -0,0 +1,46 @@ +package store + +import ( + "os" + "path/filepath" + "strings" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/spf13/afero" +) + +const cachedEmailFile = "cached-email" + +// GetCachedEmail returns the previously cached login email, or "" if none exists. +func (f FileStore) GetCachedEmail() (string, error) { + brevHome, err := f.GetBrevHomePath() + if err != nil { + return "", breverrors.WrapAndTrace(err) + } + data, err := afero.ReadFile(f.fs, filepath.Join(brevHome, cachedEmailFile)) + if err != nil { + if os.IsNotExist(err) { + return "", nil + } + return "", breverrors.WrapAndTrace(err) + } + return strings.TrimSpace(string(data)), nil +} + +// SaveCachedEmail writes the login email to ~/.brev/cached-email (0600). +func (f FileStore) SaveCachedEmail(email string) error { + brevHome, err := f.GetBrevHomePath() + if err != nil { + return breverrors.WrapAndTrace(err) + } + path := filepath.Join(brevHome, cachedEmailFile) + err = f.fs.MkdirAll(brevHome, 0o755) + if err != nil { + return breverrors.WrapAndTrace(err) + } + err = afero.WriteFile(f.fs, path, []byte(email), 0o600) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil +} diff --git a/pkg/store/email_cache_test.go b/pkg/store/email_cache_test.go new file mode 100644 index 00000000..08d7d5f1 --- /dev/null +++ b/pkg/store/email_cache_test.go @@ -0,0 +1,51 @@ +package store + +import ( + "testing" + + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestFileStore(t *testing.T) FileStore { + t.Helper() + fs := afero.NewMemMapFs() + err := fs.MkdirAll("/home/testuser/.brev", 0o755) + require.NoError(t, err) + return FileStore{ + b: BasicStore{}, + fs: fs, + userHomeDirGetter: func() (string, error) { return "/home/testuser", nil }, + } +} + +func TestCachedEmail_RoundTrip(t *testing.T) { + s := newTestFileStore(t) + + err := s.SaveCachedEmail("user@example.com") + require.NoError(t, err) + + email, err := s.GetCachedEmail() + require.NoError(t, err) + assert.Equal(t, "user@example.com", email) +} + +func TestCachedEmail_MissingFile(t *testing.T) { + s := newTestFileStore(t) + + email, err := s.GetCachedEmail() + require.NoError(t, err) + assert.Equal(t, "", email) +} + +func TestCachedEmail_Overwrites(t *testing.T) { + s := newTestFileStore(t) + + require.NoError(t, s.SaveCachedEmail("first@example.com")) + require.NoError(t, s.SaveCachedEmail("second@example.com")) + + email, err := s.GetCachedEmail() + require.NoError(t, err) + assert.Equal(t, "second@example.com", email) +} diff --git a/pkg/store/organization.go b/pkg/store/organization.go index 3414ba6b..0c4e80ae 100644 --- a/pkg/store/organization.go +++ b/pkg/store/organization.go @@ -130,6 +130,11 @@ func (s AuthHTTPStore) GetOrganizations(options *GetOrganizationsOptions) ([]ent return filteredOrgs, nil } +// GetOrganizationsByName returns organizations matching the given name. +func (s AuthHTTPStore) GetOrganizationsByName(name string) ([]entity.Organization, error) { + return s.GetOrganizations(&GetOrganizationsOptions{Name: name}) +} + func (s AuthHTTPStore) getOrganizations() ([]entity.Organization, error) { var result []entity.Organization res, err := s.authHTTPClient.restyClient.R().