Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions cmd/workspace/clusters/overrides.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package clusters

import (
"errors"
"net/http"
"strings"

"github.com/databricks/cli/libs/cmdctx"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -93,8 +97,43 @@ func sparkVersionsOverride(sparkVersionsCmd *cobra.Command) {
`)
}

func startOverride(startCmd *cobra.Command, startReq *compute.StartCluster) {
run := startCmd.RunE
startCmd.RunE = func(cmd *cobra.Command, args []string) error {
err := run(cmd, args)
if err == nil || !isInvalidState(err) || startReq.ClusterId == "" {
return err
}

ctx := cmd.Context()
w := cmdctx.WorkspaceClient(ctx)
cluster, getErr := w.Clusters.Get(ctx, compute.GetClusterRequest{
ClusterId: startReq.ClusterId,
})
if getErr != nil {
return err
}
isJobCluster := cluster.ClusterSource == compute.ClusterSourceJob
isNonTerminated := cluster.State != "" && cluster.State != compute.StateTerminated
// The Start API returns INVALID_STATE for already-running clusters,
// while the CLI help documents non-TERMINATED clusters as a no-op.
// The help separately documents that job clusters cannot be started.
if isNonTerminated && !isJobCluster {
return nil
}

return err
}
}

func isInvalidState(err error) bool {
apiErr, ok := errors.AsType[*apierr.APIError](err)
return ok && apiErr.StatusCode == http.StatusBadRequest && apiErr.ErrorCode == "INVALID_STATE"
}

func init() {
listOverrides = append(listOverrides, listOverride)
listNodeTypesOverrides = append(listNodeTypesOverrides, listNodeTypesOverride)
sparkVersionsOverrides = append(sparkVersionsOverrides, sparkVersionsOverride)
startOverrides = append(startOverrides, startOverride)
}
182 changes: 182 additions & 0 deletions cmd/workspace/clusters/overrides_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package clusters

import (
"encoding/json"
"net/http"
"testing"

"github.com/databricks/cli/libs/cmdctx"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/testserver"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func newStartTestCommand(t *testing.T, server *testserver.Server) *cobra.Command {
t.Helper()

w, err := databricks.NewWorkspaceClient(&databricks.Config{
Host: server.URL,
Token: "token",
})
require.NoError(t, err)

cmd := newStart()
ctx := cmdio.MockDiscard(t.Context())
ctx = cmdctx.SetWorkspaceClient(ctx, w)
cmd.SetContext(ctx)
return cmd
}

func TestStartRunningClusterIsNoOp(t *testing.T) {
const clusterId = "abc"

var startRequests int
var getRequests int

server := testserver.New(t)
server.Handle("POST", "/api/2.1/clusters/start", func(req testserver.Request) any {
startRequests++
var request compute.StartCluster
require.NoError(t, json.Unmarshal(req.Body, &request))
assert.Equal(t, clusterId, request.ClusterId)

return testserver.Response{
StatusCode: http.StatusBadRequest,
Body: map[string]string{
"error_code": "INVALID_STATE",
"message": "Cluster abc is in unexpected state Running.",
},
}
})
server.Handle("GET", "/api/2.1/clusters/get", func(req testserver.Request) any {
getRequests++
assert.Equal(t, clusterId, req.URL.Query().Get("cluster_id"))

return compute.ClusterDetails{
ClusterId: clusterId,
State: compute.StateRunning,
}
})

cmd := newStartTestCommand(t, server)
err := cmd.RunE(cmd, []string{clusterId})
require.NoError(t, err)
assert.Equal(t, 1, startRequests)
assert.Equal(t, 1, getRequests)
}

func TestStartInvalidStatePreservesErrorForTerminatedCluster(t *testing.T) {
const clusterId = "abc"

server := testserver.New(t)
server.Handle("POST", "/api/2.1/clusters/start", func(req testserver.Request) any {
return testserver.Response{
StatusCode: http.StatusBadRequest,
Body: map[string]string{
"error_code": "INVALID_STATE",
"message": "Cluster abc cannot be started.",
},
}
})
server.Handle("GET", "/api/2.1/clusters/get", func(req testserver.Request) any {
return compute.ClusterDetails{
ClusterId: clusterId,
State: compute.StateTerminated,
}
})

cmd := newStartTestCommand(t, server)
err := cmd.RunE(cmd, []string{clusterId})
require.Error(t, err)
assert.Contains(t, err.Error(), "Cluster abc cannot be started.")
}

func TestStartInvalidStatePreservesErrorWhenGetFails(t *testing.T) {
const clusterId = "abc"

server := testserver.New(t)
server.Handle("POST", "/api/2.1/clusters/start", func(req testserver.Request) any {
return testserver.Response{
StatusCode: http.StatusBadRequest,
Body: map[string]string{
"error_code": "INVALID_STATE",
"message": "Cluster abc is in unexpected state.",
},
}
})
server.Handle("GET", "/api/2.1/clusters/get", func(req testserver.Request) any {
return testserver.Response{
StatusCode: http.StatusNotFound,
Body: map[string]string{
"error_code": "RESOURCE_DOES_NOT_EXIST",
"message": "Cluster abc does not exist.",
},
}
})

cmd := newStartTestCommand(t, server)
err := cmd.RunE(cmd, []string{clusterId})
require.Error(t, err)
assert.Contains(t, err.Error(), "Cluster abc is in unexpected state.")
}

func TestStartInvalidStatePreservesErrorForJobCluster(t *testing.T) {
const clusterId = "abc"

server := testserver.New(t)
server.Handle("POST", "/api/2.1/clusters/start", func(req testserver.Request) any {
return testserver.Response{
StatusCode: http.StatusBadRequest,
Body: map[string]string{
"error_code": "INVALID_STATE",
"message": "Clusters launched to run a job cannot be started.",
},
}
})
server.Handle("GET", "/api/2.1/clusters/get", func(req testserver.Request) any {
return compute.ClusterDetails{
ClusterId: clusterId,
ClusterSource: compute.ClusterSourceJob,
State: compute.StateRunning,
}
})

cmd := newStartTestCommand(t, server)
err := cmd.RunE(cmd, []string{clusterId})
require.Error(t, err)
assert.Contains(t, err.Error(), "Clusters launched to run a job cannot be started.")
}

func TestStartTerminatedClusterKeepsNormalStartBehavior(t *testing.T) {
const clusterId = "abc"

var startRequests int
var getRequests int

server := testserver.New(t)
server.Handle("POST", "/api/2.1/clusters/start", func(req testserver.Request) any {
startRequests++
var request compute.StartCluster
require.NoError(t, json.Unmarshal(req.Body, &request))
assert.Equal(t, clusterId, request.ClusterId)
return testserver.Response{}
})
server.Handle("GET", "/api/2.1/clusters/get", func(req testserver.Request) any {
getRequests++
return compute.ClusterDetails{
ClusterId: clusterId,
State: compute.StateRunning,
}
})

cmd := newStartTestCommand(t, server)
require.NoError(t, cmd.Flags().Set("no-wait", "true"))
err := cmd.RunE(cmd, []string{clusterId})
require.NoError(t, err)
assert.Equal(t, 1, startRequests)
assert.Equal(t, 0, getRequests)
}