Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 114 additions & 19 deletions sqlite-vec.c
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,7 @@ static void vec_slice(sqlite3_context *context, int argc,
i8 *out = sqlite3_malloc(outSize);
if (!out) {
sqlite3_result_error_nomem(context);
return;
goto done;
}
memset(out, 0, outSize);
for (size_t i = 0; i < n; i++) {
Expand All @@ -1672,7 +1672,7 @@ static void vec_slice(sqlite3_context *context, int argc,
u8 *out = sqlite3_malloc(outSize);
if (!out) {
sqlite3_result_error_nomem(context);
return;
goto done;
}
memset(out, 0, outSize);
for (size_t i = 0; i < n / CHAR_BIT; i++) {
Expand Down Expand Up @@ -2535,6 +2535,7 @@ static int vec_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
int rc = vector_from_value(argv[0], &pCur->vector, &pCur->dimensions,
&pCur->vector_type, &pCur->cleanup, &pzErrMsg);
if (rc != SQLITE_OK) {
sqlite3_free(pzErrMsg);
return SQLITE_ERROR;
}
pCur->iRowid = 0;
Expand Down Expand Up @@ -3616,6 +3617,24 @@ void vec0_free(vec0_vtab *p) {
sqlite3_free(p->vector_columns[i].name);
p->vector_columns[i].name = NULL;
}

for (int i = 0; i < p->numPartitionColumns; i++) {
sqlite3_free(p->paritition_columns[i].name);
p->paritition_columns[i].name = NULL;
}

for (int i = 0; i < p->numAuxiliaryColumns; i++) {
sqlite3_free(p->auxiliary_columns[i].name);
p->auxiliary_columns[i].name = NULL;
}

for (int i = 0; i < p->numMetadataColumns; i++) {
sqlite3_free(p->shadowMetadataChunksNames[i]);
p->shadowMetadataChunksNames[i] = NULL;

sqlite3_free(p->metadata_columns[i].name);
p->metadata_columns[i].name = NULL;
}
}

int vec0_num_defined_user_columns(vec0_vtab *p) {
Expand Down Expand Up @@ -5143,6 +5162,7 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
goto error;
}
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL);
sqlite3_free(zSql);
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
sqlite3_finalize(stmt);
*pzErr = sqlite3_mprintf(
Expand All @@ -5160,6 +5180,7 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,

error:
vec0_free(pNew);
sqlite3_free(pNew);
return SQLITE_ERROR;
}

Expand Down Expand Up @@ -5991,6 +6012,7 @@ int vec0_metadata_filter_text(vec0_vtab * p, sqlite3_value * value, const void *
rc = sqlite3_blob_read(rowidsBlob, rowids, sqlite3_blob_bytes(rowidsBlob), 0);
if(rc != SQLITE_OK) {
sqlite3_blob_close(rowidsBlob);
sqlite3_free(rowids);
return rc;
}
sqlite3_blob_close(rowidsBlob);
Expand Down Expand Up @@ -6979,12 +7001,14 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
i64 v = sqlite3_value_int64(entry);
rc = array_append(&item.array, &v);
if (rc != SQLITE_OK) {
array_cleanup(&item.array);
goto cleanup;
}
}

if (rc != SQLITE_DONE) {
vtab_set_error(&p->base, "Error fetching next value in `x in (...)` integer expression");
array_cleanup(&item.array);
goto cleanup;
}

Expand All @@ -7004,17 +7028,33 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
entry.zString = sqlite3_mprintf("%.*s", n, s);
if(!entry.zString) {
rc = SQLITE_NOMEM;
// Clean up already-added text entries
for(size_t j = 0; j < item.array.length; j++) {
sqlite3_free(((struct Vec0MetadataInTextEntry*)item.array.z)[j].zString);
}
array_cleanup(&item.array);
goto cleanup;
}
entry.n = n;
rc = array_append(&item.array, &entry);
if (rc != SQLITE_OK) {
sqlite3_free(entry.zString);
// Clean up already-added text entries
for(size_t j = 0; j < item.array.length; j++) {
sqlite3_free(((struct Vec0MetadataInTextEntry*)item.array.z)[j].zString);
}
array_cleanup(&item.array);
goto cleanup;
}
}

if (rc != SQLITE_DONE) {
vtab_set_error(&p->base, "Error fetching next value in `x in (...)` text expression");
// Clean up text entries
for(size_t j = 0; j < item.array.length; j++) {
sqlite3_free(((struct Vec0MetadataInTextEntry*)item.array.z)[j].zString);
}
array_cleanup(&item.array);
goto cleanup;
}

Expand All @@ -7028,6 +7068,13 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,

rc = array_append(aMetadataIn, &item);
if(rc != SQLITE_OK) {
// Clean up item.array since it wasn't added to aMetadataIn
if(p->metadata_columns[item.metadata_idx].kind == VEC0_METADATA_COLUMN_KIND_TEXT) {
for(size_t j = 0; j < item.array.length; j++) {
sqlite3_free(((struct Vec0MetadataInTextEntry*)item.array.z)[j].zString);
}
}
array_cleanup(&item.array);
goto cleanup;
}
}
Expand Down Expand Up @@ -7082,6 +7129,12 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,

sqlite3_free(aMetadataIn);

// Free knn_data if not assigned to cursor (error case)
if (rc != SQLITE_OK && knn_data) {
vec0_query_knn_data_clear(knn_data);
sqlite3_free(knn_data);
}

return rc;
}

Expand Down Expand Up @@ -8055,6 +8108,7 @@ int vec0_write_metadata_value(vec0_vtab *p, int metadata_column_idx, i64 rowid,
}
sqlite3_stmt * stmt;
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
sqlite3_free((void *)zSql);
if(rc != SQLITE_OK) {
goto done;
}
Expand All @@ -8076,6 +8130,7 @@ int vec0_write_metadata_value(vec0_vtab *p, int metadata_column_idx, i64 rowid,
}
sqlite3_stmt * stmt;
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
sqlite3_free((void *)zSql);
if(rc != SQLITE_OK) {
goto done;
}
Expand Down Expand Up @@ -8274,6 +8329,7 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
goto cleanup;
}
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
sqlite3_free(zSql);
if(rc != SQLITE_OK) {
goto cleanup;
}
Expand Down Expand Up @@ -8516,12 +8572,14 @@ int vec0Update_Delete_ClearMetadata(vec0_vtab *p, int metadata_idx, i64 rowid, i
}
sqlite3_stmt * stmt;
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
sqlite3_free((void *)zSql);
if(rc != SQLITE_OK) {
goto done;
}
sqlite3_bind_int64(stmt, 1, rowid);
rc = sqlite3_step(stmt);
if(rc != SQLITE_DONE) {
sqlite3_finalize(stmt);
rc = SQLITE_ERROR;
goto done;
}
Expand Down Expand Up @@ -8608,6 +8666,7 @@ int vec0Update_UpdateAuxColumn(vec0_vtab *p, int auxiliary_column_idx, sqlite3_v
return SQLITE_NOMEM;
}
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
sqlite3_free((void *)zSql);
if(rc != SQLITE_OK) {
return rc;
}
Expand Down Expand Up @@ -9394,52 +9453,67 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor,
if (idxNum == VEC_SBE__QUERYPLAN_KNN) {
assert(argc == 2);
pCur->query_plan = VEC_SBE__QUERYPLAN_KNN;
struct sbe_query_knn_data *knn_data;
struct sbe_query_knn_data *knn_data = NULL;
void *queryVector = NULL;
vector_cleanup cleanup = vector_cleanup_noop;
i32 *topk_rowids = NULL;
f32 *distances = NULL;
u8 *candidates = NULL;
u8 *taken = NULL;
int rc = SQLITE_OK;

knn_data = sqlite3_malloc(sizeof(*knn_data));
if (!knn_data) {
return SQLITE_NOMEM;
rc = SQLITE_NOMEM;
goto knn_cleanup;
}
memset(knn_data, 0, sizeof(*knn_data));

void *queryVector;
size_t dimensions;
enum VectorElementType elementType;
vector_cleanup cleanup;
char *err;
int rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType,
rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType,
&cleanup, &err);
if (rc != SQLITE_OK) {
return SQLITE_ERROR;
sqlite3_free(err);
rc = SQLITE_ERROR;
goto knn_cleanup;
}
if (elementType != p->blob->element_type) {
return SQLITE_ERROR;
rc = SQLITE_ERROR;
goto knn_cleanup;
}
if (dimensions != p->blob->dimensions) {
return SQLITE_ERROR;
rc = SQLITE_ERROR;
goto knn_cleanup;
}

i64 k = min(sqlite3_value_int64(argv[1]), (i64)p->blob->nvectors);
if (k < 0) {
// HANDLE https://github.com/asg017/sqlite-vec/issues/55
return SQLITE_ERROR;
rc = SQLITE_ERROR;
goto knn_cleanup;
}
if (k == 0) {
knn_data->k = 0;
pCur->knn_data = knn_data;
cleanup(queryVector);
return SQLITE_OK;
}

size_t bsize = (p->blob->nvectors + 7) & ~7;

i32 *topk_rowids = sqlite3_malloc(k * sizeof(i32));
topk_rowids = sqlite3_malloc(k * sizeof(i32));
if (!topk_rowids) {
// HANDLE https://github.com/asg017/sqlite-vec/issues/55
return SQLITE_ERROR;
rc = SQLITE_ERROR;
goto knn_cleanup;
}
f32 *distances = sqlite3_malloc(bsize * sizeof(f32));
distances = sqlite3_malloc(bsize * sizeof(f32));
if (!distances) {
// HANDLE https://github.com/asg017/sqlite-vec/issues/55
return SQLITE_ERROR;
rc = SQLITE_ERROR;
goto knn_cleanup;
}

for (size_t i = 0; i < p->blob->nvectors; i++) {
Expand All @@ -9448,11 +9522,17 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor,
distances[i] =
distance_l2_sqr_float(v, (float *)queryVector, &p->blob->dimensions);
}
u8 *candidates = bitmap_new(bsize);
assert(candidates);
candidates = bitmap_new(bsize);
if (!candidates) {
rc = SQLITE_NOMEM;
goto knn_cleanup;
}

u8 *taken = bitmap_new(bsize);
assert(taken);
taken = bitmap_new(bsize);
if (!taken) {
rc = SQLITE_NOMEM;
goto knn_cleanup;
}

bitmap_fill(candidates, bsize);
for (size_t i = bsize; i >= p->blob->nvectors; i--) {
Expand All @@ -9466,6 +9546,21 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor,
knn_data->rowids = topk_rowids;

pCur->knn_data = knn_data;

// Cleanup temporary allocations (not owned by knn_data)
sqlite3_free(candidates);
sqlite3_free(taken);
cleanup(queryVector);
return SQLITE_OK;

knn_cleanup:
sqlite3_free(knn_data);
sqlite3_free(topk_rowids);
sqlite3_free(distances);
sqlite3_free(candidates);
sqlite3_free(taken);
cleanup(queryVector);
return rc;
} else {
pCur->query_plan = VEC_SBE__QUERYPLAN_FULLSCAN;
pCur->iRowid = 0;
Expand Down