diff --git a/apps/druid/adapters/cli/worker_test.go b/apps/druid/adapters/cli/worker_test.go index a8a0b14a..302145b4 100644 --- a/apps/druid/adapters/cli/worker_test.go +++ b/apps/druid/adapters/cli/worker_test.go @@ -167,6 +167,10 @@ func (fakeRestoreOCI) ResolveDigest(string) (string, error) { return "sha256:restored", nil } +func (fakeRestoreOCI) ResolveAnnotationInfo(string) (domain.AnnotationInfo, error) { + return domain.AnnotationInfo{}, nil +} + func (fakeRestoreOCI) CanUpdateTag(v1.Descriptor, string, string) (bool, error) { return false, nil } diff --git a/apps/druid/core/services/runtime_access.go b/apps/druid/core/services/runtime_access.go index 4c9bd4b4..5e881873 100644 --- a/apps/druid/core/services/runtime_access.go +++ b/apps/druid/core/services/runtime_access.go @@ -67,7 +67,7 @@ func (s *RuntimeSupervisor) Restore(id string, artifact string, restart bool, re session.markError(err) return nil, err } - materialized, err := s.runPullWorker(context.Background(), s.runtimeBackend, ports.RuntimeWorkerModeRestore, id, artifact, root, registryCredentials) + materialized, err := s.runPullWorker(context.Background(), s.runtimeBackend, ports.RuntimeWorkerModeRestore, id, artifact, root, registryCredentials, "") if err != nil { session.markError(err) return nil, err diff --git a/apps/druid/core/services/runtime_materialization.go b/apps/druid/core/services/runtime_materialization.go index 91fac4de..260cac6a 100644 --- a/apps/druid/core/services/runtime_materialization.go +++ b/apps/druid/core/services/runtime_materialization.go @@ -14,10 +14,11 @@ import ( ) func (s *RuntimeSupervisor) materializeNewScroll(ctx context.Context, runtimeService ports.RuntimeBackendInterface, artifact string, runtimeID string, namespace string, registryCredentials []domain.RegistryCredential) (*ports.RuntimeMaterialization, error) { - return s.runPullWorker(ctx, runtimeService, ports.RuntimeWorkerModeCreate, runtimeID, artifact, runtimeService.RootRef(runtimeID, namespace), registryCredentials) + storage := resolveArtifactMinDisk(artifact, registryCredentials) + return s.runPullWorker(ctx, runtimeService, ports.RuntimeWorkerModeCreate, runtimeID, artifact, runtimeService.RootRef(runtimeID, namespace), registryCredentials, storage) } -func (s *RuntimeSupervisor) runPullWorker(ctx context.Context, runtimeService ports.RuntimeBackendInterface, mode ports.RuntimeWorkerMode, runtimeID string, artifact string, root string, registryCredentials []domain.RegistryCredential) (*ports.RuntimeMaterialization, error) { +func (s *RuntimeSupervisor) runPullWorker(ctx context.Context, runtimeService ports.RuntimeBackendInterface, mode ports.RuntimeWorkerMode, runtimeID string, artifact string, root string, registryCredentials []domain.RegistryCredential, storage string) (*ports.RuntimeMaterialization, error) { if s.workerCallbacks == nil || s.workerCallbackURL == "" { return nil, fmt.Errorf("daemon materialization requires --worker-callback-url and --worker-callback-listen") } @@ -32,6 +33,7 @@ func (s *RuntimeSupervisor) runPullWorker(ctx context.Context, runtimeService po Mode: mode, RuntimeID: runtimeID, Artifact: artifact, + Storage: storage, RootRef: root, MountPath: "/scroll", CallbackURL: callbackURL, @@ -62,6 +64,22 @@ func (s *RuntimeSupervisor) runPullWorker(ctx context.Context, runtimeService po } } +func resolveArtifactMinDisk(artifact string, registryCredentials []domain.RegistryCredential) string { + if artifact == "" { + return "" + } + if _, err := os.Stat(artifact); err == nil { + return "" + } + oci := registry.NewOciClient(registry.NewCredentialStore(registryCredentials)) + info, err := oci.ResolveAnnotationInfo(artifact) + if err != nil { + logger.Log().Warn("Unable to resolve artifact min disk", zap.String("artifact", artifact), zap.Error(err)) + return "" + } + return info.MinDisk +} + func resolveArtifactDigest(artifact string, registryCredentials []domain.RegistryCredential) string { if artifact == "" { return "" diff --git a/apps/druid/core/services/runtime_update.go b/apps/druid/core/services/runtime_update.go index 21f2699e..03e22bf9 100644 --- a/apps/druid/core/services/runtime_update.go +++ b/apps/druid/core/services/runtime_update.go @@ -51,7 +51,7 @@ func (s *RuntimeSupervisor) updateExistingScroll(runtimeScroll *domain.RuntimeSc _ = s.store.UpdateScroll(runtimeScroll) return nil, err } - materialized, err := s.runPullWorker(context.Background(), s.runtimeBackend, ports.RuntimeWorkerModeUpdate, runtimeScroll.ID, artifact, runtimeScroll.Root, registryCredentials) + materialized, err := s.runPullWorker(context.Background(), s.runtimeBackend, ports.RuntimeWorkerModeUpdate, runtimeScroll.ID, artifact, runtimeScroll.Root, registryCredentials, "") if err != nil { runtimeScroll.Status = domain.RuntimeScrollStatusError runtimeScroll.LastError = err.Error() diff --git a/internal/core/ports/services_ports.go b/internal/core/ports/services_ports.go index 42a9755b..5de1eb3f 100644 --- a/internal/core/ports/services_ports.go +++ b/internal/core/ports/services_ports.go @@ -121,6 +121,7 @@ type RuntimeWorkerAction struct { Mode RuntimeWorkerMode RuntimeID string Artifact string + Storage string RootRef string MountPath string CallbackURL string @@ -164,6 +165,7 @@ type OciRegistryInterface interface { GetRepo(repoUrl string) (*remote.Repository, error) FetchFile(artifact string, filePath string) ([]byte, error) ResolveDigest(artifact string) (string, error) + ResolveAnnotationInfo(artifact string) (domain.AnnotationInfo, error) Pull(dir string, artifact string) error PullSelective(dir string, artifact string, includeData bool, progress *domain.SnapshotProgress) error CanUpdateTag(descriptor v1.Descriptor, folder string, tag string) (bool, error) diff --git a/internal/core/services/registry/oci.go b/internal/core/services/registry/oci.go index 3c612027..63cd9ce2 100644 --- a/internal/core/services/registry/oci.go +++ b/internal/core/services/registry/oci.go @@ -375,6 +375,37 @@ func (c *OciClient) ResolveDigest(artifact string) (string, error) { return desc.Digest.String(), nil } +func (c *OciClient) ResolveAnnotationInfo(artifact string) (domain.AnnotationInfo, error) { + repo, ref, _ := utils.ParseArtifactRef(artifact) + if repo == "" || ref == "" { + return domain.AnnotationInfo{}, fmt.Errorf("reference (tag or digest) must be set") + } + repoInstance, err := c.GetRepo(repo) + if err != nil { + return domain.AnnotationInfo{}, err + } + desc, err := oras.Resolve(context.Background(), repoInstance, ref, oras.DefaultResolveOptions) + if err != nil { + return domain.AnnotationInfo{}, fmt.Errorf("failed to resolve %s: %w", ref, err) + } + manifest, err := content.FetchAll(context.Background(), repoInstance, desc) + if err != nil { + return domain.AnnotationInfo{}, fmt.Errorf("failed to fetch manifest for %s: %w", ref, err) + } + var fullDesc v1.Descriptor + if err := json.Unmarshal(manifest, &fullDesc); err != nil { + return domain.AnnotationInfo{}, fmt.Errorf("failed to parse manifest for %s: %w", ref, err) + } + annotations := fullDesc.Annotations + return domain.AnnotationInfo{ + MinRam: annotations["gg.druid.scroll.minRam"], + MinDisk: annotations["gg.druid.scroll.minDisk"], + MinCpu: annotations["gg.druid.scroll.minCpu"], + Image: annotations["gg.druid.scroll.image"], + Smart: annotations["gg.druid.scroll.smart"] == "true", + }, nil +} + func fetchFileFromOCI(ctx context.Context, fetcher content.Fetcher, rootDesc v1.Descriptor, filePath string) ([]byte, error) { seen := map[string]bool{} queue := []v1.Descriptor{rootDesc} diff --git a/internal/core/services/registry/oci_test.go b/internal/core/services/registry/oci_test.go index c6a5c77a..782a3ef6 100644 --- a/internal/core/services/registry/oci_test.go +++ b/internal/core/services/registry/oci_test.go @@ -172,6 +172,44 @@ func TestPushDataChunkPathNotDoubled(t *testing.T) { } } +func TestResolveAnnotationInfoReadsManifestAnnotations(t *testing.T) { + tmpDir := t.TempDir() + t.Chdir(tmpDir) + + srv := fakeRegistry(t) + registryHost := strings.TrimPrefix(srv.URL, "http://") + + folder := filepath.Join("scrolls", "cs2server") + if err := os.MkdirAll(folder, 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(folder, "scroll.yaml"), []byte("name: test\nversion: 0.1.0\napp_version: cs2server\n"), 0644); err != nil { + t.Fatal(err) + } + + client := &OciClient{ + credentialStore: NewCredentialStore([]domain.RegistryCredential{}), + plainHTTP: true, + } + repoRef := registryHost + "/test/scroll" + if _, err := client.Push(folder, repoRef, "cs2server-prebuild", map[string]string{ + "gg.druid.scroll.minDisk": "70Gi", + "gg.druid.scroll.minRam": "1Gi", + "gg.druid.scroll.minCpu": "0.5", + "gg.druid.scroll.smart": "true", + }, false, nil); err != nil { + t.Fatalf("Push failed unexpectedly: %v", err) + } + + info, err := client.ResolveAnnotationInfo(repoRef + ":cs2server-prebuild") + if err != nil { + t.Fatal(err) + } + if info.MinDisk != "70Gi" || info.MinRam != "1Gi" || info.MinCpu != "0.5" || !info.Smart { + t.Fatalf("annotation info = %#v", info) + } +} + func TestPushPullExecutableDataChunkPreservesMode(t *testing.T) { tmpDir := t.TempDir() t.Chdir(tmpDir) diff --git a/internal/runtime/kubernetes/resources.go b/internal/runtime/kubernetes/resources.go index 2895fc99..b7d192cf 100644 --- a/internal/runtime/kubernetes/resources.go +++ b/internal/runtime/kubernetes/resources.go @@ -16,8 +16,14 @@ import ( "github.com/highcard-dev/daemon/internal/core/ports" ) -func pvcSpec(namespace string, name string, storageClass string) *corev1.PersistentVolumeClaim { - quantity := resource.MustParse("1Gi") +func pvcSpec(namespace string, name string, storageClass string, storageRequest string) *corev1.PersistentVolumeClaim { + if storageRequest == "" { + storageRequest = "1Gi" + } + quantity, err := resource.ParseQuantity(storageRequest) + if err != nil || quantity.Sign() <= 0 { + quantity = resource.MustParse("1Gi") + } spec := corev1.PersistentVolumeClaimSpec{ AccessModes: []corev1.PersistentVolumeAccessMode{corev1.ReadWriteOnce}, Resources: corev1.VolumeResourceRequirements{ diff --git a/internal/runtime/kubernetes/resources_test.go b/internal/runtime/kubernetes/resources_test.go index 2f10714d..c8021731 100644 --- a/internal/runtime/kubernetes/resources_test.go +++ b/internal/runtime/kubernetes/resources_test.go @@ -305,6 +305,7 @@ func TestSpawnPullWorkerCreateUsesFinalPVCAndWorkerJob(t *testing.T) { Mode: ports.RuntimeWorkerModeCreate, RuntimeID: "deployment-123", Artifact: "registry.local/lab:1.0", + Storage: "25Gi", RootRef: ref("games", dataPVCName("deployment-123")), MountPath: "/scroll", CallbackURL: "http://druid-cli:8083/internal/v1/workers/deployment-123/complete", @@ -320,6 +321,9 @@ func TestSpawnPullWorkerCreateUsesFinalPVCAndWorkerJob(t *testing.T) { if len(pvcs.Items) != 1 || pvcs.Items[0].Name != dataPVCName("deployment-123") { t.Fatalf("pvcs = %#v, want final PVC", pvcs.Items) } + if got := pvcs.Items[0].Spec.Resources.Requests.Storage().String(); got != "25Gi" { + t.Fatalf("pvc storage = %s, want 25Gi", got) + } if len(jobs) != 1 { t.Fatalf("jobs = %d, want 1", len(jobs)) } @@ -1184,7 +1188,7 @@ func TestStopRuntimeDeletesWorkloadsButPreservesDataAndServices(t *testing.T) { } for _, create := range []func() error{ func() error { - _, err := client.CoreV1().PersistentVolumeClaims("druid").Create(context.Background(), pvcSpec("druid", "druid-static-web-data", ""), metav1.CreateOptions{}) + _, err := client.CoreV1().PersistentVolumeClaims("druid").Create(context.Background(), pvcSpec("druid", "druid-static-web-data", "", ""), metav1.CreateOptions{}) return err }, func() error { @@ -1237,7 +1241,7 @@ func TestDeleteRuntimePurgesServicesAndDataWhenRequested(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err := client.CoreV1().PersistentVolumeClaims("druid").Create(context.Background(), pvcSpec("druid", "druid-static-web-data", ""), metav1.CreateOptions{}); err != nil { + if _, err := client.CoreV1().PersistentVolumeClaims("druid").Create(context.Background(), pvcSpec("druid", "druid-static-web-data", "", ""), metav1.CreateOptions{}); err != nil { t.Fatal(err) } if _, err := client.CoreV1().Services("druid").Create(context.Background(), service, metav1.CreateOptions{}); err != nil { diff --git a/internal/runtime/kubernetes/workers.go b/internal/runtime/kubernetes/workers.go index 3b2a4e04..6eeceed8 100644 --- a/internal/runtime/kubernetes/workers.go +++ b/internal/runtime/kubernetes/workers.go @@ -40,6 +40,7 @@ func (b *Backend) SpawnPullWorker(ctx context.Context, action ports.RuntimeWorke zap.String("namespace", namespace), zap.String("pvc", pvc), zap.String("artifact", action.Artifact), + zap.String("storage", action.Storage), ) logger.Log().Debug("Kubernetes pull worker details", zap.String("runtime_id", action.RuntimeID), @@ -50,7 +51,7 @@ func (b *Backend) SpawnPullWorker(ctx context.Context, action ports.RuntimeWorke zap.Bool("has_registry_credentials", len(action.RegistryCredentials) > 0), ) if action.Mode == ports.RuntimeWorkerModeCreate { - if err := b.ensurePVC(ctx, namespace, pvc); err != nil { + if err := b.ensurePVC(ctx, namespace, pvc, action.Storage); err != nil { logger.Log().Error("Failed to ensure runtime PVC for pull worker", zap.String("runtime_id", action.RuntimeID), zap.String("namespace", namespace), zap.String("pvc", pvc), zap.Error(err)) return err } @@ -113,9 +114,9 @@ func setJobDeadlineFromContext(ctx context.Context, job *batchv1.Job) { job.Spec.ActiveDeadlineSeconds = &seconds } -func (b *Backend) ensurePVC(ctx context.Context, namespace string, name string) error { - pvc := pvcSpec(namespace, name, b.config.StorageClass) - logger.Log().Debug("Ensuring Kubernetes PVC", zap.String("namespace", namespace), zap.String("pvc", name), zap.String("storage_class", b.config.StorageClass)) +func (b *Backend) ensurePVC(ctx context.Context, namespace string, name string, storageRequest string) error { + pvc := pvcSpec(namespace, name, b.config.StorageClass, storageRequest) + logger.Log().Debug("Ensuring Kubernetes PVC", zap.String("namespace", namespace), zap.String("pvc", name), zap.String("storage_class", b.config.StorageClass), zap.String("storage", storageRequest)) _, err := b.client.CoreV1().PersistentVolumeClaims(namespace).Create(ctx, pvc, metav1.CreateOptions{}) if apierrors.IsAlreadyExists(err) { logger.Log().Debug("Kubernetes PVC already exists", zap.String("namespace", namespace), zap.String("pvc", name)) diff --git a/test/mock/services.go b/test/mock/services.go index b919a411..88e3cdbf 100644 --- a/test/mock/services.go +++ b/test/mock/services.go @@ -880,6 +880,21 @@ func (mr *MockOciRegistryInterfaceMockRecorder) ResolveDigest(artifact any) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveDigest", reflect.TypeOf((*MockOciRegistryInterface)(nil).ResolveDigest), artifact) } +// ResolveAnnotationInfo mocks base method. +func (m *MockOciRegistryInterface) ResolveAnnotationInfo(artifact string) (domain.AnnotationInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResolveAnnotationInfo", artifact) + ret0, _ := ret[0].(domain.AnnotationInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ResolveAnnotationInfo indicates an expected call of ResolveAnnotationInfo. +func (mr *MockOciRegistryInterfaceMockRecorder) ResolveAnnotationInfo(artifact any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveAnnotationInfo", reflect.TypeOf((*MockOciRegistryInterface)(nil).ResolveAnnotationInfo), artifact) +} + // MockQueueManagerInterface is a mock of QueueManagerInterface interface. type MockQueueManagerInterface struct { ctrl *gomock.Controller