Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions runner/internal/common/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package types
type TerminationReason string

const (
TerminationReasonExecutorError TerminationReason = "executor_error"
TerminationReasonCreatingContainerError TerminationReason = "creating_container_error"
TerminationReasonContainerExitedWithError TerminationReason = "container_exited_with_error"
TerminationReasonDoneByRunner TerminationReason = "done_by_runner"
TerminationReasonTerminatedByUser TerminationReason = "terminated_by_user"
TerminationReasonTerminatedByServer TerminationReason = "terminated_by_server"
TerminationReasonMaxDurationExceeded TerminationReason = "max_duration_exceeded"
TerminationReasonLogQuotaExceeded TerminationReason = "log_quota_exceeded"
TerminationReasonExecutorError TerminationReason = "executor_error"
TerminationReasonCreatingContainerError TerminationReason = "creating_container_error"
TerminationReasonContainerExitedWithError TerminationReason = "container_exited_with_error"
TerminationReasonDoneByRunner TerminationReason = "done_by_runner"
TerminationReasonTerminatedByUser TerminationReason = "terminated_by_user"
TerminationReasonTerminatedByServer TerminationReason = "terminated_by_server"
TerminationReasonMaxDurationExceeded TerminationReason = "max_duration_exceeded"
TerminationReasonLogQuotaExceeded TerminationReason = "log_quota_exceeded"
TerminationReasonDataTransferQuotaExceeded TerminationReason = "data_transfer_quota_exceeded"
)
18 changes: 18 additions & 0 deletions runner/internal/runner/api/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,24 @@ func (s *Server) stopPostHandler(w http.ResponseWriter, r *http.Request) (interf
return nil, nil
}

func (s *Server) terminatePostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
var body schemas.TerminateBody
if err := api.DecodeJSONBody(w, r, &body, true); err != nil {
return nil, err
}
ctx := r.Context()
log.Error(ctx, "Terminate requested", "reason", body.Reason, "message", body.Message)
// No executor.Lock() needed — SetJobStateWithTerminationReason acquires its own lock.
// Using the external lock would deadlock with io.Copy holding it during job execution.
s.executor.SetJobStateWithTerminationReason(
ctx,
schemas.JobStateFailed,
body.Reason,
body.Message,
)
return nil, nil
}

func isMaxBytesError(err error) bool {
var maxBytesError *http.MaxBytesError
return errors.As(err, &maxBytesError)
Expand Down
1 change: 1 addition & 0 deletions runner/internal/runner/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func NewServer(ctx context.Context, address string, version string, ex executor.
r.AddHandler("POST", "/api/run", s.runPostHandler)
r.AddHandler("GET", "/api/pull", s.pullGetHandler)
r.AddHandler("POST", "/api/stop", s.stopPostHandler)
r.AddHandler("POST", "/api/terminate", s.terminatePostHandler)
r.AddHandler("GET", "/logs_ws", s.logsWsGetHandler)
return s, nil
}
Expand Down
5 changes: 5 additions & 0 deletions runner/internal/runner/schemas/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ type SubmitBody struct {
LogQuotaHour int `json:"log_quota_hour"` // bytes per hour, 0 = unlimited
}

type TerminateBody struct {
Reason types.TerminationReason `json:"reason"`
Message string `json:"message"`
}

type PullResponse struct {
JobStates []JobStateEvent `json:"job_states"`
JobLogs []LogEvent `json:"job_logs"`
Expand Down
77 changes: 76 additions & 1 deletion runner/internal/shim/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"os/user"
Expand Down Expand Up @@ -37,6 +38,7 @@ import (
"github.com/dstackai/dstack/runner/internal/common/types"
"github.com/dstackai/dstack/runner/internal/shim/backends"
"github.com/dstackai/dstack/runner/internal/shim/host"
"github.com/dstackai/dstack/runner/internal/shim/netmeter"
)

// TODO: Allow for configuration via cli arguments or environment variables.
Expand Down Expand Up @@ -380,7 +382,8 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
if err := d.tasks.Update(task); err != nil {
return fmt.Errorf("%w: failed to update task %s: %w", ErrInternal, task.ID, err)
}
err = d.waitContainer(ctx, &task)

err = d.waitContainerWithQuota(ctx, &task, cfg)
}
if err != nil {
log.Error(ctx, "failed to run container", "err", err)
Expand Down Expand Up @@ -910,6 +913,49 @@ func (d *DockerRunner) waitContainer(ctx context.Context, task *Task) error {
return nil
}

// waitContainerWithQuota waits for the container to finish, optionally enforcing
// a data transfer quota. If the quota is exceeded, it notifies the runner
// (so the server reads the termination reason via /api/pull) and stops the container.
func (d *DockerRunner) waitContainerWithQuota(ctx context.Context, task *Task, cfg TaskConfig) error {
if cfg.DataTransferQuota <= 0 {
return d.waitContainer(ctx, task)
}

nm := netmeter.New(task.ID, cfg.DataTransferQuota)
if err := nm.Start(ctx); err != nil {
errMessage := fmt.Sprintf("data transfer quota configured but metering unavailable: %s", err)
log.Error(ctx, errMessage)
task.SetStatusTerminated(string(types.TerminationReasonExecutorError), errMessage)
return fmt.Errorf("data transfer meter: %w", err)
}
defer nm.Stop()

waitDone := make(chan error, 1)
go func() { waitDone <- d.waitContainer(ctx, task) }()

select {
case err := <-waitDone:
return err
case <-nm.Exceeded():
log.Error(ctx, "Data transfer quota exceeded", "task", task.ID, "quota", cfg.DataTransferQuota)
terminateMsg := fmt.Sprintf("Outbound data transfer exceeded quota of %d bytes", cfg.DataTransferQuota)
if err := terminateRunner(ctx, d.dockerParams.RunnerHTTPPort(),
types.TerminationReasonDataTransferQuotaExceeded, terminateMsg); err != nil {
log.Error(ctx, "failed to notify runner of termination", "err", err)
}
stopTimeout := 10
stopOpts := container.StopOptions{Timeout: &stopTimeout}
if err := d.client.ContainerStop(ctx, task.containerID, stopOpts); err != nil {
log.Error(ctx, "failed to stop container after quota exceeded", "err", err)
}
<-waitDone
// The runner already set the job state with the termination reason.
// The server will read it via /api/pull.
task.SetStatusTerminated(string(types.TerminationReasonDoneByRunner), "")
return nil
}
}

func encodeRegistryAuth(username string, password string) (string, error) {
if username == "" && password == "" {
return "", nil
Expand Down Expand Up @@ -1180,6 +1226,31 @@ func getContainerLastLogs(ctx context.Context, client docker.APIClient, containe
return lines, nil
}

// terminateRunner calls the runner's /api/terminate endpoint to set the job termination state.
// This allows the server to read the termination reason via /api/pull before the container dies.
func terminateRunner(ctx context.Context, runnerPort int, reason types.TerminationReason, message string) error {
url := fmt.Sprintf("http://localhost:%d/api/terminate", runnerPort)
body := fmt.Sprintf(`{"reason":%q,"message":%q}`, reason, message)
// 5s is generous for a localhost HTTP call; if the runner doesn't respond in time,
// we proceed with stopping the container anyway (the server will handle the termination).
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(body))
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status: %d", resp.StatusCode)
}
return nil
}

/* DockerParameters interface implementation for CLIArgs */

func (c *CLIArgs) DockerPrivileged() bool {
Expand Down Expand Up @@ -1228,6 +1299,10 @@ func (c *CLIArgs) DockerPorts() []int {
return []int{c.Runner.HTTPPort, c.Runner.SSHPort}
}

func (c *CLIArgs) RunnerHTTPPort() int {
return c.Runner.HTTPPort
}

func (c *CLIArgs) MakeRunnerDir(name string) (string, error) {
runnerTemp := filepath.Join(c.Shim.HomeDir, "runners", name)
if err := os.MkdirAll(runnerTemp, 0o755); err != nil {
Expand Down
4 changes: 4 additions & 0 deletions runner/internal/shim/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ func (c *dockerParametersMock) DockerPorts() []int {
return []int{}
}

func (c *dockerParametersMock) RunnerHTTPPort() int {
return 10999
}

func (c *dockerParametersMock) DockerMounts(string) ([]mount.Mount, error) {
return nil, nil
}
Expand Down
10 changes: 6 additions & 4 deletions runner/internal/shim/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ type DockerParameters interface {
DockerShellCommands(authorizedKeys []string, runnerHttpAddress string) []string
DockerMounts(string) ([]mount.Mount, error)
DockerPorts() []int
RunnerHTTPPort() int
MakeRunnerDir(name string) (string, error)
DockerPJRTDevice() string
}
Expand Down Expand Up @@ -97,10 +98,11 @@ type TaskConfig struct {
InstanceMounts []InstanceMountPoint `json:"instance_mounts"`
// GPUDevices allows the server to set gpu devices instead of relying on the runner default logic.
// E.g. passing nvidia devices directly instead of using nvidia-container-toolkit.
GPUDevices []GPUDevice `json:"gpu_devices"`
HostSshUser string `json:"host_ssh_user"`
HostSshKeys []string `json:"host_ssh_keys"`
ContainerSshKeys []string `json:"container_ssh_keys"`
GPUDevices []GPUDevice `json:"gpu_devices"`
HostSshUser string `json:"host_ssh_user"`
HostSshKeys []string `json:"host_ssh_keys"`
ContainerSshKeys []string `json:"container_ssh_keys"`
DataTransferQuota int64 `json:"data_transfer_quota"` // total bytes for job lifetime; 0 = unlimited
}

type TaskListItem struct {
Expand Down
Loading
Loading