diff --git a/task/task.go b/task/task.go index b56c469..8bfdc54 100644 --- a/task/task.go +++ b/task/task.go @@ -1,16 +1,21 @@ package task import ( - "os" + "errors" + "os/exec" "github.com/google/uuid" ) +var ( + ErrNotExists = errors.New("task does not exist") +) + type Task struct { ID uuid.UUID Name string Executable string Args []string - Process *os.Process + Cmd *exec.Cmd } diff --git a/tests/helper/task.go b/tests/helper/task.go index e542044..2f1420c 100644 --- a/tests/helper/task.go +++ b/tests/helper/task.go @@ -2,11 +2,16 @@ package helper import ( "bytes" + "context" "dirigeant/task" "encoding/json" + "fmt" "io" + "net/http" + "net/http/httptest" "runtime" + "github.com/go-chi/chi/v5" "github.com/google/uuid" ) @@ -52,3 +57,29 @@ func JsonEncodeTask(t task.Task) io.Reader { json.NewEncoder(w).Encode(t) return w } + +func NewTaskGetRequest(id uuid.UUID) *http.Request { + if id == uuid.Nil { + return httptest.NewRequest("GET", "/tasks", nil) + } + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("id", id.String()) + + r := httptest.NewRequest("GET", fmt.Sprintf("/tasks/%s", id), nil) + + return r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) +} + +func NewTaskPostRequest(t task.Task) *http.Request { + return httptest.NewRequest("POST", "/tasks", JsonEncodeTask(t)) +} + +func NewTaskDeleteRequest(id uuid.UUID) *http.Request { + rctx := chi.NewRouteContext() + rctx.URLParams.Add("id", id.String()) + + r := httptest.NewRequest("DELETE", fmt.Sprintf("/tasks/%s", id), nil) + + return r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) +} diff --git a/tests/worker/list_tasks_test.go b/tests/worker/list_tasks_test.go index 270329d..89c43c3 100644 --- a/tests/worker/list_tasks_test.go +++ b/tests/worker/list_tasks_test.go @@ -2,26 +2,28 @@ package worker import ( "dirigeant/task" + "dirigeant/tests/helper" "dirigeant/worker" "encoding/json" "net/http" "net/http/httptest" "testing" + "github.com/google/uuid" "github.com/stretchr/testify/assert" ) func TestListTasks__ShouldReturnAnEmptySlice(t *testing.T) { - request := httptest.NewRequest("GET", "/tasks", nil) - responseRecorder := httptest.NewRecorder() - api := &worker.Api{ Worker: &worker.Worker{}, } + request := helper.NewTaskGetRequest(uuid.Nil) + responseRecorder := httptest.NewRecorder() + api.HandleListTasks(responseRecorder, request) tasks := []task.Task{} json.NewDecoder(responseRecorder.Body).Decode(&tasks) assert.Equal(t, http.StatusOK, responseRecorder.Code, "Response status code should be 200 OK") - assert.Equal(t, []task.Task{}, tasks, "Response body should be an empty slice") + assert.Empty(t, tasks, "Response body should be an empty slice") } diff --git a/tests/worker/stop_task_test.go b/tests/worker/stop_task_test.go new file mode 100644 index 0000000..d95afe6 --- /dev/null +++ b/tests/worker/stop_task_test.go @@ -0,0 +1,55 @@ +package worker + +import ( + "dirigeant/task" + "dirigeant/tests/helper" + "dirigeant/worker" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestStopTask__ShouldReturnAnErrorIfNotFound(t *testing.T) { + api := &worker.Api{ + Worker: &worker.Worker{}, + } + request := helper.NewTaskDeleteRequest(uuid.New()) + responseRecorder := httptest.NewRecorder() + + api.HandleDeleteTask(responseRecorder, request) + + assert.Equal(t, http.StatusNotFound, responseRecorder.Code, "Response status code should be 404 Not Found") + assert.Equal(t, fmt.Sprintf("Error when stopping the task: %v", task.ErrNotExists), responseRecorder.Body.String(), "Response body should contain error message") +} + +func TestStopTask__ShouldStopCompletedTask(t *testing.T) { + api := &worker.Api{ + Worker: &worker.Worker{ + Tasks: make(map[uuid.UUID]*task.Task), + }, + } + testTask := helper.PrintFileTask("print-task", helper.HostsFilePath) + + // 1 - Create a task + request := helper.NewTaskPostRequest(testTask) + responseRecorder := httptest.NewRecorder() + api.HandleCreateTask(responseRecorder, request) + + assert.Equal(t, http.StatusCreated, responseRecorder.Code, "Response status code should be 201 Created") + assert.Empty(t, responseRecorder.Body, "Response body should be empty") + assert.Equal(t, 1, len(api.Worker.Tasks), "Tasks map should contain 1 task") + + // 2 - Delete a task + request = helper.NewTaskDeleteRequest(testTask.ID) + responseRecorder = httptest.NewRecorder() + + api.HandleDeleteTask(responseRecorder, request) + + assert.Equal(t, http.StatusNoContent, responseRecorder.Code, "Response status code should be 204 No Content") + assert.Empty(t, responseRecorder.Body, "Response body should be empty") + assert.Empty(t, api.Worker.Tasks, "Tasks map should be empty") +} diff --git a/tests/worker/task_logs_test.go b/tests/worker/task_logs_test.go index 30c7081..37f6e73 100644 --- a/tests/worker/task_logs_test.go +++ b/tests/worker/task_logs_test.go @@ -39,15 +39,15 @@ func TestTaskLogs__PrintFile(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - request := httptest.NewRequest("POST", "/tasks", helper.JsonEncodeTask(helper.PrintFileTask(tc.name, tc.path))) + api := &worker.Api{ + Worker: &worker.Worker{ + Tasks: make(map[uuid.UUID]*task.Task), + }, + } + request := helper.NewTaskPostRequest(helper.PrintFileTask(tc.name, tc.path)) responseRecorder := httptest.NewRecorder() stdout := helper.CaptureStdout(func() { - api := &worker.Api{ - Worker: &worker.Worker{ - Tasks: make(map[uuid.UUID]*task.Task), - }, - } api.HandleCreateTask(responseRecorder, request) }) diff --git a/worker/api.go b/worker/api.go index 838349f..cce293b 100644 --- a/worker/api.go +++ b/worker/api.go @@ -3,6 +3,7 @@ package worker import ( "dirigeant/task" "encoding/json" + "errors" "fmt" "net/http" "slices" @@ -86,7 +87,12 @@ func (a *Api) HandleDeleteTask(w http.ResponseWriter, r *http.Request) { } if err := a.Worker.StopTask(taskId); err != nil { - w.WriteHeader(http.StatusInternalServerError) + if errors.Is(err, task.ErrNotExists) { + w.WriteHeader(http.StatusNotFound) + } else { + w.WriteHeader(http.StatusInternalServerError) + } + fmt.Fprintf(w, "Error when stopping the task: %v", err) return } diff --git a/worker/worker.go b/worker/worker.go index 8b55832..b37139a 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -2,7 +2,6 @@ package worker import ( "dirigeant/task" - "fmt" "iter" "maps" "os" @@ -24,10 +23,9 @@ func (w *Worker) GetTask(id uuid.UUID) *task.Task { } func (w *Worker) StartTask(t task.Task) error { - cmd := exec.Command(t.Executable, t.Args...) - t.Process = cmd.Process + t.Cmd = exec.Command(t.Executable, t.Args...) - stdout, err := cmd.CombinedOutput() + stdout, err := t.Cmd.CombinedOutput() os.Stdout.Write(stdout) if err != nil { return err @@ -41,8 +39,16 @@ func (w *Worker) StartTask(t task.Task) error { func (w *Worker) StopTask(id uuid.UUID) error { t := w.GetTask(id) if t == nil { - return fmt.Errorf("%s not found", id) + return task.ErrNotExists } - return t.Process.Kill() + if t.Cmd.ProcessState == nil { + if err := t.Cmd.Process.Kill(); err != nil { + return err + } + } + + delete(w.Tasks, t.ID) + + return nil }