diff --git a/cmd/workspace/clusters/overrides.go b/cmd/workspace/clusters/overrides.go index 45c530a14a2..c15b56ffcb3 100644 --- a/cmd/workspace/clusters/overrides.go +++ b/cmd/workspace/clusters/overrides.go @@ -1,9 +1,13 @@ package clusters import ( + "errors" + "net/http" "strings" + "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go/apierr" "github.com/databricks/databricks-sdk-go/service/compute" "github.com/spf13/cobra" ) @@ -93,8 +97,43 @@ func sparkVersionsOverride(sparkVersionsCmd *cobra.Command) { `) } +func startOverride(startCmd *cobra.Command, startReq *compute.StartCluster) { + run := startCmd.RunE + startCmd.RunE = func(cmd *cobra.Command, args []string) error { + err := run(cmd, args) + if err == nil || !isInvalidState(err) || startReq.ClusterId == "" { + return err + } + + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + cluster, getErr := w.Clusters.Get(ctx, compute.GetClusterRequest{ + ClusterId: startReq.ClusterId, + }) + if getErr != nil { + return err + } + isJobCluster := cluster.ClusterSource == compute.ClusterSourceJob + isNonTerminated := cluster.State != "" && cluster.State != compute.StateTerminated + // The Start API returns INVALID_STATE for already-running clusters, + // while the CLI help documents non-TERMINATED clusters as a no-op. + // The help separately documents that job clusters cannot be started. + if isNonTerminated && !isJobCluster { + return nil + } + + return err + } +} + +func isInvalidState(err error) bool { + apiErr, ok := errors.AsType[*apierr.APIError](err) + return ok && apiErr.StatusCode == http.StatusBadRequest && apiErr.ErrorCode == "INVALID_STATE" +} + func init() { listOverrides = append(listOverrides, listOverride) listNodeTypesOverrides = append(listNodeTypesOverrides, listNodeTypesOverride) sparkVersionsOverrides = append(sparkVersionsOverrides, sparkVersionsOverride) + startOverrides = append(startOverrides, startOverride) } diff --git a/cmd/workspace/clusters/overrides_test.go b/cmd/workspace/clusters/overrides_test.go new file mode 100644 index 00000000000..ec87a323994 --- /dev/null +++ b/cmd/workspace/clusters/overrides_test.go @@ -0,0 +1,182 @@ +package clusters + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/testserver" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/compute" + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newStartTestCommand(t *testing.T, server *testserver.Server) *cobra.Command { + t.Helper() + + w, err := databricks.NewWorkspaceClient(&databricks.Config{ + Host: server.URL, + Token: "token", + }) + require.NoError(t, err) + + cmd := newStart() + ctx := cmdio.MockDiscard(t.Context()) + ctx = cmdctx.SetWorkspaceClient(ctx, w) + cmd.SetContext(ctx) + return cmd +} + +func TestStartRunningClusterIsNoOp(t *testing.T) { + const clusterId = "abc" + + var startRequests int + var getRequests int + + server := testserver.New(t) + server.Handle("POST", "/api/2.1/clusters/start", func(req testserver.Request) any { + startRequests++ + var request compute.StartCluster + require.NoError(t, json.Unmarshal(req.Body, &request)) + assert.Equal(t, clusterId, request.ClusterId) + + return testserver.Response{ + StatusCode: http.StatusBadRequest, + Body: map[string]string{ + "error_code": "INVALID_STATE", + "message": "Cluster abc is in unexpected state Running.", + }, + } + }) + server.Handle("GET", "/api/2.1/clusters/get", func(req testserver.Request) any { + getRequests++ + assert.Equal(t, clusterId, req.URL.Query().Get("cluster_id")) + + return compute.ClusterDetails{ + ClusterId: clusterId, + State: compute.StateRunning, + } + }) + + cmd := newStartTestCommand(t, server) + err := cmd.RunE(cmd, []string{clusterId}) + require.NoError(t, err) + assert.Equal(t, 1, startRequests) + assert.Equal(t, 1, getRequests) +} + +func TestStartInvalidStatePreservesErrorForTerminatedCluster(t *testing.T) { + const clusterId = "abc" + + server := testserver.New(t) + server.Handle("POST", "/api/2.1/clusters/start", func(req testserver.Request) any { + return testserver.Response{ + StatusCode: http.StatusBadRequest, + Body: map[string]string{ + "error_code": "INVALID_STATE", + "message": "Cluster abc cannot be started.", + }, + } + }) + server.Handle("GET", "/api/2.1/clusters/get", func(req testserver.Request) any { + return compute.ClusterDetails{ + ClusterId: clusterId, + State: compute.StateTerminated, + } + }) + + cmd := newStartTestCommand(t, server) + err := cmd.RunE(cmd, []string{clusterId}) + require.Error(t, err) + assert.Contains(t, err.Error(), "Cluster abc cannot be started.") +} + +func TestStartInvalidStatePreservesErrorWhenGetFails(t *testing.T) { + const clusterId = "abc" + + server := testserver.New(t) + server.Handle("POST", "/api/2.1/clusters/start", func(req testserver.Request) any { + return testserver.Response{ + StatusCode: http.StatusBadRequest, + Body: map[string]string{ + "error_code": "INVALID_STATE", + "message": "Cluster abc is in unexpected state.", + }, + } + }) + server.Handle("GET", "/api/2.1/clusters/get", func(req testserver.Request) any { + return testserver.Response{ + StatusCode: http.StatusNotFound, + Body: map[string]string{ + "error_code": "RESOURCE_DOES_NOT_EXIST", + "message": "Cluster abc does not exist.", + }, + } + }) + + cmd := newStartTestCommand(t, server) + err := cmd.RunE(cmd, []string{clusterId}) + require.Error(t, err) + assert.Contains(t, err.Error(), "Cluster abc is in unexpected state.") +} + +func TestStartInvalidStatePreservesErrorForJobCluster(t *testing.T) { + const clusterId = "abc" + + server := testserver.New(t) + server.Handle("POST", "/api/2.1/clusters/start", func(req testserver.Request) any { + return testserver.Response{ + StatusCode: http.StatusBadRequest, + Body: map[string]string{ + "error_code": "INVALID_STATE", + "message": "Clusters launched to run a job cannot be started.", + }, + } + }) + server.Handle("GET", "/api/2.1/clusters/get", func(req testserver.Request) any { + return compute.ClusterDetails{ + ClusterId: clusterId, + ClusterSource: compute.ClusterSourceJob, + State: compute.StateRunning, + } + }) + + cmd := newStartTestCommand(t, server) + err := cmd.RunE(cmd, []string{clusterId}) + require.Error(t, err) + assert.Contains(t, err.Error(), "Clusters launched to run a job cannot be started.") +} + +func TestStartTerminatedClusterKeepsNormalStartBehavior(t *testing.T) { + const clusterId = "abc" + + var startRequests int + var getRequests int + + server := testserver.New(t) + server.Handle("POST", "/api/2.1/clusters/start", func(req testserver.Request) any { + startRequests++ + var request compute.StartCluster + require.NoError(t, json.Unmarshal(req.Body, &request)) + assert.Equal(t, clusterId, request.ClusterId) + return testserver.Response{} + }) + server.Handle("GET", "/api/2.1/clusters/get", func(req testserver.Request) any { + getRequests++ + return compute.ClusterDetails{ + ClusterId: clusterId, + State: compute.StateRunning, + } + }) + + cmd := newStartTestCommand(t, server) + require.NoError(t, cmd.Flags().Set("no-wait", "true")) + err := cmd.RunE(cmd, []string{clusterId}) + require.NoError(t, err) + assert.Equal(t, 1, startRequests) + assert.Equal(t, 0, getRequests) +}