diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 00c1e05d0d..3748030387 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -578,7 +578,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", waiter.RunId)) // Return the run ID even on error so callers can fetch the run's failure details. - return waiter.RunId, waitForJobToStart(ctx, client, waiter.RunId, opts.TaskStartupTimeout) + return waiter.RunId, waitForJobToStart(ctx, client, waiter.RunId, opts) } func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, clusterID string, opts ClientOptions) error { @@ -642,7 +642,7 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient, sp := cmdio.NewSpinner(ctx, cmdio.WithElapsedTime()) defer sp.Close() if autoStart { - sp.Update("Ensuring the cluster is running...") + sp.Update("Waiting for compute to start...") err := client.Clusters.EnsureClusterIsRunning(ctx, clusterID) if err != nil { return fmt.Errorf("failed to ensure that the cluster is running: %w", err) @@ -662,13 +662,21 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient, // waitForJobToStart polls the task status until the SSH server task is in RUNNING state or terminates. // Returns an error if the task fails to start or if polling times out. -func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, taskStartupTimeout time.Duration) error { +func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, opts ClientOptions) error { + waitingMessage := "Waiting for compute to start..." + if opts.Accelerator != "" { + // GPU capacity is acquired on demand and routinely takes 10+ minutes; without + // this notice users assume a long PENDING wait means the service is down. + cmdio.LogString(ctx, fmt.Sprintf("Waiting for %s compute to be provisioned. This can take upwards of 10 minutes depending on capacity...", opts.Accelerator)) + waitingMessage = fmt.Sprintf("Waiting for %s compute to be provisioned...", opts.Accelerator) + } + sp := cmdio.NewSpinner(ctx, cmdio.WithElapsedTime()) defer sp.Close() - sp.Update("Starting SSH server...") + sp.Update(waitingMessage) var prevState jobs.RunLifecycleStateV2State - _, err := retries.Poll(ctx, taskStartupTimeout, func() (*jobs.RunTask, *retries.Err) { + _, err := retries.Poll(ctx, opts.TaskStartupTimeout, func() (*jobs.RunTask, *retries.Err) { run, err := client.Jobs.GetRun(ctx, jobs.GetRunRequest{ RunId: runID, }) @@ -697,7 +705,7 @@ func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, // Update spinner if state changed if currentState != prevState { - sp.Update(fmt.Sprintf("Starting SSH server... (task: %s)", currentState)) + sp.Update(fmt.Sprintf("%s (task: %s)", waitingMessage, currentState)) prevState = currentState } diff --git a/experimental/ssh/internal/client/client_internal_test.go b/experimental/ssh/internal/client/client_internal_test.go index 35740d8cee..a7347b08c7 100644 --- a/experimental/ssh/internal/client/client_internal_test.go +++ b/experimental/ssh/internal/client/client_internal_test.go @@ -110,7 +110,7 @@ func TestWaitForJobToStartSurfacesFailure(t *testing.T) { api.EXPECT().GetRunOutput(mock.Anything, jobs.GetRunOutputRequest{RunId: 99}).Return( &jobs.RunOutput{}, nil) - err := waitForJobToStart(ctx, m.WorkspaceClient, 1, 30*time.Second) + err := waitForJobToStart(ctx, m.WorkspaceClient, 1, ClientOptions{TaskStartupTimeout: 30 * time.Second}) require.Error(t, err) assert.Contains(t, err.Error(), "ssh server bootstrap job failed") assert.Contains(t, err.Error(), "Could not reach driver of cluster 0605-x.")