Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ var (
ErrKeysNotForEncrypt = errors.New("envector: keys opened without KeyPartEnc cannot encrypt")
ErrKeysNotForDecrypt = errors.New("envector: keys opened without KeyPartSec cannot decrypt")
ErrKeysNotForRegister = errors.New("envector: keys opened without KeyPartEval have no eval key for register/activate")
ErrAlreadyExists = errors.New("envector: request_id already exists (operation already committed)")
ErrRequestIDTooLong = errors.New("envector: RequestID exceeds MaxRequestIDLength")
)
25 changes: 25 additions & 0 deletions header.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package envector

import (
"fmt"
"strings"

es2pb "github.com/CryptoLabInc/envector-go-sdk/internal/transport/pb/es2"
)

// Marks a server side idempotency rejection
const alreadyExistsPrefix = "ALREADY_EXISTS:"

// Shared by gRPC in SDK which returns ResponseHeader
func checkHeader(rpc string, h *es2pb.ResponseHeader) error {
if h == nil {
return nil
}
if rc := h.GetReturnCode(); rc != es2pb.ReturnCode_Success {
if strings.HasPrefix(h.GetErrorMessage(), alreadyExistsPrefix) {
return fmt.Errorf("envector: %s: %w: %s", rpc, ErrAlreadyExists, h.GetErrorMessage())
}
return fmt.Errorf("envector: %s: server returned %s: %s", rpc, rc, h.GetErrorMessage())
}
return nil
}
9 changes: 0 additions & 9 deletions index.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,4 @@ func (i *Index) Drop(ctx context.Context) error {
return checkHeader("delete_index", resp.GetHeader())
}

func checkHeader(rpc string, h *es2pb.ResponseHeader) error {
if h == nil {
return nil
}
if rc := h.GetReturnCode(); rc != es2pb.ReturnCode_Success {
return fmt.Errorf("envector: %s: server returned %s: %s", rpc, rc, h.GetErrorMessage())
}
return nil
}

24 changes: 18 additions & 6 deletions insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ const insertChunkSize = 1 * 1024 * 1024
// per vector. Metadata is stored verbatim — the SDK never interprets it
// (callers commonly pass a JSON envelope; opaque blobs work equally well).
type InsertRequest struct {
Vectors [][]float32
Metadata []string
Vectors [][]float32
Metadata []string
RequestID string // server side idempotency key
}

// InsertResult reports the server-assigned item IDs in insertion order.
Expand All @@ -27,6 +28,9 @@ type InsertResult struct {
// streams the ciphertexts through BatchInsertData. Frames are split at
// ~1 MiB of payload. Returns ErrKeysRequired when the Index was opened
// without WithIndexKeys.
//
// Idempotency: every call carrires a RequestID and retrying with the same RequestID
// makes server reject with ErrAlreadyExists
func (i *Index) Insert(ctx context.Context, req InsertRequest) (*InsertResult, error) {
if i.client.conn == nil {
return nil, ErrClientClosed
Expand All @@ -37,6 +41,9 @@ func (i *Index) Insert(ctx context.Context, req InsertRequest) (*InsertResult, e
if len(req.Vectors) == 0 {
return &InsertResult{}, nil
}
if len(req.RequestID) > MaxRequestIDLength {
return nil, ErrRequestIDTooLong
}
if d := i.keys.Dim(); d > 0 {
for j, v := range req.Vectors {
if len(v) != d {
Expand All @@ -45,6 +52,11 @@ func (i *Index) Insert(ctx context.Context, req InsertRequest) (*InsertResult, e
}
}

requestID := req.RequestID
if requestID == "" {
requestID = newRequestID()
}

ciphers, innerCounts, err := i.keys.Encrypt(req.Vectors)
if err != nil {
return nil, fmt.Errorf("envector: batch_insert_data encrypt: %w", err)
Expand All @@ -61,7 +73,7 @@ func (i *Index) Insert(ctx context.Context, req InsertRequest) (*InsertResult, e
for idx, blob := range ciphers {
count := innerCounts[idx]
if cur > 0 && cur+len(blob) > insertChunkSize {
if err := sendInsertFrame(stream, i.name, packed); err != nil {
if err := sendInsertFrame(stream, i.name, requestID, packed); err != nil {
return nil, err
}
packed = packed[:0]
Expand Down Expand Up @@ -93,7 +105,7 @@ func (i *Index) Insert(ctx context.Context, req InsertRequest) (*InsertResult, e
cur += len(blob)
}
if len(packed) > 0 {
if err := sendInsertFrame(stream, i.name, packed); err != nil {
if err := sendInsertFrame(stream, i.name, requestID, packed); err != nil {
return nil, err
}
}
Expand All @@ -108,9 +120,9 @@ func (i *Index) Insert(ctx context.Context, req InsertRequest) (*InsertResult, e
return &InsertResult{ItemIDs: resp.GetItemIds()}, nil
}

func sendInsertFrame(stream es2epb.ES2EService_BatchInsertDataClient, indexName string, packed []*es2pb.PackedVectors) error {
func sendInsertFrame(stream es2epb.ES2EService_BatchInsertDataClient, indexName, requestID string, packed []*es2pb.PackedVectors) error {
msg := &es2epb.BatchInsertDataRequest{
Header: &es2pb.RequestHeader{Type: es2pb.MessageType_BatchInsertData},
Header: &es2pb.RequestHeader{Type: es2pb.MessageType_BatchInsertData, Id: requestID},
IndexName: indexName,
PackedVectors: packed,
}
Expand Down
121 changes: 121 additions & 0 deletions insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"reflect"
"strings"
"testing"
)

Expand Down Expand Up @@ -93,6 +94,126 @@ func TestIndex_Insert_StreamsPackedVectorsAndPassesMetadata(t *testing.T) {
}
}

func TestIndex_Insert_PopulatesRequestHeaderID(t *testing.T) {
c, fake := newFakeClient(t)
fake.indexList = []string{"demo"}
fake.itemIDs = []int64{1}
keys := openTestKeys(t)

idx, _ := c.Index(context.Background(),
WithIndexName("demo"),
WithIndexKeys(keys),
)
if _, err := idx.Insert(context.Background(), InsertRequest{Vectors: [][]float32{make([]float32, 128)}}); err != nil {
t.Fatalf("Insert: %v", err)
}

if len(fake.batchInsertHeaders) == 0 {
t.Fatal("expected at least one frame with a populated header")
}
for i, h := range fake.batchInsertHeaders {
if h.GetId() == "" {
t.Errorf("frame %d: RequestHeader.Id is empty — server-side dedup will not engage", i)
}
if len(h.GetId()) > MaxRequestIDLength {
t.Errorf("frame %d: RequestHeader.Id %q exceeds MaxRequestIDLength=%d", i, h.GetId(), MaxRequestIDLength)
}
}

first := fake.batchInsertHeaders[0].GetId()
for i, h := range fake.batchInsertHeaders[1:] {
if h.GetId() != first {
t.Errorf("frame %d Id=%q, want %q (all frames must share the same request id)", i+1, h.GetId(), first)
}
}
}

func TestIndex_Insert_PropagatesCallerSuppliedRequestID(t *testing.T) {
c, fake := newFakeClient(t)
fake.indexList = []string{"demo"}
fake.itemIDs = []int64{1}
keys := openTestKeys(t)

idx, _ := c.Index(context.Background(),
WithIndexName("demo"),
WithIndexKeys(keys),
)
const supplied = "caller-supplied-id-0808"
_, err := idx.Insert(context.Background(), InsertRequest{
Vectors: [][]float32{make([]float32, 128)},
RequestID: supplied,
})
if err != nil {
t.Fatalf("Insert: %v", err)
}

if len(fake.batchInsertHeaders) == 0 || fake.batchInsertHeaders[0].GetId() != supplied {
t.Errorf("RequestHeader.Id = %q, want %q", fake.batchInsertHeaders[0].GetId(), supplied)
}
}

func TestIndex_Insert_RejectsOverlongRequestID(t *testing.T) {
c, fake := newFakeClient(t)
fake.indexList = []string{"demo"}
keys := openTestKeys(t)

idx, _ := c.Index(context.Background(),
WithIndexName("demo"),
WithIndexKeys(keys),
)

_, err := idx.Insert(context.Background(), InsertRequest{
Vectors: [][]float32{make([]float32, 128)},
RequestID: strings.Repeat("x", MaxRequestIDLength+1),
})
if !errors.Is(err, ErrRequestIDTooLong) {
t.Errorf("got %v, want ErrRequestIDTooLong", err)
}
if len(fake.batchInsertPackets) != 0 {
t.Error("no RPC should have been sent for an invalid RequestID")
}
}

func TestIndex_Insert_AlreadyExistsMapsToTypedError(t *testing.T) {
c, fake := newFakeClient(t)
fake.indexList = []string{"demo"}
fake.batchInsertRespErr = "ALREADY_EXISTS: request_id already exists"
keys := openTestKeys(t)

idx, _ := c.Index(context.Background(),
WithIndexName("demo"),
WithIndexKeys(keys),
)

_, err := idx.Insert(context.Background(), InsertRequest{
Vectors: [][]float32{make([]float32, 128)},
RequestID: "retry-id-001",
})
if !errors.Is(err, ErrAlreadyExists) {
t.Fatalf("got %v, want ErrAlreadyExists", err)
}
}

func TestIndex_Insert_GenericFailDoesNotMatchAlreadyExists(t *testing.T) {
c, fake := newFakeClient(t)
fake.indexList = []string{"demo"}
fake.batchInsertRespErr = "internal error: disk full"
keys := openTestKeys(t)

idx, _ := c.Index(context.Background(),
WithIndexName("demo"),
WithIndexKeys(keys),
)

_, err := idx.Insert(context.Background(), InsertRequest{Vectors: [][]float32{make([]float32, 128)}})
if err == nil {
t.Fatal("expected error")
}
if errors.Is(err, ErrAlreadyExists) {
t.Errorf("non-ALREADY_EXISTS failure should not match ErrAlreadyExists: %v", err)
}
}

func TestIndex_Insert_ChunksAboveThreshold(t *testing.T) {
c, fake := newFakeClient(t)
fake.indexList = []string{"demo"}
Expand Down
17 changes: 17 additions & 0 deletions requestid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package envector

import (
"crypto/rand"
"encoding/hex"
)

const MaxRequestIDLength = 30 // mirrors envector-msa
const requestIDByteLen = 14 // 28-char hex string

func newRequestID() string {
var b [requestIDByteLen]byte
if _, err := rand.Read(b[:]); err != nil {
return ""
}
return hex.EncodeToString(b[:])
}
16 changes: 16 additions & 0 deletions testutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ type fakeES2E struct {
createIndexInfo *es2pb.IndexInfo
deleteIndexCalls []string
batchInsertPackets [][]*es2pb.PackedVectors
batchInsertHeaders []*es2pb.RequestHeader
batchInsertIndex string
batchInsertRespErr string // make error response only if not empty
getMetadataReq *es2epb.GetMetadataRequest
innerProductReq *es2epb.InnerProductRequest

Expand Down Expand Up @@ -181,8 +183,22 @@ func (f *fakeES2E) BatchInsertData(stream grpc.ClientStreamingServer[es2epb.Batc
f.mu.Lock()
f.batchInsertIndex = msg.GetIndexName()
f.batchInsertPackets = append(f.batchInsertPackets, msg.GetPackedVectors())
f.batchInsertHeaders = append(f.batchInsertHeaders, msg.GetHeader())
f.mu.Unlock()
}

f.mu.Lock()
respErr := f.batchInsertRespErr
f.mu.Unlock()
if respErr != "" {
return stream.SendAndClose(&es2epb.BatchInsertDataResponse{
Header: &es2pb.ResponseHeader{
ReturnCode: es2pb.ReturnCode_Fail,
ErrorMessage: respErr,
},
})
}

return stream.SendAndClose(&es2epb.BatchInsertDataResponse{Header: f.header(), ItemIds: append([]int64{}, f.itemIDs...)})
}

Expand Down
Loading