diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 18107f84db4..bde7c86e4fc 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -1,5 +1,119 @@ pub mod vortex_tensor +pub mod vortex_tensor::encodings + +pub mod vortex_tensor::encodings::norm + +pub struct vortex_tensor::encodings::norm::NormVector + +impl core::clone::Clone for vortex_tensor::encodings::norm::NormVector + +pub fn vortex_tensor::encodings::norm::NormVector::clone(&self) -> vortex_tensor::encodings::norm::NormVector + +impl core::fmt::Debug for vortex_tensor::encodings::norm::NormVector + +pub fn vortex_tensor::encodings::norm::NormVector::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::vtable::VTable for vortex_tensor::encodings::norm::NormVector + +pub type vortex_tensor::encodings::norm::NormVector::Array = vortex_tensor::encodings::norm::NormVectorArray + +pub type vortex_tensor::encodings::norm::NormVector::Metadata = vortex_array::metadata::EmptyMetadata + +pub type vortex_tensor::encodings::norm::NormVector::OperationsVTable = vortex_tensor::encodings::norm::NormVector + +pub type vortex_tensor::encodings::norm::NormVector::ValidityVTable = vortex_array::vtable::validity::ValidityVTableFromValidityHelper + +pub fn vortex_tensor::encodings::norm::NormVector::array_eq(array: &vortex_tensor::encodings::norm::NormVectorArray, other: &vortex_tensor::encodings::norm::NormVectorArray, precision: vortex_array::hash::Precision) -> bool + +pub fn vortex_tensor::encodings::norm::NormVector::array_hash(array: &vortex_tensor::encodings::norm::NormVectorArray, state: &mut H, precision: vortex_array::hash::Precision) + +pub fn vortex_tensor::encodings::norm::NormVector::buffer(_array: &vortex_tensor::encodings::norm::NormVectorArray, idx: usize) -> vortex_array::buffer::BufferHandle + +pub fn vortex_tensor::encodings::norm::NormVector::buffer_name(_array: &vortex_tensor::encodings::norm::NormVectorArray, idx: usize) -> core::option::Option + +pub fn vortex_tensor::encodings::norm::NormVector::build(dtype: &vortex_array::dtype::DType, len: usize, _metadata: &Self::Metadata, _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVector::child(array: &vortex_tensor::encodings::norm::NormVectorArray, idx: usize) -> vortex_array::array::ArrayRef + +pub fn vortex_tensor::encodings::norm::NormVector::child_name(_array: &vortex_tensor::encodings::norm::NormVectorArray, idx: usize) -> alloc::string::String + +pub fn vortex_tensor::encodings::norm::NormVector::deserialize(_bytes: &[u8], _dtype: &vortex_array::dtype::DType, _len: usize, _buffers: &[vortex_array::buffer::BufferHandle], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVector::dtype(array: &vortex_tensor::encodings::norm::NormVectorArray) -> &vortex_array::dtype::DType + +pub fn vortex_tensor::encodings::norm::NormVector::execute(array: alloc::sync::Arc, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVector::id(&self) -> vortex_array::vtable::dyn_::ArrayId + +pub fn vortex_tensor::encodings::norm::NormVector::len(array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize + +pub fn vortex_tensor::encodings::norm::NormVector::metadata(_array: &vortex_tensor::encodings::norm::NormVectorArray) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVector::nbuffers(_array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize + +pub fn vortex_tensor::encodings::norm::NormVector::nchildren(array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize + +pub fn vortex_tensor::encodings::norm::NormVector::serialize(_metadata: Self::Metadata) -> vortex_error::VortexResult>> + +pub fn vortex_tensor::encodings::norm::NormVector::stats(array: &vortex_tensor::encodings::norm::NormVectorArray) -> vortex_array::stats::array::StatsSetRef<'_> + +pub fn vortex_tensor::encodings::norm::NormVector::vtable(_array: &Self::Array) -> &Self + +pub fn vortex_tensor::encodings::norm::NormVector::with_children(array: &mut vortex_tensor::encodings::norm::NormVectorArray, children: alloc::vec::Vec) -> vortex_error::VortexResult<()> + +impl vortex_array::vtable::operations::OperationsVTable for vortex_tensor::encodings::norm::NormVector + +pub fn vortex_tensor::encodings::norm::NormVector::scalar_at(array: &vortex_tensor::encodings::norm::NormVectorArray, index: usize) -> vortex_error::VortexResult + +pub struct vortex_tensor::encodings::norm::NormVectorArray + +impl vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::compress(vector_array: vortex_array::array::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVectorArray::decompress(&self, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVectorArray::norms(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_tensor::encodings::norm::NormVectorArray::try_new(vector_array: vortex_array::array::ArrayRef, norms: vortex_array::array::ArrayRef, validity: vortex_array::validity::Validity) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVectorArray::vector_array(&self) -> &vortex_array::array::ArrayRef + +impl vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::to_array(&self) -> vortex_array::array::ArrayRef + +impl core::clone::Clone for vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::clone(&self) -> vortex_tensor::encodings::norm::NormVectorArray + +impl core::convert::AsRef for vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::as_ref(&self) -> &dyn vortex_array::array::DynArray + +impl core::convert::From for vortex_array::array::ArrayRef + +pub fn vortex_array::array::ArrayRef::from(value: vortex_tensor::encodings::norm::NormVectorArray) -> vortex_array::array::ArrayRef + +impl core::fmt::Debug for vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::ops::deref::Deref for vortex_tensor::encodings::norm::NormVectorArray + +pub type vortex_tensor::encodings::norm::NormVectorArray::Target = dyn vortex_array::array::DynArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::deref(&self) -> &Self::Target + +impl vortex_array::array::IntoArray for vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::into_array(self) -> vortex_array::array::ArrayRef + +impl vortex_array::vtable::validity::ValidityHelper for vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::validity(&self) -> &vortex_array::validity::Validity + pub mod vortex_tensor::fixed_shape pub struct vortex_tensor::fixed_shape::FixedShapeTensor @@ -136,7 +250,7 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&se pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result @@ -166,7 +280,7 @@ pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::arity(&self, _options: &Self: pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName -pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result diff --git a/vortex-tensor/src/encodings/mod.rs b/vortex-tensor/src/encodings/mod.rs new file mode 100644 index 00000000000..9cc2b35b972 --- /dev/null +++ b/vortex-tensor/src/encodings/mod.rs @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +pub mod norm; +// TODO: Spherical coordinate encoding. diff --git a/vortex-tensor/src/encodings/norm/array.rs b/vortex-tensor/src/encodings/norm/array.rs new file mode 100644 index 00000000000..1ea2b7339b5 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/array.rs @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use num_traits::Float; +use vortex::array::ArrayRef; +use vortex::array::ExecutionCtx; +use vortex::array::IntoArray; +use vortex::array::arrays::ExtensionArray; +use vortex::array::arrays::FixedSizeListArray; +use vortex::array::arrays::PrimitiveArray; +use vortex::array::match_each_float_ptype; +use vortex::array::stats::ArrayStats; +use vortex::array::validity::Validity; +use vortex::dtype::DType; +use vortex::dtype::Nullability; +use vortex::dtype::extension::ExtDType; +use vortex::dtype::extension::ExtDTypeRef; +use vortex::error::VortexResult; +use vortex::error::vortex_ensure; +use vortex::error::vortex_ensure_eq; +use vortex::error::vortex_err; +use vortex::expr::Expression; +use vortex::expr::root; +use vortex::extension::EmptyMetadata; +use vortex::scalar_fn::EmptyOptions; +use vortex::scalar_fn::ScalarFn; + +use crate::scalar_fns::l2_norm::L2Norm; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; +use crate::vector::Vector; + +/// A normalized array that stores unit-normalized vectors alongside their original L2 norms. +/// +/// Each vector in the array is divided by its L2 norm, producing a unit-normalized vector. The +/// original norms are stored separately so that the original vectors can be reconstructed. +/// +/// The `vector_array` child carries its own validity and nullability, so a nullable input vector +/// array produces a nullable `NormVectorArray`. +#[derive(Debug, Clone)] +pub struct NormVectorArray { + /// The backing vector array that has been unit normalized. + /// + /// The underlying elements of the vector array must be floating-point. This child may be + /// nullable; its validity determines the validity of the `NormVectorArray`. + pub(crate) vector_array: ArrayRef, + + /// The L2 norms of each vector. + /// + /// This must have the same dtype as the elements of the vector array. + pub(crate) norms: ArrayRef, + + /// Stats set owned by this array. + pub(crate) stats_set: ArrayStats, +} + +impl NormVectorArray { + /// Creates a new [`NormVectorArray`] from a unit-normalized vector array and its L2 norms. + /// + /// The `vector_array` must be a [`Vector`] extension array with floating-point elements, and + /// `norms` must be a primitive array of the same float type with the same length. The + /// `vector_array` may be nullable. + pub fn try_new(vector_array: ArrayRef, norms: ArrayRef) -> VortexResult { + let ext = Self::validate(&vector_array)?; + + let element_ptype = extension_element_ptype(&ext)?; + + let nullability = Nullability::from(vector_array.dtype().is_nullable()); + let expected_norms_dtype = DType::Primitive(element_ptype, nullability); + vortex_ensure_eq!( + *norms.dtype(), + expected_norms_dtype, + "norms dtype must match vector element type" + ); + + vortex_ensure_eq!( + vector_array.len(), + norms.len(), + "vector_array and norms must have the same length" + ); + + Ok(Self { + vector_array, + norms, + stats_set: ArrayStats::default(), + }) + } + + /// Validates that the given array has the [`Vector`] extension type and returns the extension + /// dtype. + fn validate(vector_array: &ArrayRef) -> VortexResult { + let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| { + vortex_err!( + "vector_array dtype must be an extension type, got {}", + vector_array.dtype() + ) + })?; + + vortex_ensure!( + ext.is::(), + "vector_array must have the Vector extension type, got {}", + vector_array.dtype() + ); + + Ok(ext.clone()) + } + + /// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and + /// dividing each vector by its norm. + /// + /// The input must be a [`Vector`] extension array with floating-point elements. Nullable inputs + /// are supported; the validity mask is preserved and the normalized data for null rows is + /// unspecified. + pub fn compress(vector_array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let ext = Self::validate(&vector_array)?; + + let list_size = extension_list_size(&ext)?; + let row_count = vector_array.len(); + let nullability = Nullability::from(vector_array.dtype().is_nullable()); + + // Compute L2 norms using the scalar function. If the input is nullable, the norms will + // also be nullable (null vectors produce null norms). + let storage = extension_storage(&vector_array)?; + let l2_norm_expr = + Expression::try_new(ScalarFn::new(L2Norm, EmptyOptions).erased(), [root()])?; + let norms_prim: PrimitiveArray = vector_array.apply(&l2_norm_expr)?.execute(ctx)?; + let norms_array = norms_prim.clone().into_array(); + + // Extract flat elements from the (always non-nullable) storage for normalization. + let flat = extract_flat_elements(&storage, list_size, ctx)?; + + match_each_float_ptype!(flat.ptype(), |T| { + let norms_slice = norms_prim.as_slice::(); + + let normalized_elems: PrimitiveArray = (0..row_count) + .flat_map(|i| { + let inv_norm = safe_inv_norm(norms_slice[i]); + flat.row::(i).iter().map(move |&v| v * inv_norm) + }) + .collect(); + + // Reconstruct the vector array with the same nullability as the input. + let validity = Validity::from(nullability); + let fsl = FixedSizeListArray::new( + normalized_elems.into_array(), + u32::try_from(list_size)?, + validity, + row_count, + ); + + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + let normalized_vector = ExtensionArray::new(ext_dtype, fsl.into_array()).into_array(); + + Self::try_new(normalized_vector, norms_array) + }) + } + + /// Returns a reference to the backing vector array that has been unit normalized. + pub fn vector_array(&self) -> &ArrayRef { + &self.vector_array + } + + /// Returns a reference to the L2 norms of each vector. + pub fn norms(&self) -> &ArrayRef { + &self.norms + } + + /// Reconstructs the original vectors by multiplying each unit-normalized vector by its L2 norm. + /// + /// The returned array has the same dtype (including nullability) as the original + /// `vector_array` child. + pub fn decompress(&self, ctx: &mut ExecutionCtx) -> VortexResult { + let ext = Self::validate(&self.vector_array)?; + let nullability = Nullability::from(self.vector_array.dtype().is_nullable()); + + let list_size = extension_list_size(&ext)?; + let row_count = self.vector_array.len(); + + let storage = extension_storage(&self.vector_array)?; + let flat = extract_flat_elements(&storage, list_size, ctx)?; + + let norms_prim: PrimitiveArray = self.norms.clone().execute(ctx)?; + + match_each_float_ptype!(flat.ptype(), |T| { + let norms_slice = norms_prim.as_slice::(); + + let result_elems: PrimitiveArray = (0..row_count) + .flat_map(|i| { + let norm = norms_slice[i]; + flat.row::(i).iter().map(move |&v| v * norm) + }) + .collect(); + + let validity = Validity::from(nullability); + let fsl = FixedSizeListArray::new( + result_elems.into_array(), + u32::try_from(list_size)?, + validity, + row_count, + ); + + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + }) + } +} + +/// Returns `1 / norm` if the norm is non-zero, or zero otherwise. +/// +/// This avoids division by zero for zero-length or all-zero vectors. +fn safe_inv_norm(norm: T) -> T { + if norm == T::zero() { + T::zero() + } else { + T::one() / norm + } +} diff --git a/vortex-tensor/src/encodings/norm/mod.rs b/vortex-tensor/src/encodings/norm/mod.rs new file mode 100644 index 00000000000..4a060567bbc --- /dev/null +++ b/vortex-tensor/src/encodings/norm/mod.rs @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +mod array; +pub use array::NormVectorArray; + +// TODO: Compute operations for NormVector. + +mod vtable; +pub use vtable::NormVector; + +#[cfg(test)] +mod tests; diff --git a/vortex-tensor/src/encodings/norm/tests.rs b/vortex-tensor/src/encodings/norm/tests.rs new file mode 100644 index 00000000000..66b3cdb14dc --- /dev/null +++ b/vortex-tensor/src/encodings/norm/tests.rs @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::IntoArray; +use vortex::array::LEGACY_SESSION; +use vortex::array::VortexSessionExecute; +use vortex::array::arrays::Extension; +use vortex::array::arrays::PrimitiveArray; +use vortex::error::VortexResult; + +use crate::encodings::norm::NormVectorArray; +use crate::utils::test_helpers::assert_close; +use crate::utils::test_helpers::extract_vector_rows; +use crate::utils::test_helpers::vector_array; + +#[test] +fn encode_unit_vectors() -> VortexResult<()> { + // Already unit-length vectors: norms should be 1.0 and vectors unchanged. + let arr = vector_array( + 3, + &[ + 1.0, 0.0, 0.0, // norm = 1.0 + 0.0, 1.0, 0.0, // norm = 1.0 + ], + )?; + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let norm = NormVectorArray::compress(arr, &mut ctx)?; + let norms: PrimitiveArray = norm.norms().clone().execute(&mut ctx)?; + assert_close(norms.as_slice::(), &[1.0, 1.0]); + + let rows = extract_vector_rows(norm.vector_array(), &mut ctx)?; + assert_close(&rows[0], &[1.0, 0.0, 0.0]); + assert_close(&rows[1], &[0.0, 1.0, 0.0]); + + Ok(()) +} + +#[test] +fn encode_non_unit_vectors() -> VortexResult<()> { + let arr = vector_array( + 2, + &[ + 3.0, 4.0, // norm = 5.0 + 0.0, 0.0, // norm = 0.0 (zero vector) + ], + )?; + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let norm = NormVectorArray::compress(arr, &mut ctx)?; + let norms: PrimitiveArray = norm.norms().clone().execute(&mut ctx)?; + assert_close(norms.as_slice::(), &[5.0, 0.0]); + + let rows = extract_vector_rows(norm.vector_array(), &mut ctx)?; + assert_close(&rows[0], &[3.0 / 5.0, 4.0 / 5.0]); + assert_close(&rows[1], &[0.0, 0.0]); + + Ok(()) +} + +#[test] +fn execute_round_trip() -> VortexResult<()> { + let arr = vector_array( + 2, + &[ + 3.0, 4.0, // norm = 5.0 + 6.0, 8.0, // norm = 10.0 + ], + )?; + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let norm = NormVectorArray::compress(arr, &mut ctx)?; + + // Execute to reconstruct the original vectors. + let reconstructed = norm.decompress(&mut ctx)?; + + // The reconstructed array should be a Vector extension array. + assert!(reconstructed.as_opt::().is_some()); + + let rows = extract_vector_rows(&reconstructed, &mut ctx)?; + assert_close(&rows[0], &[3.0, 4.0]); + assert_close(&rows[1], &[6.0, 8.0]); + + Ok(()) +} + +#[test] +fn execute_round_trip_zero_vector() -> VortexResult<()> { + let arr = vector_array(2, &[0.0, 0.0])?; + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let norm = NormVectorArray::compress(arr, &mut ctx)?; + + let reconstructed = norm.decompress(&mut ctx)?; + + let rows = extract_vector_rows(&reconstructed, &mut ctx)?; + // Zero vector should remain zero after round-trip. + assert_close(&rows[0], &[0.0, 0.0]); + + Ok(()) +} + +#[test] +fn scalar_at_returns_original_vector() -> VortexResult<()> { + let arr = vector_array( + 2, + &[ + 3.0, 4.0, // norm = 5.0 + 6.0, 8.0, // norm = 10.0 + ], + )?; + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let encoded = NormVectorArray::compress(arr, &mut ctx)?; + + // `scalar_at` on the NormVectorArray should match `scalar_at` on the decompressed result. + let decompressed = encoded.decompress(&mut ctx)?; + + let norm_array = encoded.into_array(); + for i in 0..2 { + assert_eq!(norm_array.scalar_at(i)?, decompressed.scalar_at(i)?); + } + + Ok(()) +} diff --git a/vortex-tensor/src/encodings/norm/vtable/mod.rs b/vortex-tensor/src/encodings/norm/vtable/mod.rs new file mode 100644 index 00000000000..51a7dd4d47f --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/mod.rs @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::hash::Hasher; +use std::sync::Arc; + +use vortex::array::ArrayEq; +use vortex::array::ArrayHash; +use vortex::array::ArrayRef; +use vortex::array::EmptyMetadata; +use vortex::array::ExecutionCtx; +use vortex::array::ExecutionResult; +use vortex::array::Precision; +use vortex::array::buffer::BufferHandle; +use vortex::array::serde::ArrayChildren; +use vortex::array::stats::StatsSetRef; +use vortex::array::vtable; +use vortex::array::vtable::ArrayId; +use vortex::array::vtable::VTable; +use vortex::array::vtable::ValidityVTableFromChild; +use vortex::dtype::DType; +use vortex::dtype::Nullability; +use vortex::error::VortexResult; +use vortex::error::vortex_ensure_eq; +use vortex::error::vortex_err; +use vortex::error::vortex_panic; +use vortex::session::VortexSession; + +use crate::encodings::norm::array::NormVectorArray; +use crate::utils::extension_element_ptype; + +mod operations; +mod validity; + +vtable!(NormVector); + +#[derive(Debug, Clone)] +pub struct NormVector; + +impl VTable for NormVector { + type Array = NormVectorArray; + type Metadata = EmptyMetadata; + type OperationsVTable = Self; + type ValidityVTable = ValidityVTableFromChild; + + fn vtable(_array: &Self::Array) -> &Self { + &NormVector + } + + fn id(&self) -> ArrayId { + ArrayId::new_ref("vortex.tensor.norm_vector") + } + + fn len(array: &NormVectorArray) -> usize { + array.vector_array().len() + } + + fn dtype(array: &NormVectorArray) -> &DType { + array.vector_array().dtype() + } + + fn stats(array: &NormVectorArray) -> StatsSetRef<'_> { + array.stats_set.to_ref(array.as_ref()) + } + + fn array_hash(array: &NormVectorArray, state: &mut H, precision: Precision) { + array.vector_array().array_hash(state, precision); + array.norms().array_hash(state, precision); + } + + fn array_eq(array: &NormVectorArray, other: &NormVectorArray, precision: Precision) -> bool { + array.norms().array_eq(other.norms(), precision) + && array + .vector_array() + .array_eq(other.vector_array(), precision) + } + + fn nbuffers(_array: &NormVectorArray) -> usize { + 0 + } + + fn buffer(_array: &NormVectorArray, idx: usize) -> BufferHandle { + vortex_panic!("NormVectorArray has no buffers (index {idx})") + } + + fn buffer_name(_array: &NormVectorArray, idx: usize) -> Option { + vortex_panic!("NormVectorArray has no buffers (index {idx})") + } + + fn nchildren(_array: &NormVectorArray) -> usize { + 2 + } + + fn child(array: &NormVectorArray, idx: usize) -> ArrayRef { + match idx { + 0 => array.vector_array().clone(), + 1 => array.norms().clone(), + _ => vortex_panic!("NormVectorArray child index {idx} out of bounds"), + } + } + + fn child_name(_array: &NormVectorArray, idx: usize) -> String { + match idx { + 0 => "vector_array".to_string(), + 1 => "norms".to_string(), + _ => vortex_panic!("NormVectorArray child_name index {idx} out of bounds"), + } + } + + fn metadata(_array: &NormVectorArray) -> VortexResult { + Ok(EmptyMetadata) + } + + fn serialize(_metadata: Self::Metadata) -> VortexResult>> { + Ok(Some(vec![])) + } + + fn deserialize( + _bytes: &[u8], + _dtype: &DType, + _len: usize, + _buffers: &[BufferHandle], + _session: &VortexSession, + ) -> VortexResult { + Ok(EmptyMetadata) + } + + fn build( + dtype: &DType, + len: usize, + _metadata: &Self::Metadata, + _buffers: &[BufferHandle], + children: &dyn ArrayChildren, + ) -> VortexResult { + vortex_ensure_eq!( + children.len(), + 2, + "NormVectorArray requires exactly 2 children" + ); + + let vector_array = children.get(0, dtype, len)?; + + let ext = dtype.as_extension_opt().ok_or_else(|| { + vortex_err!("NormVectorArray dtype must be an extension type, got {dtype}") + })?; + let element_ptype = extension_element_ptype(ext)?; + let nullability = Nullability::from(dtype.is_nullable()); + let norms_dtype = DType::Primitive(element_ptype, nullability); + let norms = children.get(1, &norms_dtype, len)?; + + NormVectorArray::try_new(vector_array, norms) + } + + fn with_children(array: &mut NormVectorArray, children: Vec) -> VortexResult<()> { + vortex_ensure_eq!( + children.len(), + 2, + "NormVectorArray requires exactly 2 children" + ); + + let [vector_array, norms]: [ArrayRef; 2] = children + .try_into() + .map_err(|_| vortex_err!("NormVectorArray requires exactly 2 children"))?; + + *array = NormVectorArray::try_new(vector_array, norms)?; + Ok(()) + } + + fn execute( + array: Arc, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + Ok(ExecutionResult::done(array.decompress(ctx)?)) + } +} diff --git a/vortex-tensor/src/encodings/norm/vtable/operations.rs b/vortex-tensor/src/encodings/norm/vtable/operations.rs new file mode 100644 index 00000000000..a384501f8d8 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/operations.rs @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::IntoArray; +use vortex::array::arrays::ConstantArray; +use vortex::array::arrays::FixedSizeList; +use vortex::array::builtins::ArrayBuiltins; +use vortex::array::vtable::OperationsVTable; +use vortex::dtype::Nullability; +use vortex::error::VortexResult; +use vortex::error::vortex_err; +use vortex::scalar::Scalar; +use vortex::scalar_fn::fns::operators::Operator; + +use crate::encodings::norm::array::NormVectorArray; +use crate::encodings::norm::vtable::NormVector; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; + +impl OperationsVTable for NormVector { + fn scalar_at(array: &NormVectorArray, index: usize) -> VortexResult { + let ext = array + .vector_array() + .dtype() + .as_extension_opt() + .ok_or_else(|| { + vortex_err!( + "expected Vector extension dtype, got {}", + array.vector_array().dtype() + ) + })?; + let list_size = extension_list_size(ext)?; + + // Get the storage (FixedSizeList) and slice out the elements for this row. + let storage = extension_storage(array.vector_array())?; + let fsl = storage + .as_opt::() + .ok_or_else(|| vortex_err!("expected FixedSizeList storage"))?; + let row_elements = fsl.fixed_size_list_elements_at(index)?; + + // Multiply all elements by the norm using a ConstantArray broadcast. + let norm_scalar = array.norms().scalar_at(index)?; + let norm_broadcast = ConstantArray::new(norm_scalar, list_size).into_array(); + let scaled = row_elements.binary(norm_broadcast, Operator::Mul)?; + + // Rebuild the FSL scalar, then wrap in the extension type. + let element_dtype = ext + .storage_dtype() + .as_fixed_size_list_element_opt() + .ok_or_else(|| { + vortex_err!( + "expected FixedSizeList storage dtype, got {}", + ext.storage_dtype() + ) + })?; + + let children: Vec = (0..list_size) + .map(|i| scaled.scalar_at(i)) + .collect::>()?; + + let fsl_scalar = + Scalar::fixed_size_list(element_dtype.clone(), children, Nullability::NonNullable); + + Ok(Scalar::extension_ref(ext.clone(), fsl_scalar)) + } +} diff --git a/vortex-tensor/src/encodings/norm/vtable/validity.rs b/vortex-tensor/src/encodings/norm/vtable/validity.rs new file mode 100644 index 00000000000..8925ffc7378 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/validity.rs @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::ArrayRef; +use vortex::array::vtable::ValidityChild; + +use crate::encodings::norm::array::NormVectorArray; +use crate::encodings::norm::vtable::NormVector; + +impl ValidityChild for NormVector { + fn validity_child(array: &NormVectorArray) -> &ArrayRef { + array.vector_array() + } +} diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 56e96488167..c036b9854b2 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -5,8 +5,12 @@ //! including unit vectors, spherical coordinates, and similarity measures such as cosine //! similarity. +pub mod matcher; +pub mod scalar_fns; + pub mod fixed_shape; pub mod vector; -pub mod matcher; -pub mod scalar_fns; +pub mod encodings; + +mod utils; diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index cd2f158d719..d2f9b539e6d 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId; use vortex::scalar_fn::ScalarFnVTable; use crate::matcher::AnyTensor; -use crate::scalar_fns::utils::extension_element_ptype; -use crate::scalar_fns::utils::extension_list_size; -use crate::scalar_fns::utils::extension_storage; -use crate::scalar_fns::utils::extract_flat_elements; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; /// Cosine similarity between two columns. /// @@ -115,7 +115,7 @@ impl ScalarFnVTable for CosineSimilarity { &self, _options: &Self::Options, args: &dyn ExecutionArgs, - _ctx: &mut ExecutionCtx, + ctx: &mut ExecutionCtx, ) -> VortexResult { let lhs = args.get(0)?; let rhs = args.get(1)?; @@ -135,8 +135,8 @@ impl ScalarFnVTable for CosineSimilarity { let lhs_storage = extension_storage(&lhs)?; let rhs_storage = extension_storage(&rhs)?; - let lhs_flat = extract_flat_elements(&lhs_storage, list_size)?; - let rhs_flat = extract_flat_elements(&rhs_storage, list_size)?; + let lhs_flat = extract_flat_elements(&lhs_storage, list_size, ctx)?; + let rhs_flat = extract_flat_elements(&rhs_storage, list_size, ctx)?; match_each_float_ptype!(lhs_flat.ptype(), |T| { let result: PrimitiveArray = (0..row_count) @@ -196,11 +196,11 @@ mod tests { use vortex::scalar_fn::ScalarFn; use crate::scalar_fns::cosine_similarity::CosineSimilarity; - use crate::scalar_fns::utils::test_helpers::assert_close; - use crate::scalar_fns::utils::test_helpers::constant_tensor_array; - use crate::scalar_fns::utils::test_helpers::constant_vector_array; - use crate::scalar_fns::utils::test_helpers::tensor_array; - use crate::scalar_fns::utils::test_helpers::vector_array; + use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::constant_tensor_array; + use crate::utils::test_helpers::constant_vector_array; + use crate::utils::test_helpers::tensor_array; + use crate::utils::test_helpers::vector_array; /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index e0a3bac4143..b605b4a844d 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId; use vortex::scalar_fn::ScalarFnVTable; use crate::matcher::AnyTensor; -use crate::scalar_fns::utils::extension_element_ptype; -use crate::scalar_fns::utils::extension_list_size; -use crate::scalar_fns::utils::extension_storage; -use crate::scalar_fns::utils::extract_flat_elements; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; /// L2 norm (Euclidean norm) of a tensor or vector column. /// @@ -98,7 +98,7 @@ impl ScalarFnVTable for L2Norm { &self, _options: &Self::Options, args: &dyn ExecutionArgs, - _ctx: &mut ExecutionCtx, + ctx: &mut ExecutionCtx, ) -> VortexResult { let input = args.get(0)?; let row_count = args.row_count(); @@ -113,7 +113,7 @@ impl ScalarFnVTable for L2Norm { let list_size = extension_list_size(ext)?; let storage = extension_storage(&input)?; - let flat = extract_flat_elements(&storage, list_size)?; + let flat = extract_flat_elements(&storage, list_size, ctx)?; match_each_float_ptype!(flat.ptype(), |T| { let result: PrimitiveArray = (0..row_count) @@ -163,9 +163,9 @@ mod tests { use vortex::scalar_fn::ScalarFn; use crate::scalar_fns::l2_norm::L2Norm; - use crate::scalar_fns::utils::test_helpers::assert_close; - use crate::scalar_fns::utils::test_helpers::tensor_array; - use crate::scalar_fns::utils::test_helpers::vector_array; + use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::tensor_array; + use crate::utils::test_helpers::vector_array; /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. fn eval_l2_norm(input: vortex::array::ArrayRef, len: usize) -> VortexResult> { diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs index 2597f1115c8..2f56305cd53 100644 --- a/vortex-tensor/src/scalar_fns/mod.rs +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -5,5 +5,3 @@ pub mod cosine_similarity; pub mod l2_norm; - -mod utils; diff --git a/vortex-tensor/src/scalar_fns/utils.rs b/vortex-tensor/src/utils.rs similarity index 85% rename from vortex-tensor/src/scalar_fns/utils.rs rename to vortex-tensor/src/utils.rs index 0eb3e423ea0..2a99cf06b3e 100644 --- a/vortex-tensor/src/scalar_fns/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -2,10 +2,12 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex::array::ArrayRef; +use vortex::array::ExecutionCtx; use vortex::array::IntoArray; use vortex::array::arrays::Constant; use vortex::array::arrays::ConstantArray; use vortex::array::arrays::Extension; +use vortex::array::arrays::FixedSizeListArray; use vortex::array::arrays::PrimitiveArray; use vortex::dtype::DType; use vortex::dtype::NativePType; @@ -91,13 +93,17 @@ impl FlatElements { /// /// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is /// materialized to avoid expanding it to the full column length. -pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResult { +pub fn extract_flat_elements( + storage: &ArrayRef, + list_size: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult { if let Some(constant) = storage.as_opt::() { // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a huge // amount of data. let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); - let fsl = single.to_canonical()?.into_fixed_size_list(); - let elems = fsl.elements().to_canonical()?.into_primitive(); + let fsl: FixedSizeListArray = single.execute(ctx)?; + let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?; return Ok(FlatElements { elems, stride: 0, @@ -106,8 +112,8 @@ pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResu } // Otherwise we have to fully expand all of the data. - let fsl = storage.to_canonical()?.into_fixed_size_list(); - let elems = fsl.elements().to_canonical()?.into_primitive(); + let fsl: FixedSizeListArray = storage.clone().execute(ctx)?; + let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?; Ok(FlatElements { elems, stride: list_size, @@ -118,6 +124,7 @@ pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResu #[cfg(test)] pub mod test_helpers { use vortex::array::ArrayRef; + use vortex::array::ExecutionCtx; use vortex::array::IntoArray; use vortex::array::arrays::ConstantArray; use vortex::array::arrays::ExtensionArray; @@ -128,9 +135,13 @@ pub mod test_helpers { use vortex::dtype::Nullability; use vortex::dtype::extension::ExtDType; use vortex::error::VortexResult; + use vortex::error::vortex_err; use vortex::extension::EmptyMetadata; use vortex::scalar::Scalar; + use super::extension_list_size; + use super::extension_storage; + use super::extract_flat_elements; use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; use crate::vector::Vector; @@ -210,6 +221,25 @@ pub mod test_helpers { Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } + /// Extracts the f64 rows from a [`Vector`] extension array. + /// + /// Returns a `Vec>` where each inner vec is one vector's elements. + pub fn extract_vector_rows( + array: &ArrayRef, + ctx: &mut ExecutionCtx, + ) -> VortexResult>> { + let ext = array + .dtype() + .as_extension_opt() + .ok_or_else(|| vortex_err!("expected Vector extension dtype, got {}", array.dtype()))?; + let list_size = extension_list_size(ext)?; + let storage = extension_storage(array)?; + let flat = extract_flat_elements(&storage, list_size, ctx)?; + Ok((0..array.len()) + .map(|i| flat.row::(i).to_vec()) + .collect()) + } + /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected` /// value, with support for NaN (NaN == NaN is considered equal). #[track_caller]