diff --git a/Cargo.lock b/Cargo.lock index 11ae1e1..3c2803f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -13,6 +13,7 @@ dependencies = [ "ryx-query", "serde", "serde_json", + "smallvec", "sqlx", "thiserror", "tokio", @@ -1650,6 +1651,7 @@ dependencies = [ "once_cell", "serde", "serde_json", + "smallvec", "sqlx", "thiserror", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 1333102..b4b518b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,6 +62,7 @@ sqlx = { version = "0.8.6", features = [ # Full tokio runtime. "full" is fine for a library crate — callers can restrict # features if they need a lighter binary. tokio = { version = "1.40", features = ["full"] } +smallvec = "1.13" # ── Serialization ───────────────────────────────────────────────────────────── # serde + serde_json: used to pass structured data between Rust and Python diff --git a/ryx-query/Cargo.toml b/ryx-query/Cargo.toml index 146b139..a537cd1 100644 --- a/ryx-query/Cargo.toml +++ b/ryx-query/Cargo.toml @@ -11,6 +11,7 @@ serde_json = "1" thiserror = "2" once_cell = "1" tracing = "0.1" +smallvec = "1.13" [dev-dependencies] criterion = { version = "0.5", features = ["async_tokio"] } diff --git a/ryx-query/src/ast.rs b/ryx-query/src/ast.rs index 2701b6a..f93b38f 100644 --- a/ryx-query/src/ast.rs +++ b/ryx-query/src/ast.rs @@ -32,7 +32,7 @@ pub enum SqlValue { Text(String), /// Used by `__in` and `__range` lookups. The compiler expands it into /// multiple bind placeholders. - List(Vec), + List(smallvec::SmallVec<[Box; 4]>), } impl SqlValue { diff --git a/ryx-query/src/compiler/compiler.rs b/ryx-query/src/compiler/compiler.rs index 7ddab47..6987319 100644 --- a/ryx-query/src/compiler/compiler.rs +++ b/ryx-query/src/compiler/compiler.rs @@ -16,6 +16,7 @@ use crate::errors::{QueryError, QueryResult}; use crate::lookups::date_lookups as date; use crate::lookups::json_lookups as json; use crate::lookups::{self, LookupContext}; +use smallvec::SmallVec; pub use super::helpers::{apply_like_wrapping, qualified_col, split_qualified, KNOWN_TRANSFORMS}; @@ -24,12 +25,12 @@ use super::helpers; #[derive(Debug, Clone)] pub struct CompiledQuery { pub sql: String, - pub values: Vec, + pub values: SmallVec<[SqlValue; 8]>, pub db_alias: Option, } pub fn compile(node: &QueryNode) -> QueryResult { - let mut values: Vec = Vec::new(); + let mut values: SmallVec<[SqlValue; 8]> = SmallVec::new(); let sql = match &node.operation { QueryOperation::Select { columns } => { compile_select(node, columns.as_deref(), &mut values)? @@ -53,7 +54,7 @@ pub fn compile(node: &QueryNode) -> QueryResult { fn compile_select( node: &QueryNode, columns: Option<&[String]>, - values: &mut Vec, + values: &mut SmallVec<[SqlValue; 8]>, ) -> QueryResult { let base_cols = match columns { None => "*".to_string(), @@ -134,7 +135,7 @@ fn compile_select( Ok(sql) } -fn compile_aggregate(node: &QueryNode, values: &mut Vec) -> QueryResult { +fn compile_aggregate(node: &QueryNode, values: &mut SmallVec<[SqlValue; 8]>) -> QueryResult { if node.annotations.is_empty() { return Err(QueryError::Internal( "aggregate() called with no aggregate expressions".into(), @@ -158,7 +159,7 @@ fn compile_aggregate(node: &QueryNode, values: &mut Vec) -> QueryResul Ok(sql) } -fn compile_count(node: &QueryNode, values: &mut Vec) -> QueryResult { +fn compile_count(node: &QueryNode, values: &mut SmallVec<[SqlValue; 8]>) -> QueryResult { let mut sql = format!("SELECT COUNT(*) FROM {}", helpers::quote_col(&node.table)); if !node.joins.is_empty() { sql.push(' '); @@ -173,7 +174,7 @@ fn compile_count(node: &QueryNode, values: &mut Vec) -> QueryResult) -> QueryResult { +fn compile_delete(node: &QueryNode, values: &mut SmallVec<[SqlValue; 8]>) -> QueryResult { let mut sql = format!("DELETE FROM {}", helpers::quote_col(&node.table)); let where_sql = compile_where_combined(&node.filters, node.q_filter.as_ref(), values, node.backend)?; @@ -187,7 +188,7 @@ fn compile_delete(node: &QueryNode, values: &mut Vec) -> QueryResult, + values: &mut SmallVec<[SqlValue; 8]>, ) -> QueryResult { if assignments.is_empty() { return Err(QueryError::Internal("UPDATE with no assignments".into())); @@ -217,7 +218,7 @@ fn compile_insert( node: &QueryNode, cols_vals: &[(String, SqlValue)], returning_id: bool, - values: &mut Vec, + values: &mut SmallVec<[SqlValue; 8]>, ) -> QueryResult { if cols_vals.is_empty() { return Err(QueryError::Internal("INSERT with no values".into())); @@ -340,7 +341,7 @@ pub fn compile_order_by(clauses: &[crate::ast::OrderByClause]) -> String { fn compile_where_combined( filters: &[FilterNode], q: Option<&QNode>, - values: &mut Vec, + values: &mut SmallVec<[SqlValue; 8]>, backend: Backend, ) -> QueryResult { let flat = if filters.is_empty() { @@ -361,7 +362,11 @@ fn compile_where_combined( }) } -pub fn compile_q(q: &QNode, values: &mut Vec, backend: Backend) -> QueryResult { +pub fn compile_q( + q: &QNode, + values: &mut SmallVec<[SqlValue; 8]>, + backend: Backend, +) -> QueryResult { match q { QNode::Leaf { field, @@ -392,7 +397,7 @@ pub fn compile_q(q: &QNode, values: &mut Vec, backend: Backend) -> Que fn compile_filters( filters: &[FilterNode], - values: &mut Vec, + values: &mut SmallVec<[SqlValue; 8]>, backend: Backend, ) -> QueryResult { let parts: Vec = filters @@ -407,7 +412,7 @@ fn compile_single_filter( lookup: &str, value: &SqlValue, negated: bool, - values: &mut Vec, + values: &mut SmallVec<[SqlValue; 8]>, backend: Backend, ) -> QueryResult { let (base_column, applied_transforms, json_key) = if field.contains("__") { @@ -474,9 +479,9 @@ fn compile_single_filter( } if lookup == "in" { - let items = match value { - SqlValue::List(v) => v.clone(), - other => vec![other.clone()], + let items: SmallVec<[SqlValue; 4]> = match value { + SqlValue::List(v) => v.iter().map(|x| (**x).clone()).collect(), + other => smallvec::smallvec![(*other).clone()], }; if items.is_empty() { return Ok("(1 = 0)".into()); @@ -495,9 +500,9 @@ fn compile_single_filter( } if lookup == "has_any" || lookup == "has_all" { - let items = match value { - SqlValue::List(v) => v.clone(), - other => vec![other.clone()], + let items: SmallVec<[SqlValue; 4]> = match value { + SqlValue::List(v) => v.iter().map(|x| (**x).clone()).collect(), + other => smallvec::smallvec![(*other).clone()], }; if items.is_empty() { return Ok("(1 = 0)".into()); @@ -537,7 +542,7 @@ fn compile_single_filter( if lookup == "range" { let (lo, hi) = match value { - SqlValue::List(v) if v.len() == 2 => (v[0].clone(), v[1].clone()), + SqlValue::List(v) if v.len() == 2 => (v[0].as_ref().clone(), v[1].as_ref().clone()), _ => return Err(QueryError::Internal("range needs exactly 2 values".into())), }; values.push(lo); diff --git a/ryx/bulk.py b/ryx/bulk.py index a46f0b1..b5a5119 100644 --- a/ryx/bulk.py +++ b/ryx/bulk.py @@ -25,13 +25,13 @@ from __future__ import annotations -# import asyncio -# import itertools from typing import List, Sequence, Type, TYPE_CHECKING if TYPE_CHECKING: from ryx.models import Model +from ryx import ryx_core as _core + def _detect_backend() -> str: """Detect the database backend from the RYX_DATABASE_URL env var. @@ -113,12 +113,20 @@ async def bulk_create( pk_field = model._meta.pk_field - # Process in batches + # Process in batches — all SQL and execution handled in Rust for batch in _chunked(instances, batch_size): - pks = await _insert_batch(model, batch, fields, col_names, ignore_conflicts) - # Assign returned PKs to instances - for inst, pk in zip(batch, pks): - object.__setattr__(inst, pk_field.attname, pk) + rows = [[f.to_db(getattr(inst, f.attname)) for f in fields] for inst in batch] + res = await _core.bulk_insert( + model._meta.table_name, + col_names, + rows, + True, # returning_id + ignore_conflicts, + ) + # On PostgreSQL/SQLite res is list of ids; on MySQL res is rows_affected + if pk_field and isinstance(res, list): + for inst, pk in zip(batch, res): + object.__setattr__(inst, pk_field.attname, pk) return list(instances) @@ -255,11 +263,7 @@ async def bulk_update( } total = 0 - from ryx import ryx_core as _core - from ryx.pool_ext import execute_with_params - for batch in _chunked(instances, batch_size): - # Collect valid instances (with pk set) valid = [inst for inst in batch if inst.pk is not None] if not valid: continue @@ -268,55 +272,27 @@ async def bulk_update( pk_col = pk_field.column table = model._meta.table_name - # Build CASE WHEN clauses. - # Strategy: inline integers directly in SQL (zero FFI cost), - # use ? placeholders only for non-integer values. - case_clauses = [] - all_values = [] - + # Collect values per column in the order of pks + col_names: List[str] = [] + field_values: List[List[object]] = [] for fname in update_fields: if fname not in field_objs: continue fobj = field_objs[fname] - col = fobj.column - case_parts = [f'"{col}" = CASE "{pk_col}"'] - for inst in valid: - val = fobj.to_db(getattr(inst, fname)) - if isinstance(val, int) and not isinstance(val, bool): - # Inline integers — zero FFI overhead - case_parts.append(f"WHEN {inst.pk} THEN {val}") - else: - case_parts.append("WHEN ? THEN ?") - all_values.append(inst.pk) - all_values.append(val) - case_parts.append("END") - case_clauses.append(" ".join(case_parts)) - - if not case_clauses: - continue + col_names.append(fobj.column) + vals = [fobj.to_db(getattr(inst, fname)) for inst in valid] + field_values.append(vals) - # WHERE IN — inline integer PKs - pk_parts = [] - for pk in pks: - if isinstance(pk, int): - pk_parts.append(str(pk)) - else: - pk_parts.append("?") - all_values.append(pk) + if not col_names: + continue - sql = ( - f'UPDATE "{table}" SET ' - f"{', '.join(case_clauses)} " - f'WHERE "{pk_col}" IN ({", ".join(pk_parts)})' + result = await _core.bulk_update( + table, + pk_col, + list(zip(col_names,field_values)), + pks, ) - - if all_values: - await execute_with_params(sql, all_values) - else: - from ryx.executor_helpers import raw_execute - - await raw_execute(sql) - total += len(valid) + total += result return total diff --git a/ryx/models.py b/ryx/models.py index 598aa1b..da4bc73 100644 --- a/ryx/models.py +++ b/ryx/models.py @@ -167,9 +167,6 @@ def __init__(self, alias: Optional[str] = None) -> None: def contribute_to_class(self, model: type, name: str) -> None: self._model = model - def contribute_to_class(self, model: type, name: str) -> None: - self._model = model - def get_queryset(self): from ryx.queryset import QuerySet @@ -292,9 +289,9 @@ async def bulk_create(self, instances: list[Model], batch_size: int = 500) -> li async def bulk_update( self, instances: list, fields: list, batch_size: int = 500 ) -> int: - from ryx.bulk import bulk_update + from ryx.bulk import bulk_update as _update - return await bulk_update(self._model, instances, fields, batch_size=batch_size) + return await _update(self._model, instances, fields, batch_size=batch_size) async def bulk_delete( self, instances: Optional[list] = None, batch_size: int = 500 diff --git a/ryx/queryset.py b/ryx/queryset.py index 512dc0f..e931911 100644 --- a/ryx/queryset.py +++ b/ryx/queryset.py @@ -250,7 +250,7 @@ class QuerySet: def __init__( self, - model: type, + model: Model, builder: Optional[_core.QueryBuilder] = None, *, _select_columns: Optional[List[str]] = None, @@ -320,13 +320,15 @@ def filter(self, *q_args: Q, **kwargs: Any) -> "QuerySet": node = q.to_q_node() builder = _apply_q_node(builder, node) - # kwargs (flat filters) - for key, val in kwargs.items(): - # Support Django-style primary key lookup in kwargs - if key == "pk": - key = self._model._meta.pk_field.attname - field, lookup = _parse_lookup_key(key) - builder = builder.add_filter(field, lookup, val, negated=False) + # kwargs (flat filters) batched to reduce FFI crossings + if kwargs: + batch = [] + for key, val in kwargs.items(): + if key == "pk": + key = self._model._meta.pk_field.attname + field, lookup = _parse_lookup_key(key) + batch.append((field, lookup, val, False)) + builder = builder.add_filters_batch(batch) return self._clone(builder) def exclude(self, *q_args: Q, **kwargs: Any) -> "QuerySet": @@ -336,9 +338,12 @@ def exclude(self, *q_args: Q, **kwargs: Any) -> "QuerySet": for q in q_args: builder = _apply_q_node(builder, (~q).to_q_node()) - for key, val in kwargs.items(): - field, lookup = _parse_lookup_key(key) - builder = builder.add_filter(field, lookup, val, negated=True) + if kwargs: + batch = [] + for key, val in kwargs.items(): + field, lookup = _parse_lookup_key(key) + batch.append((field, lookup, val, True)) + builder = builder.add_filters_batch(batch) return self._clone(builder) @@ -458,8 +463,8 @@ def order_by(self, *fields: str) -> "QuerySet": """Override ordering. Pass ``"-field"`` for DESC, ``"field"`` for ASC.""" builder = self._builder - for f in fields: - builder = builder.add_order_by(f) + if fields: + builder = builder.add_order_by_batch(list(fields)) return self._clone(builder) def limit(self, n: int) -> "QuerySet": @@ -723,7 +728,8 @@ async def delete(self) -> int: async def update(self, **kwargs: Any) -> int: """Bulk update. Fires pre_update / post_update signals.""" - alias = self._resolve_db_alias("write") + # Resolve database alias: .using() -> Meta.database -> default + alias = self._using or self._model._meta.database builder = self._builder if alias: @@ -740,23 +746,6 @@ async def bulk_delete(self) -> int: """Alias for delete().""" return await self.delete() - async def update(self, **kwargs: Any) -> int: - """Bulk update. Fires pre_update / post_update signals.""" - - # Resolve database alias: .using() -> Meta.database -> default - alias = self._using or self._model._meta.database - - builder = self._builder - if alias: - builder = builder.set_using(alias) - - await pre_update.send(sender=self._model, queryset=self, fields=kwargs) - n = await builder.execute_update(list(kwargs.items())) - await post_update.send( - sender=self._model, queryset=self, updated_count=n, fields=kwargs - ) - return n - async def in_bulk(self, id_list: list, *, field_name: str = "pk") -> dict: """Return a dict of {pk: instance} for the given list of PKs.""" diff --git a/src/executor.rs b/src/executor.rs index f37dbb1..892db1b 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -39,14 +39,14 @@ use std::collections::HashMap; -use serde_json::Value as JsonValue; use sqlx::{Column, Row, any::AnyRow}; use tracing::{debug, instrument}; use crate::errors::{RyxError, RyxResult}; use crate::pool; -use ryx_query::{ast::SqlValue, compiler::CompiledQuery}; +use ryx_query::{ast::{SqlValue, QueryNode}, compiler::CompiledQuery}; use crate::transaction; +use smallvec::SmallVec; // ### // Result types @@ -57,7 +57,7 @@ use crate::transaction; /// Using `serde_json::Value` lets us represent NULL, integers, floats, strings, /// and booleans without a custom enum. JSON values convert cleanly to Python /// objects in the PyO3 layer. -pub type DecodedRow = HashMap; +pub type DecodedRow = HashMap; /// Result of a non-SELECT query (INSERT/UPDATE/DELETE). #[derive(Debug)] @@ -87,7 +87,7 @@ pub async fn fetch_all(query: CompiledQuery) -> RyxResult> { } return Err(RyxError::Internal("Transaction is no longer active".into())); } - + let pool = pool::get(query.db_alias.as_deref())?; debug!(sql = %query.sql, "Executing SELECT"); @@ -96,10 +96,25 @@ pub async fn fetch_all(query: CompiledQuery) -> RyxResult> { q = bind_values(q, &query.values); let rows = q.fetch_all(&*pool).await.map_err(RyxError::Database)?; - - let decoded = rows.iter().map(decode_row).collect(); + + let decoded = decode_rows(&rows); Ok(decoded) } + +/// Execute raw SQL (no binds) directly, bypassing compiler. +#[instrument(skip(sql))] +pub async fn fetch_raw(sql: String, db_alias: Option) -> RyxResult> { + let pool = pool::get(db_alias.as_deref())?; + let rows = sqlx::query(&sql).fetch_all(&*pool).await.map_err(RyxError::Database)?; + Ok(decode_rows(&rows)) +} + +/// Compile a QueryNode then fetch all (single FFI hop helper). +#[instrument(skip(node))] +pub async fn fetch_all_compiled(node: QueryNode) -> RyxResult> { + let compiled = ryx_query::compiler::compile(&node).map_err(RyxError::from)?; + fetch_all(compiled).await +} /// Execute a SELECT COUNT(*) query and return the count. @@ -116,11 +131,10 @@ pub async fn fetch_count(query: CompiledQuery) -> RyxResult { return Ok(0); } if let Some(value) = rows[0].values().next() { - if let Some(i) = value.as_i64() { - return Ok(i); - } - if let Some(f) = value.as_f64() { - return Ok(f as i64); + match value { + SqlValue::Int(i) => return Ok(*i), + SqlValue::Float(f) => return Ok(*f as i64), + _ => {} } } return Err(RyxError::Internal( @@ -147,6 +161,12 @@ pub async fn fetch_count(query: CompiledQuery) -> RyxResult { Ok(count) } +#[instrument(skip(node))] +pub async fn fetch_count_compiled(node: QueryNode) -> RyxResult { + let compiled = ryx_query::compiler::compile(&node).map_err(RyxError::from)?; + fetch_count(compiled).await +} + /// Execute a SELECT and return at most one row. /// @@ -185,12 +205,18 @@ pub async fn fetch_one(query: CompiledQuery) -> RyxResult { match rows.len() { 0 => Err(RyxError::DoesNotExist), - 1 => Ok(decode_row(&rows[0])), + 1 => Ok(decode_row(&rows[0], None)), _ => Err(RyxError::MultipleObjectsReturned), } } } +#[instrument(skip(node))] +pub async fn fetch_one_compiled(node: QueryNode) -> RyxResult { + let compiled = ryx_query::compiler::compile(&node).map_err(RyxError::from)?; + fetch_one(compiled).await +} + /// Execute an INSERT, UPDATE, or DELETE query. /// @@ -208,9 +234,13 @@ pub async fn execute(query: CompiledQuery) -> RyxResult { // Check if this is a RETURNING query if query.sql.to_uppercase().contains("RETURNING") { let rows = active_tx.fetch_query(query).await?; - let last_insert_id = rows - .first() - .and_then(|row| row.values().next().and_then(|v| v.as_i64())); + let last_insert_id = rows.first().and_then(|row| { + row.values().next().and_then(|v| match v { + SqlValue::Int(i) => Some(*i), + SqlValue::Float(f) => Some(*f as i64), + _ => None, + }) + }); return Ok(MutationResult { rows_affected: 1, last_insert_id, @@ -233,11 +263,11 @@ pub async fn execute(query: CompiledQuery) -> RyxResult { if query.sql.to_uppercase().contains("RETURNING") { let mut q = sqlx::query(&query.sql); q = bind_values(q, &query.values); - + let rows = q.fetch_all(&*pool).await.map_err(RyxError::Database)?; - + let last_insert_id = rows.first().and_then(|row| row.try_get::(0).ok()); - + return Ok(MutationResult { rows_affected: rows.len() as u64, last_insert_id, @@ -256,6 +286,151 @@ pub async fn execute(query: CompiledQuery) -> RyxResult { } +/// Execute QueryNode +#[instrument(skip(node))] +pub async fn execute_compiled(node: QueryNode) -> RyxResult { + let compiled = ryx_query::compiler::compile(&node).map_err(RyxError::from)?; + execute(compiled).await +} + +/// Bulk insert rows with values already mapped to SqlValue in one shot. +pub async fn bulk_insert( + table: String, + columns: Vec, + rows: Vec>, + returning_id: bool, + ignore_conflicts: bool, + db_alias: Option, +) -> RyxResult { + if rows.is_empty() { + return Ok(MutationResult { rows_affected: 0, last_insert_id: None }); + } + let pool = pool::get(db_alias.as_deref())?; + let backend = pool::get_backend(db_alias.as_deref())?; + + let col_list = columns.iter().map(|c| format!("\"{}\"", c)).collect::>().join(", "); + let row_ph = format!("({})", std::iter::repeat("?").take(columns.len()).collect::>().join(", ")); + let values_sql = std::iter::repeat(row_ph.clone()).take(rows.len()).collect::>().join(", "); + + let mut flat: SmallVec<[SqlValue; 8]> = SmallVec::new(); + for row in rows { + for v in row { + flat.push(v); + } + } + + let (insert_kw, conflict_suffix) = if ignore_conflicts { + match backend { + ryx_query::Backend::PostgreSQL => ("INSERT INTO", " ON CONFLICT DO NOTHING"), + ryx_query::Backend::MySQL => ("INSERT IGNORE INTO", ""), + ryx_query::Backend::SQLite => ("INSERT OR IGNORE INTO", ""), + } + } else { + ("INSERT INTO", "") + }; + + let sql = format!( + "{} \"{}\" ({}) VALUES {}{}{}", + insert_kw, + table, + col_list, + values_sql, + conflict_suffix, + if returning_id { " RETURNING id" } else { "" } + ); + let mut q = sqlx::query(&sql); + q = bind_values(q, &flat); + if returning_id { + let rows = q.fetch_all(&*pool).await.map_err(RyxError::Database)?; + let last_insert_id = rows.first().and_then(|r| r.try_get::(0).ok()); + Ok(MutationResult { rows_affected: rows.len() as u64, last_insert_id }) + } else { + let res = q.execute(&*pool).await.map_err(RyxError::Database)?; + Ok(MutationResult { rows_affected: res.rows_affected(), last_insert_id: None }) + } +} + +/// Bulk delete by primary key values in one shot. +pub async fn bulk_delete( + table: String, + pk_col: String, + pks: Vec, + db_alias: Option, +) -> RyxResult { + if pks.is_empty() { + return Ok(MutationResult { rows_affected: 0, last_insert_id: None }); + } + let pool = pool::get(db_alias.as_deref())?; + let ph = std::iter::repeat("?").take(pks.len()).collect::>().join(", "); + let sql = format!( + "DELETE FROM \"{}\" WHERE \"{}\" IN ({})", + table, pk_col, ph + ); + let mut q = sqlx::query(&sql); + q = bind_values(q, &pks); + let res = q.execute(&*pool).await.map_err(RyxError::Database)?; + Ok(MutationResult { rows_affected: res.rows_affected(), last_insert_id: None }) +} + +/// Bulk update using CASE WHEN, values already mapped to SqlValue. +pub async fn bulk_update( + table: String, + pk_col: String, + col_names: Vec, + field_values: Vec>, + pks: Vec, + db_alias: Option, +) -> RyxResult { + let pool = pool::get(db_alias.as_deref())?; + let n = pks.len(); + let f = field_values.len(); + if n == 0 || f == 0 { + return Ok(MutationResult { rows_affected: 0, last_insert_id: None }); + } + + let mut case_clauses = Vec::with_capacity(f); + let mut all_values: SmallVec<[SqlValue; 8]> = SmallVec::with_capacity(n * f * 2 + n); + + for (fi, col_name) in col_names.iter().enumerate() { + let mut case_parts = Vec::with_capacity(n * 3 + 2); + case_parts.push(format!("\"{}\" = CASE \"{}\"", col_name, pk_col)); + for i in 0..n { + case_parts.push("WHEN ? THEN ?".to_string()); + all_values.push(pks[i].clone()); + all_values.push(field_values[fi][i].clone()); + } + case_parts.push("END".to_string()); + case_clauses.push(case_parts.join(" ")); + } + + let pk_placeholders: Vec = (0..n).map(|_| "?".to_string()).collect(); + for pk in &pks { + all_values.push(pk.clone()); + } + + let sql = format!( + "UPDATE \"{}\" SET {} WHERE \"{}\" IN ({})", + table, + case_clauses.join(", "), + pk_col, + pk_placeholders.join(", ") + ); + + let mut q = sqlx::query(&sql); + q = bind_values(q, &all_values); + let res = q.execute(&*pool).await.map_err(RyxError::Database)?; + Ok(MutationResult { rows_affected: res.rows_affected(), last_insert_id: None }) +} + +/// Execute raw SQL without bind params. +#[instrument(skip(sql))] +pub async fn execute_raw(sql: String, db_alias: Option) -> RyxResult<()> { + let pool = pool::get(db_alias.as_deref())?; + sqlx::query(&sql).execute(&*pool).await.map_err(RyxError::Database)?; + Ok(()) +} + + // ### // Internal helpers // ### @@ -289,39 +464,32 @@ fn bind_values<'q>( q } -/// Decode a single `AnyRow` into a `DecodedRow` (HashMap). -/// -/// We iterate over the columns and use sqlx's `try_get` to extract each value. -/// The `Any` database driver supports a limited set of types natively: -/// - i64 (maps to Bool and Int as well) -/// - f64 -/// - String -/// - Vec (bytes) -/// - bool -/// -/// Decode an AnyRow into a HashMap. -/// -/// We try each type in order and fall back to String if nothing else works. -/// -/// Boolean detection on SQLite uses a zero-allocation case-insensitive check -/// on the column name (no `to_lowercase()` allocation). -fn decode_row(row: &AnyRow) -> DecodedRow { - let mut map = HashMap::new(); - - for column in row.columns() { - let name = column.name().to_string(); - - // Try to extract values in type priority order. - // On SQLite, booleans are stored as INTEGER (0/1), so we try i64 first - // and then check if the value looks like a bool. - // On Postgres/MySQL, bool columns decode as bool natively. - // - // null: sqlx signals NULL by returning an Err on every typed get. - // We detect this by trying Option last. - - let value: JsonValue = if let Ok(i) = row.try_get::(column.ordinal()) { - // Zero-allocation boolean detection: check common boolean column - // prefixes/suffixes without allocating a lowercase string. +/// Decode all rows with a precomputed column-name vector to reduce per-row allocations. +fn decode_rows(rows: &[AnyRow]) -> Vec { + if rows.is_empty() { + return Vec::new(); + } + + let col_names: Vec = rows[0] + .columns() + .iter() + .map(|c| c.name().to_string()) + .collect(); + + rows.iter() + .map(|row| decode_row(row, Some(&col_names))) + .collect() +} + +fn decode_row(row: &AnyRow, names: Option<&Vec>) -> DecodedRow { + let mut map = HashMap::with_capacity(row.columns().len()); + + for (idx, column) in row.columns().iter().enumerate() { + let name = names + .and_then(|n| n.get(idx).cloned()) + .unwrap_or_else(|| column.name().to_string()); + + let value = if let Ok(i) = row.try_get::(column.ordinal()) { let looks_bool = name.starts_with("is_") || name.starts_with("Is_") || name.starts_with("IS_") @@ -335,21 +503,18 @@ fn decode_row(row: &AnyRow) -> DecodedRow { || name.ends_with("_Flag") || name.ends_with("_FLAG"); if looks_bool && (i == 0 || i == 1) { - JsonValue::Bool(i != 0) + SqlValue::Bool(i != 0) } else { - JsonValue::Number(i.into()) + SqlValue::Int(i) } } else if let Ok(b) = row.try_get::(column.ordinal()) { - JsonValue::Bool(b) + SqlValue::Bool(b) } else if let Ok(f) = row.try_get::(column.ordinal()) { - serde_json::Number::from_f64(f) - .map(JsonValue::Number) - .unwrap_or(JsonValue::Null) + SqlValue::Float(f) } else if let Ok(s) = row.try_get::(column.ordinal()) { - JsonValue::String(s) + SqlValue::Text(s) } else { - // Either NULL or a type we don't handle — represent as null. - JsonValue::Null + SqlValue::Null }; map.insert(name, value); diff --git a/src/lib.rs b/src/lib.rs index 92f7e9e..e42619e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use pyo3::prelude::IntoPyObject; use pyo3::{IntoPyObjectExt, prelude::*}; use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple}; -use serde_json::Value as JsonValue; use tokio::sync::Mutex as TokioMutex; pub mod errors; @@ -123,12 +122,7 @@ fn raw_fetch<'py>( alias: Option, ) -> PyResult> { pyo3_async_runtimes::tokio::future_into_py(py, async move { - let compiled = compiler::CompiledQuery { - sql, - values: vec![], - db_alias: alias, - }; - let rows = executor::fetch_all(compiled).await.map_err(PyErr::from)?; + let rows = executor::fetch_raw(sql, alias).await.map_err(PyErr::from)?; Python::attach(|py| { let py_rows = decoded_rows_to_py(py, rows)?; Ok(py_rows.unbind()) @@ -144,12 +138,7 @@ fn raw_execute<'py>( alias: Option, ) -> PyResult> { pyo3_async_runtimes::tokio::future_into_py(py, async move { - let compiled = compiler::CompiledQuery { - sql, - values: vec![], - db_alias: alias, - }; - executor::execute(compiled).await.map_err(PyErr::from)?; + executor::execute_raw(sql, alias).await.map_err(PyErr::from)?; Python::attach(|py| Ok(py.None().into_pyobject(py)?.unbind())) }) } @@ -163,7 +152,7 @@ fn raw_execute<'py>( #[pyclass(from_py_object, name = "QueryBuilder")] #[derive(Clone)] pub struct PyQueryBuilder { - node: QueryNode, + node: Arc, } #[pymethods] @@ -174,13 +163,13 @@ impl PyQueryBuilder { let backend = pool::get_backend(None).unwrap_or(ryx_query::Backend::PostgreSQL); Ok(Self { - node: QueryNode::select(table).with_backend(backend), + node: Arc::new(QueryNode::select(table).with_backend(backend)), }) } fn set_using(&self, alias: String) -> PyResult { Ok(PyQueryBuilder { - node: self.node.clone().with_db_alias(alias), + node: Arc::new(self.node.as_ref().clone().with_db_alias(alias)), }) } @@ -193,19 +182,38 @@ impl PyQueryBuilder { ) -> PyResult { let sql_value = py_to_sql_value(value)?; Ok(PyQueryBuilder { - node: self.node.clone().with_filter(FilterNode { + node: Arc::new(self.node.as_ref().clone().with_filter(FilterNode { field, lookup, value: sql_value, negated, - }), + })), }) } + /// Add multiple filters in a single FFI call to reduce overhead when applying + /// many kwargs-based filters from Python. + fn add_filters_batch( + &self, + filters: Vec<(String, String, Bound<'_, PyAny>, bool)>, + ) -> PyResult { + let mut node = self.node.as_ref().clone(); + for (field, lookup, value, negated) in filters { + let sql_value = py_to_sql_value(&value)?; + node = node.with_filter(FilterNode { + field, + lookup, + value: sql_value, + negated, + }); + } + Ok(PyQueryBuilder { node: Arc::new(node) }) + } + fn add_q_node(&self, node: &Bound<'_, PyAny>) -> PyResult { let q = py_dict_to_qnode(node)?; Ok(PyQueryBuilder { - node: self.node.clone().with_q(q), + node: Arc::new(self.node.as_ref().clone().with_q(q)), }) } @@ -225,18 +233,18 @@ impl PyQueryBuilder { other => AggFunc::Raw(other.to_string()), }; PyQueryBuilder { - node: self.node.clone().with_annotation(AggregateExpr { + node: Arc::new(self.node.as_ref().clone().with_annotation(AggregateExpr { alias, func: agg_func, field, distinct, - }), + })), } } fn add_group_by(&self, field: String) -> PyQueryBuilder { PyQueryBuilder { - node: self.node.clone().with_group_by(field), + node: Arc::new(self.node.as_ref().clone().with_group_by(field)), } } @@ -257,58 +265,68 @@ impl PyQueryBuilder { }; let alias_opt = if alias.is_empty() { None } else { Some(alias) }; PyQueryBuilder { - node: self.node.clone().with_join(JoinClause { + node: Arc::new(self.node.as_ref().clone().with_join(JoinClause { kind: join_kind, table, alias: alias_opt, on_left, on_right, - }), + })), } } fn add_order_by(&self, field: String) -> PyQueryBuilder { PyQueryBuilder { - node: self - .node - .clone() - .with_order_by(OrderByClause::parse(&field)), + node: Arc::new( + self.node + .as_ref() + .clone() + .with_order_by(OrderByClause::parse(&field)), + ), + } + } + + /// Batch add ORDER BY clauses to reduce repeated crossings. + fn add_order_by_batch(&self, fields: Vec) -> PyQueryBuilder { + let mut node = self.node.as_ref().clone(); + for f in fields { + node = node.with_order_by(OrderByClause::parse(&f)); } + PyQueryBuilder { node: Arc::new(node) } } fn set_limit(&self, n: u64) -> PyQueryBuilder { PyQueryBuilder { - node: self.node.clone().with_limit(n), + node: Arc::new(self.node.as_ref().clone().with_limit(n)), } } fn set_offset(&self, n: u64) -> PyQueryBuilder { PyQueryBuilder { - node: self.node.clone().with_offset(n), + node: Arc::new(self.node.as_ref().clone().with_offset(n)), } } fn set_distinct(&self) -> PyQueryBuilder { - let mut node = self.node.clone(); + let mut node = self.node.as_ref().clone(); node.distinct = true; - PyQueryBuilder { node } + PyQueryBuilder { node: Arc::new(node) } } // # Execution methods fn fetch_all<'py>(&self, py: Python<'py>) -> PyResult> { - let compiled = compiler::compile(&self.node).map_err(RyxError::from)?; + let node = self.node.as_ref().clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { - let rows = executor::fetch_all(compiled).await.map_err(PyErr::from)?; + let rows = executor::fetch_all_compiled(node).await.map_err(PyErr::from)?; Python::attach(|py| Ok(decoded_rows_to_py(py, rows)?.unbind())) }) } fn fetch_first<'py>(&self, py: Python<'py>) -> PyResult> { - let node = self.node.clone().with_limit(1); - let compiled = compiler::compile(&node).map_err(RyxError::from)?; + let node = self.node.as_ref().clone().with_limit(1); pyo3_async_runtimes::tokio::future_into_py(py, async move { - let rows = executor::fetch_all(compiled).await.map_err(PyErr::from)?; + let rows = executor::fetch_all_compiled(node).await.map_err(PyErr::from)?; Python::attach(|py| match rows.into_iter().next() { Some(row) => Ok(decoded_row_to_py(py, row)?.into_any().unbind()), None => Ok(py.None().into_pyobject(py)?.unbind()), @@ -317,29 +335,27 @@ impl PyQueryBuilder { } fn fetch_get<'py>(&self, py: Python<'py>) -> PyResult> { - let compiled = compiler::compile(&self.node).map_err(RyxError::from)?; + let node = self.node.as_ref().clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { - let row = executor::fetch_one(compiled).await.map_err(PyErr::from)?; + let row = executor::fetch_one_compiled(node).await.map_err(PyErr::from)?; Python::attach(|py| Ok(decoded_row_to_py(py, row)?.into_any().unbind())) }) } fn fetch_count<'py>(&self, py: Python<'py>) -> PyResult> { - let mut count_node = self.node.clone(); + let mut count_node = self.node.as_ref().clone(); count_node.operation = QueryOperation::Count; - let compiled = compiler::compile(&count_node).map_err(RyxError::from)?; pyo3_async_runtimes::tokio::future_into_py(py, async move { - let count = executor::fetch_count(compiled).await.map_err(PyErr::from)?; + let count = executor::fetch_count_compiled(count_node).await.map_err(PyErr::from)?; Python::attach(|py| Ok(count.into_pyobject(py)?.unbind())) }) } fn fetch_aggregate<'py>(&self, py: Python<'py>) -> PyResult> { - let mut agg_node = self.node.clone(); + let mut agg_node = self.node.as_ref().clone(); agg_node.operation = QueryOperation::Aggregate; - let compiled = compiler::compile(&agg_node).map_err(RyxError::from)?; pyo3_async_runtimes::tokio::future_into_py(py, async move { - let rows = executor::fetch_all(compiled).await.map_err(PyErr::from)?; + let rows = executor::fetch_all_compiled(agg_node).await.map_err(PyErr::from)?; Python::attach(|py| match rows.into_iter().next() { Some(row) => Ok(decoded_row_to_py(py, row)?.into_any().unbind()), None => Ok(PyDict::new(py).into_any().unbind()), @@ -348,11 +364,10 @@ impl PyQueryBuilder { } fn execute_delete<'py>(&self, py: Python<'py>) -> PyResult> { - let mut del_node = self.node.clone(); + let mut del_node = self.node.as_ref().clone(); del_node.operation = QueryOperation::Delete; - let compiled = compiler::compile(&del_node).map_err(RyxError::from)?; pyo3_async_runtimes::tokio::future_into_py(py, async move { - let res = executor::execute(compiled).await.map_err(PyErr::from)?; + let res = executor::execute_compiled(del_node).await.map_err(PyErr::from)?; Python::attach(|py| Ok(res.rows_affected.into_pyobject(py)?.unbind())) }) } @@ -367,14 +382,13 @@ impl PyQueryBuilder { .map(|(col, val)| Ok::<_, PyErr>((col, py_to_sql_value(&val)?))) .collect::>()?; - let mut upd_node = self.node.clone(); + let mut upd_node = self.node.as_ref().clone(); upd_node.operation = QueryOperation::Update { assignments: rust_assignments, }; - let compiled = compiler::compile(&upd_node).map_err(RyxError::from)?; pyo3_async_runtimes::tokio::future_into_py(py, async move { - let res = executor::execute(compiled).await.map_err(PyErr::from)?; + let res = executor::execute_compiled(upd_node).await.map_err(PyErr::from)?; Python::attach(|py| Ok(res.rows_affected.into_pyobject(py)?.unbind())) }) } @@ -390,15 +404,14 @@ impl PyQueryBuilder { .map(|(col, val)| Ok::<_, PyErr>((col, py_to_sql_value(&val)?))) .collect::>()?; - let mut ins_node = self.node.clone(); + let mut ins_node = self.node.as_ref().clone(); ins_node.operation = QueryOperation::Insert { values: rust_values, returning_id, }; - let compiled = compiler::compile(&ins_node).map_err(RyxError::from)?; pyo3_async_runtimes::tokio::future_into_py(py, async move { - let res = executor::execute(compiled).await.map_err(PyErr::from)?; + let res = executor::execute_compiled(ins_node).await.map_err(PyErr::from)?; Python::attach(|py| match res.last_insert_id { Some(id) => Ok(id.into_pyobject(py)?.unbind()), None => Ok(res.rows_affected.into_pyobject(py)?.unbind()), @@ -434,15 +447,15 @@ fn py_to_sql_value(obj: &Bound<'_, PyAny>) -> PyResult { if let Ok(list) = obj.cast::() { let items = list .iter() - .map(|i| py_to_sql_value(&i)) - .collect::>()?; + .map(|i| py_to_sql_value(&i).map(Box::new)) + .collect::; 4]>>>()?; return Ok(SqlValue::List(items)); } if let Ok(tup) = obj.cast::() { let items = tup .iter() - .map(|i| py_to_sql_value(&i)) - .collect::>()?; + .map(|i| py_to_sql_value(&i).map(Box::new)) + .collect::; 4]>>>()?; return Ok(SqlValue::List(items)); } Ok(SqlValue::Text(obj.str()?.to_str()?.to_string())) @@ -528,18 +541,18 @@ fn py_dict_children(dict: &Bound<'_, PyDict>) -> PyResult> { fn decoded_row_to_py<'py>( py: Python<'py>, - row: HashMap, + row: HashMap, ) -> PyResult> { let dict = PyDict::new(py); for (k, v) in row { - dict.set_item(k, json_to_py(py, &v)?)?; + dict.set_item(k, sql_to_py(py, &v)?)?; } Ok(dict) } fn decoded_rows_to_py<'py>( py: Python<'py>, - rows: Vec>, + rows: Vec>, ) -> PyResult> { let list = PyList::empty(py); for row in rows { @@ -548,39 +561,25 @@ fn decoded_rows_to_py<'py>( Ok(list) } -fn json_to_py<'py>(py: Python<'py>, v: &JsonValue) -> PyResult> { +fn sql_to_py<'py>(py: Python<'py>, v: &SqlValue) -> PyResult> { Ok(match v { - JsonValue::Null => py.None(), - JsonValue::Bool(b) => { + SqlValue::Null => py.None(), + SqlValue::Bool(b) => { let py_bool = (*b).into_pyobject(py)?; as Clone>::clone(&py_bool) .into_any() .unbind() } - JsonValue::String(s) => s.into_pyobject(py)?.into_any().unbind(), - JsonValue::Number(n) => { - if let Some(i) = n.as_i64() { - i.into_pyobject(py)?.into_any().unbind() - } else if let Some(f) = n.as_f64() { - f.into_pyobject(py)?.into_any().unbind() - } else { - n.to_string().into_pyobject(py)?.into_any().unbind() - } - } - JsonValue::Array(arr) => { + SqlValue::Int(i) => i.into_pyobject(py)?.into_any().unbind(), + SqlValue::Float(f) => f.into_pyobject(py)?.into_any().unbind(), + SqlValue::Text(s) => s.into_pyobject(py)?.into_any().unbind(), + SqlValue::List(items) => { let list = PyList::empty(py); - for item in arr { - list.append(json_to_py(py, item)?)?; + for item in items { + list.append(sql_to_py(py, item)?)?; } list.into_any().unbind() } - JsonValue::Object(map) => { - let dict = PyDict::new(py); - for (k, v2) in map { - dict.set_item(k, json_to_py(py, v2)?)?; - } - dict.into_any().unbind() - } }) } @@ -722,7 +721,7 @@ fn execute_with_params<'py>( pyo3_async_runtimes::tokio::future_into_py(py, async move { let compiled = compiler::CompiledQuery { sql, - values: sql_values, + values: sql_values.into(), db_alias: None, }; let result = executor::execute(compiled).await.map_err(PyErr::from)?; @@ -745,7 +744,7 @@ fn fetch_with_params<'py>( pyo3_async_runtimes::tokio::future_into_py(py, async move { let compiled = compiler::CompiledQuery { sql, - values: sql_values, + values: sql_values.into(), db_alias: None, }; let rows = executor::fetch_all(compiled).await.map_err(PyErr::from)?; @@ -770,24 +769,13 @@ fn bulk_delete<'py>( pk_col: String, pks: Vec>, ) -> PyResult> { - // Fast path: PKs are always integers — skip the full type-checking cascade let pk_list = PyList::new(py, pks)?; let pk_values = py_int_list_to_sql_values(&pk_list)?; pyo3_async_runtimes::tokio::future_into_py(py, async move { - // Build the DELETE query manually (no QueryBuilder needed) - let placeholders: Vec = (0..pk_values.len()).map(|i| format!("?{}", i + 1)).collect(); - let sql = format!( - "DELETE FROM \"{}\" WHERE \"{}\" IN ({})", - table, pk_col, placeholders.join(", ") - ); - - let compiled = compiler::CompiledQuery { - sql, - values: pk_values, - db_alias: None, - }; - let result = executor::execute(compiled).await.map_err(PyErr::from)?; + let result = executor::bulk_delete(table, pk_col, pk_values, None) + .await + .map_err(PyErr::from)?; Python::attach(|py| { let n = (result.rows_affected as i64).into_pyobject(py)?; Ok(n.unbind()) @@ -795,6 +783,44 @@ fn bulk_delete<'py>( }) } +/// Bulk insert: values are mapped in Rust then executed in a single FFI call. +#[pyfunction] +#[pyo3(signature = (table, columns, rows, returning_id=true, ignore_conflicts=false))] +fn bulk_insert<'py>( + py: Python<'py>, + table: String, + columns: Vec, + rows: Vec>>, + returning_id: bool, + ignore_conflicts: bool, +) -> PyResult> { + let mut rust_rows: Vec> = Vec::with_capacity(rows.len()); + for row in rows { + let mut vals = Vec::with_capacity(row.len()); + for v in row { + vals.push(py_to_sql_value(&v)?); + } + rust_rows.push(vals); + } + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let res = executor::bulk_insert( + table, + columns, + rust_rows, + returning_id, + ignore_conflicts, + None, + ) + .await + .map_err(PyErr::from)?; + Python::attach(|py| match res.last_insert_id { + Some(id) => Ok(id.into_pyobject(py)?.unbind()), + None => Ok(res.rows_affected.into_pyobject(py)?.unbind()), + }) + }) +} + /// Bulk update using CASE WHEN in a single FFI call. /// /// Builds a single UPDATE statement with CASE WHEN clauses: @@ -853,25 +879,13 @@ fn bulk_update<'py>( } // WHERE IN clause - let pk_placeholders: Vec = (0..n).map(|_| "?".to_string()).collect(); for pk in &pk_values { all_values.push(pk.clone()); } - let sql = format!( - "UPDATE \"{}\" SET {} WHERE \"{}\" IN ({})", - table, - case_clauses.join(", "), - pk_col, - pk_placeholders.join(", ") - ); - - let compiled = compiler::CompiledQuery { - sql, - values: all_values, - db_alias: None, - }; - let result = executor::execute(compiled).await.map_err(PyErr::from)?; + let result = executor::bulk_update(table, pk_col, col_names, field_values, pk_values, None) + .await + .map_err(PyErr::from)?; Python::attach(|py| { let n = (result.rows_affected as i64).into_pyobject(py)?; Ok(n.unbind()) @@ -901,8 +915,6 @@ fn ryx_core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(available_lookups, m)?)?; m.add_function(wrap_pyfunction!(list_lookups, m)?)?; m.add_function(wrap_pyfunction!(list_transforms, m)?)?; - m.add_function(wrap_pyfunction!(list_lookups, m)?)?; - m.add_function(wrap_pyfunction!(list_transforms, m)?)?; m.add_function(wrap_pyfunction!(list_aliases,m)?)?; m.add_function(wrap_pyfunction!(get_backend, m)?)?; m.add_function(wrap_pyfunction!(is_connected, m)?)?; @@ -911,6 +923,7 @@ fn ryx_core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(raw_execute, m)?)?; m.add_function(wrap_pyfunction!(execute_with_params, m)?)?; m.add_function(wrap_pyfunction!(fetch_with_params, m)?)?; + m.add_function(wrap_pyfunction!(bulk_insert, m)?)?; m.add_function(wrap_pyfunction!(bulk_delete, m)?)?; m.add_function(wrap_pyfunction!(bulk_update, m)?)?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; diff --git a/src/pool.rs b/src/pool.rs index fd58f3a..a8a4ff1 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -27,7 +27,6 @@ use std::collections::HashMap; use std::sync::{Arc, OnceLock, RwLock}; -use serde::{Deserialize, Serialize}; use sqlx::{ AnyPool, any::{AnyPoolOptions, install_default_drivers}, @@ -234,4 +233,3 @@ pub fn stats(alias: Option<&str>) -> RyxResult { idle: pool.num_idle() as u32, }) } - diff --git a/src/transaction.rs b/src/transaction.rs index 481584a..c22a754 100644 --- a/src/transaction.rs +++ b/src/transaction.rs @@ -31,7 +31,7 @@ use std::sync::{Arc, Mutex as StdMutex}; use tokio::sync::Mutex; use sqlx::{Any, Transaction}; -use tracing::{debug, instrument}; +use tracing::debug; use crate::errors::{RyxError, RyxResult}; use crate::pool; @@ -171,7 +171,7 @@ impl TransactionHandle { pub async fn fetch_query( &self, query: CompiledQuery, - ) -> RyxResult>> { + ) -> RyxResult>> { let mut guard = self.inner.lock().await; let tx = guard.as_mut().ok_or_else(|| { RyxError::Internal("Transaction already committed or rolled back".into()) @@ -191,19 +191,17 @@ impl TransactionHandle { let mut map = std::collections::HashMap::new(); for col in row.columns() { let name = col.name().to_string(); - let val: serde_json::Value = + let val = if let Ok(b) = row.try_get::(col.ordinal()) { - serde_json::Value::Bool(b) + SqlValue::Bool(b) } else if let Ok(i) = row.try_get::(col.ordinal()) { - serde_json::Value::Number(i.into()) + SqlValue::Int(i) } else if let Ok(f) = row.try_get::(col.ordinal()) { - serde_json::Number::from_f64(f) - .map(serde_json::Value::Number) - .unwrap_or(serde_json::Value::Null) + SqlValue::Float(f) } else if let Ok(s) = row.try_get::(col.ordinal()) { - serde_json::Value::String(s) + SqlValue::Text(s) } else { - serde_json::Value::Null + SqlValue::Null }; map.insert(name, val); }