diff --git a/errors.go b/errors.go index 392decd..da8c7d1 100644 --- a/errors.go +++ b/errors.go @@ -11,4 +11,6 @@ var ( ErrKeysNotForEncrypt = errors.New("envector: keys opened without KeyPartEnc cannot encrypt") ErrKeysNotForDecrypt = errors.New("envector: keys opened without KeyPartSec cannot decrypt") ErrKeysNotForRegister = errors.New("envector: keys opened without KeyPartEval have no eval key for register/activate") + ErrAlreadyExists = errors.New("envector: request_id already exists (operation already committed)") + ErrRequestIDTooLong = errors.New("envector: RequestID exceeds MaxRequestIDLength") ) diff --git a/header.go b/header.go new file mode 100644 index 0000000..7a5761b --- /dev/null +++ b/header.go @@ -0,0 +1,25 @@ +package envector + +import ( + "fmt" + "strings" + + es2pb "github.com/CryptoLabInc/envector-go-sdk/internal/transport/pb/es2" +) + +// Marks a server side idempotency rejection +const alreadyExistsPrefix = "ALREADY_EXISTS:" + +// Shared by gRPC in SDK which returns ResponseHeader +func checkHeader(rpc string, h *es2pb.ResponseHeader) error { + if h == nil { + return nil + } + if rc := h.GetReturnCode(); rc != es2pb.ReturnCode_Success { + if strings.HasPrefix(h.GetErrorMessage(), alreadyExistsPrefix) { + return fmt.Errorf("envector: %s: %w: %s", rpc, ErrAlreadyExists, h.GetErrorMessage()) + } + return fmt.Errorf("envector: %s: server returned %s: %s", rpc, rc, h.GetErrorMessage()) + } + return nil +} diff --git a/index.go b/index.go index 9f14191..9b5a73b 100644 --- a/index.go +++ b/index.go @@ -122,13 +122,4 @@ func (i *Index) Drop(ctx context.Context) error { return checkHeader("delete_index", resp.GetHeader()) } -func checkHeader(rpc string, h *es2pb.ResponseHeader) error { - if h == nil { - return nil - } - if rc := h.GetReturnCode(); rc != es2pb.ReturnCode_Success { - return fmt.Errorf("envector: %s: server returned %s: %s", rpc, rc, h.GetErrorMessage()) - } - return nil -} diff --git a/insert.go b/insert.go index f4286ee..9a19158 100644 --- a/insert.go +++ b/insert.go @@ -14,8 +14,9 @@ const insertChunkSize = 1 * 1024 * 1024 // per vector. Metadata is stored verbatim — the SDK never interprets it // (callers commonly pass a JSON envelope; opaque blobs work equally well). type InsertRequest struct { - Vectors [][]float32 - Metadata []string + Vectors [][]float32 + Metadata []string + RequestID string // server side idempotency key } // InsertResult reports the server-assigned item IDs in insertion order. @@ -27,6 +28,9 @@ type InsertResult struct { // streams the ciphertexts through BatchInsertData. Frames are split at // ~1 MiB of payload. Returns ErrKeysRequired when the Index was opened // without WithIndexKeys. +// +// Idempotency: every call carrires a RequestID and retrying with the same RequestID +// makes server reject with ErrAlreadyExists func (i *Index) Insert(ctx context.Context, req InsertRequest) (*InsertResult, error) { if i.client.conn == nil { return nil, ErrClientClosed @@ -37,6 +41,9 @@ func (i *Index) Insert(ctx context.Context, req InsertRequest) (*InsertResult, e if len(req.Vectors) == 0 { return &InsertResult{}, nil } + if len(req.RequestID) > MaxRequestIDLength { + return nil, ErrRequestIDTooLong + } if d := i.keys.Dim(); d > 0 { for j, v := range req.Vectors { if len(v) != d { @@ -45,6 +52,11 @@ func (i *Index) Insert(ctx context.Context, req InsertRequest) (*InsertResult, e } } + requestID := req.RequestID + if requestID == "" { + requestID = newRequestID() + } + ciphers, innerCounts, err := i.keys.Encrypt(req.Vectors) if err != nil { return nil, fmt.Errorf("envector: batch_insert_data encrypt: %w", err) @@ -61,7 +73,7 @@ func (i *Index) Insert(ctx context.Context, req InsertRequest) (*InsertResult, e for idx, blob := range ciphers { count := innerCounts[idx] if cur > 0 && cur+len(blob) > insertChunkSize { - if err := sendInsertFrame(stream, i.name, packed); err != nil { + if err := sendInsertFrame(stream, i.name, requestID, packed); err != nil { return nil, err } packed = packed[:0] @@ -93,7 +105,7 @@ func (i *Index) Insert(ctx context.Context, req InsertRequest) (*InsertResult, e cur += len(blob) } if len(packed) > 0 { - if err := sendInsertFrame(stream, i.name, packed); err != nil { + if err := sendInsertFrame(stream, i.name, requestID, packed); err != nil { return nil, err } } @@ -108,9 +120,9 @@ func (i *Index) Insert(ctx context.Context, req InsertRequest) (*InsertResult, e return &InsertResult{ItemIDs: resp.GetItemIds()}, nil } -func sendInsertFrame(stream es2epb.ES2EService_BatchInsertDataClient, indexName string, packed []*es2pb.PackedVectors) error { +func sendInsertFrame(stream es2epb.ES2EService_BatchInsertDataClient, indexName, requestID string, packed []*es2pb.PackedVectors) error { msg := &es2epb.BatchInsertDataRequest{ - Header: &es2pb.RequestHeader{Type: es2pb.MessageType_BatchInsertData}, + Header: &es2pb.RequestHeader{Type: es2pb.MessageType_BatchInsertData, Id: requestID}, IndexName: indexName, PackedVectors: packed, } diff --git a/insert_test.go b/insert_test.go index a7eab85..9cf44a6 100644 --- a/insert_test.go +++ b/insert_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "reflect" + "strings" "testing" ) @@ -93,6 +94,126 @@ func TestIndex_Insert_StreamsPackedVectorsAndPassesMetadata(t *testing.T) { } } +func TestIndex_Insert_PopulatesRequestHeaderID(t *testing.T) { + c, fake := newFakeClient(t) + fake.indexList = []string{"demo"} + fake.itemIDs = []int64{1} + keys := openTestKeys(t) + + idx, _ := c.Index(context.Background(), + WithIndexName("demo"), + WithIndexKeys(keys), + ) + if _, err := idx.Insert(context.Background(), InsertRequest{Vectors: [][]float32{make([]float32, 128)}}); err != nil { + t.Fatalf("Insert: %v", err) + } + + if len(fake.batchInsertHeaders) == 0 { + t.Fatal("expected at least one frame with a populated header") + } + for i, h := range fake.batchInsertHeaders { + if h.GetId() == "" { + t.Errorf("frame %d: RequestHeader.Id is empty — server-side dedup will not engage", i) + } + if len(h.GetId()) > MaxRequestIDLength { + t.Errorf("frame %d: RequestHeader.Id %q exceeds MaxRequestIDLength=%d", i, h.GetId(), MaxRequestIDLength) + } + } + + first := fake.batchInsertHeaders[0].GetId() + for i, h := range fake.batchInsertHeaders[1:] { + if h.GetId() != first { + t.Errorf("frame %d Id=%q, want %q (all frames must share the same request id)", i+1, h.GetId(), first) + } + } +} + +func TestIndex_Insert_PropagatesCallerSuppliedRequestID(t *testing.T) { + c, fake := newFakeClient(t) + fake.indexList = []string{"demo"} + fake.itemIDs = []int64{1} + keys := openTestKeys(t) + + idx, _ := c.Index(context.Background(), + WithIndexName("demo"), + WithIndexKeys(keys), + ) + const supplied = "caller-supplied-id-0808" + _, err := idx.Insert(context.Background(), InsertRequest{ + Vectors: [][]float32{make([]float32, 128)}, + RequestID: supplied, + }) + if err != nil { + t.Fatalf("Insert: %v", err) + } + + if len(fake.batchInsertHeaders) == 0 || fake.batchInsertHeaders[0].GetId() != supplied { + t.Errorf("RequestHeader.Id = %q, want %q", fake.batchInsertHeaders[0].GetId(), supplied) + } +} + +func TestIndex_Insert_RejectsOverlongRequestID(t *testing.T) { + c, fake := newFakeClient(t) + fake.indexList = []string{"demo"} + keys := openTestKeys(t) + + idx, _ := c.Index(context.Background(), + WithIndexName("demo"), + WithIndexKeys(keys), + ) + + _, err := idx.Insert(context.Background(), InsertRequest{ + Vectors: [][]float32{make([]float32, 128)}, + RequestID: strings.Repeat("x", MaxRequestIDLength+1), + }) + if !errors.Is(err, ErrRequestIDTooLong) { + t.Errorf("got %v, want ErrRequestIDTooLong", err) + } + if len(fake.batchInsertPackets) != 0 { + t.Error("no RPC should have been sent for an invalid RequestID") + } +} + +func TestIndex_Insert_AlreadyExistsMapsToTypedError(t *testing.T) { + c, fake := newFakeClient(t) + fake.indexList = []string{"demo"} + fake.batchInsertRespErr = "ALREADY_EXISTS: request_id already exists" + keys := openTestKeys(t) + + idx, _ := c.Index(context.Background(), + WithIndexName("demo"), + WithIndexKeys(keys), + ) + + _, err := idx.Insert(context.Background(), InsertRequest{ + Vectors: [][]float32{make([]float32, 128)}, + RequestID: "retry-id-001", + }) + if !errors.Is(err, ErrAlreadyExists) { + t.Fatalf("got %v, want ErrAlreadyExists", err) + } +} + +func TestIndex_Insert_GenericFailDoesNotMatchAlreadyExists(t *testing.T) { + c, fake := newFakeClient(t) + fake.indexList = []string{"demo"} + fake.batchInsertRespErr = "internal error: disk full" + keys := openTestKeys(t) + + idx, _ := c.Index(context.Background(), + WithIndexName("demo"), + WithIndexKeys(keys), + ) + + _, err := idx.Insert(context.Background(), InsertRequest{Vectors: [][]float32{make([]float32, 128)}}) + if err == nil { + t.Fatal("expected error") + } + if errors.Is(err, ErrAlreadyExists) { + t.Errorf("non-ALREADY_EXISTS failure should not match ErrAlreadyExists: %v", err) + } +} + func TestIndex_Insert_ChunksAboveThreshold(t *testing.T) { c, fake := newFakeClient(t) fake.indexList = []string{"demo"} diff --git a/requestid.go b/requestid.go new file mode 100644 index 0000000..403dc4f --- /dev/null +++ b/requestid.go @@ -0,0 +1,17 @@ +package envector + +import ( + "crypto/rand" + "encoding/hex" +) + +const MaxRequestIDLength = 30 // mirrors envector-msa +const requestIDByteLen = 14 // 28-char hex string + +func newRequestID() string { + var b [requestIDByteLen]byte + if _, err := rand.Read(b[:]); err != nil { + return "" + } + return hex.EncodeToString(b[:]) +} diff --git a/testutil_test.go b/testutil_test.go index e6d476b..b6b1c81 100644 --- a/testutil_test.go +++ b/testutil_test.go @@ -32,7 +32,9 @@ type fakeES2E struct { createIndexInfo *es2pb.IndexInfo deleteIndexCalls []string batchInsertPackets [][]*es2pb.PackedVectors + batchInsertHeaders []*es2pb.RequestHeader batchInsertIndex string + batchInsertRespErr string // make error response only if not empty getMetadataReq *es2epb.GetMetadataRequest innerProductReq *es2epb.InnerProductRequest @@ -181,8 +183,22 @@ func (f *fakeES2E) BatchInsertData(stream grpc.ClientStreamingServer[es2epb.Batc f.mu.Lock() f.batchInsertIndex = msg.GetIndexName() f.batchInsertPackets = append(f.batchInsertPackets, msg.GetPackedVectors()) + f.batchInsertHeaders = append(f.batchInsertHeaders, msg.GetHeader()) f.mu.Unlock() } + + f.mu.Lock() + respErr := f.batchInsertRespErr + f.mu.Unlock() + if respErr != "" { + return stream.SendAndClose(&es2epb.BatchInsertDataResponse{ + Header: &es2pb.ResponseHeader{ + ReturnCode: es2pb.ReturnCode_Fail, + ErrorMessage: respErr, + }, + }) + } + return stream.SendAndClose(&es2epb.BatchInsertDataResponse{Header: f.header(), ItemIds: append([]int64{}, f.itemIDs...)}) }