From 9242452b0b60e94b97fbbe9d316567639a6ccadc Mon Sep 17 00:00:00 2001 From: Seth Hoenig Date: Fri, 22 May 2026 11:20:51 -0500 Subject: [PATCH] c: able to join contexts --- README.md | 111 +++++++++++++++++++++++++++++++++- context.go | 6 ++ join.go | 135 +++++++++++++++++++++++++++++++++++++++++ join_test.go | 168 +++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 418 insertions(+), 2 deletions(-) create mode 100644 join.go create mode 100644 join_test.go diff --git a/README.md b/README.md index ee35750..457f2b6 100644 --- a/README.md +++ b/README.md @@ -11,16 +11,123 @@ maintaining 100% compatibility. ### Requirements -The minimum Go version is `go1.23`. +The minimum Go version is `go1.26`. ### Install -The `forms` package can be added to a project with `go get`. +The `scope` package can be added to a project with `go get`. ```shell go get -u cattlecloud.net/go/scope@latest ``` +### Examples + +##### New + +```go +ctx := scope.New() +``` + +##### TTL + +```go +ctx, cancel := scope.TTL(5 * time.Second) +// ctx is canceled after 5 seconds +defer cancel() +``` + +##### Deadline + +```go +ctx, cancel := scope.Deadline(time.Now().Add(10 * time.Second)) +// ctx is canceled at the specified time +defer cancel() +``` + +##### Cancelable + +```go +ctx, cancel := scope.Cancelable() +// ctx can be canceled manually +defer cancel() +``` + +##### WithCancel + +```go +ctx, cancel := scope.WithCancel(parentCtx) +defer cancel() +``` + +##### WithTTL + +```go +ctx, cancel := scope.WithTTL(parentCtx, 3 * time.Second) +// parentCtx with a 3 second timeout +defer cancel() +``` + +##### WithValue + +```go +ctx := scope.WithValue(parentCtx, "userID", 123) +``` + +##### Value + +```go +userID := scope.Value[int](ctx, "userID") +``` + +##### Join + +```go +ctx1, cancel1 := scope.WithCancel(scope.New()) +ctx2, cancel2 := scope.TTL(5 * time.Second) + +joined, cancel := scope.Join(ctx1, ctx2) +// joined is canceled when ctx1 or ctx2 is canceled +defer cancel() +defer cancel1() +defer cancel2() +``` + +###### Deadline + +```go +ctx1, _ := scope.Deadline(time.Now().Add(10 * time.Second)) +ctx2, _ := scope.Deadline(time.Now().Add(20 * time.Second)) + +joined, _ := scope.Join(ctx1, ctx2) +deadline, ok := joined.Deadline() // deadline is 10 seconds, ok is true +``` + +###### Done + +```go +joined, cancel := scope.Join(ctx1, ctx2) +<-joined.Done() // blocks until either ctx1 or ctx2 is done +``` + +###### Err + +```go +joined, cancel := scope.Join(ctx1, ctx2) +<-joined.Done() +err := joined.Err() // returns the error from the first canceled context +``` + +###### Value + +```go +ctx1 := scope.WithValue(scope.New(), "key", "value1") +ctx2 := scope.WithValue(scope.New(), "key", "value2") + +joined, _ := scope.Join(ctx1, ctx2) +val := joined.Value("key") // returns value1 (ctx1's value is checked first) +``` + ### License The `cattlecloud.net/go/scope` module is open source under the [BSD](LICENSE) license. diff --git a/context.go b/context.go index fb63dea..7104097 100644 --- a/context.go +++ b/context.go @@ -27,6 +27,12 @@ func TTL(duration time.Duration) (C, Cancel) { return context.WithTimeout(New(), duration) } +// Deadline will create a fresh context not part of any preceding chain of +// values, and will expire at the given expiration time. +func Deadline(expiration time.Time) (C, Cancel) { + return context.WithDeadline(New(), expiration) +} + // Cancelable will create a fresh context not part of any preceding chain of // values, and includes a Cancel function. func Cancelable() (C, Cancel) { diff --git a/join.go b/join.go new file mode 100644 index 0000000..08ea079 --- /dev/null +++ b/join.go @@ -0,0 +1,135 @@ +package scope + +import ( + "context" + "sync" + "time" +) + +// join implements a context that is canceled when either of two +// contexts is canceled. It is a variation on the original implementation +// with race condition bug fixes. +// +// https://github.com/LK4D4/joincontext/blob/master/context.go +type join struct { + once sync.Once + a C + b C + done chan struct{} + + lock *sync.Mutex + err error +} + +// Join combines two contexts into a single context that is canceled when +// either input context is canceled. The returned context's Deadline, +// Value, and Err methods delegate to the earliest of the two input +// contexts. If either context is already done, the returned context is +// immediately done with that context's error. +func Join(a, b C) (C, Cancel) { + j := &join{ + a: a, + b: b, + done: make(chan struct{}), + lock: new(sync.Mutex), + } + + // check if either context is already done before spawning the goroutine + select { + case <-a.Done(): + j.lock.Lock() + j.err = a.Err() + j.lock.Unlock() + + close(j.done) + return j, func() {} + + case <-b.Done(): + j.lock.Lock() + j.err = b.Err() + j.lock.Unlock() + + close(j.done) + return j, func() {} + + default: + } + + go j.run() + return j, j.cancel +} + +// Deadline returns the earliest deadline from either context. +// If neither context has a deadline, ok is false. +func (j *join) Deadline() (deadline time.Time, ok bool) { + a, aok := j.a.Deadline() + if !aok { + return j.b.Deadline() + } + + b, bok := j.b.Deadline() + if !bok { + return a, true + } + + if b.Before(a) { + return b, true + } + + return a, true +} + +// Done returns a channel that is closed when either context is done. +func (j *join) Done() <-chan struct{} { + return j.done +} + +// Err returns the error from whichever context was canceled first, +// or ErrCanceled if Cancel was called on the joined context. +func (j *join) Err() error { + j.lock.Lock() + defer j.lock.Unlock() + return j.err +} + +// Value returns the value associated with key in either context, +// prioritizing the first context's value if present. +func (j *join) Value(key any) any { + v := j.a.Value(key) + + if v == nil { + v = j.b.Value(key) + } + + return v +} + +func (j *join) run() { + select { + case <-j.a.Done(): + j.once.Do(func() { + j.lock.Lock() + j.err = j.a.Err() + j.lock.Unlock() + close(j.done) + }) + case <-j.b.Done(): + j.once.Do(func() { + j.lock.Lock() + j.err = j.b.Err() + j.lock.Unlock() + close(j.done) + }) + case <-j.done: + return + } +} + +func (j *join) cancel() { + j.once.Do(func() { + j.lock.Lock() + j.err = context.Canceled + j.lock.Unlock() + close(j.done) + }) +} diff --git a/join_test.go b/join_test.go new file mode 100644 index 0000000..0d7d543 --- /dev/null +++ b/join_test.go @@ -0,0 +1,168 @@ +// Copyright (c) CattleCloud LLC +// SPDX-License-Identifier: BSD-3-Clause + +package scope + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestJoin_CancelA(t *testing.T) { + t.Parallel() + + ctxA, cancelA := Cancelable() + ctxB := New() + + j, _ := Join(ctxA, ctxB) + cancelA() + + <-j.Done() + if !errors.Is(j.Err(), context.Canceled) { + t.Errorf("expected context.Canceled, got %v", j.Err()) + } +} + +func TestJoin_CancelB(t *testing.T) { + t.Parallel() + + ctxA := New() + ctxB, cancelB := Cancelable() + + j, _ := Join(ctxA, ctxB) + cancelB() + + <-j.Done() + if !errors.Is(j.Err(), context.Canceled) { + t.Errorf("expected context.Canceled, got %v", j.Err()) + } +} + +func TestJoin_CancelSelf(t *testing.T) { + t.Parallel() + + ctxA := New() + ctxB := New() + + j, cancel := Join(ctxA, ctxB) + cancel() + + <-j.Done() + if !errors.Is(j.Err(), context.Canceled) { + t.Errorf("expected context.Canceled, got %v", j.Err()) + } +} + +func TestJoin_AlreadyDone(t *testing.T) { + t.Parallel() + + t.Run("a already done", func(t *testing.T) { + ctxA, cancelA := Cancelable() + cancelA() + ctxB := New() + + j1, _ := Join(ctxA, ctxB) + select { + case <-j1.Done(): + default: + t.Fatal("expected joined context to be done immediately when A is already done") + } + if !errors.Is(j1.Err(), context.Canceled) { + t.Errorf("expected context.Canceled, got %v", j1.Err()) + } + }) + + t.Run("b already done", func(t *testing.T) { + ctxC := New() + ctxD, cancelD := Cancelable() + cancelD() + + j2, _ := Join(ctxC, ctxD) + select { + case <-j2.Done(): + default: + t.Fatal("expected joined context to be done immediately when B is already done") + } + if !errors.Is(j2.Err(), context.Canceled) { + t.Errorf("expected context.Canceled, got %v", j2.Err()) + } + }) +} + +func TestJoin_Value(t *testing.T) { + t.Parallel() + + type key string + + ctxA := WithValue(New(), key("k1"), "v1") + ctxB := WithValue(New(), key("k2"), "v2") + ctxBConflicting := WithValue(ctxB, key("k1"), "v1-b") + + j1, cancel1 := Join(ctxA, ctxB) + defer cancel1() + + if v := j1.Value(key("k1")); v != "v1" { + t.Errorf("expected v1, got %v", v) + } + if v := j1.Value(key("k2")); v != "v2" { + t.Errorf("expected v2, got %v", v) + } + if v := j1.Value(key("missing")); v != nil { + t.Errorf("expected nil, got %v", v) + } + + j2, cancel2 := Join(ctxA, ctxBConflicting) + defer cancel2() + + // A must take precedence over B + if v := j2.Value(key("k1")); v != "v1" { + t.Errorf("expected v1 (from ctxA), got %v", v) + } +} + +func TestJoin_Deadline(t *testing.T) { + t.Parallel() + + now := time.Now() + + ctxNoDeadline := New() + + ctxEarly, cancelEarly := Deadline(now.Add(1 * time.Hour)) + defer cancelEarly() + + ctxLate, cancelLate := Deadline(now.Add(2 * time.Hour)) + defer cancelLate() + + tests := []struct { + name string + a context.Context + b context.Context + wantDeadline time.Time + wantOk bool + }{ + {"no deadlines", ctxNoDeadline, ctxNoDeadline, time.Time{}, false}, + {"only a has deadline", ctxEarly, ctxNoDeadline, now.Add(1 * time.Hour), true}, + {"only b has deadline", ctxNoDeadline, ctxEarly, now.Add(1 * time.Hour), true}, + {"both have deadlines, a is earlier", ctxEarly, ctxLate, now.Add(1 * time.Hour), true}, + {"both have deadlines, b is earlier", ctxLate, ctxEarly, now.Add(1 * time.Hour), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + j, cancel := Join(tt.a, tt.b) + defer cancel() + + gotDeadline, gotOk := j.Deadline() + + if gotOk != tt.wantOk { + t.Errorf("Deadline() ok = %v, want %v", gotOk, tt.wantOk) + } + + if gotOk && !gotDeadline.Equal(tt.wantDeadline) { + t.Errorf("Deadline() = %v, want %v", gotDeadline, tt.wantDeadline) + } + }) + } +}