Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions cmd/sandbox/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sandbox

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -41,6 +42,14 @@ const (
// of an unsupported type for this API."
const orgIDHeader = "X-Databricks-Org-Id"

// sandboxAPITimeout bounds every sandbox HTTP request. The long-running
// flows (cold start, etc.) are made of many short calls polled in a
// loop, so this caps the *individual* call — not the overall flow. The
// cap exists so that in regions where the manager isn't deployed the
// CLI surfaces a clear error instead of hanging indefinitely behind the
// gateway's silent drop.
const sandboxAPITimeout = 10 * time.Second

// maxNameBytes mirrors the server-side `Sandbox.name` cap. The server
// measures bytes (not runes), so emoji hit the limit faster than expected;
// mirroring it client-side lets us fail fast with the observed byte count.
Expand All @@ -56,7 +65,8 @@ func validateName(name string) error {

// sandboxAPI wraps the SDK ApiClient with workspace-id-aware request headers.
type sandboxAPI struct {
c *client.DatabricksClient
c *client.DatabricksClient
timeout time.Duration
}

// sandboxCreateBody is the inner `Sandbox` message in the create payload.
Expand Down Expand Up @@ -180,7 +190,7 @@ func newSandboxAPI(w *databricks.WorkspaceClient) (*sandboxAPI, error) {
if err != nil {
return nil, fmt.Errorf("failed to create sandbox API client: %w", err)
}
return &sandboxAPI{c: c}, nil
return &sandboxAPI{c: c, timeout: sandboxAPITimeout}, nil
}

// headers attaches the workspace routing identifier so multi-workspace
Expand All @@ -195,13 +205,27 @@ func (a *sandboxAPI) headers() map[string]string {
return map[string]string{orgIDHeader: wsID}
}

// do issues a single sandbox API request bounded by sandboxAPITimeout and
// translates a context-deadline timeout into a user-facing message that
// hints at the most likely cause: the region not having sandbox enabled.
// Every sandbox API method must call through this wrapper rather than
// a.c.Do directly so the timeout and the translation stay uniform.
func (a *sandboxAPI) do(ctx context.Context, method, path string, request, response any) error {
ctx, cancel := context.WithTimeout(ctx, a.timeout)
defer cancel()
err := a.c.Do(ctx, method, path, a.headers(), nil, request, response)
if errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("sandbox API timed out after %s — this region may not have sandbox enabled, or the manager is unreachable", a.timeout)
}
return err
}

// create calls POST /api/2.0/lakebox/sandboxes. An empty `name` is omitted
// so the server treats it as "unset" rather than "explicit empty string".
func (a *sandboxAPI) create(ctx context.Context, name string) (*createResponse, error) {
body := createRequest{Sandbox: sandboxCreateBody{Name: name}}
var resp createResponse
err := a.c.Do(ctx, http.MethodPost, sandboxAPIPath, a.headers(), nil, body, &resp)
if err != nil {
if err := a.do(ctx, http.MethodPost, sandboxAPIPath, body, &resp); err != nil {
return nil, err
}
return &resp, nil
Expand Down Expand Up @@ -239,8 +263,7 @@ func (a *sandboxAPI) listPage(ctx context.Context, pageToken string) (*listRespo
query["page_token"] = pageToken
}
var resp listResponse
err := a.c.Do(ctx, http.MethodGet, sandboxAPIPath, a.headers(), nil, query, &resp)
if err != nil {
if err := a.do(ctx, http.MethodGet, sandboxAPIPath, query, &resp); err != nil {
return nil, err
}
return &resp, nil
Expand All @@ -249,8 +272,7 @@ func (a *sandboxAPI) listPage(ctx context.Context, pageToken string) (*listRespo
// get calls GET /api/2.0/lakebox/sandboxes/{id}.
func (a *sandboxAPI) get(ctx context.Context, id string) (*sandboxEntry, error) {
var resp sandboxEntry
err := a.c.Do(ctx, http.MethodGet, sandboxPath(id), a.headers(), nil, nil, &resp)
if err != nil {
if err := a.do(ctx, http.MethodGet, sandboxPath(id), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
Expand All @@ -273,25 +295,23 @@ func (a *sandboxAPI) update(ctx context.Context, id string, name *string, idleTi
NoAutostop: noAutostop,
}
var resp sandboxEntry
err := a.c.Do(ctx, http.MethodPatch, sandboxPath(id), a.headers(), nil, body, &resp)
if err != nil {
if err := a.do(ctx, http.MethodPatch, sandboxPath(id), body, &resp); err != nil {
return nil, err
}
return &resp, nil
}

// delete calls DELETE /api/2.0/lakebox/sandboxes/{id}.
func (a *sandboxAPI) delete(ctx context.Context, id string) error {
return a.c.Do(ctx, http.MethodDelete, sandboxPath(id), a.headers(), nil, nil, nil)
return a.do(ctx, http.MethodDelete, sandboxPath(id), nil, nil)
}

// stop calls POST /api/2.0/lakebox/sandboxes/{id}/stop and returns the
// refreshed sandbox.
func (a *sandboxAPI) stop(ctx context.Context, id string) (*sandboxEntry, error) {
body := map[string]string{"sandbox_id": id}
var resp sandboxEntry
err := a.c.Do(ctx, http.MethodPost, sandboxPath(id)+"/stop", a.headers(), nil, body, &resp)
if err != nil {
if err := a.do(ctx, http.MethodPost, sandboxPath(id)+"/stop", body, &resp); err != nil {
return nil, err
}
return &resp, nil
Expand All @@ -302,8 +322,7 @@ func (a *sandboxAPI) stop(ctx context.Context, id string) (*sandboxEntry, error)
func (a *sandboxAPI) start(ctx context.Context, id string) (*sandboxEntry, error) {
body := map[string]string{"sandbox_id": id}
var resp sandboxEntry
err := a.c.Do(ctx, http.MethodPost, sandboxPath(id)+"/start", a.headers(), nil, body, &resp)
if err != nil {
if err := a.do(ctx, http.MethodPost, sandboxPath(id)+"/start", body, &resp); err != nil {
return nil, err
}
return &resp, nil
Expand All @@ -316,8 +335,7 @@ func (a *sandboxAPI) start(ctx context.Context, id string) (*sandboxEntry, error
// `create` call.
func (a *sandboxAPI) registerKey(ctx context.Context, publicKey, name string) (*sshKeyEntry, error) {
var resp sshKeyEntry
err := a.c.Do(ctx, http.MethodPost, sandboxKeysAPIPath, a.headers(), nil, registerKeyRequest{PublicKey: publicKey, Name: name}, &resp)
if err != nil {
if err := a.do(ctx, http.MethodPost, sandboxKeysAPIPath, registerKeyRequest{PublicKey: publicKey, Name: name}, &resp); err != nil {
return nil, err
}
return &resp, nil
Expand Down Expand Up @@ -345,14 +363,13 @@ type listKeysResponse struct {
// listKeys calls GET /api/2.0/lakebox/ssh-keys.
func (a *sandboxAPI) listKeys(ctx context.Context) ([]sshKeyEntry, error) {
var resp listKeysResponse
err := a.c.Do(ctx, http.MethodGet, sandboxKeysAPIPath, a.headers(), nil, nil, &resp)
if err != nil {
if err := a.do(ctx, http.MethodGet, sandboxKeysAPIPath, nil, &resp); err != nil {
return nil, err
}
return resp.SshKeys, nil
}

// deleteKey calls DELETE /api/2.0/lakebox/ssh-keys/{key_hash}.
func (a *sandboxAPI) deleteKey(ctx context.Context, keyHash string) error {
return a.c.Do(ctx, http.MethodDelete, sandboxKeysAPIPath+"/"+url.PathEscape(keyHash), a.headers(), nil, nil, nil)
return a.do(ctx, http.MethodDelete, sandboxKeysAPIPath+"/"+url.PathEscape(keyHash), nil, nil)
}
35 changes: 35 additions & 0 deletions cmd/sandbox/api_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package sandbox

import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/client"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand All @@ -29,3 +34,33 @@ func TestValidateNameCountsBytesNotRunes(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), "260 bytes")
}

// In regions where the manager isn't deployed the gateway silently
// holds the connection open rather than returning a structured error,
// so the do wrapper must surface a timeout in user-language instead of
// letting the call hang.
func TestSandboxAPIDoTranslatesTimeout(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// The SDK probes .well-known/databricks-config before the
// first API call; let that 404 fast so the test exercises
// the sandbox API path and not the SDK's host-metadata
// fallback (which has its own 60s timeout).
if strings.HasPrefix(r.URL.Path, "/.well-known/") {
http.NotFound(w, r)
return
}
<-r.Context().Done()
}))
t.Cleanup(srv.Close)

w, err := databricks.NewWorkspaceClient(&databricks.Config{Host: srv.URL, Token: "test-token"})
require.NoError(t, err)
c, err := client.New(w.Config)
require.NoError(t, err)
api := &sandboxAPI{c: c, timeout: 50 * time.Millisecond}

_, err = api.get(t.Context(), "any-id")
require.Error(t, err)
assert.Contains(t, err.Error(), "sandbox API timed out")
assert.Contains(t, err.Error(), "this region may not have sandbox enabled")
}
8 changes: 8 additions & 0 deletions cmd/sandbox/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ import (

const defaultGatewayPort = "2222"

// sshConnectTimeoutSecs caps the TCP+SSH handshake. Once the channel is
// up the timeout no longer applies, so a long-lived interactive session
// is not affected; only an unreachable gateway (region not enabled,
// firewall, etc.) is bounded, which would otherwise hang behind ssh's
// default ~75s connect timeout.
const sshConnectTimeoutSecs = 10

func newSSHCommand() *cobra.Command {
var gatewayPort string

Expand Down Expand Up @@ -289,6 +296,7 @@ func buildSSHArgs(sandboxID, host, port, keyPath string, extraArgs []string) []s
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "LogLevel=ERROR",
"-o", fmt.Sprintf("ConnectTimeout=%d", sshConnectTimeoutSecs),
fmt.Sprintf("%s@%s", sandboxID, host),
}
if len(extraArgs) == 1 {
Expand Down
4 changes: 4 additions & 0 deletions cmd/sandbox/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ func TestBuildSSHArgsBaseFlags(t *testing.T) {
assert.Contains(t, got, "/keys/id")
assert.Contains(t, got, "-p")
assert.Contains(t, got, "2222")
// ConnectTimeout bounds the handshake — a region without the
// sandbox manager deployed would otherwise hang in ssh's default
// ~75s connect timeout.
assert.Contains(t, got, "ConnectTimeout=10")
}

func TestBuildSSHArgsQuoting(t *testing.T) {
Expand Down
Loading