Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions experimental/ssh/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", waiter.RunId))

// Return the run ID even on error so callers can fetch the run's failure details.
return waiter.RunId, waitForJobToStart(ctx, client, waiter.RunId, opts.TaskStartupTimeout)
return waiter.RunId, waitForJobToStart(ctx, client, waiter.RunId, opts)
}

func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, clusterID string, opts ClientOptions) error {
Expand Down Expand Up @@ -642,7 +642,7 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient,
sp := cmdio.NewSpinner(ctx, cmdio.WithElapsedTime())
defer sp.Close()
if autoStart {
sp.Update("Ensuring the cluster is running...")
sp.Update("Waiting for compute to start...")
err := client.Clusters.EnsureClusterIsRunning(ctx, clusterID)
if err != nil {
return fmt.Errorf("failed to ensure that the cluster is running: %w", err)
Expand All @@ -662,13 +662,21 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient,

// waitForJobToStart polls the task status until the SSH server task is in RUNNING state or terminates.
// Returns an error if the task fails to start or if polling times out.
func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, taskStartupTimeout time.Duration) error {
func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, opts ClientOptions) error {
waitingMessage := "Waiting for compute to start..."
if opts.Accelerator != "" {
// GPU capacity is acquired on demand and routinely takes 10+ minutes; without
// this notice users assume a long PENDING wait means the service is down.
cmdio.LogString(ctx, fmt.Sprintf("Waiting for %s compute to be provisioned. This can take upwards of 10 minutes depending on capacity...", opts.Accelerator))
waitingMessage = fmt.Sprintf("Waiting for %s compute to be provisioned...", opts.Accelerator)
}

sp := cmdio.NewSpinner(ctx, cmdio.WithElapsedTime())
defer sp.Close()
sp.Update("Starting SSH server...")
sp.Update(waitingMessage)
var prevState jobs.RunLifecycleStateV2State

_, err := retries.Poll(ctx, taskStartupTimeout, func() (*jobs.RunTask, *retries.Err) {
_, err := retries.Poll(ctx, opts.TaskStartupTimeout, func() (*jobs.RunTask, *retries.Err) {
run, err := client.Jobs.GetRun(ctx, jobs.GetRunRequest{
RunId: runID,
})
Expand Down Expand Up @@ -697,7 +705,7 @@ func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient,

// Update spinner if state changed
if currentState != prevState {
sp.Update(fmt.Sprintf("Starting SSH server... (task: %s)", currentState))
sp.Update(fmt.Sprintf("%s (task: %s)", waitingMessage, currentState))
prevState = currentState
}

Expand Down
2 changes: 1 addition & 1 deletion experimental/ssh/internal/client/client_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func TestWaitForJobToStartSurfacesFailure(t *testing.T) {
api.EXPECT().GetRunOutput(mock.Anything, jobs.GetRunOutputRequest{RunId: 99}).Return(
&jobs.RunOutput{}, nil)

err := waitForJobToStart(ctx, m.WorkspaceClient, 1, 30*time.Second)
err := waitForJobToStart(ctx, m.WorkspaceClient, 1, ClientOptions{TaskStartupTimeout: 30 * time.Second})
require.Error(t, err)
assert.Contains(t, err.Error(), "ssh server bootstrap job failed")
assert.Contains(t, err.Error(), "Could not reach driver of cluster 0605-x.")
Expand Down
Loading