Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ sqlx = { version = "0.8.6", features = [
# Full tokio runtime. "full" is fine for a library crate — callers can restrict
# features if they need a lighter binary.
tokio = { version = "1.40", features = ["full"] }
smallvec = "1.13"

# ── Serialization ─────────────────────────────────────────────────────────────
# serde + serde_json: used to pass structured data between Rust and Python
Expand Down
1 change: 1 addition & 0 deletions ryx-query/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ serde_json = "1"
thiserror = "2"
once_cell = "1"
tracing = "0.1"
smallvec = "1.13"

[dev-dependencies]
criterion = { version = "0.5", features = ["async_tokio"] }
Expand Down
2 changes: 1 addition & 1 deletion ryx-query/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub enum SqlValue {
Text(String),
/// Used by `__in` and `__range` lookups. The compiler expands it into
/// multiple bind placeholders.
List(Vec<SqlValue>),
List(smallvec::SmallVec<[Box<SqlValue>; 4]>),
}

impl SqlValue {
Expand Down
43 changes: 24 additions & 19 deletions ryx-query/src/compiler/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::errors::{QueryError, QueryResult};
use crate::lookups::date_lookups as date;
use crate::lookups::json_lookups as json;
use crate::lookups::{self, LookupContext};
use smallvec::SmallVec;

pub use super::helpers::{apply_like_wrapping, qualified_col, split_qualified, KNOWN_TRANSFORMS};

Expand All @@ -24,12 +25,12 @@ use super::helpers;
#[derive(Debug, Clone)]
pub struct CompiledQuery {
pub sql: String,
pub values: Vec<SqlValue>,
pub values: SmallVec<[SqlValue; 8]>,
pub db_alias: Option<String>,
}

pub fn compile(node: &QueryNode) -> QueryResult<CompiledQuery> {
let mut values: Vec<SqlValue> = Vec::new();
let mut values: SmallVec<[SqlValue; 8]> = SmallVec::new();
let sql = match &node.operation {
QueryOperation::Select { columns } => {
compile_select(node, columns.as_deref(), &mut values)?
Expand All @@ -53,7 +54,7 @@ pub fn compile(node: &QueryNode) -> QueryResult<CompiledQuery> {
fn compile_select(
node: &QueryNode,
columns: Option<&[String]>,
values: &mut Vec<SqlValue>,
values: &mut SmallVec<[SqlValue; 8]>,
) -> QueryResult<String> {
let base_cols = match columns {
None => "*".to_string(),
Expand Down Expand Up @@ -134,7 +135,7 @@ fn compile_select(
Ok(sql)
}

fn compile_aggregate(node: &QueryNode, values: &mut Vec<SqlValue>) -> QueryResult<String> {
fn compile_aggregate(node: &QueryNode, values: &mut SmallVec<[SqlValue; 8]>) -> QueryResult<String> {
if node.annotations.is_empty() {
return Err(QueryError::Internal(
"aggregate() called with no aggregate expressions".into(),
Expand All @@ -158,7 +159,7 @@ fn compile_aggregate(node: &QueryNode, values: &mut Vec<SqlValue>) -> QueryResul
Ok(sql)
}

fn compile_count(node: &QueryNode, values: &mut Vec<SqlValue>) -> QueryResult<String> {
fn compile_count(node: &QueryNode, values: &mut SmallVec<[SqlValue; 8]>) -> QueryResult<String> {
let mut sql = format!("SELECT COUNT(*) FROM {}", helpers::quote_col(&node.table));
if !node.joins.is_empty() {
sql.push(' ');
Expand All @@ -173,7 +174,7 @@ fn compile_count(node: &QueryNode, values: &mut Vec<SqlValue>) -> QueryResult<St
Ok(sql)
}

fn compile_delete(node: &QueryNode, values: &mut Vec<SqlValue>) -> QueryResult<String> {
fn compile_delete(node: &QueryNode, values: &mut SmallVec<[SqlValue; 8]>) -> QueryResult<String> {
let mut sql = format!("DELETE FROM {}", helpers::quote_col(&node.table));
let where_sql =
compile_where_combined(&node.filters, node.q_filter.as_ref(), values, node.backend)?;
Expand All @@ -187,7 +188,7 @@ fn compile_delete(node: &QueryNode, values: &mut Vec<SqlValue>) -> QueryResult<S
fn compile_update(
node: &QueryNode,
assignments: &[(String, SqlValue)],
values: &mut Vec<SqlValue>,
values: &mut SmallVec<[SqlValue; 8]>,
) -> QueryResult<String> {
if assignments.is_empty() {
return Err(QueryError::Internal("UPDATE with no assignments".into()));
Expand Down Expand Up @@ -217,7 +218,7 @@ fn compile_insert(
node: &QueryNode,
cols_vals: &[(String, SqlValue)],
returning_id: bool,
values: &mut Vec<SqlValue>,
values: &mut SmallVec<[SqlValue; 8]>,
) -> QueryResult<String> {
if cols_vals.is_empty() {
return Err(QueryError::Internal("INSERT with no values".into()));
Expand Down Expand Up @@ -340,7 +341,7 @@ pub fn compile_order_by(clauses: &[crate::ast::OrderByClause]) -> String {
fn compile_where_combined(
filters: &[FilterNode],
q: Option<&QNode>,
values: &mut Vec<SqlValue>,
values: &mut SmallVec<[SqlValue; 8]>,
backend: Backend,
) -> QueryResult<String> {
let flat = if filters.is_empty() {
Expand All @@ -361,7 +362,11 @@ fn compile_where_combined(
})
}

pub fn compile_q(q: &QNode, values: &mut Vec<SqlValue>, backend: Backend) -> QueryResult<String> {
pub fn compile_q(
q: &QNode,
values: &mut SmallVec<[SqlValue; 8]>,
backend: Backend,
) -> QueryResult<String> {
match q {
QNode::Leaf {
field,
Expand Down Expand Up @@ -392,7 +397,7 @@ pub fn compile_q(q: &QNode, values: &mut Vec<SqlValue>, backend: Backend) -> Que

fn compile_filters(
filters: &[FilterNode],
values: &mut Vec<SqlValue>,
values: &mut SmallVec<[SqlValue; 8]>,
backend: Backend,
) -> QueryResult<String> {
let parts: Vec<String> = filters
Expand All @@ -407,7 +412,7 @@ fn compile_single_filter(
lookup: &str,
value: &SqlValue,
negated: bool,
values: &mut Vec<SqlValue>,
values: &mut SmallVec<[SqlValue; 8]>,
backend: Backend,
) -> QueryResult<String> {
let (base_column, applied_transforms, json_key) = if field.contains("__") {
Expand Down Expand Up @@ -474,9 +479,9 @@ fn compile_single_filter(
}

if lookup == "in" {
let items = match value {
SqlValue::List(v) => v.clone(),
other => vec![other.clone()],
let items: SmallVec<[SqlValue; 4]> = match value {
SqlValue::List(v) => v.iter().map(|x| (**x).clone()).collect(),
other => smallvec::smallvec![(*other).clone()],
};
if items.is_empty() {
return Ok("(1 = 0)".into());
Expand All @@ -495,9 +500,9 @@ fn compile_single_filter(
}

if lookup == "has_any" || lookup == "has_all" {
let items = match value {
SqlValue::List(v) => v.clone(),
other => vec![other.clone()],
let items: SmallVec<[SqlValue; 4]> = match value {
SqlValue::List(v) => v.iter().map(|x| (**x).clone()).collect(),
other => smallvec::smallvec![(*other).clone()],
};
if items.is_empty() {
return Ok("(1 = 0)".into());
Expand Down Expand Up @@ -537,7 +542,7 @@ fn compile_single_filter(

if lookup == "range" {
let (lo, hi) = match value {
SqlValue::List(v) if v.len() == 2 => (v[0].clone(), v[1].clone()),
SqlValue::List(v) if v.len() == 2 => (v[0].as_ref().clone(), v[1].as_ref().clone()),
_ => return Err(QueryError::Internal("range needs exactly 2 values".into())),
};
values.push(lo);
Expand Down
82 changes: 29 additions & 53 deletions ryx/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@

from __future__ import annotations

# import asyncio
# import itertools
from typing import List, Sequence, Type, TYPE_CHECKING

if TYPE_CHECKING:
from ryx.models import Model

from ryx import ryx_core as _core


def _detect_backend() -> str:
"""Detect the database backend from the RYX_DATABASE_URL env var.
Expand Down Expand Up @@ -113,12 +113,20 @@ async def bulk_create(

pk_field = model._meta.pk_field

# Process in batches
# Process in batches — all SQL and execution handled in Rust
for batch in _chunked(instances, batch_size):
pks = await _insert_batch(model, batch, fields, col_names, ignore_conflicts)
# Assign returned PKs to instances
for inst, pk in zip(batch, pks):
object.__setattr__(inst, pk_field.attname, pk)
rows = [[f.to_db(getattr(inst, f.attname)) for f in fields] for inst in batch]
res = await _core.bulk_insert(
model._meta.table_name,
col_names,
rows,
True, # returning_id
ignore_conflicts,
)
# On PostgreSQL/SQLite res is list of ids; on MySQL res is rows_affected
if pk_field and isinstance(res, list):
for inst, pk in zip(batch, res):
object.__setattr__(inst, pk_field.attname, pk)

return list(instances)

Expand Down Expand Up @@ -255,11 +263,7 @@ async def bulk_update(
}
total = 0

from ryx import ryx_core as _core
from ryx.pool_ext import execute_with_params

for batch in _chunked(instances, batch_size):
# Collect valid instances (with pk set)
valid = [inst for inst in batch if inst.pk is not None]
if not valid:
continue
Expand All @@ -268,55 +272,27 @@ async def bulk_update(
pk_col = pk_field.column
table = model._meta.table_name

# Build CASE WHEN clauses.
# Strategy: inline integers directly in SQL (zero FFI cost),
# use ? placeholders only for non-integer values.
case_clauses = []
all_values = []

# Collect values per column in the order of pks
col_names: List[str] = []
field_values: List[List[object]] = []
for fname in update_fields:
if fname not in field_objs:
continue
fobj = field_objs[fname]
col = fobj.column
case_parts = [f'"{col}" = CASE "{pk_col}"']
for inst in valid:
val = fobj.to_db(getattr(inst, fname))
if isinstance(val, int) and not isinstance(val, bool):
# Inline integers — zero FFI overhead
case_parts.append(f"WHEN {inst.pk} THEN {val}")
else:
case_parts.append("WHEN ? THEN ?")
all_values.append(inst.pk)
all_values.append(val)
case_parts.append("END")
case_clauses.append(" ".join(case_parts))

if not case_clauses:
continue
col_names.append(fobj.column)
vals = [fobj.to_db(getattr(inst, fname)) for inst in valid]
field_values.append(vals)

# WHERE IN — inline integer PKs
pk_parts = []
for pk in pks:
if isinstance(pk, int):
pk_parts.append(str(pk))
else:
pk_parts.append("?")
all_values.append(pk)
if not col_names:
continue

sql = (
f'UPDATE "{table}" SET '
f"{', '.join(case_clauses)} "
f'WHERE "{pk_col}" IN ({", ".join(pk_parts)})'
result = await _core.bulk_update(
table,
pk_col,
list(zip(col_names,field_values)),
pks,
)

if all_values:
await execute_with_params(sql, all_values)
else:
from ryx.executor_helpers import raw_execute

await raw_execute(sql)
total += len(valid)
total += result

return total

Expand Down
7 changes: 2 additions & 5 deletions ryx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,6 @@ def __init__(self, alias: Optional[str] = None) -> None:
def contribute_to_class(self, model: type, name: str) -> None:
self._model = model

def contribute_to_class(self, model: type, name: str) -> None:
self._model = model

def get_queryset(self):
from ryx.queryset import QuerySet

Expand Down Expand Up @@ -292,9 +289,9 @@ async def bulk_create(self, instances: list[Model], batch_size: int = 500) -> li
async def bulk_update(
self, instances: list, fields: list, batch_size: int = 500
) -> int:
from ryx.bulk import bulk_update
from ryx.bulk import bulk_update as _update

return await bulk_update(self._model, instances, fields, batch_size=batch_size)
return await _update(self._model, instances, fields, batch_size=batch_size)

async def bulk_delete(
self, instances: Optional[list] = None, batch_size: int = 500
Expand Down
Loading
Loading