From 252875c1204dfc70a8fe2700de3f511bea1d68d8 Mon Sep 17 00:00:00 2001 From: sjmiller609 <7516283+sjmiller609@users.noreply.github.com> Date: Thu, 25 Jun 2026 14:22:18 +0000 Subject: [PATCH] Add fork identity wait plumbing --- server/cmd/api/fork_identity.go | 99 +++++++++++++++++++++ server/cmd/api/fork_identity_test.go | 107 +++++++++++++++++++++++ server/cmd/api/main.go | 3 + server/cmd/wrapper/fork_identity.go | 78 +++++++++++++++++ server/cmd/wrapper/fork_identity_test.go | 50 +++++++++++ server/cmd/wrapper/main.go | 68 ++++++++------ server/cmd/wrapper/supervisord.go | 9 +- server/lib/forkidentity/config.go | 48 ++++++++++ server/lib/forkidentity/env.go | 65 ++++++++++++++ server/lib/forkidentity/paths.go | 15 ++++ server/lib/forkidentity/payload.go | 96 ++++++++++++++++++++ 11 files changed, 610 insertions(+), 28 deletions(-) create mode 100644 server/cmd/api/fork_identity.go create mode 100644 server/cmd/api/fork_identity_test.go create mode 100644 server/cmd/wrapper/fork_identity.go create mode 100644 server/cmd/wrapper/fork_identity_test.go create mode 100644 server/lib/forkidentity/config.go create mode 100644 server/lib/forkidentity/env.go create mode 100644 server/lib/forkidentity/paths.go create mode 100644 server/lib/forkidentity/payload.go diff --git a/server/cmd/api/fork_identity.go b/server/cmd/api/fork_identity.go new file mode 100644 index 00000000..13e9fdb4 --- /dev/null +++ b/server/cmd/api/fork_identity.go @@ -0,0 +1,99 @@ +package main + +import ( + "encoding/json" + "fmt" + "log/slog" + "net/http" + "os" + "time" + + "github.com/kernel/kernel-images/server/lib/forkidentity" +) + +func forkIdentityHandler(log *slog.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + enabled, err := forkidentity.WaitEnabled() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if !enabled { + http.Error(w, "fork identity wait is disabled", http.StatusConflict) + return + } + + var payload forkidentity.Payload + dec := json.NewDecoder(http.MaxBytesReader(w, r.Body, forkidentity.MaxPayloadBytes)) + if err := dec.Decode(&payload); err != nil { + http.Error(w, fmt.Sprintf("decode payload: %v", err), http.StatusBadRequest) + return + } + if err := payload.Validate(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := os.Remove(forkidentity.AppliedFile); err != nil && !os.IsNotExist(err) { + log.Error("fork identity applied marker reset failed", "err", err) + http.Error(w, "failed to reset fork identity", http.StatusInternalServerError) + return + } + if err := forkidentity.WritePayload(payload); err != nil { + log.Error("fork identity payload write failed", "err", err) + http.Error(w, "failed to write fork identity", http.StatusInternalServerError) + return + } + if err := forkidentity.WaitAppliedMarker(payload.InstanceName(), 30*time.Second); err != nil { + log.Error("fork identity apply wait failed", "err", err) + http.Error(w, "fork identity not applied", http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusNoContent) + } +} + +func forkIdentityConfigHandler(log *slog.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + payload, err := forkidentity.ReadPayload() + if err != nil { + if os.IsNotExist(err) { + enabled, parseErr := forkidentity.WaitEnabled() + if parseErr != nil { + http.Error(w, parseErr.Error(), http.StatusInternalServerError) + return + } + if enabled { + w.WriteHeader(http.StatusAccepted) + return + } + http.NotFound(w, r) + return + } + log.Error("fork identity config read failed", "err", err) + http.Error(w, "failed to read fork identity", http.StatusInternalServerError) + return + } + enabled, err := forkidentity.WaitEnabled() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if enabled { + applied, err := forkidentity.AppliedMarkerMatches(payload.InstanceName()) + if err != nil { + log.Error("fork identity applied marker read failed", "err", err) + http.Error(w, "failed to read fork identity", http.StatusInternalServerError) + return + } + if !applied { + w.WriteHeader(http.StatusAccepted) + return + } + } + resp := forkidentity.ExtensionConfigFromPayload(payload) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + log.Error("fork identity config encode failed", "err", err) + } + } +} diff --git a/server/cmd/api/fork_identity_test.go b/server/cmd/api/fork_identity_test.go new file mode 100644 index 00000000..de7ccff7 --- /dev/null +++ b/server/cmd/api/fork_identity_test.go @@ -0,0 +1,107 @@ +package main + +import ( + "encoding/json" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/kernel/kernel-images/server/lib/forkidentity" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestForkIdentityConfigHandlerReturnsNotFoundWithoutPayload(t *testing.T) { + useTempForkIdentityFiles(t) + + req := httptest.NewRequest(http.MethodGet, "/internal/fork-identity/config", nil) + rec := httptest.NewRecorder() + forkIdentityConfigHandler(slog.Default()).ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestForkIdentityConfigHandlerReturnsAcceptedWhileWaiting(t *testing.T) { + useTempForkIdentityFiles(t) + t.Setenv(forkidentity.WaitEnv, "true") + + req := httptest.NewRequest(http.MethodGet, "/internal/fork-identity/config", nil) + rec := httptest.NewRecorder() + forkIdentityConfigHandler(slog.Default()).ServeHTTP(rec, req) + + assert.Equal(t, http.StatusAccepted, rec.Code) +} + +func TestForkIdentityConfigHandlerReturnsExtensionConfig(t *testing.T) { + useTempForkIdentityFiles(t) + payload := forkidentity.Payload{ + "instance_name": "browser-1", + "metro_api_url": "https://metro.example.test/browser/kernel", + "kernel_metro_api_base_url": "https://kernel-metro.example.test/browser/kernel", + "session_intel_url": "https://intel.example.test", + "future_identity_field_name": "future-value", + } + writeForkIdentityPayloadForTest(t, payload) + + req := httptest.NewRequest(http.MethodGet, "/internal/fork-identity/config", nil) + rec := httptest.NewRecorder() + forkIdentityConfigHandler(slog.Default()).ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var got forkidentity.ExtensionConfig + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + assert.Equal(t, forkidentity.ExtensionConfig{ + InstanceName: "browser-1", + MetroAPIURL: "https://intel.example.test", + }, got) +} + +func TestForkIdentityConfigHandlerReturnsAcceptedUntilPayloadApplied(t *testing.T) { + useTempForkIdentityFiles(t) + t.Setenv(forkidentity.WaitEnv, "true") + payload := forkidentity.Payload{ + "instance_name": "browser-1", + "session_intel_url": "https://intel.example.test", + } + writeForkIdentityPayloadForTest(t, payload) + + req := httptest.NewRequest(http.MethodGet, "/internal/fork-identity/config", nil) + rec := httptest.NewRecorder() + forkIdentityConfigHandler(slog.Default()).ServeHTTP(rec, req) + require.Equal(t, http.StatusAccepted, rec.Code) + + require.NoError(t, forkidentity.WriteAppliedMarker("browser-1")) + rec = httptest.NewRecorder() + forkIdentityConfigHandler(slog.Default()).ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var got forkidentity.ExtensionConfig + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + assert.Equal(t, forkidentity.ExtensionConfig{ + InstanceName: "browser-1", + MetroAPIURL: "https://intel.example.test", + }, got) +} + +func useTempForkIdentityFiles(t *testing.T) { + t.Helper() + dir := t.TempDir() + oldPayloadFile := forkidentity.PayloadFile + oldAppliedFile := forkidentity.AppliedFile + forkidentity.PayloadFile = filepath.Join(dir, "fork-identity.json") + forkidentity.AppliedFile = filepath.Join(dir, "fork-identity-applied") + t.Cleanup(func() { + forkidentity.PayloadFile = oldPayloadFile + forkidentity.AppliedFile = oldAppliedFile + }) +} + +func writeForkIdentityPayloadForTest(t *testing.T, payload forkidentity.Payload) { + t.Helper() + data, err := json.Marshal(payload) + require.NoError(t, err) + require.NoError(t, os.WriteFile(forkidentity.PayloadFile, data, 0o600)) +} diff --git a/server/cmd/api/main.go b/server/cmd/api/main.go index b3500a52..7668a5eb 100644 --- a/server/cmd/api/main.go +++ b/server/cmd/api/main.go @@ -147,6 +147,9 @@ func main() { }) oapi.HandlerFromMux(strictHandler, r) + r.Post("/internal/fork-identity", forkIdentityHandler(slogger)) + r.Get("/internal/fork-identity/config", forkIdentityConfigHandler(slogger)) + // endpoints to expose the spec r.Get("/spec.yaml", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/vnd.oai.openapi") diff --git a/server/cmd/wrapper/fork_identity.go b/server/cmd/wrapper/fork_identity.go new file mode 100644 index 00000000..aa6807cd --- /dev/null +++ b/server/cmd/wrapper/fork_identity.go @@ -0,0 +1,78 @@ +package main + +import ( + "context" + "errors" + "os" + "path/filepath" + goruntime "runtime" + + "github.com/kernel/kernel-images/server/lib/forkidentity" +) + +func forkIdentityWaitEnabled() (bool, error) { + return forkidentity.WaitEnabled() +} + +func waitForForkIdentityIfEnabled(ctx context.Context, enabled bool) bool { + if !enabled { + return true + } + stopAll("envoy") + + for _, path := range []string{forkidentity.AppliedFile, forkidentity.PayloadFile} { + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + fatalf("fork identity reset %s: %v", path, err) + } + } + if err := os.MkdirAll(filepath.Dir(forkidentity.ReadyFile), 0o755); err != nil { + fatalf("fork identity ready dir: %v", err) + } + if err := os.WriteFile(forkidentity.ReadyFile, []byte("waiting\n"), 0o644); err != nil { + fatalf("fork identity ready file: %v", err) + } + + logf("fork identity waiting payload=%s", forkidentity.PayloadFile) + payload, err := waitForForkIdentityPayload(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + logf("fork identity wait canceled") + return false + } + fatalf("fork identity payload wait: %v", err) + } + if err := applyForkIdentityPayload(payload); err != nil { + fatalf("fork identity apply: %v", err) + } + if err := forkidentity.WriteAppliedMarker(payload.InstanceName()); err != nil { + fatalf("fork identity applied file: %v", err) + } + logf("fork identity applied instance=%s", payload.InstanceName()) + return true +} + +func waitForForkIdentityPayload(ctx context.Context) (forkidentity.Payload, error) { + for { + payload, err := forkidentity.ReadPayload() + if err == nil { + return payload, nil + } + if !os.IsNotExist(err) { + return nil, err + } + if err := ctx.Err(); err != nil { + return nil, ctx.Err() + } + goruntime.Gosched() + } +} + +func applyForkIdentityPayload(payload forkidentity.Payload) error { + for _, key := range forkidentity.ClearEnvKeys(payload) { + _ = os.Unsetenv(key) + } + for key, value := range forkidentity.Env(payload) { + _ = os.Setenv(key, value) + } + return nil +} diff --git a/server/cmd/wrapper/fork_identity_test.go b/server/cmd/wrapper/fork_identity_test.go new file mode 100644 index 00000000..69ad10bf --- /dev/null +++ b/server/cmd/wrapper/fork_identity_test.go @@ -0,0 +1,50 @@ +package main + +import ( + "os" + "testing" + + "github.com/kernel/kernel-images/server/lib/forkidentity" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplyForkIdentityPayloadSetsAndClearsEnv(t *testing.T) { + t.Setenv("METRO_NAME", "old-metro") + t.Setenv("S2_STREAM", "old-stream") + t.Setenv("FUTURE_IDENTITY_FIELD_NAME", "old-future") + + err := applyForkIdentityPayload(forkidentity.Payload{ + "instance_name": "browser-1", + "metro_name": "iad", + "xds_server": "xds.example.test", + "kernel_instance_jwt": "jwt", + "metro_api_url": "https://metro.example.test/browser/kernel", + "session_intel_url": "https://intel.example.test", + "future_identity_field_name": "future-value", + "empty_future_identity_field": "", + }) + require.NoError(t, err) + + assert.Equal(t, "browser-1", os.Getenv("INSTANCE_NAME")) + assert.Equal(t, "browser-1", os.Getenv("INST_NAME")) + assert.Equal(t, "iad", os.Getenv("METRO_NAME")) + assert.Equal(t, "xds.example.test", os.Getenv("XDS_SERVER")) + assert.Equal(t, "jwt", os.Getenv("KERNEL_INSTANCE_JWT")) + assert.Equal(t, "https://metro.example.test/browser/kernel", os.Getenv("KERNEL_METRO_API_BASE_URL")) + assert.Equal(t, "https://intel.example.test", os.Getenv("SESSION_INTEL_URL")) + assert.Equal(t, "future-value", os.Getenv("FUTURE_IDENTITY_FIELD_NAME")) + assert.Empty(t, os.Getenv("EMPTY_FUTURE_IDENTITY_FIELD")) + assert.Empty(t, os.Getenv("S2_STREAM")) +} + +func TestForkIdentityURLPrecedence(t *testing.T) { + payload := forkidentity.Payload{ + "metro_api_url": "https://legacy.example.test/browser/kernel", + "kernel_metro_api_base_url": "https://metro.example.test/browser/kernel", + "session_intel_url": "https://intel.example.test", + } + + assert.Equal(t, "https://metro.example.test/browser/kernel", forkidentity.MetroAPIURL(payload)) + assert.Equal(t, "https://intel.example.test", forkidentity.ExtensionAPIURL(payload)) +} diff --git a/server/cmd/wrapper/main.go b/server/cmd/wrapper/main.go index b263fe83..500097c6 100644 --- a/server/cmd/wrapper/main.go +++ b/server/cmd/wrapper/main.go @@ -13,6 +13,7 @@ package main import ( + "context" "fmt" "os" "os/exec" @@ -21,6 +22,8 @@ import ( "strings" "syscall" "time" + + "github.com/kernel/kernel-images/server/lib/forkidentity" ) const ( @@ -60,6 +63,10 @@ func main() { prof := detectProfile() stzManaged := scaleToZeroManaged() logf("starting wrapper (profile=%s stz=%s)", profileName(prof), stzMode(stzManaged)) + forkIdentityWait, err := forkIdentityWaitEnabled() + if err != nil { + fatalf("fork identity config: %v", err) + } // Register signal handling early so a SIGTERM/SIGINT during the // seconds-long startup window queues into the channel instead of @@ -67,6 +74,8 @@ func main() { // goroutine is installed below, once supervisord is running. sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) + startupCtx, cancelStartup := context.WithCancel(context.Background()) + defer cancelStartup() // /dev/shm: only mount when not running under Docker (Docker manages it). if os.Getenv("WITHDOCKER") == "" { @@ -142,6 +151,7 @@ func main() { // before this point gets picked up on the first iteration. go func() { <-sigs + cancelStartup() logf("shutdown: stopping services") _ = exec.Command("supervisorctl", "-c", supervisorConf, "stop", "all").Run() _ = supCmd.Process.Signal(syscall.SIGTERM) @@ -182,41 +192,36 @@ func main() { } waitForSocket(pulseSocket, 10*time.Second) startAll("chromium") + if forkIdentityWait { + waitForHTTPProbe("chromium devtools", "http://127.0.0.1:"+forkidentity.FirstNonEmpty(os.Getenv("INTERNAL_PORT"), defaultIntPort)+"/json/version", 30*time.Second) + startAll("kernel-images-api") + } waitForSocket(dbusSocket, 10*time.Second) if prof == profileHeadful && webrtc { startAll("neko") } + if forkIdentityWait { + waitForHTTPProbe("public cdp", "http://127.0.0.1:"+os.Getenv("CHROME_PORT")+"/json/version", 10*time.Second) + } browserDone := time.Now() - // FORK HOOK: - // When this binary runs as a forked snapshot restore, the per-fork - // identity envs (INST_NAME, METRO_NAME, XDS_SERVER, KERNEL_INSTANCE_JWT, - // plus any future per-tenant secrets) won't be set yet at this point — - // the snapshot was taken from a different instance. Insert the - // following sequence here once the env-delivery channel exists: - // 1. Block on the host-pushed env bundle (vsock socket, virtio-fs - // drop file, or whatever transport the control plane settles on). - // 2. Apply the bundle to this process's environ via os.Setenv so - // the identity phase below picks them up via the existing $VAR - // expansion in init-envoy.sh and the supervisorctl-spawned - // services inherit them. - // 3. The identity phase uses `supervisorctl restart envoy` - // (idempotent — start on first boot, stop+start on a re-render - // after fork) so a restored snapshot drops its stale identity - // cleanly. - // Boot path keeps running through unchanged: the wait simply no-ops - // when there's no fork bundle to receive. - - // Identity phase: identity-bound services. Render envoy bootstrap with - // INST_NAME/JWT/etc and (re)start envoy + kernel-images-api. Both - // services use `restart` so the same code path works for boot (start a - // stopped service) and post-fork (stop+start to force a re-read of - // refreshed envs). + if !waitForForkIdentityIfEnabled(startupCtx, forkIdentityWait) { + if err := supCmd.Wait(); err != nil { + logf("supervisord exited: %v", err) + } + return + } + + // Identity phase: render envoy bootstrap with INST_NAME/JWT/etc. In fork + // identity wait mode, kernel-images-api was started early and is not + // restarted here, so public CDP stays connected after identity apply. identityStart := time.Now() if isExecutable("/usr/local/bin/init-envoy.sh") { runStreamFatal("envoy-init", "/usr/local/bin/init-envoy.sh") } - restartAll("kernel-images-api") + if !forkIdentityWait { + restartAll("kernel-images-api") + } identityDone := time.Now() // Wait for the union of caller-visible ready signals. Each probe runs @@ -293,6 +298,19 @@ func waitAllReady(t0 time.Time, webrtc bool) map[string]time.Duration { return durations } +func waitForHTTPProbe(name, url string, timeout time.Duration) { + start := time.Now() + deadline := start.Add(timeout) + for time.Now().Before(deadline) { + if httpProbeOK(url) { + logf("%s ready in %s", name, since(start)) + return + } + time.Sleep(20 * time.Millisecond) + } + fatalf("%s unavailable after %s", name, timeout) +} + type probe struct { name string fn func() bool diff --git a/server/cmd/wrapper/supervisord.go b/server/cmd/wrapper/supervisord.go index 47d9339d..4672db47 100644 --- a/server/cmd/wrapper/supervisord.go +++ b/server/cmd/wrapper/supervisord.go @@ -24,10 +24,13 @@ func startAll(progs ...string) { supervisorctl("start", progs...) } +func stopAll(progs ...string) { + supervisorctl("stop", progs...) +} + // restartAll is the start-or-stop+start variant. It's used for services -// that may already be running from a snapshot restore (post-fork, see the -// FORK HOOK in main) so they pick up refreshed envs cleanly. supervisorctl -// `restart` is a no-op stop on cold programs followed by a normal start. +// that need to pick up refreshed envs cleanly. supervisorctl `restart` is +// a no-op stop on cold programs followed by a normal start. func restartAll(progs ...string) { supervisorctl("restart", progs...) } diff --git a/server/lib/forkidentity/config.go b/server/lib/forkidentity/config.go new file mode 100644 index 00000000..56ae0ff7 --- /dev/null +++ b/server/lib/forkidentity/config.go @@ -0,0 +1,48 @@ +package forkidentity + +import ( + "fmt" + "os" + "strings" +) + +type ExtensionConfig struct { + InstanceName string `json:"instanceName"` + MetroAPIURL string `json:"metroApiUrl"` +} + +func WaitEnabled() (bool, error) { + raw := os.Getenv(WaitEnv) + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", "0", "false", "no", "off": + return false, nil + case "1", "true", "yes", "on": + return true, nil + default: + return false, fmt.Errorf("%s must be a boolean, got %q", WaitEnv, raw) + } +} + +func MetroAPIURL(payload Payload) string { + return FirstNonEmpty(payload.Get("kernel_metro_api_base_url"), payload.Get("metro_api_url"), payload.Get("session_intel_url")) +} + +func ExtensionAPIURL(payload Payload) string { + return FirstNonEmpty(payload.Get("session_intel_url"), MetroAPIURL(payload)) +} + +func ExtensionConfigFromPayload(payload Payload) ExtensionConfig { + return ExtensionConfig{ + InstanceName: payload.InstanceName(), + MetroAPIURL: ExtensionAPIURL(payload), + } +} + +func FirstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + return "" +} diff --git a/server/lib/forkidentity/env.go b/server/lib/forkidentity/env.go new file mode 100644 index 00000000..65842b90 --- /dev/null +++ b/server/lib/forkidentity/env.go @@ -0,0 +1,65 @@ +package forkidentity + +import ( + "sort" + "strings" +) + +var KnownEnvKeys = []string{ + "INSTANCE_NAME", + "INST_NAME", + "METRO_NAME", + "XDS_SERVER", + "KERNEL_INSTANCE_JWT", + "METRO_API_URL", + "KERNEL_METRO_API_BASE_URL", + "SESSION_INTEL_URL", + "S2_STREAM", +} + +func Env(payload Payload) map[string]string { + values := map[string]string{} + for key, value := range payload { + envKey := payloadKeyToEnvKey(key) + if envKey == "" || strings.TrimSpace(value) == "" { + continue + } + values[envKey] = strings.TrimSpace(value) + } + if payload.InstanceName() != "" { + values["INSTANCE_NAME"] = payload.InstanceName() + } + if values["INST_NAME"] == "" && payload.InstanceName() != "" { + values["INST_NAME"] = payload.InstanceName() + } + if metroAPIURL := MetroAPIURL(payload); metroAPIURL != "" { + values["KERNEL_METRO_API_BASE_URL"] = metroAPIURL + } + return values +} + +func ClearEnvKeys(payload Payload) []string { + keys := map[string]struct{}{} + for _, key := range KnownEnvKeys { + keys[key] = struct{}{} + } + for key := range payload { + if envKey := payloadKeyToEnvKey(key); envKey != "" { + keys[envKey] = struct{}{} + } + } + out := make([]string, 0, len(keys)) + for key := range keys { + out = append(out, key) + } + sort.Strings(out) + return out +} + +func payloadKeyToEnvKey(key string) string { + key = strings.TrimSpace(key) + if key == "" { + return "" + } + return strings.ToUpper(key) +} diff --git a/server/lib/forkidentity/paths.go b/server/lib/forkidentity/paths.go new file mode 100644 index 00000000..29a9a52b --- /dev/null +++ b/server/lib/forkidentity/paths.go @@ -0,0 +1,15 @@ +package forkidentity + +const ( + WaitEnv = "KERNEL_FORK_IDENTITY_WAIT" + + DefaultReadyFile = "/run/kernel/fork-identity-ready" + DefaultPayloadFile = "/run/kernel/fork-identity.json" + DefaultAppliedFile = "/run/kernel/fork-identity-applied" +) + +var ( + ReadyFile = DefaultReadyFile + PayloadFile = DefaultPayloadFile + AppliedFile = DefaultAppliedFile +) diff --git a/server/lib/forkidentity/payload.go b/server/lib/forkidentity/payload.go new file mode 100644 index 00000000..29716bd7 --- /dev/null +++ b/server/lib/forkidentity/payload.go @@ -0,0 +1,96 @@ +package forkidentity + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "time" +) + +const MaxPayloadBytes = 64 * 1024 + +type Payload map[string]string + +func (p Payload) Get(key string) string { + return strings.TrimSpace(p[key]) +} + +func (p Payload) InstanceName() string { + return p.Get("instance_name") +} + +func (p Payload) Validate() error { + if p.InstanceName() == "" { + return fmt.Errorf("instance_name is required") + } + if ExtensionAPIURL(p) == "" { + return fmt.Errorf("one of session_intel_url, kernel_metro_api_base_url, or metro_api_url is required") + } + return nil +} + +func ReadPayload() (Payload, error) { + data, err := os.ReadFile(PayloadFile) + if err != nil { + return nil, err + } + var payload Payload + if err := json.Unmarshal(data, &payload); err != nil { + return nil, fmt.Errorf("decode payload: %w", err) + } + return payload, payload.Validate() +} + +func WritePayload(payload Payload) error { + if err := payload.Validate(); err != nil { + return err + } + data, err := json.Marshal(payload) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(PayloadFile), 0o755); err != nil { + return err + } + tmp := PayloadFile + ".tmp" + if err := os.WriteFile(tmp, append(data, '\n'), 0o600); err != nil { + return err + } + return os.Rename(tmp, PayloadFile) +} + +func WriteAppliedMarker(instanceName string) error { + if err := os.MkdirAll(filepath.Dir(AppliedFile), 0o755); err != nil { + return err + } + return os.WriteFile(AppliedFile, []byte(strings.TrimSpace(instanceName)+"\n"), 0o644) +} + +func AppliedMarkerMatches(instanceName string) (bool, error) { + data, err := os.ReadFile(AppliedFile) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + return strings.TrimSpace(string(data)) == strings.TrimSpace(instanceName), nil +} + +func WaitAppliedMarker(instanceName string, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + matches, err := AppliedMarkerMatches(instanceName) + if err != nil { + return err + } + if matches { + return nil + } + runtime.Gosched() + } + return fmt.Errorf("timed out waiting for applied marker") +}