diff --git a/acceptance/ssh/connect-local-cluster/out.test.toml b/acceptance/ssh/connect-local-cluster/out.test.toml new file mode 100644 index 0000000000..d560f1de04 --- /dev/null +++ b/acceptance/ssh/connect-local-cluster/out.test.toml @@ -0,0 +1,5 @@ +Local = true +Cloud = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/ssh/connect-local-cluster/output.txt b/acceptance/ssh/connect-local-cluster/output.txt new file mode 100644 index 0000000000..bf22895059 --- /dev/null +++ b/acceptance/ssh/connect-local-cluster/output.txt @@ -0,0 +1,75 @@ +{ + "method": "GET", + "path": "/api/2.0/preview/scim/v2/Me" +} +{ + "method": "GET", + "path": "/api/2.0/preview/scim/v2/Me" +} +{ + "method": "GET", + "path": "/api/2.0/preview/scim/v2/Me" +} +{ + "method": "GET", + "path": "/api/2.0/secrets/get", + "q": { + "key": "client-private-key", + "scope": "[USERNAME]-test-cluster-123-ssh-tunnel-keys" + } +} +{ + "method": "GET", + "path": "/api/2.0/secrets/get", + "q": { + "key": "client-private-key", + "scope": "[USERNAME]-test-cluster-123-ssh-tunnel-keys" + } +} +{ + "method": "GET", + "path": "/api/2.0/secrets/get", + "q": { + "key": "client-public-key", + "scope": "[USERNAME]-test-cluster-123-ssh-tunnel-keys" + } +} +{ + "method": "GET", + "path": "/api/2.0/secrets/list", + "q": { + "scope": "[USERNAME]-test-cluster-123-ssh-tunnel-keys" + } +} +{ + "method": "GET", + "path": "/api/2.0/secrets/list", + "q": { + "scope": "[USERNAME]-test-cluster-123-ssh-tunnel-keys" + } +} +{ + "method": "POST", + "path": "/api/2.0/secrets/put", + "body": { + "key": "client-private-key", + "scope": "[USERNAME]-test-cluster-123-ssh-tunnel-keys", + "string_value": "[PRIVATE_KEY]" + } +} +{ + "method": "POST", + "path": "/api/2.0/secrets/put", + "body": { + "key": "client-public-key", + "scope": "[USERNAME]-test-cluster-123-ssh-tunnel-keys", + "string_value": "[PUBLIC_KEY]" + } +} +{ + "method": "POST", + "path": "/api/2.0/secrets/scopes/create", + "body": { + "scope": "[USERNAME]-test-cluster-123-ssh-tunnel-keys" + } +} diff --git a/acceptance/ssh/connect-local-cluster/script b/acceptance/ssh/connect-local-cluster/script new file mode 100644 index 0000000000..ca9d86115a --- /dev/null +++ b/acceptance/ssh/connect-local-cluster/script @@ -0,0 +1,8 @@ +# Test the SSH connect flow for a classic cluster with pre-set metadata. +# The --metadata flag bypasses job submission, so this tests: +# cluster state check, secret scope creation, SSH key generation/storage, metadata parsing. +# The flow will fail at SSH spawn (expected, no real SSH server). +errcode $CLI ssh connect --cluster=test-cluster-123 --metadata=root,2222,test-cluster-123 2>LOG.stderr + +# Verify the API sequence: secrets scope creation + key storage +print_requests.py //secrets //scim --get --sort diff --git a/acceptance/ssh/connect-local-cluster/test.toml b/acceptance/ssh/connect-local-cluster/test.toml new file mode 100644 index 0000000000..b0dc2c499d --- /dev/null +++ b/acceptance/ssh/connect-local-cluster/test.toml @@ -0,0 +1,18 @@ +Local = true +Cloud = false +RecordRequests = true +Timeout = "10s" + +# Return a running cluster for the cluster check +[[Server]] +Pattern = "GET /api/2.1/clusters/get" +Response.Body = '{"cluster_id":"test-cluster-123","state":"RUNNING"}' + +# Replace generated RSA key content in request recordings +[[Repls]] +Old = '"-----BEGIN RSA PRIVATE KEY-----\\n[^"]*-----END RSA PRIVATE KEY-----\\n"' +New = '"[PRIVATE_KEY]"' + +[[Repls]] +Old = '"ssh-rsa [A-Za-z0-9+/=]+\\n"' +New = '"[PUBLIC_KEY]"' diff --git a/acceptance/ssh/connect-local-validation/out.test.toml b/acceptance/ssh/connect-local-validation/out.test.toml new file mode 100644 index 0000000000..d560f1de04 --- /dev/null +++ b/acceptance/ssh/connect-local-validation/out.test.toml @@ -0,0 +1,5 @@ +Local = true +Cloud = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/ssh/connect-local-validation/output.txt b/acceptance/ssh/connect-local-validation/output.txt new file mode 100644 index 0000000000..505f5fdd7c --- /dev/null +++ b/acceptance/ssh/connect-local-validation/output.txt @@ -0,0 +1,25 @@ + +=== Accelerator without serverless +Error: --accelerator flag can only be used with serverless compute (--name flag) + +Exit code: 1 + +=== Invalid accelerator value +Error: invalid accelerator value: "CPU_1x", expected "GPU_1xA10" or "GPU_8xH100" + +Exit code: 1 + +=== Invalid connection name +Error: connection name "bad name!" must consist of letters, numbers, dashes, and underscores + +Exit code: 1 + +=== Invalid IDE value +Error: invalid IDE value: "vim", expected "vscode" or "cursor" + +Exit code: 1 + +=== Environment version too low +Error: environment version must be >= 4, got 3 + +Exit code: 1 diff --git a/acceptance/ssh/connect-local-validation/script b/acceptance/ssh/connect-local-validation/script new file mode 100644 index 0000000000..9127167671 --- /dev/null +++ b/acceptance/ssh/connect-local-validation/script @@ -0,0 +1,14 @@ +title "Accelerator without serverless\n" +errcode $CLI ssh connect --cluster abc --accelerator GPU_1xA10 + +title "Invalid accelerator value\n" +errcode $CLI ssh connect --name my-conn --accelerator CPU_1x + +title "Invalid connection name\n" +errcode $CLI ssh connect --name "bad name!" + +title "Invalid IDE value\n" +errcode $CLI ssh connect --cluster abc --ide vim + +title "Environment version too low\n" +errcode $CLI ssh connect --cluster abc --environment-version 3 diff --git a/acceptance/ssh/connect-local-validation/test.toml b/acceptance/ssh/connect-local-validation/test.toml new file mode 100644 index 0000000000..7d36fb9dc1 --- /dev/null +++ b/acceptance/ssh/connect-local-validation/test.toml @@ -0,0 +1,2 @@ +Local = true +Cloud = false diff --git a/acceptance/ssh/connect-serverless-cpu/out.stdout.txt b/acceptance/ssh/connect-serverless-cpu/out.stdout.txt new file mode 100644 index 0000000000..41cae5e7d1 --- /dev/null +++ b/acceptance/ssh/connect-serverless-cpu/out.stdout.txt @@ -0,0 +1 @@ +Connection successful diff --git a/acceptance/ssh/connect-serverless-cpu/out.test.toml b/acceptance/ssh/connect-serverless-cpu/out.test.toml new file mode 100644 index 0000000000..b57de8531d --- /dev/null +++ b/acceptance/ssh/connect-serverless-cpu/out.test.toml @@ -0,0 +1,9 @@ +Local = false +Cloud = false +RequiresUnityCatalog = true + +[CloudEnvs] + gcp = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["direct"] diff --git a/acceptance/ssh/connect-serverless-cpu/output.txt b/acceptance/ssh/connect-serverless-cpu/output.txt new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/acceptance/ssh/connect-serverless-cpu/output.txt @@ -0,0 +1 @@ + diff --git a/acceptance/ssh/connect-serverless-cpu/script b/acceptance/ssh/connect-serverless-cpu/script new file mode 100644 index 0000000000..2c1a89631b --- /dev/null +++ b/acceptance/ssh/connect-serverless-cpu/script @@ -0,0 +1,6 @@ +errcode $CLI ssh connect --name serverless-cpu-test --releases-dir=$CLI_RELEASES_DIR -- "echo 'Connection successful'" >out.stdout.txt 2>LOG.stderr + +if ! grep -q "Connection successful" out.stdout.txt; then + run_id=$(cat LOG.stderr | grep -o "Job submitted successfully with run ID: [0-9]*" | grep -o "[0-9]*$") + trace $CLI jobs get-run "$run_id" > LOG.job +fi diff --git a/acceptance/ssh/connect-serverless-cpu/test.toml b/acceptance/ssh/connect-serverless-cpu/test.toml new file mode 100644 index 0000000000..11bdf952b7 --- /dev/null +++ b/acceptance/ssh/connect-serverless-cpu/test.toml @@ -0,0 +1,12 @@ +Local = false +Cloud = false + +# Serverless CPU is only available in newer environments +RequiresUnityCatalog = true + +# Serverless CPU is not available in GCP yet +[CloudEnvs] + gcp = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["direct"] diff --git a/experimental/ssh/internal/client/client_test.go b/experimental/ssh/internal/client/client_test.go index ef9e6fb53b..da34189711 100644 --- a/experimental/ssh/internal/client/client_test.go +++ b/experimental/ssh/internal/client/client_test.go @@ -169,6 +169,112 @@ func TestGenerateDefaultConnectionNameMatchesRegex(t *testing.T) { } } +func TestFormatMetadata(t *testing.T) { + tests := []struct { + name string + userName string + port int + clusterID string + want string + }{ + { + name: "with cluster ID", + userName: "root", + port: 2222, + clusterID: "abc-123", + want: "root,2222,abc-123", + }, + { + name: "without cluster ID", + userName: "root", + port: 2222, + want: "root,2222", + }, + { + name: "empty userName returns empty", + port: 2222, + want: "", + }, + { + name: "zero port returns empty", + userName: "root", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := client.FormatMetadata(tt.userName, tt.port, tt.clusterID) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestIsServerlessMode(t *testing.T) { + tests := []struct { + name string + opts client.ClientOptions + want bool + }{ + { + name: "cluster only", + opts: client.ClientOptions{ClusterID: "abc-123"}, + want: false, + }, + { + name: "connection name only", + opts: client.ClientOptions{ConnectionName: "my-conn"}, + want: true, + }, + { + name: "both cluster and connection name", + opts: client.ClientOptions{ClusterID: "abc-123", ConnectionName: "my-conn"}, + want: false, + }, + { + name: "neither", + opts: client.ClientOptions{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.opts.IsServerlessMode()) + }) + } +} + +func TestSessionIdentifier(t *testing.T) { + tests := []struct { + name string + opts client.ClientOptions + want string + }{ + { + name: "cluster mode", + opts: client.ClientOptions{ClusterID: "abc-123"}, + want: "abc-123", + }, + { + name: "serverless mode", + opts: client.ClientOptions{ConnectionName: "my-conn"}, + want: "my-conn", + }, + { + name: "both returns cluster ID", + opts: client.ClientOptions{ClusterID: "abc-123", ConnectionName: "my-conn"}, + want: "abc-123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.opts.SessionIdentifier()) + }) + } +} + func TestToProxyCommand(t *testing.T) { exe, err := os.Executable() require.NoError(t, err) diff --git a/experimental/ssh/internal/client/job_test.go b/experimental/ssh/internal/client/job_test.go new file mode 100644 index 0000000000..0ef50b03c2 --- /dev/null +++ b/experimental/ssh/internal/client/job_test.go @@ -0,0 +1,236 @@ +package client + +import ( + "context" + "testing" + "time" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/compute" + "github.com/databricks/databricks-sdk-go/service/iam" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// setupJobTestMocks configures common mocks for job submission tests. +// Returns the mock client and a pointer that will be set with the captured SubmitRun request. +func setupJobTestMocks(t *testing.T, ctx context.Context, runID int64) (*mocks.MockWorkspaceClient, *jobs.SubmitRun) { + t.Helper() + m := mocks.NewMockWorkspaceClient(t) + + m.GetMockCurrentUserAPI().EXPECT().Me(ctx).Return(&iam.User{ + UserName: "testuser@example.com", + }, nil) + + m.GetMockWorkspaceAPI().EXPECT().MkdirsByPath(ctx, mock.AnythingOfType("string")).Return(nil) + m.GetMockWorkspaceAPI().EXPECT().Import(ctx, mock.AnythingOfType("workspace.Import")).Return(nil) + + var capturedRequest jobs.SubmitRun + m.GetMockJobsAPI().EXPECT().Submit(ctx, mock.AnythingOfType("jobs.SubmitRun")). + Run(func(_ context.Context, req jobs.SubmitRun) { + capturedRequest = req + }). + Return(&jobs.WaitGetRunJobTerminatedOrSkipped[jobs.SubmitRunResponse]{RunId: runID}, nil) + + m.GetMockJobsAPI().EXPECT().GetRun(ctx, jobs.GetRunRequest{RunId: runID}).Return(&jobs.Run{ + Tasks: []jobs.RunTask{ + { + TaskKey: sshServerTaskKey, + Status: &jobs.RunStatus{ + State: jobs.RunLifecycleStateV2StateRunning, + }, + }, + }, + }, nil) + + return m, &capturedRequest +} + +func TestSubmitJob_ClassicCluster(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m, captured := setupJobTestMocks(t, ctx, 42) + + opts := ClientOptions{ + ClusterID: "cluster-123", + ServerTimeout: 30 * time.Minute, + TaskStartupTimeout: 5 * time.Minute, + } + + err := submitSSHTunnelJob(ctx, m.WorkspaceClient, "0.1.0", "test-scope", opts) + require.NoError(t, err) + + require.Len(t, captured.Tasks, 1) + task := captured.Tasks[0] + assert.Equal(t, sshServerTaskKey, task.TaskKey) + assert.Equal(t, "cluster-123", task.ExistingClusterId) + assert.Empty(t, task.EnvironmentKey) + assert.Nil(t, task.Compute) + assert.Nil(t, captured.Environments) +} + +func TestSubmitJob_ServerlessGPU(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m, captured := setupJobTestMocks(t, ctx, 43) + + opts := ClientOptions{ + ConnectionName: "gpu-conn", + Accelerator: "GPU_1xA10", + ServerTimeout: 30 * time.Minute, + TaskStartupTimeout: 5 * time.Minute, + } + + err := submitSSHTunnelJob(ctx, m.WorkspaceClient, "0.1.0", "test-scope", opts) + require.NoError(t, err) + + require.Len(t, captured.Tasks, 1) + task := captured.Tasks[0] + assert.Equal(t, sshServerTaskKey, task.TaskKey) + assert.Empty(t, task.ExistingClusterId) + assert.Equal(t, serverlessEnvironmentKey, task.EnvironmentKey) + require.NotNil(t, task.Compute) + assert.Equal(t, compute.HardwareAcceleratorType("GPU_1xA10"), task.Compute.HardwareAccelerator) + + require.Len(t, captured.Environments, 1) + assert.Equal(t, serverlessEnvironmentKey, captured.Environments[0].EnvironmentKey) + require.NotNil(t, captured.Environments[0].Spec) + assert.Equal(t, "4", captured.Environments[0].Spec.EnvironmentVersion) +} + +func TestSubmitJob_ServerlessCPU(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m, captured := setupJobTestMocks(t, ctx, 44) + + opts := ClientOptions{ + ConnectionName: "cpu-conn", + ServerTimeout: 30 * time.Minute, + TaskStartupTimeout: 5 * time.Minute, + } + + err := submitSSHTunnelJob(ctx, m.WorkspaceClient, "0.1.0", "test-scope", opts) + require.NoError(t, err) + + require.Len(t, captured.Tasks, 1) + task := captured.Tasks[0] + assert.Equal(t, sshServerTaskKey, task.TaskKey) + assert.Empty(t, task.ExistingClusterId) + assert.Equal(t, serverlessEnvironmentKey, task.EnvironmentKey) + assert.Nil(t, task.Compute) + + require.Len(t, captured.Environments, 1) + assert.Equal(t, serverlessEnvironmentKey, captured.Environments[0].EnvironmentKey) +} + +func TestSubmitJob_NotebookUpload(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + m.GetMockCurrentUserAPI().EXPECT().Me(ctx).Return(&iam.User{ + UserName: "testuser@example.com", + }, nil) + + m.GetMockWorkspaceAPI().EXPECT().MkdirsByPath(ctx, mock.AnythingOfType("string")).Return(nil) + + var capturedImport workspace.Import + m.GetMockWorkspaceAPI().EXPECT().Import(ctx, mock.AnythingOfType("workspace.Import")). + Run(func(_ context.Context, req workspace.Import) { + capturedImport = req + }). + Return(nil) + + m.GetMockJobsAPI().EXPECT().Submit(ctx, mock.AnythingOfType("jobs.SubmitRun")). + Return(&jobs.WaitGetRunJobTerminatedOrSkipped[jobs.SubmitRunResponse]{RunId: 45}, nil) + + m.GetMockJobsAPI().EXPECT().GetRun(ctx, jobs.GetRunRequest{RunId: 45}).Return(&jobs.Run{ + Tasks: []jobs.RunTask{ + { + TaskKey: sshServerTaskKey, + Status: &jobs.RunStatus{State: jobs.RunLifecycleStateV2StateRunning}, + }, + }, + }, nil) + + opts := ClientOptions{ + ClusterID: "cluster-123", + ServerTimeout: 30 * time.Minute, + TaskStartupTimeout: 5 * time.Minute, + } + + err := submitSSHTunnelJob(ctx, m.WorkspaceClient, "0.1.0", "test-scope", opts) + require.NoError(t, err) + + assert.Contains(t, capturedImport.Path, "ssh-server-bootstrap") + assert.Equal(t, workspace.ImportFormatSource, capturedImport.Format) + assert.Equal(t, workspace.LanguagePython, capturedImport.Language) + assert.True(t, capturedImport.Overwrite) + assert.NotEmpty(t, capturedImport.Content) +} + +func TestSubmitJob_TaskParameters(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m, captured := setupJobTestMocks(t, ctx, 46) + + opts := ClientOptions{ + ClusterID: "cluster-123", + ShutdownDelay: 10 * time.Minute, + MaxClients: 5, + ServerTimeout: 30 * time.Minute, + TaskStartupTimeout: 5 * time.Minute, + } + + err := submitSSHTunnelJob(ctx, m.WorkspaceClient, "0.1.0", "test-scope", opts) + require.NoError(t, err) + + task := captured.Tasks[0] + params := task.NotebookTask.BaseParameters + assert.Equal(t, "0.1.0", params["version"]) + assert.Equal(t, "test-scope", params["secretScopeName"]) + assert.Equal(t, "10m0s", params["shutdownDelay"]) + assert.Equal(t, "5", params["maxClients"]) + assert.Equal(t, "cluster-123", params["sessionId"]) + assert.Equal(t, "false", params["serverless"]) +} + +func TestSubmitJob_ServerlessParameters(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m, captured := setupJobTestMocks(t, ctx, 47) + + opts := ClientOptions{ + ConnectionName: "gpu-conn", + Accelerator: "GPU_8xH100", + ServerTimeout: 30 * time.Minute, + TaskStartupTimeout: 5 * time.Minute, + } + + err := submitSSHTunnelJob(ctx, m.WorkspaceClient, "0.1.0", "test-scope", opts) + require.NoError(t, err) + + task := captured.Tasks[0] + params := task.NotebookTask.BaseParameters + assert.Equal(t, "gpu-conn", params["sessionId"]) + assert.Equal(t, "true", params["serverless"]) + + require.NotNil(t, task.Compute) + assert.Equal(t, compute.HardwareAcceleratorType("GPU_8xH100"), task.Compute.HardwareAccelerator) +} + +func TestSubmitJob_CustomEnvironmentVersion(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m, captured := setupJobTestMocks(t, ctx, 48) + + opts := ClientOptions{ + ConnectionName: "my-conn", + EnvironmentVersion: 7, + ServerTimeout: 30 * time.Minute, + TaskStartupTimeout: 5 * time.Minute, + } + + err := submitSSHTunnelJob(ctx, m.WorkspaceClient, "0.1.0", "test-scope", opts) + require.NoError(t, err) + + require.Len(t, captured.Environments, 1) + assert.Equal(t, "7", captured.Environments[0].Spec.EnvironmentVersion) +} diff --git a/experimental/ssh/internal/client/run_test.go b/experimental/ssh/internal/client/run_test.go new file mode 100644 index 0000000000..8faee9c1e4 --- /dev/null +++ b/experimental/ssh/internal/client/run_test.go @@ -0,0 +1,192 @@ +package client + +import ( + "encoding/base64" + "errors" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/compute" + "github.com/databricks/databricks-sdk-go/service/iam" + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const ( + testPrivateKeyName = "client-private-key" + testPublicKeyName = "client-public-key" +) + +// mockSecretsForRun sets up secrets mocks (scope exists, keys exist) for a Run() test. +// Run() wraps the context with WithCancel, so all mocks must use mock.Anything for context. +func mockSecretsForRun(m *mocks.MockWorkspaceClient, sessionID string) { + scopeName := "testuser@example.com-" + sessionID + "-ssh-tunnel-keys" + + m.GetMockCurrentUserAPI().EXPECT().Me(mock.Anything).Return(&iam.User{ + UserName: "testuser@example.com", + }, nil) + + m.GetMockSecretsAPI().EXPECT().ListSecretsByScope(mock.Anything, scopeName). + Return(&workspace.ListSecretsResponse{}, nil) + + privKey := base64.StdEncoding.EncodeToString([]byte("fake-private-key")) + pubKey := base64.StdEncoding.EncodeToString([]byte("fake-public-key")) + + m.GetMockSecretsAPI().EXPECT().GetSecret(mock.Anything, workspace.GetSecretRequest{ + Scope: scopeName, + Key: testPrivateKeyName, + }).Return(&workspace.GetSecretResponse{Value: privKey}, nil) + + m.GetMockSecretsAPI().EXPECT().GetSecret(mock.Anything, workspace.GetSecretRequest{ + Scope: scopeName, + Key: testPublicKeyName, + }).Return(&workspace.GetSecretResponse{Value: pubKey}, nil) +} + +func TestRun_EmptySessionID(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + err := Run(ctx, m.WorkspaceClient, ClientOptions{ + ProxyMode: true, + }) + assert.EqualError(t, err, "either --cluster or --name must be provided") +} + +// In Run(), the order for classic clusters is: cluster check -> secrets -> keys -> metadata -> SSH. +// Tests below mock only the calls that happen before the expected failure point. + +func TestRun_ClassicCluster_ClusterCheckFails(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + // Cluster check happens before secrets. No secrets mock needed. + m.GetMockClustersAPI().EXPECT().GetByClusterId(mock.Anything, "cluster-123"). + Return(nil, errors.New("cluster not found")) + + err := Run(ctx, m.WorkspaceClient, ClientOptions{ + ClusterID: "cluster-123", + ServerMetadata: "root,2222,cluster-123", + SSHKeysDir: t.TempDir(), + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to get cluster info") +} + +func TestRun_ClassicCluster_ClusterNotRunning(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + m.GetMockClustersAPI().EXPECT().GetByClusterId(mock.Anything, "cluster-123"). + Return(&compute.ClusterDetails{State: compute.StateTerminated}, nil) + + err := Run(ctx, m.WorkspaceClient, ClientOptions{ + ClusterID: "cluster-123", + ServerMetadata: "root,2222,cluster-123", + SSHKeysDir: t.TempDir(), + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "is not running") + assert.Contains(t, err.Error(), "--auto-start-cluster") +} + +func TestRun_SecretScopeCreationFails(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + // Use serverless mode to skip cluster check, making this test purely about secrets + scopeName := "testuser@example.com-my-conn-ssh-tunnel-keys" + m.GetMockCurrentUserAPI().EXPECT().Me(mock.Anything).Return(&iam.User{ + UserName: "testuser@example.com", + }, nil) + + m.GetMockSecretsAPI().EXPECT().ListSecretsByScope(mock.Anything, scopeName). + Return(nil, databricks.ErrResourceDoesNotExist) + m.GetMockSecretsAPI().EXPECT().CreateScope(mock.Anything, workspace.CreateScope{Scope: scopeName}). + Return(errors.New("limit exceeded")) + + err := Run(ctx, m.WorkspaceClient, ClientOptions{ + ConnectionName: "my-conn", + ServerMetadata: "root,2222,cluster-123", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to create secret scope") +} + +func TestRun_ServerlessSkipsClusterCheck(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + mockSecretsForRun(m, "my-conn") + + // No cluster mock — if Run tries to check cluster state, the mock will panic. + // Using metadata without cluster ID: serverless requires it, so Run fails after secrets. + err := Run(ctx, m.WorkspaceClient, ClientOptions{ + ConnectionName: "my-conn", + ClientPrivateKeyName: testPrivateKeyName, + ClientPublicKeyName: testPublicKeyName, + ServerMetadata: "root,2222", + SSHKeysDir: t.TempDir(), + }) + assert.Error(t, err) + // Fails at serverless cluster ID check, proving cluster check was skipped + assert.Contains(t, err.Error(), "cluster ID is required for serverless connections") + assert.NotContains(t, err.Error(), "failed to get cluster") +} + +func TestRun_InvalidMetadataFormat(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + // Cluster check passes, secrets pass, then metadata parsing fails + m.GetMockClustersAPI().EXPECT().GetByClusterId(mock.Anything, "cluster-123"). + Return(&compute.ClusterDetails{State: compute.StateRunning}, nil) + mockSecretsForRun(m, "cluster-123") + + err := Run(ctx, m.WorkspaceClient, ClientOptions{ + ClusterID: "cluster-123", + ClientPrivateKeyName: testPrivateKeyName, + ClientPublicKeyName: testPublicKeyName, + ServerMetadata: "badformat", + SSHKeysDir: t.TempDir(), + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid metadata") +} + +func TestRun_EmptyUserNameInMetadata(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + m.GetMockClustersAPI().EXPECT().GetByClusterId(mock.Anything, "cluster-123"). + Return(&compute.ClusterDetails{State: compute.StateRunning}, nil) + mockSecretsForRun(m, "cluster-123") + + err := Run(ctx, m.WorkspaceClient, ClientOptions{ + ClusterID: "cluster-123", + ClientPrivateKeyName: testPrivateKeyName, + ClientPublicKeyName: testPublicKeyName, + ServerMetadata: ",2222,cluster-123", + SSHKeysDir: t.TempDir(), + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "remote user name is empty") +} + +func TestRun_ServerlessMissingClusterID(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + mockSecretsForRun(m, "my-conn") + + err := Run(ctx, m.WorkspaceClient, ClientOptions{ + ConnectionName: "my-conn", + ClientPrivateKeyName: testPrivateKeyName, + ClientPublicKeyName: testPublicKeyName, + ServerMetadata: "root,2222", + SSHKeysDir: t.TempDir(), + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "cluster ID is required for serverless connections") +} diff --git a/experimental/ssh/internal/keys/keys_test.go b/experimental/ssh/internal/keys/keys_test.go new file mode 100644 index 0000000000..d920078971 --- /dev/null +++ b/experimental/ssh/internal/keys/keys_test.go @@ -0,0 +1,289 @@ +package keys + +import ( + "crypto/x509" + "encoding/base64" + "encoding/pem" + "errors" + "os" + "path/filepath" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/iam" + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestGetLocalSSHKeyPath_DefaultDir(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + path, err := GetLocalSSHKeyPath(t.Context(), "cluster-123", "") + require.NoError(t, err) + assert.Equal(t, filepath.Join(tmpDir, ".databricks", "ssh-tunnel-keys", "cluster-123"), path) +} + +func TestGetLocalSSHKeyPath_CustomDir(t *testing.T) { + customDir := "/custom/keys/dir" + path, err := GetLocalSSHKeyPath(t.Context(), "my-session", customDir) + require.NoError(t, err) + assert.Equal(t, filepath.Join(customDir, "my-session"), path) +} + +func TestSaveSSHKeyPair(t *testing.T) { + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "session-1", "id_rsa") + + privateKey := []byte("fake-private-key") + publicKey := []byte("fake-public-key") + + err := SaveSSHKeyPair(keyPath, privateKey, publicKey) + require.NoError(t, err) + + // Verify private key + privData, err := os.ReadFile(keyPath) + require.NoError(t, err) + assert.Equal(t, privateKey, privData) + + privInfo, err := os.Stat(keyPath) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0o600), privInfo.Mode().Perm()) + + // Verify public key + pubData, err := os.ReadFile(keyPath + ".pub") + require.NoError(t, err) + assert.Equal(t, publicKey, pubData) + + pubInfo, err := os.Stat(keyPath + ".pub") + require.NoError(t, err) + assert.Equal(t, os.FileMode(0o644), pubInfo.Mode().Perm()) +} + +func TestSaveSSHKeyPair_OverwritesExisting(t *testing.T) { + tmpDir := t.TempDir() + keyDir := filepath.Join(tmpDir, "session-1") + keyPath := filepath.Join(keyDir, "id_rsa") + + // Create existing keys + require.NoError(t, os.MkdirAll(keyDir, 0o700)) + require.NoError(t, os.WriteFile(keyPath, []byte("old-private"), 0o600)) + require.NoError(t, os.WriteFile(keyPath+".pub", []byte("old-public"), 0o644)) + + err := SaveSSHKeyPair(keyPath, []byte("new-private"), []byte("new-public")) + require.NoError(t, err) + + privData, err := os.ReadFile(keyPath) + require.NoError(t, err) + assert.Equal(t, []byte("new-private"), privData) + + pubData, err := os.ReadFile(keyPath + ".pub") + require.NoError(t, err) + assert.Equal(t, []byte("new-public"), pubData) +} + +func TestGenerateSSHKeyPair(t *testing.T) { + privateKeyBytes, publicKeyBytes, err := generateSSHKeyPair() + require.NoError(t, err) + + // Verify private key is valid PEM-encoded RSA key + block, _ := pem.Decode(privateKeyBytes) + require.NotNil(t, block, "expected PEM block for private key") + assert.Equal(t, "RSA PRIVATE KEY", block.Type) + + _, err = x509.ParsePKCS1PrivateKey(block.Bytes) + require.NoError(t, err) + + // Verify public key is in authorized_keys format + assert.Contains(t, string(publicKeyBytes), "ssh-rsa ") +} + +func TestCreateKeysSecretScope_ScopeAlreadyExists(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + currentUserAPI := m.GetMockCurrentUserAPI() + currentUserAPI.EXPECT().Me(ctx).Return(&iam.User{UserName: "testuser@example.com"}, nil) + + secretsAPI := m.GetMockSecretsAPI() + secretsAPI.EXPECT().ListSecretsByScope(ctx, "testuser@example.com-cluster-123-ssh-tunnel-keys"). + Return(&workspace.ListSecretsResponse{}, nil) + + scopeName, err := CreateKeysSecretScope(ctx, m.WorkspaceClient, "cluster-123") + require.NoError(t, err) + assert.Equal(t, "testuser@example.com-cluster-123-ssh-tunnel-keys", scopeName) +} + +func TestCreateKeysSecretScope_ScopeDoesNotExist(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + currentUserAPI := m.GetMockCurrentUserAPI() + currentUserAPI.EXPECT().Me(ctx).Return(&iam.User{UserName: "testuser@example.com"}, nil) + + scopeName := "testuser@example.com-my-conn-ssh-tunnel-keys" + + secretsAPI := m.GetMockSecretsAPI() + secretsAPI.EXPECT().ListSecretsByScope(ctx, scopeName). + Return(nil, databricks.ErrResourceDoesNotExist) + secretsAPI.EXPECT().CreateScope(ctx, workspace.CreateScope{Scope: scopeName}). + Return(nil) + + result, err := CreateKeysSecretScope(ctx, m.WorkspaceClient, "my-conn") + require.NoError(t, err) + assert.Equal(t, scopeName, result) +} + +func TestCreateKeysSecretScope_ListError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + currentUserAPI := m.GetMockCurrentUserAPI() + currentUserAPI.EXPECT().Me(ctx).Return(&iam.User{UserName: "testuser@example.com"}, nil) + + secretsAPI := m.GetMockSecretsAPI() + secretsAPI.EXPECT().ListSecretsByScope(ctx, "testuser@example.com-cluster-123-ssh-tunnel-keys"). + Return(nil, errors.New("permission denied")) + + _, err := CreateKeysSecretScope(ctx, m.WorkspaceClient, "cluster-123") + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to check if secret scope") + assert.Contains(t, err.Error(), "permission denied") +} + +func TestCreateKeysSecretScope_CreateError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + currentUserAPI := m.GetMockCurrentUserAPI() + currentUserAPI.EXPECT().Me(ctx).Return(&iam.User{UserName: "testuser@example.com"}, nil) + + scopeName := "testuser@example.com-cluster-123-ssh-tunnel-keys" + + secretsAPI := m.GetMockSecretsAPI() + secretsAPI.EXPECT().ListSecretsByScope(ctx, scopeName). + Return(nil, databricks.ErrResourceDoesNotExist) + secretsAPI.EXPECT().CreateScope(ctx, workspace.CreateScope{Scope: scopeName}). + Return(errors.New("limit exceeded")) + + _, err := CreateKeysSecretScope(ctx, m.WorkspaceClient, "cluster-123") + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to create secrets scope") + assert.Contains(t, err.Error(), "limit exceeded") +} + +func TestGetSecret_DecodesBase64(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + encoded := base64.StdEncoding.EncodeToString([]byte("my-secret-value")) + + secretsAPI := m.GetMockSecretsAPI() + secretsAPI.EXPECT().GetSecret(ctx, workspace.GetSecretRequest{ + Scope: "my-scope", + Key: "my-key", + }).Return(&workspace.GetSecretResponse{ + Key: "my-key", + Value: encoded, + }, nil) + + value, err := GetSecret(ctx, m.WorkspaceClient, "my-scope", "my-key") + require.NoError(t, err) + assert.Equal(t, []byte("my-secret-value"), value) +} + +func TestGetSecret_Error(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + secretsAPI := m.GetMockSecretsAPI() + secretsAPI.EXPECT().GetSecret(ctx, workspace.GetSecretRequest{ + Scope: "my-scope", + Key: "my-key", + }).Return(nil, errors.New("not found")) + + _, err := GetSecret(ctx, m.WorkspaceClient, "my-scope", "my-key") + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to get secret my-key from scope my-scope") +} + +func TestCheckAndGenerateSSHKeyPairFromSecrets_ExistingKeys(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + privKeyContent := []byte("private-key-content") + pubKeyContent := []byte("public-key-content") + + secretsAPI := m.GetMockSecretsAPI() + secretsAPI.EXPECT().GetSecret(ctx, workspace.GetSecretRequest{ + Scope: "my-scope", + Key: "client-private-key", + }).Return(&workspace.GetSecretResponse{ + Value: base64.StdEncoding.EncodeToString(privKeyContent), + }, nil) + secretsAPI.EXPECT().GetSecret(ctx, workspace.GetSecretRequest{ + Scope: "my-scope", + Key: "client-public-key", + }).Return(&workspace.GetSecretResponse{ + Value: base64.StdEncoding.EncodeToString(pubKeyContent), + }, nil) + + privBytes, pubBytes, err := CheckAndGenerateSSHKeyPairFromSecrets(ctx, m.WorkspaceClient, "my-scope", "client-private-key", "client-public-key") + require.NoError(t, err) + assert.Equal(t, privKeyContent, privBytes) + assert.Equal(t, pubKeyContent, pubBytes) +} + +func TestCheckAndGenerateSSHKeyPairFromSecrets_GeneratesNewKeys(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + secretsAPI := m.GetMockSecretsAPI() + // First GetSecret fails (no existing private key) + secretsAPI.EXPECT().GetSecret(ctx, workspace.GetSecretRequest{ + Scope: "my-scope", + Key: "client-private-key", + }).Return(nil, errors.New("not found")) + + // Expect both PutSecret calls for new keys (use mock.MatchedBy since key values are generated) + secretsAPI.EXPECT().PutSecret(ctx, mock.MatchedBy(func(req workspace.PutSecret) bool { + return req.Scope == "my-scope" && req.Key == "client-private-key" && req.StringValue != "" + })).Return(nil) + secretsAPI.EXPECT().PutSecret(ctx, mock.MatchedBy(func(req workspace.PutSecret) bool { + return req.Scope == "my-scope" && req.Key == "client-public-key" && req.StringValue != "" + })).Return(nil) + + privBytes, pubBytes, err := CheckAndGenerateSSHKeyPairFromSecrets(ctx, m.WorkspaceClient, "my-scope", "client-private-key", "client-public-key") + require.NoError(t, err) + + // Verify the generated keys are valid + block, _ := pem.Decode(privBytes) + require.NotNil(t, block) + assert.Equal(t, "RSA PRIVATE KEY", block.Type) + + assert.Contains(t, string(pubBytes), "ssh-rsa ") +} + +func TestCheckAndGenerateSSHKeyPairFromSecrets_PutError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + secretsAPI := m.GetMockSecretsAPI() + secretsAPI.EXPECT().GetSecret(ctx, workspace.GetSecretRequest{ + Scope: "my-scope", + Key: "client-private-key", + }).Return(nil, errors.New("not found")) + + secretsAPI.EXPECT().PutSecret(ctx, mock.MatchedBy(func(req workspace.PutSecret) bool { + return req.Scope == "my-scope" && req.Key == "client-private-key" + })).Return(errors.New("quota exceeded")) + + _, _, err := CheckAndGenerateSSHKeyPairFromSecrets(ctx, m.WorkspaceClient, "my-scope", "client-private-key", "client-public-key") + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to store secret client-private-key") +} diff --git a/experimental/ssh/internal/workspace/workspace_test.go b/experimental/ssh/internal/workspace/workspace_test.go new file mode 100644 index 0000000000..853ec90131 --- /dev/null +++ b/experimental/ssh/internal/workspace/workspace_test.go @@ -0,0 +1,102 @@ +package workspace + +import ( + "encoding/json" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/iam" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetWorkspaceContentDir(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + m.GetMockCurrentUserAPI().EXPECT().Me(ctx).Return(&iam.User{ + UserName: "testuser@example.com", + }, nil) + + dir, err := GetWorkspaceContentDir(ctx, m.WorkspaceClient, "0.1.0", "cluster-123") + require.NoError(t, err) + assert.Equal(t, "/Workspace/Users/testuser@example.com/.databricks/ssh-tunnel/0.1.0/cluster-123", dir) +} + +func TestGetWorkspaceVersionedDir(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + m.GetMockCurrentUserAPI().EXPECT().Me(ctx).Return(&iam.User{ + UserName: "testuser@example.com", + }, nil) + + dir, err := GetWorkspaceVersionedDir(ctx, m.WorkspaceClient, "0.2.0") + require.NoError(t, err) + assert.Equal(t, "/Workspace/Users/testuser@example.com/.databricks/ssh-tunnel/0.2.0", dir) +} + +func TestGetWorkspaceContentDir_ServerlessSession(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + m.GetMockCurrentUserAPI().EXPECT().Me(ctx).Return(&iam.User{ + UserName: "testuser@example.com", + }, nil) + + dir, err := GetWorkspaceContentDir(ctx, m.WorkspaceClient, "0.1.0", "databricks-gpu-1xa10-961dabbd") + require.NoError(t, err) + assert.Equal(t, "/Workspace/Users/testuser@example.com/.databricks/ssh-tunnel/0.1.0/databricks-gpu-1xa10-961dabbd", dir) +} + +func TestGetWorkspaceContentDir_CurrentUserError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + + m.GetMockCurrentUserAPI().EXPECT().Me(ctx).Return(nil, assert.AnError) + + _, err := GetWorkspaceContentDir(ctx, m.WorkspaceClient, "0.1.0", "cluster-123") + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to get current user") +} + +func TestWorkspaceMetadata_JSON(t *testing.T) { + tests := []struct { + name string + metadata WorkspaceMetadata + json string + }{ + { + name: "with cluster ID", + metadata: WorkspaceMetadata{Port: 2222, ClusterID: "abc-123"}, + json: `{"port":2222,"cluster_id":"abc-123"}`, + }, + { + name: "without cluster ID", + metadata: WorkspaceMetadata{Port: 3333}, + json: `{"port":3333}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Marshal + data, err := json.Marshal(tt.metadata) + require.NoError(t, err) + assert.JSONEq(t, tt.json, string(data)) + + // Unmarshal + var parsed WorkspaceMetadata + err = json.Unmarshal([]byte(tt.json), &parsed) + require.NoError(t, err) + assert.Equal(t, tt.metadata, parsed) + }) + } +} + +func TestWorkspaceMetadata_InvalidJSON(t *testing.T) { + var metadata WorkspaceMetadata + err := json.Unmarshal([]byte("not json"), &metadata) + assert.Error(t, err) +} diff --git a/libs/testserver/handlers.go b/libs/testserver/handlers.go index b2a95b1902..5f640b503b 100644 --- a/libs/testserver/handlers.go +++ b/libs/testserver/handlers.go @@ -261,6 +261,10 @@ func AddDefaultHandlers(server *Server) { return req.Workspace.JobsGetRunOutput(req) }) + server.Handle("POST", "/api/2.2/jobs/runs/submit", func(req Request) any { + return req.Workspace.JobsSubmit(req) + }) + server.Handle("GET", "/api/2.2/jobs/runs/list", func(req Request) any { return MapList(req.Workspace, req.Workspace.JobRuns, "runs") }) @@ -613,6 +617,10 @@ func AddDefaultHandlers(server *Server) { }) // Secrets: + server.Handle("GET", "/api/2.0/secrets/list", func(req Request) any { + return req.Workspace.SecretsListSecrets(req) + }) + server.Handle("POST", "/api/2.0/secrets/put", func(req Request) any { return req.Workspace.SecretsPut(req) }) diff --git a/libs/testserver/jobs.go b/libs/testserver/jobs.go index 15800341de..3a9f2898d2 100644 --- a/libs/testserver/jobs.go +++ b/libs/testserver/jobs.go @@ -62,6 +62,44 @@ func (s *FakeWorkspace) JobsCreate(req Request) Response { return Response{Body: jobs.CreateResponse{JobId: jobId}} } +func (s *FakeWorkspace) JobsSubmit(req Request) Response { + var request jobs.SubmitRun + if err := json.Unmarshal(req.Body, &request); err != nil { + return Response{ + StatusCode: 400, + Body: fmt.Sprintf("request parsing error: %s", err), + } + } + + defer s.LockUnlock()() + + runId := nextID() + + var tasks []jobs.RunTask + for _, t := range request.Tasks { + tasks = append(tasks, jobs.RunTask{ + RunId: nextID(), + TaskKey: t.TaskKey, + State: &jobs.RunState{ + LifeCycleState: jobs.RunLifeCycleStateRunning, + }, + Status: &jobs.RunStatus{ + State: jobs.RunLifecycleStateV2StateRunning, + }, + }) + } + + s.JobRuns[runId] = jobs.Run{ + RunId: runId, + State: &jobs.RunState{LifeCycleState: jobs.RunLifeCycleStateRunning}, + RunType: jobs.RunTypeSubmitRun, + RunName: request.RunName, + Tasks: tasks, + } + + return Response{Body: jobs.SubmitRunResponse{RunId: runId}} +} + func (s *FakeWorkspace) JobsReset(req Request) Response { var request jobs.ResetJob if err := json.Unmarshal(req.Body, &request); err != nil { diff --git a/libs/testserver/secret_scopes.go b/libs/testserver/secret_scopes.go index 8dab821738..90fb382bee 100644 --- a/libs/testserver/secret_scopes.go +++ b/libs/testserver/secret_scopes.go @@ -84,7 +84,10 @@ func (s *FakeWorkspace) SecretsDeleteScope(req Request) Response { if _, exists := s.SecretScopes[request.Scope]; !exists { return Response{ StatusCode: 404, - Body: map[string]string{"message": fmt.Sprintf("Scope %s does not exist", request.Scope)}, + Body: map[string]string{ + "error_code": "RESOURCE_DOES_NOT_EXIST", + "message": fmt.Sprintf("Scope %s does not exist", request.Scope), + }, } } @@ -105,7 +108,10 @@ func (s *FakeWorkspace) SecretsListAcls(req Request) Response { if _, exists := s.SecretScopes[scope]; !exists { return Response{ StatusCode: 404, - Body: map[string]string{"message": fmt.Sprintf("Scope %s does not exist", scope)}, + Body: map[string]string{ + "error_code": "RESOURCE_DOES_NOT_EXIST", + "message": fmt.Sprintf("Scope %s does not exist", scope), + }, } } diff --git a/libs/testserver/secrets.go b/libs/testserver/secrets.go index 1f298ea9f0..73efa17086 100644 --- a/libs/testserver/secrets.go +++ b/libs/testserver/secrets.go @@ -23,7 +23,10 @@ func (s *FakeWorkspace) SecretsPut(req Request) Response { if _, exists := s.SecretScopes[request.Scope]; !exists { return Response{ StatusCode: 404, - Body: map[string]string{"message": fmt.Sprintf("Scope %s does not exist", request.Scope)}, + Body: map[string]string{ + "error_code": "RESOURCE_DOES_NOT_EXIST", + "message": fmt.Sprintf("Scope %s does not exist", request.Scope), + }, } } @@ -40,6 +43,35 @@ func (s *FakeWorkspace) SecretsPut(req Request) Response { return Response{} } +func (s *FakeWorkspace) SecretsListSecrets(req Request) Response { + defer s.LockUnlock()() + + scope := req.URL.Query().Get("scope") + + if _, exists := s.SecretScopes[scope]; !exists { + return Response{ + StatusCode: 404, + Body: map[string]string{ + "error_code": "RESOURCE_DOES_NOT_EXIST", + "message": fmt.Sprintf("Scope %s does not exist", scope), + }, + } + } + + var secrets []workspace.SecretMetadata + if s.Secrets != nil && s.Secrets[scope] != nil { + for key := range s.Secrets[scope] { + secrets = append(secrets, workspace.SecretMetadata{Key: key}) + } + } + + return Response{ + Body: workspace.ListSecretsResponse{ + Secrets: secrets, + }, + } +} + func (s *FakeWorkspace) SecretsGet(req Request) Response { defer s.LockUnlock()() @@ -50,7 +82,10 @@ func (s *FakeWorkspace) SecretsGet(req Request) Response { if _, exists := s.SecretScopes[scope]; !exists { return Response{ StatusCode: 404, - Body: map[string]string{"message": fmt.Sprintf("Scope %s does not exist", scope)}, + Body: map[string]string{ + "error_code": "RESOURCE_DOES_NOT_EXIST", + "message": fmt.Sprintf("Scope %s does not exist", scope), + }, } } @@ -58,7 +93,10 @@ func (s *FakeWorkspace) SecretsGet(req Request) Response { if s.Secrets == nil || s.Secrets[scope] == nil { return Response{ StatusCode: 404, - Body: map[string]string{"message": fmt.Sprintf("Secret %s/%s not found", scope, key)}, + Body: map[string]string{ + "error_code": "RESOURCE_DOES_NOT_EXIST", + "message": fmt.Sprintf("Secret %s/%s not found", scope, key), + }, } } @@ -66,7 +104,10 @@ func (s *FakeWorkspace) SecretsGet(req Request) Response { if !exists { return Response{ StatusCode: 404, - Body: map[string]string{"message": fmt.Sprintf("Secret %s/%s not found", scope, key)}, + Body: map[string]string{ + "error_code": "RESOURCE_DOES_NOT_EXIST", + "message": fmt.Sprintf("Secret %s/%s not found", scope, key), + }, } }