diff --git a/go.mod b/go.mod index 37ff180..6ea1593 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module go.bug.st/f -go 1.22.3 +go 1.25 require github.com/stretchr/testify v1.9.0 diff --git a/iter.go b/iter.go new file mode 100644 index 0000000..a94fb4b --- /dev/null +++ b/iter.go @@ -0,0 +1,41 @@ +package f + +import "iter" + +// FilterIter takes an iterator and a matcher and returns only those elements that satisfy the matcher. +func FilterIter[T any](values iter.Seq[T], matcher Matcher[T]) iter.Seq[T] { + return func(yield func(x T) bool) { + for x := range values { + if matcher(x) { + if !yield(x) { + return + } + } + } + } +} + +// Map applies the Mapper function to each element of the iterator. +func MapIter[T, U any](values iter.Seq[T], mapper Mapper[T, U]) iter.Seq[U] { + return func(yield func(x U) bool) { + for x := range values { + if !yield(mapper(x)) { + return + } + } + } +} + +// Reducer is a function that reduces an iterator's elements to a single value. +func ReduceIter[T any](values iter.Seq[T], reducer Reducer[T], initialValue ...T) T { + var result T + if len(initialValue) > 1 { + panic("initialValue must be a single value") + } else if len(initialValue) == 1 { + result = initialValue[0] + } + for v := range values { + result = reducer(result, v) + } + return result +} diff --git a/iter_test.go b/iter_test.go new file mode 100644 index 0000000..1e52d5a --- /dev/null +++ b/iter_test.go @@ -0,0 +1,117 @@ +package f + +import ( + "slices" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFilterIter(t *testing.T) { + tests := []struct { + name string + input []int + matcher Matcher[int] + expected []int + }{ + { + name: "even numbers", + input: []int{1, 2, 3, 4, 5}, + matcher: func(x int) bool { return x%2 == 0 }, + expected: []int{2, 4}, + }, + { + name: "odd numbers", + input: []int{1, 2, 3, 4, 5}, + matcher: func(x int) bool { return x%2 == 1 }, + expected: []int{1, 3, 5}, + }, + { + name: "none", + input: []int{2, 4, 6}, + matcher: func(x int) bool { return x < 0 }, + expected: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + values := slices.Values(tt.input) + result := slices.Collect(FilterIter(values, tt.matcher)) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestMapIter(t *testing.T) { + tests := []struct { + name string + input []int + mapper Mapper[int, int] + expected []int + }{ + { + name: "double", + input: []int{1, 2, 3}, + mapper: func(x int) int { return x * 2 }, + expected: []int{2, 4, 6}, + }, + { + name: "negate", + input: []int{1, -2, 3}, + mapper: func(x int) int { return -x }, + expected: []int{-1, 2, -3}, + }, + { + name: "identity", + input: []int{1, 2, 3}, + mapper: func(x int) int { return x }, + expected: []int{1, 2, 3}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + values := slices.Values(tt.input) + result := slices.Collect(MapIter(values, tt.mapper)) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestReduceIter(t *testing.T) { + tests := []struct { + name string + input []int + reducer Reducer[int] + initialValue int + expected int + }{ + { + name: "sum", + input: []int{1, 2, 3, 4}, + reducer: func(a, b int) int { return a + b }, + initialValue: 0, + expected: 10, + }, + { + name: "product", + input: []int{1, 2, 3, 4}, + reducer: func(a, b int) int { return a * b }, + initialValue: 1, + expected: 24, + }, + { + name: "subtract", + input: []int{10, 2, 3}, + reducer: func(a, b int) int { return a - b }, + initialValue: 20, + expected: 5, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + values := slices.Values(tt.input) + result := ReduceIter(values, tt.reducer, tt.initialValue) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/slices.go b/slices.go index 392b785..af31158 100644 --- a/slices.go +++ b/slices.go @@ -9,6 +9,7 @@ package f import ( + "iter" "runtime" "sync" "sync/atomic" @@ -124,3 +125,16 @@ func Count[T any](in []T, matcher Matcher[T]) int { } return count } + +// RefIter takes a slice of type []T and returns an iterator that yields +// pointers to each element of the slice. +func RefIter[T any](slice []T) iter.Seq[*T] { + return func(yield func(*T) bool) { + for i := range slice { + if !yield(&slice[i]) { + return + } + } + } +} + diff --git a/slices_test.go b/slices_test.go index 5ea7a84..4678f12 100644 --- a/slices_test.go +++ b/slices_test.go @@ -9,9 +9,11 @@ package f_test import ( + "slices" "strings" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" f "go.bug.st/f" ) @@ -100,3 +102,35 @@ func TestCount(t *testing.T) { require.Equal(t, 0, f.Count(a, f.Equals("ddd"))) require.Equal(t, 3, f.Count(a, f.NotEquals("ddd"))) } + +func TestRefIter(t *testing.T) { + type foo struct { + value int + } + values := []foo{ + {value: 1}, + {value: 2}, + {value: 3}, + } + + t.Run("not working for range", func(t *testing.T) { + for _, v := range values { + v.value *= 10 + } + assert.Equal(t, []foo{{1}, {2}, {3}}, values) + }) + + t.Run("not working slices.Values", func(t *testing.T) { + for v := range slices.Values(values) { + v.value *= 10 + } + assert.Equal(t, []foo{{1}, {2}, {3}}, values) + }) + + t.Run("working RefIter", func(t *testing.T) { + for v := range f.RefIter(values) { + v.value *= 10 + } + assert.Equal(t, []foo{{10}, {20}, {30}}, values) + }) +}