diff --git a/server/main.go b/server/main.go index b13b7bb..06d6d4e 100644 --- a/server/main.go +++ b/server/main.go @@ -95,10 +95,14 @@ func main() { sort.Sort(tasks) hashToTask := make(map[string]pkg.Task) + operationAliasToID := make(map[string]string) var buf bytes.Buffer for _, task := range tasks { buf.Write([]byte(task.IDWithHostPort())) - hashToTask[pkg.Hash([]byte(task.ID()))] = task + hashToTask[task.Hash()] = task + if task.OperationAlias() != "" { + operationAliasToID[task.OperationAlias()] = task.OperationID() + } } newVersion := pkg.Hash(buf.Bytes()) @@ -108,7 +112,7 @@ func main() { logger.Infof("%d tasks discovered:\n%s", len(tasks), tasks) version = newVersion - err = taskUpdater.Update(ctx, hashToTask, version) + err = taskUpdater.Update(ctx, hashToTask, operationAliasToID, version) if err != nil { logger.Errorf("failed to update tasks: %v", err) version = "" // drop version so we will retry update on next iteration diff --git a/server/pkg/auth.go b/server/pkg/auth.go index b3411a2..c9410df 100644 --- a/server/pkg/auth.go +++ b/server/pkg/auth.go @@ -2,6 +2,7 @@ package pkg import ( "context" + "fmt" "net/http" "strings" "sync" @@ -18,12 +19,13 @@ import ( type authServer struct { authv3.UnimplementedAuthorizationServer - mx sync.RWMutex - hashToTasks map[string]Task - yt ytsdk.Client - ytProxy string - logger *SimpleLogger - authCookieName string + mx sync.RWMutex + hashToTasks map[string]Task + operationAliasToID map[string]string + yt ytsdk.Client + ytProxy string + logger *SimpleLogger + authCookieName string } func CreateAuthServer(yt ytsdk.Client, ytProxy string, logger *SimpleLogger, authCookieName string) *authServer { @@ -41,22 +43,11 @@ func (s *authServer) Check(ctx context.Context, req *authv3.CheckRequest) (*auth httpAttrs := req.GetAttributes().GetRequest().GetHttp() path := httpAttrs.GetPath() headers := httpAttrs.GetHeaders() + host := httpAttrs.GetHost() - var hash string - if routerHeaderValue, ok := httpAttrs.Headers[routerHeaderName]; ok { - hash = routerHeaderValue - } else if host := httpAttrs.Host; host != "" { - hash = strings.Split(host, ".")[0] - } else { - s.logger.Warnf("authority (host) or %s headers are missing in request", routerHeaderName) - return deniedResponse, nil - } - - s.logger.Debugf("checking auth for hash %q, path %q", hash, path) - - task, ok := s.getHashToTasks()[hash] - if !ok { - s.logger.Warnf("no entry for hash %q in tasks registry", hash) + task, err := s.findTaskByRequest(host, headers) + if err != nil { + s.logger.Warnf("failed to find task during auth check: %s", err) return deniedResponse, nil } @@ -66,7 +57,7 @@ func (s *authServer) Check(ctx context.Context, req *authv3.CheckRequest) (*auth return okResponse, nil } - s.logger.Debugf("auth for hash %q, path %q, task %v", hash, path, task) + s.logger.Debugf("auth for path %q, task %v", path, task) allowed, err := s.checkOperationPermission(ctx, task.operationID, headers) if err != nil { @@ -80,21 +71,43 @@ func (s *authServer) Check(ctx context.Context, req *authv3.CheckRequest) (*auth return okResponse, nil } -func (s *authServer) SetHashToTasks(hashToTasks map[string]Task) { +func (s *authServer) SetTasksData(hashToTasks map[string]Task, operationAliasToID map[string]string) { s.mx.Lock() defer s.mx.Unlock() s.hashToTasks = hashToTasks + s.operationAliasToID = operationAliasToID } -func (s *authServer) getHashToTasks() map[string]Task { +func (s *authServer) findTaskByRequest(host string, headers map[string]string) (*Task, error) { s.mx.RLock() defer s.mx.RUnlock() - return s.hashToTasks + var hash string + if routerHeaderValue, ok := headers[routerHeaderName]; ok { + hash = routerHeaderValue + } else if host != "" { + subdomain := strings.Split(host, ".")[0] + if operationAlias, taskName, service, ok := tryParseAliasSubdomain(subdomain); ok { + operationID, ok := s.operationAliasToID[operationAlias] + if !ok { + return nil, fmt.Errorf("operation by alias %q from subdomain was not found", operationAlias) + } + hash = (&Task{operationID: operationID, taskName: taskName, service: service}).Hash() + } else { + hash = subdomain + } + } else { + return nil, fmt.Errorf("authority (host) or %s headers are missing in request", routerHeaderName) + } + + if task, ok := s.hashToTasks[hash]; !ok { + return nil, fmt.Errorf("no entry for hash %q in tasks registry", hash) + } else { + return &task, nil + } } -// TODO: temporary implementation, use YT Go SDK instead func (s *authServer) checkOperationPermission(ctx context.Context, operationID string, headers map[string]string) (bool, error) { userCredentials := s.getYTCredentialsFromHeaders(headers) if userCredentials == nil { diff --git a/server/pkg/auth_test.go b/server/pkg/auth_test.go new file mode 100644 index 0000000..1cf149b --- /dev/null +++ b/server/pkg/auth_test.go @@ -0,0 +1,164 @@ +package pkg + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFindTaskByRequest(t *testing.T) { + // Setup test data + task1 := Task{ + operationID: "op-123", + taskName: "worker", + service: "api", + } + task1Hash := task1.Hash() + + task2 := Task{ + operationID: "op-456", + operationAlias: "myalias", + taskName: "master", + service: "ui", + } + task2Hash := task2.Hash() + + task3 := Task{ + operationID: "op-789", + taskName: "executor", + service: "grpc", + } + task3Hash := task3.Hash() + + hashToTasks := map[string]Task{ + task1Hash: task1, + task2Hash: task2, + task3Hash: task3, + } + + operationAliasToID := map[string]string{ + "myalias": "op-456", + "anotheralias": "op-999", + } + + server := CreateAuthServer(nil, "", &SimpleLogger{}, "") + server.SetTasksData(hashToTasks, operationAliasToID) + + tests := []struct { + name string + host string + headers map[string]string + expectedID string + errorMsg string + }{ + // Source 1: Direct hash from x-yt-taskproxy-id header + { + name: "hash from header - valid task", + host: "ignored.example.com", + headers: map[string]string{ + "x-yt-taskproxy-id": task1Hash, + }, + expectedID: task1.operationID, + }, + { + name: "hash from header - invalid hash", + host: "ignored.example.com", + headers: map[string]string{ + "x-yt-taskproxy-id": "nonexistent", + }, + errorMsg: "no entry for hash \"nonexistent\" in tasks registry", + }, + { + name: "hash from header - empty hash", + host: "ignored.example.com", + headers: map[string]string{ + "x-yt-taskproxy-id": "", + }, + errorMsg: "no entry for hash \"\" in tasks registry", + }, + { + name: "hash from header - header takes precedence over host", + host: task3Hash + ".example.com", + headers: map[string]string{ + "x-yt-taskproxy-id": task1Hash, + }, + expectedID: task1.operationID, + }, + + // Source 2: Alias-based subdomain (format: alias-taskname-service) + { + name: "alias subdomain - valid alias", + host: "myalias-master-ui.example.com", + headers: map[string]string{ + "other-header": "value", + }, + expectedID: task2.operationID, + }, + { + name: "alias subdomain - unknown alias", + host: "unknownalias-master-ui.example.com", + errorMsg: "operation by alias \"unknownalias\" from subdomain was not found", + }, + { + name: "alias subdomain - valid alias but task not found", + host: "anotheralias-worker-api.example.com", + errorMsg: "no entry for hash", + }, + { + name: "alias subdomain - with port", + host: "myalias-master-ui.example.com:8080", + expectedID: task2.operationID, + }, + + // Source 3: Direct hash from subdomain (fallback) + { + name: "direct hash subdomain - valid hash", + host: task1Hash + ".example.com", + expectedID: task1.operationID, + }, + { + name: "direct hash subdomain - invalid hash", + host: "badhash.example.com", + errorMsg: "no entry for hash \"badhash\" in tasks registry", + }, + { + name: "direct hash subdomain - single part (no dots)", + host: task3Hash, + expectedID: task3.operationID, + }, + + // Misc errors + { + name: "invalid alias domain format", + host: "part1-part2.example.com", + errorMsg: "no entry for hash \"part1-part2\" in tasks registry", + }, + { + name: "empty host with other headers", + headers: map[string]string{ + "authorization": "Bearer token", + }, + errorMsg: "authority (host) or x-yt-taskproxy-id headers are missing in request", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + headers := tt.headers + if headers == nil { + headers = map[string]string{} + } + task, err := server.findTaskByRequest(tt.host, headers) + + if tt.errorMsg != "" { + require.ErrorContains(t, err, tt.errorMsg) + assert.Nil(t, task) + } else { + require.NoError(t, err) + require.NotNil(t, task) + assert.Equal(t, tt.expectedID, task.operationID) + } + }) + } +} diff --git a/server/pkg/discovery.go b/server/pkg/discovery.go index 81a1762..5673528 100644 --- a/server/pkg/discovery.go +++ b/server/pkg/discovery.go @@ -307,7 +307,7 @@ func (d *taskDiscovery) listOperations(ctx context.Context) ([]ytsdk.OperationSt for { d.logger.Debugf( - "loading running operations chunk, limit %d, cursor %s, already loaded %d operations", + "loading running operations chunk, limit %d, cursor %v, already loaded %d operations", limit, cursor, len(operations), diff --git a/server/pkg/task.go b/server/pkg/task.go index 7e36381..d030bdf 100644 --- a/server/pkg/task.go +++ b/server/pkg/task.go @@ -28,13 +28,25 @@ type Task struct { jobs []HostPort } -var valueRegexp = regexp.MustCompile(`^[a-z0-9]+$`) +var valueRegexp = regexp.MustCompile(`^[a-z0-9]{1,30}$`) // Identifies task, for sorting and domain hash func (t *Task) ID() string { return t.operationID + t.taskName + t.service } +func (t *Task) Hash() string { + return Hash([]byte(t.ID())) +} + +func (t *Task) OperationID() string { + return t.operationID +} + +func (t *Task) OperationAlias() string { + return t.operationAlias +} + // ID with jobs (host, port)-s to create correct version for xDS data (jobs can move between hosts) func (t *Task) IDWithHostPort() string { sb := strings.Builder{} @@ -91,6 +103,14 @@ func getTaskAliasDomain(task Task, baseDomain string) string { return fmt.Sprintf("%s-%s-%s.%s", task.operationAlias, task.taskName, task.service, baseDomain) } +func tryParseAliasSubdomain(subdomain string) (string, string, string, bool) { + parts := strings.Split(subdomain, "-") + if len(parts) != 3 { + return "", "", "", false + } + return parts[0], parts[1], parts[2], true +} + func Hash(source []byte) string { hash := fmt.Sprintf("%x", sha256.Sum256(source)) return hash[len(hash)-8:] @@ -105,7 +125,7 @@ func (a TaskList) Less(i, j int) bool { return a[i].ID() < a[j].ID() } func (a TaskList) String() string { sb := strings.Builder{} for _, task := range a { - sb.WriteString(fmt.Sprintf("\t%v\n", task)) + fmt.Fprintf(&sb, "\t%v\n", task) } return sb.String() } diff --git a/server/pkg/task_test.go b/server/pkg/task_test.go index 6d290f6..e1f8313 100644 --- a/server/pkg/task_test.go +++ b/server/pkg/task_test.go @@ -30,7 +30,7 @@ func TestValidateTask(t *testing.T) { taskName: "task", service: "service", }, - err: errors.New("field \"operationAlias\" value \"ali-as\" does not match regexp \"^[a-z0-9]+$\""), + err: errors.New("field \"operationAlias\" value \"ali-as\" does not match regexp \"^[a-z0-9]{1,30}$\""), }, { name: "invalid task name", @@ -40,7 +40,7 @@ func TestValidateTask(t *testing.T) { taskName: "Task", service: "service", }, - err: errors.New("field \"taskName\" value \"Task\" does not match regexp \"^[a-z0-9]+$\""), + err: errors.New("field \"taskName\" value \"Task\" does not match regexp \"^[a-z0-9]{1,30}$\""), }, { name: "invalid service", @@ -50,7 +50,17 @@ func TestValidateTask(t *testing.T) { taskName: "task", service: "$ervice", }, - err: errors.New("field \"service\" value \"$ervice\" does not match regexp \"^[a-z0-9]+$\""), + err: errors.New("field \"service\" value \"$ervice\" does not match regexp \"^[a-z0-9]{1,30}$\""), + }, + { + name: "invalid service", + task: Task{ + operationID: "123", + operationAlias: "alias", + taskName: "task", + service: "serviceserviceserviceserviceserviceserviceserviceservice", + }, + err: errors.New("field \"service\" value \"serviceserviceserviceserviceserviceserviceserviceservice\" does not match regexp \"^[a-z0-9]{1,30}$\""), }, { name: "do not check if no alias", diff --git a/server/pkg/updater.go b/server/pkg/updater.go index 776657c..bfcb937 100644 --- a/server/pkg/updater.go +++ b/server/pkg/updater.go @@ -35,13 +35,18 @@ func CreateTaskUpdater( } } -func (u *taskUpdater) Update(ctx context.Context, hashToTask map[string]Task, version string) error { +func (u *taskUpdater) Update( + ctx context.Context, + hashToTask map[string]Task, + operationAliasToID map[string]string, + version string, +) error { snapshot, err := makeSnapshot(hashToTask, version, u.baseDomain, u.tls, u.authEnabled) if err != nil { return fmt.Errorf("failed to make snapshot: %v", err) } - u.authServer.SetHashToTasks(hashToTask) + u.authServer.SetTasksData(hashToTask, operationAliasToID) if err := u.cache.SetSnapshot(ctx, NodeID, snapshot); err != nil { return fmt.Errorf("failed to set snapshot: %v", err)