From 0fceac5d946f3439e7fe94e5c9d86c46ce9f5377 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 20 Mar 2026 16:38:28 -0400 Subject: [PATCH 1/6] boilerplate with `NormVector` encoding Signed-off-by: Connor Tsui --- vortex-tensor/src/encodings/mod.rs | 5 + vortex-tensor/src/encodings/norm/array.rs | 33 +++++ vortex-tensor/src/encodings/norm/mod.rs | 13 ++ .../src/encodings/norm/vtable/mod.rs | 121 ++++++++++++++++++ .../src/encodings/norm/vtable/operations.rs | 15 +++ .../src/encodings/norm/vtable/validity.rs | 14 ++ vortex-tensor/src/lib.rs | 6 +- 7 files changed, 205 insertions(+), 2 deletions(-) create mode 100644 vortex-tensor/src/encodings/mod.rs create mode 100644 vortex-tensor/src/encodings/norm/array.rs create mode 100644 vortex-tensor/src/encodings/norm/mod.rs create mode 100644 vortex-tensor/src/encodings/norm/vtable/mod.rs create mode 100644 vortex-tensor/src/encodings/norm/vtable/operations.rs create mode 100644 vortex-tensor/src/encodings/norm/vtable/validity.rs diff --git a/vortex-tensor/src/encodings/mod.rs b/vortex-tensor/src/encodings/mod.rs new file mode 100644 index 00000000000..0d9feeb17ce --- /dev/null +++ b/vortex-tensor/src/encodings/mod.rs @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +pub mod norm; +// mod spherical; diff --git a/vortex-tensor/src/encodings/norm/array.rs b/vortex-tensor/src/encodings/norm/array.rs new file mode 100644 index 00000000000..9e6e19dd3a8 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/array.rs @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::ArrayRef; + +/// A normalized array that stores unit-normalized vectors alongside their original L2 norms. +/// +/// Each vector in the array is divided by its L2 norm, producing a unit-normalized vector. The +/// original norms are stored separately so that the original vectors can be reconstructed. +#[derive(Debug, Clone)] +pub struct NormVectorArray { + /// The backing vector array that has been unit normalized. + /// + /// The underlying elements of the vector array must be floating-point. + vector_array: ArrayRef, + + /// The L2 (Frobenius) norms of each vector. + /// + /// This must have the same dtype as the elements of the vector array. + norms: ArrayRef, +} + +impl NormVectorArray { + /// Returns a reference to the backing vector array that has been unit normalized. + pub fn vector_array(&self) -> &ArrayRef { + &self.vector_array + } + + /// Returns a reference to the L2 (Frobenius) norms of each vector. + pub fn norms(&self) -> &ArrayRef { + &self.norms + } +} diff --git a/vortex-tensor/src/encodings/norm/mod.rs b/vortex-tensor/src/encodings/norm/mod.rs new file mode 100644 index 00000000000..252f867de9c --- /dev/null +++ b/vortex-tensor/src/encodings/norm/mod.rs @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +mod array; +pub use array::NormVectorArray; + +// pub(crate) mod compute; + +mod vtable; +pub use vtable::NormVector; + +// #[cfg(test)] +// mod tests; diff --git a/vortex-tensor/src/encodings/norm/vtable/mod.rs b/vortex-tensor/src/encodings/norm/vtable/mod.rs new file mode 100644 index 00000000000..55f7cd026dd --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/mod.rs @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::hash::Hasher; + +use vortex::array::ArrayRef; +use vortex::array::EmptyMetadata; +use vortex::array::ExecutionCtx; +use vortex::array::ExecutionStep; +use vortex::array::Precision; +use vortex::array::buffer::BufferHandle; +use vortex::array::serde::ArrayChildren; +use vortex::array::stats::StatsSetRef; +use vortex::array::vtable; +use vortex::array::vtable::ArrayId; +use vortex::array::vtable::VTable; +use vortex::array::vtable::ValidityVTableFromChild; +use vortex::dtype::DType; +use vortex::error::VortexResult; +use vortex::session::VortexSession; + +use crate::encodings::norm::array::NormVectorArray; + +mod operations; +mod validity; + +vtable!(NormVector); + +#[derive(Debug)] +pub struct NormVector; + +impl VTable for NormVector { + type Array = NormVectorArray; + type Metadata = EmptyMetadata; + type OperationsVTable = Self; + type ValidityVTable = ValidityVTableFromChild; + + fn id(_array: &NormVectorArray) -> ArrayId { + ArrayId::new_ref("vortex.tensor.norm_vector") + } + + fn len(array: &NormVectorArray) -> usize { + array.vector_array().len() + } + + fn dtype(array: &NormVectorArray) -> &DType { + array.vector_array().dtype() + } + + fn stats(array: &NormVectorArray) -> StatsSetRef<'_> { + array.vector_array().statistics() + } + + fn array_hash(array: &NormVectorArray, state: &mut H, precision: Precision) { + todo!() + } + + fn array_eq(array: &NormVectorArray, other: &NormVectorArray, precision: Precision) -> bool { + todo!() + } + + fn nbuffers(array: &NormVectorArray) -> usize { + todo!() + } + + fn buffer(array: &NormVectorArray, idx: usize) -> BufferHandle { + todo!() + } + + fn buffer_name(array: &NormVectorArray, idx: usize) -> Option { + todo!() + } + + fn nchildren(array: &NormVectorArray) -> usize { + todo!() + } + + fn child(array: &NormVectorArray, idx: usize) -> ArrayRef { + todo!() + } + + fn child_name(array: &NormVectorArray, idx: usize) -> String { + todo!() + } + + fn metadata(array: &NormVectorArray) -> VortexResult { + todo!() + } + + fn serialize(metadata: Self::Metadata) -> VortexResult>> { + todo!() + } + + fn deserialize( + bytes: &[u8], + _dtype: &DType, + _len: usize, + _buffers: &[BufferHandle], + _session: &VortexSession, + ) -> VortexResult { + todo!() + } + + fn build( + dtype: &DType, + len: usize, + metadata: &Self::Metadata, + buffers: &[BufferHandle], + children: &dyn ArrayChildren, + ) -> VortexResult { + todo!() + } + + fn with_children(array: &mut NormVectorArray, children: Vec) -> VortexResult<()> { + todo!() + } + + fn execute(array: &NormVectorArray, ctx: &mut ExecutionCtx) -> VortexResult { + todo!() + } +} diff --git a/vortex-tensor/src/encodings/norm/vtable/operations.rs b/vortex-tensor/src/encodings/norm/vtable/operations.rs new file mode 100644 index 00000000000..5c743f1d01d --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/operations.rs @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::vtable::OperationsVTable; +use vortex::error::VortexResult; +use vortex::scalar::Scalar; + +use crate::encodings::norm::array::NormVectorArray; +use crate::encodings::norm::vtable::NormVector; + +impl OperationsVTable for NormVector { + fn scalar_at(array: &NormVectorArray, index: usize) -> VortexResult { + todo!() + } +} diff --git a/vortex-tensor/src/encodings/norm/vtable/validity.rs b/vortex-tensor/src/encodings/norm/vtable/validity.rs new file mode 100644 index 00000000000..8925ffc7378 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/validity.rs @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::ArrayRef; +use vortex::array::vtable::ValidityChild; + +use crate::encodings::norm::array::NormVectorArray; +use crate::encodings::norm::vtable::NormVector; + +impl ValidityChild for NormVector { + fn validity_child(array: &NormVectorArray) -> &ArrayRef { + array.vector_array() + } +} diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 56e96488167..7aca9f54ab3 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -5,8 +5,10 @@ //! including unit vectors, spherical coordinates, and similarity measures such as cosine //! similarity. +pub mod matcher; +pub mod scalar_fns; + pub mod fixed_shape; pub mod vector; -pub mod matcher; -pub mod scalar_fns; +pub mod encodings; From 1f0cd46fec8b355d74b4416d87bd8d5774601aeb Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 20 Mar 2026 17:13:02 -0400 Subject: [PATCH 2/6] add most implementation Signed-off-by: Connor Tsui --- vortex-tensor/src/encodings/norm/array.rs | 58 +++++++++++- .../src/encodings/norm/vtable/mod.rs | 94 ++++++++++++++----- .../src/encodings/norm/vtable/operations.rs | 2 +- vortex-tensor/src/lib.rs | 2 + .../src/scalar_fns/cosine_similarity.rs | 18 ++-- vortex-tensor/src/scalar_fns/l2_norm.rs | 14 +-- vortex-tensor/src/scalar_fns/mod.rs | 2 - vortex-tensor/src/{scalar_fns => }/utils.rs | 0 8 files changed, 145 insertions(+), 45 deletions(-) rename vortex-tensor/src/{scalar_fns => }/utils.rs (100%) diff --git a/vortex-tensor/src/encodings/norm/array.rs b/vortex-tensor/src/encodings/norm/array.rs index 9e6e19dd3a8..275d8b0f885 100644 --- a/vortex-tensor/src/encodings/norm/array.rs +++ b/vortex-tensor/src/encodings/norm/array.rs @@ -2,6 +2,16 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex::array::ArrayRef; +use vortex::array::ExecutionCtx; +use vortex::dtype::DType; +use vortex::dtype::Nullability; +use vortex::error::VortexResult; +use vortex::error::vortex_ensure; +use vortex::error::vortex_ensure_eq; +use vortex::error::vortex_err; + +use crate::utils::extension_element_ptype; +use crate::vector::Vector; /// A normalized array that stores unit-normalized vectors alongside their original L2 norms. /// @@ -12,15 +22,54 @@ pub struct NormVectorArray { /// The backing vector array that has been unit normalized. /// /// The underlying elements of the vector array must be floating-point. - vector_array: ArrayRef, + pub(crate) vector_array: ArrayRef, /// The L2 (Frobenius) norms of each vector. /// /// This must have the same dtype as the elements of the vector array. - norms: ArrayRef, + pub(crate) norms: ArrayRef, } impl NormVectorArray { + /// Creates a new [`NormVectorArray`] from a unit-normalized vector array and its L2 norms. + /// + /// The `vector_array` must be a [`Vector`] extension array with floating-point elements, and + /// `norms` must be a primitive array of the same float type with the same length. + pub fn try_new(vector_array: ArrayRef, norms: ArrayRef) -> VortexResult { + let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| { + vortex_err!( + "vector_array dtype must be an extension type, got {}", + vector_array.dtype() + ) + })?; + + vortex_ensure!( + ext.is::(), + "vector_array must have the Vector extension type, got {}", + vector_array.dtype() + ); + + let element_ptype = extension_element_ptype(ext)?; + + let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); + vortex_ensure_eq!( + *norms.dtype(), + expected_norms_dtype, + "norms dtype must match vector element type" + ); + + vortex_ensure_eq!( + vector_array.len(), + norms.len(), + "vector_array and norms must have the same length" + ); + + Ok(Self { + vector_array, + norms, + }) + } + /// Returns a reference to the backing vector array that has been unit normalized. pub fn vector_array(&self) -> &ArrayRef { &self.vector_array @@ -30,4 +79,9 @@ impl NormVectorArray { pub fn norms(&self) -> &ArrayRef { &self.norms } + + // TODO docs + pub(super) fn execute_into_vector(&self, ctx: &mut ExecutionCtx) -> VortexResult { + todo!() + } } diff --git a/vortex-tensor/src/encodings/norm/vtable/mod.rs b/vortex-tensor/src/encodings/norm/vtable/mod.rs index 55f7cd026dd..783b225ad75 100644 --- a/vortex-tensor/src/encodings/norm/vtable/mod.rs +++ b/vortex-tensor/src/encodings/norm/vtable/mod.rs @@ -3,6 +3,8 @@ use std::hash::Hasher; +use vortex::array::ArrayEq; +use vortex::array::ArrayHash; use vortex::array::ArrayRef; use vortex::array::EmptyMetadata; use vortex::array::ExecutionCtx; @@ -16,10 +18,15 @@ use vortex::array::vtable::ArrayId; use vortex::array::vtable::VTable; use vortex::array::vtable::ValidityVTableFromChild; use vortex::dtype::DType; +use vortex::dtype::Nullability; use vortex::error::VortexResult; +use vortex::error::vortex_ensure_eq; +use vortex::error::vortex_err; +use vortex::error::vortex_panic; use vortex::session::VortexSession; use crate::encodings::norm::array::NormVectorArray; +use crate::utils::extension_element_ptype; mod operations; mod validity; @@ -52,70 +59,109 @@ impl VTable for NormVector { } fn array_hash(array: &NormVectorArray, state: &mut H, precision: Precision) { - todo!() + array.vector_array().array_hash(state, precision); + array.norms().array_hash(state, precision); } fn array_eq(array: &NormVectorArray, other: &NormVectorArray, precision: Precision) -> bool { - todo!() + array.norms().array_eq(other.norms(), precision) + && array + .vector_array() + .array_eq(other.vector_array(), precision) } - fn nbuffers(array: &NormVectorArray) -> usize { - todo!() + fn nbuffers(_array: &NormVectorArray) -> usize { + 0 } - fn buffer(array: &NormVectorArray, idx: usize) -> BufferHandle { - todo!() + fn buffer(_array: &NormVectorArray, idx: usize) -> BufferHandle { + vortex_panic!("NormVectorArray has no buffers (index {idx})") } - fn buffer_name(array: &NormVectorArray, idx: usize) -> Option { - todo!() + fn buffer_name(_array: &NormVectorArray, idx: usize) -> Option { + vortex_panic!("NormVectorArray has no buffers (index {idx})") } - fn nchildren(array: &NormVectorArray) -> usize { - todo!() + fn nchildren(_array: &NormVectorArray) -> usize { + 2 } fn child(array: &NormVectorArray, idx: usize) -> ArrayRef { - todo!() + match idx { + 0 => array.vector_array().clone(), + 1 => array.norms().clone(), + _ => vortex_panic!("NormVectorArray child index {idx} out of bounds"), + } } - fn child_name(array: &NormVectorArray, idx: usize) -> String { - todo!() + fn child_name(_array: &NormVectorArray, idx: usize) -> String { + match idx { + 0 => "vector_array".to_string(), + 1 => "norms".to_string(), + _ => vortex_panic!("NormVectorArray child_name index {idx} out of bounds"), + } } - fn metadata(array: &NormVectorArray) -> VortexResult { - todo!() + fn metadata(_array: &NormVectorArray) -> VortexResult { + Ok(EmptyMetadata) } - fn serialize(metadata: Self::Metadata) -> VortexResult>> { - todo!() + fn serialize(_metadata: Self::Metadata) -> VortexResult>> { + Ok(Some(vec![])) } fn deserialize( - bytes: &[u8], + _bytes: &[u8], _dtype: &DType, _len: usize, _buffers: &[BufferHandle], _session: &VortexSession, ) -> VortexResult { - todo!() + Ok(EmptyMetadata) } fn build( dtype: &DType, len: usize, - metadata: &Self::Metadata, - buffers: &[BufferHandle], + _metadata: &Self::Metadata, + _buffers: &[BufferHandle], children: &dyn ArrayChildren, ) -> VortexResult { - todo!() + vortex_ensure_eq!( + children.len(), + 2, + "NormVectorArray requires exactly 2 children" + ); + + let vector_array = children.get(0, dtype, len)?; + + let ext = dtype.as_extension_opt().ok_or_else(|| { + vortex_err!("NormVectorArray dtype must be an extension type, got {dtype}") + })?; + let element_ptype = extension_element_ptype(ext)?; + let norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); + let norms = children.get(1, &norms_dtype, len)?; + + NormVectorArray::try_new(vector_array, norms) } fn with_children(array: &mut NormVectorArray, children: Vec) -> VortexResult<()> { - todo!() + vortex_ensure_eq!( + children.len(), + 2, + "NormVectorArray requires exactly 2 children" + ); + + let [vector_array, norms]: [ArrayRef; 2] = children + .try_into() + .map_err(|_| vortex_err!("NormVectorArray requires exactly 2 children"))?; + + array.vector_array = vector_array; + array.norms = norms; + Ok(()) } fn execute(array: &NormVectorArray, ctx: &mut ExecutionCtx) -> VortexResult { - todo!() + Ok(ExecutionStep::Done(array.execute_into_vector(ctx)?)) } } diff --git a/vortex-tensor/src/encodings/norm/vtable/operations.rs b/vortex-tensor/src/encodings/norm/vtable/operations.rs index 5c743f1d01d..0ecec996d9e 100644 --- a/vortex-tensor/src/encodings/norm/vtable/operations.rs +++ b/vortex-tensor/src/encodings/norm/vtable/operations.rs @@ -10,6 +10,6 @@ use crate::encodings::norm::vtable::NormVector; impl OperationsVTable for NormVector { fn scalar_at(array: &NormVectorArray, index: usize) -> VortexResult { - todo!() + array.vector_array().scalar_at(index) } } diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 7aca9f54ab3..c036b9854b2 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -12,3 +12,5 @@ pub mod fixed_shape; pub mod vector; pub mod encodings; + +mod utils; diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index cd2f158d719..2b922649307 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId; use vortex::scalar_fn::ScalarFnVTable; use crate::matcher::AnyTensor; -use crate::scalar_fns::utils::extension_element_ptype; -use crate::scalar_fns::utils::extension_list_size; -use crate::scalar_fns::utils::extension_storage; -use crate::scalar_fns::utils::extract_flat_elements; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; /// Cosine similarity between two columns. /// @@ -196,11 +196,11 @@ mod tests { use vortex::scalar_fn::ScalarFn; use crate::scalar_fns::cosine_similarity::CosineSimilarity; - use crate::scalar_fns::utils::test_helpers::assert_close; - use crate::scalar_fns::utils::test_helpers::constant_tensor_array; - use crate::scalar_fns::utils::test_helpers::constant_vector_array; - use crate::scalar_fns::utils::test_helpers::tensor_array; - use crate::scalar_fns::utils::test_helpers::vector_array; + use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::constant_tensor_array; + use crate::utils::test_helpers::constant_vector_array; + use crate::utils::test_helpers::tensor_array; + use crate::utils::test_helpers::vector_array; /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index e0a3bac4143..1535c7b7463 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId; use vortex::scalar_fn::ScalarFnVTable; use crate::matcher::AnyTensor; -use crate::scalar_fns::utils::extension_element_ptype; -use crate::scalar_fns::utils::extension_list_size; -use crate::scalar_fns::utils::extension_storage; -use crate::scalar_fns::utils::extract_flat_elements; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; /// L2 norm (Euclidean norm) of a tensor or vector column. /// @@ -163,9 +163,9 @@ mod tests { use vortex::scalar_fn::ScalarFn; use crate::scalar_fns::l2_norm::L2Norm; - use crate::scalar_fns::utils::test_helpers::assert_close; - use crate::scalar_fns::utils::test_helpers::tensor_array; - use crate::scalar_fns::utils::test_helpers::vector_array; + use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::tensor_array; + use crate::utils::test_helpers::vector_array; /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. fn eval_l2_norm(input: vortex::array::ArrayRef, len: usize) -> VortexResult> { diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs index 2597f1115c8..2f56305cd53 100644 --- a/vortex-tensor/src/scalar_fns/mod.rs +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -5,5 +5,3 @@ pub mod cosine_similarity; pub mod l2_norm; - -mod utils; diff --git a/vortex-tensor/src/scalar_fns/utils.rs b/vortex-tensor/src/utils.rs similarity index 100% rename from vortex-tensor/src/scalar_fns/utils.rs rename to vortex-tensor/src/utils.rs From 35d057b628508f99e326bbd6754dd0e1ba5d5382 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 20 Mar 2026 17:48:17 -0400 Subject: [PATCH 3/6] implement compress and decompress Signed-off-by: Connor Tsui --- vortex-tensor/src/encodings/norm/array.rs | 130 +++++++++++++++++- vortex-tensor/src/encodings/norm/mod.rs | 4 +- vortex-tensor/src/encodings/norm/tests.rs | 110 +++++++++++++++ .../src/encodings/norm/vtable/mod.rs | 2 +- 4 files changed, 240 insertions(+), 6 deletions(-) create mode 100644 vortex-tensor/src/encodings/norm/tests.rs diff --git a/vortex-tensor/src/encodings/norm/array.rs b/vortex-tensor/src/encodings/norm/array.rs index 275d8b0f885..55b9efd7ac2 100644 --- a/vortex-tensor/src/encodings/norm/array.rs +++ b/vortex-tensor/src/encodings/norm/array.rs @@ -1,16 +1,33 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use num_traits::Float; use vortex::array::ArrayRef; use vortex::array::ExecutionCtx; +use vortex::array::IntoArray; +use vortex::array::ToCanonical; +use vortex::array::arrays::ExtensionArray; +use vortex::array::arrays::FixedSizeListArray; +use vortex::array::arrays::PrimitiveArray; +use vortex::array::arrays::ScalarFnArray; +use vortex::array::match_each_float_ptype; +use vortex::array::validity::Validity; use vortex::dtype::DType; use vortex::dtype::Nullability; +use vortex::dtype::extension::ExtDType; use vortex::error::VortexResult; use vortex::error::vortex_ensure; use vortex::error::vortex_ensure_eq; use vortex::error::vortex_err; +use vortex::extension::EmptyMetadata; +use vortex::scalar_fn::EmptyOptions; +use vortex::scalar_fn::ScalarFn; +use crate::scalar_fns::l2_norm::L2Norm; use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; use crate::vector::Vector; /// A normalized array that stores unit-normalized vectors alongside their original L2 norms. @@ -70,6 +87,63 @@ impl NormVectorArray { }) } + /// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and + /// dividing each vector by its norm. + /// + /// The input must be a [`Vector`] extension array with floating-point elements. + pub fn compress(vector_array: ArrayRef) -> VortexResult { + let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| { + vortex_err!( + "vector_array dtype must be an extension type, got {}", + vector_array.dtype() + ) + })?; + + vortex_ensure!( + ext.is::(), + "vector_array must have the Vector extension type, got {}", + vector_array.dtype() + ); + + let list_size = extension_list_size(ext)?; + let row_count = vector_array.len(); + + // Compute L2 norms using the scalar function. + let l2_norm_fn = ScalarFn::new(L2Norm, EmptyOptions).erased(); + let norms = ScalarFnArray::try_new(l2_norm_fn, vec![vector_array.clone()], row_count)? + .to_primitive() + .into_array(); + + // Divide each vector element by its corresponding norm. + let storage = extension_storage(&vector_array)?; + let flat = extract_flat_elements(&storage, list_size)?; + let norms_prim = norms.to_canonical()?.into_primitive(); + + match_each_float_ptype!(flat.ptype(), |T| { + let norms_slice = norms_prim.as_slice::(); + + let normalized_elems: PrimitiveArray = (0..row_count) + .flat_map(|i| { + let inv_norm = safe_inv_norm(norms_slice[i]); + flat.row::(i).iter().map(move |&v| v * inv_norm) + }) + .collect(); + + let fsl = FixedSizeListArray::new( + normalized_elems.into_array(), + u32::try_from(list_size)?, + Validity::NonNullable, + row_count, + ); + + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + let normalized_vector = ExtensionArray::new(ext_dtype, fsl.into_array()).into_array(); + + Self::try_new(normalized_vector, norms) + }) + } + /// Returns a reference to the backing vector array that has been unit normalized. pub fn vector_array(&self) -> &ArrayRef { &self.vector_array @@ -80,8 +154,58 @@ impl NormVectorArray { &self.norms } - // TODO docs - pub(super) fn execute_into_vector(&self, ctx: &mut ExecutionCtx) -> VortexResult { - todo!() + /// Reconstructs the original vectors by multiplying each unit-normalized vector by its L2 norm. + pub fn decompress(&self, _ctx: &mut ExecutionCtx) -> VortexResult { + let ext_dtype = self + .vector_array + .dtype() + .as_extension_opt() + .ok_or_else(|| { + vortex_err!( + "expected Vector extension dtype, got {}", + self.vector_array.dtype() + ) + })?; + + let list_size = extension_list_size(ext_dtype)?; + let row_count = self.vector_array.len(); + + let storage = extension_storage(&self.vector_array)?; + let flat = extract_flat_elements(&storage, list_size)?; + + let norms_prim = self.norms.to_canonical()?.into_primitive(); + + match_each_float_ptype!(flat.ptype(), |T| { + let norms_slice = norms_prim.as_slice::(); + + let result_elems: PrimitiveArray = (0..row_count) + .flat_map(|i| { + let norm = norms_slice[i]; + flat.row::(i).iter().map(move |&v| v * norm) + }) + .collect(); + + let fsl = FixedSizeListArray::new( + result_elems.into_array(), + u32::try_from(list_size)?, + Validity::NonNullable, + row_count, + ); + + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + }) + } +} + +/// Returns `1 / norm` if the norm is non-zero, or zero otherwise. +/// +/// This avoids division by zero for zero-length or all-zero vectors. +fn safe_inv_norm(norm: T) -> T { + if norm == T::zero() { + T::zero() + } else { + T::one() / norm } } diff --git a/vortex-tensor/src/encodings/norm/mod.rs b/vortex-tensor/src/encodings/norm/mod.rs index 252f867de9c..9cd20e5cdac 100644 --- a/vortex-tensor/src/encodings/norm/mod.rs +++ b/vortex-tensor/src/encodings/norm/mod.rs @@ -9,5 +9,5 @@ pub use array::NormVectorArray; mod vtable; pub use vtable::NormVector; -// #[cfg(test)] -// mod tests; +#[cfg(test)] +mod tests; diff --git a/vortex-tensor/src/encodings/norm/tests.rs b/vortex-tensor/src/encodings/norm/tests.rs new file mode 100644 index 00000000000..652e916577b --- /dev/null +++ b/vortex-tensor/src/encodings/norm/tests.rs @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::VortexSessionExecute; +use vortex::array::arrays::Extension; +use vortex::error::VortexResult; + +use crate::encodings::norm::NormVectorArray; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; +use crate::utils::test_helpers::assert_close; +use crate::utils::test_helpers::vector_array; + +#[test] +fn encode_unit_vectors() -> VortexResult<()> { + // Already unit-length vectors: norms should be 1.0 and vectors unchanged. + let arr = vector_array( + 3, + &[ + 1.0, 0.0, 0.0, // norm = 1.0 + 0.0, 1.0, 0.0, // norm = 1.0 + ], + )?; + + let norm = NormVectorArray::compress(arr)?; + let norms = norm.norms().to_canonical()?.into_primitive(); + assert_close(norms.as_slice::(), &[1.0, 1.0]); + + let vectors = norm.vector_array(); + let ext = vectors.dtype().as_extension_opt().unwrap(); + let list_size = extension_list_size(ext)?; + let storage = extension_storage(vectors)?; + let flat = extract_flat_elements(&storage, list_size)?; + assert_close(flat.row::(0), &[1.0, 0.0, 0.0]); + assert_close(flat.row::(1), &[0.0, 1.0, 0.0]); + + Ok(()) +} + +#[test] +fn encode_non_unit_vectors() -> VortexResult<()> { + let arr = vector_array( + 2, + &[ + 3.0, 4.0, // norm = 5.0 + 0.0, 0.0, // norm = 0.0 (zero vector) + ], + )?; + + let norm = NormVectorArray::compress(arr)?; + let norms = norm.norms().to_canonical()?.into_primitive(); + assert_close(norms.as_slice::(), &[5.0, 0.0]); + + let vectors = norm.vector_array(); + let ext = vectors.dtype().as_extension_opt().unwrap(); + let list_size = extension_list_size(ext)?; + let storage = extension_storage(vectors)?; + let flat = extract_flat_elements(&storage, list_size)?; + assert_close(flat.row::(0), &[3.0 / 5.0, 4.0 / 5.0]); + assert_close(flat.row::(1), &[0.0, 0.0]); + + Ok(()) +} + +#[test] +fn execute_round_trip() -> VortexResult<()> { + let original_elements = &[ + 3.0, 4.0, // norm = 5.0 + 6.0, 8.0, // norm = 10.0 + ]; + let arr = vector_array(2, original_elements)?; + + let norm = NormVectorArray::compress(arr)?; + + // Execute to reconstruct the original vectors. + let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx(); + let reconstructed = norm.decompress(&mut ctx)?; + + // The reconstructed array should be a Vector extension array. + assert!(reconstructed.as_opt::().is_some()); + + let ext = reconstructed.dtype().as_extension_opt().unwrap(); + let list_size = extension_list_size(ext)?; + let storage = extension_storage(&reconstructed)?; + let flat = extract_flat_elements(&storage, list_size)?; + assert_close(flat.row::(0), &[3.0, 4.0]); + assert_close(flat.row::(1), &[6.0, 8.0]); + + Ok(()) +} + +#[test] +fn execute_round_trip_zero_vector() -> VortexResult<()> { + let arr = vector_array(2, &[0.0, 0.0])?; + + let norm = NormVectorArray::compress(arr)?; + + let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx(); + let reconstructed = norm.decompress(&mut ctx)?; + + let ext = reconstructed.dtype().as_extension_opt().unwrap(); + let list_size = extension_list_size(ext)?; + let storage = extension_storage(&reconstructed)?; + let flat = extract_flat_elements(&storage, list_size)?; + // Zero vector should remain zero after round-trip. + assert_close(flat.row::(0), &[0.0, 0.0]); + + Ok(()) +} diff --git a/vortex-tensor/src/encodings/norm/vtable/mod.rs b/vortex-tensor/src/encodings/norm/vtable/mod.rs index 783b225ad75..83facf80739 100644 --- a/vortex-tensor/src/encodings/norm/vtable/mod.rs +++ b/vortex-tensor/src/encodings/norm/vtable/mod.rs @@ -162,6 +162,6 @@ impl VTable for NormVector { } fn execute(array: &NormVectorArray, ctx: &mut ExecutionCtx) -> VortexResult { - Ok(ExecutionStep::Done(array.execute_into_vector(ctx)?)) + Ok(ExecutionStep::Done(array.decompress(ctx)?)) } } From c056c9994d330ed76a39f83e65e49c3ca6fc93d1 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 20 Mar 2026 18:03:27 -0400 Subject: [PATCH 4/6] fix scalars Signed-off-by: Connor Tsui --- vortex-tensor/src/encodings/norm/tests.rs | 25 +++++++++ .../src/encodings/norm/vtable/mod.rs | 6 ++- .../src/encodings/norm/vtable/operations.rs | 53 ++++++++++++++++++- 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/vortex-tensor/src/encodings/norm/tests.rs b/vortex-tensor/src/encodings/norm/tests.rs index 652e916577b..ef87e18d912 100644 --- a/vortex-tensor/src/encodings/norm/tests.rs +++ b/vortex-tensor/src/encodings/norm/tests.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use vortex::array::IntoArray; use vortex::array::VortexSessionExecute; use vortex::array::arrays::Extension; use vortex::error::VortexResult; @@ -108,3 +109,27 @@ fn execute_round_trip_zero_vector() -> VortexResult<()> { Ok(()) } + +#[test] +fn scalar_at_returns_original_vector() -> VortexResult<()> { + let arr = vector_array( + 2, + &[ + 3.0, 4.0, // norm = 5.0 + 6.0, 8.0, // norm = 10.0 + ], + )?; + + let encoded = NormVectorArray::compress(arr)?; + + // `scalar_at` on the NormVectorArray should match `scalar_at` on the decompressed result. + let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx(); + let decompressed = encoded.decompress(&mut ctx)?; + + let norm_array = encoded.into_array(); + for i in 0..2 { + assert_eq!(norm_array.scalar_at(i)?, decompressed.scalar_at(i)?); + } + + Ok(()) +} diff --git a/vortex-tensor/src/encodings/norm/vtable/mod.rs b/vortex-tensor/src/encodings/norm/vtable/mod.rs index 83facf80739..e612ac851e4 100644 --- a/vortex-tensor/src/encodings/norm/vtable/mod.rs +++ b/vortex-tensor/src/encodings/norm/vtable/mod.rs @@ -42,7 +42,11 @@ impl VTable for NormVector { type OperationsVTable = Self; type ValidityVTable = ValidityVTableFromChild; - fn id(_array: &NormVectorArray) -> ArrayId { + fn vtable(_array: &Self::Array) -> &Self { + &NormVector + } + + fn id(&self) -> ArrayId { ArrayId::new_ref("vortex.tensor.norm_vector") } diff --git a/vortex-tensor/src/encodings/norm/vtable/operations.rs b/vortex-tensor/src/encodings/norm/vtable/operations.rs index 0ecec996d9e..a384501f8d8 100644 --- a/vortex-tensor/src/encodings/norm/vtable/operations.rs +++ b/vortex-tensor/src/encodings/norm/vtable/operations.rs @@ -1,15 +1,66 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use vortex::array::IntoArray; +use vortex::array::arrays::ConstantArray; +use vortex::array::arrays::FixedSizeList; +use vortex::array::builtins::ArrayBuiltins; use vortex::array::vtable::OperationsVTable; +use vortex::dtype::Nullability; use vortex::error::VortexResult; +use vortex::error::vortex_err; use vortex::scalar::Scalar; +use vortex::scalar_fn::fns::operators::Operator; use crate::encodings::norm::array::NormVectorArray; use crate::encodings::norm::vtable::NormVector; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; impl OperationsVTable for NormVector { fn scalar_at(array: &NormVectorArray, index: usize) -> VortexResult { - array.vector_array().scalar_at(index) + let ext = array + .vector_array() + .dtype() + .as_extension_opt() + .ok_or_else(|| { + vortex_err!( + "expected Vector extension dtype, got {}", + array.vector_array().dtype() + ) + })?; + let list_size = extension_list_size(ext)?; + + // Get the storage (FixedSizeList) and slice out the elements for this row. + let storage = extension_storage(array.vector_array())?; + let fsl = storage + .as_opt::() + .ok_or_else(|| vortex_err!("expected FixedSizeList storage"))?; + let row_elements = fsl.fixed_size_list_elements_at(index)?; + + // Multiply all elements by the norm using a ConstantArray broadcast. + let norm_scalar = array.norms().scalar_at(index)?; + let norm_broadcast = ConstantArray::new(norm_scalar, list_size).into_array(); + let scaled = row_elements.binary(norm_broadcast, Operator::Mul)?; + + // Rebuild the FSL scalar, then wrap in the extension type. + let element_dtype = ext + .storage_dtype() + .as_fixed_size_list_element_opt() + .ok_or_else(|| { + vortex_err!( + "expected FixedSizeList storage dtype, got {}", + ext.storage_dtype() + ) + })?; + + let children: Vec = (0..list_size) + .map(|i| scaled.scalar_at(i)) + .collect::>()?; + + let fsl_scalar = + Scalar::fixed_size_list(element_dtype.clone(), children, Nullability::NonNullable); + + Ok(Scalar::extension_ref(ext.clone(), fsl_scalar)) } } From c9ff85e0d29b161d94b006a70393e0c616e87a27 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Mon, 23 Mar 2026 18:42:44 -0600 Subject: [PATCH 5/6] rebase Signed-off-by: Connor Tsui --- vortex-tensor/public-api.lock | 114 ++++++++++++++++++ .../src/encodings/norm/vtable/mod.rs | 12 +- 2 files changed, 122 insertions(+), 4 deletions(-) diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 18107f84db4..83818d74c3c 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -1,5 +1,119 @@ pub mod vortex_tensor +pub mod vortex_tensor::encodings + +pub mod vortex_tensor::encodings::norm + +pub struct vortex_tensor::encodings::norm::NormVector + +impl core::clone::Clone for vortex_tensor::encodings::norm::NormVector + +pub fn vortex_tensor::encodings::norm::NormVector::clone(&self) -> vortex_tensor::encodings::norm::NormVector + +impl core::fmt::Debug for vortex_tensor::encodings::norm::NormVector + +pub fn vortex_tensor::encodings::norm::NormVector::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::vtable::VTable for vortex_tensor::encodings::norm::NormVector + +pub type vortex_tensor::encodings::norm::NormVector::Array = vortex_tensor::encodings::norm::NormVectorArray + +pub type vortex_tensor::encodings::norm::NormVector::Metadata = vortex_array::metadata::EmptyMetadata + +pub type vortex_tensor::encodings::norm::NormVector::OperationsVTable = vortex_tensor::encodings::norm::NormVector + +pub type vortex_tensor::encodings::norm::NormVector::ValidityVTable = vortex_array::vtable::validity::ValidityVTableFromChild + +pub fn vortex_tensor::encodings::norm::NormVector::array_eq(array: &vortex_tensor::encodings::norm::NormVectorArray, other: &vortex_tensor::encodings::norm::NormVectorArray, precision: vortex_array::hash::Precision) -> bool + +pub fn vortex_tensor::encodings::norm::NormVector::array_hash(array: &vortex_tensor::encodings::norm::NormVectorArray, state: &mut H, precision: vortex_array::hash::Precision) + +pub fn vortex_tensor::encodings::norm::NormVector::buffer(_array: &vortex_tensor::encodings::norm::NormVectorArray, idx: usize) -> vortex_array::buffer::BufferHandle + +pub fn vortex_tensor::encodings::norm::NormVector::buffer_name(_array: &vortex_tensor::encodings::norm::NormVectorArray, idx: usize) -> core::option::Option + +pub fn vortex_tensor::encodings::norm::NormVector::build(dtype: &vortex_array::dtype::DType, len: usize, _metadata: &Self::Metadata, _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVector::child(array: &vortex_tensor::encodings::norm::NormVectorArray, idx: usize) -> vortex_array::array::ArrayRef + +pub fn vortex_tensor::encodings::norm::NormVector::child_name(_array: &vortex_tensor::encodings::norm::NormVectorArray, idx: usize) -> alloc::string::String + +pub fn vortex_tensor::encodings::norm::NormVector::deserialize(_bytes: &[u8], _dtype: &vortex_array::dtype::DType, _len: usize, _buffers: &[vortex_array::buffer::BufferHandle], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVector::dtype(array: &vortex_tensor::encodings::norm::NormVectorArray) -> &vortex_array::dtype::DType + +pub fn vortex_tensor::encodings::norm::NormVector::execute(array: alloc::sync::Arc, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVector::id(&self) -> vortex_array::vtable::dyn_::ArrayId + +pub fn vortex_tensor::encodings::norm::NormVector::len(array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize + +pub fn vortex_tensor::encodings::norm::NormVector::metadata(_array: &vortex_tensor::encodings::norm::NormVectorArray) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVector::nbuffers(_array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize + +pub fn vortex_tensor::encodings::norm::NormVector::nchildren(_array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize + +pub fn vortex_tensor::encodings::norm::NormVector::serialize(_metadata: Self::Metadata) -> vortex_error::VortexResult>> + +pub fn vortex_tensor::encodings::norm::NormVector::stats(array: &vortex_tensor::encodings::norm::NormVectorArray) -> vortex_array::stats::array::StatsSetRef<'_> + +pub fn vortex_tensor::encodings::norm::NormVector::vtable(_array: &Self::Array) -> &Self + +pub fn vortex_tensor::encodings::norm::NormVector::with_children(array: &mut vortex_tensor::encodings::norm::NormVectorArray, children: alloc::vec::Vec) -> vortex_error::VortexResult<()> + +impl vortex_array::vtable::operations::OperationsVTable for vortex_tensor::encodings::norm::NormVector + +pub fn vortex_tensor::encodings::norm::NormVector::scalar_at(array: &vortex_tensor::encodings::norm::NormVectorArray, index: usize) -> vortex_error::VortexResult + +impl vortex_array::vtable::validity::ValidityChild for vortex_tensor::encodings::norm::NormVector + +pub fn vortex_tensor::encodings::norm::NormVector::validity_child(array: &vortex_tensor::encodings::norm::NormVectorArray) -> &vortex_array::array::ArrayRef + +pub struct vortex_tensor::encodings::norm::NormVectorArray + +impl vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::compress(vector_array: vortex_array::array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVectorArray::decompress(&self, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVectorArray::norms(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_tensor::encodings::norm::NormVectorArray::try_new(vector_array: vortex_array::array::ArrayRef, norms: vortex_array::array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVectorArray::vector_array(&self) -> &vortex_array::array::ArrayRef + +impl vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::to_array(&self) -> vortex_array::array::ArrayRef + +impl core::clone::Clone for vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::clone(&self) -> vortex_tensor::encodings::norm::NormVectorArray + +impl core::convert::AsRef for vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::as_ref(&self) -> &dyn vortex_array::array::DynArray + +impl core::convert::From for vortex_array::array::ArrayRef + +pub fn vortex_array::array::ArrayRef::from(value: vortex_tensor::encodings::norm::NormVectorArray) -> vortex_array::array::ArrayRef + +impl core::fmt::Debug for vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::ops::deref::Deref for vortex_tensor::encodings::norm::NormVectorArray + +pub type vortex_tensor::encodings::norm::NormVectorArray::Target = dyn vortex_array::array::DynArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::deref(&self) -> &Self::Target + +impl vortex_array::array::IntoArray for vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::into_array(self) -> vortex_array::array::ArrayRef + pub mod vortex_tensor::fixed_shape pub struct vortex_tensor::fixed_shape::FixedShapeTensor diff --git a/vortex-tensor/src/encodings/norm/vtable/mod.rs b/vortex-tensor/src/encodings/norm/vtable/mod.rs index e612ac851e4..14d16caba24 100644 --- a/vortex-tensor/src/encodings/norm/vtable/mod.rs +++ b/vortex-tensor/src/encodings/norm/vtable/mod.rs @@ -2,13 +2,14 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::hash::Hasher; +use std::sync::Arc; use vortex::array::ArrayEq; use vortex::array::ArrayHash; use vortex::array::ArrayRef; use vortex::array::EmptyMetadata; use vortex::array::ExecutionCtx; -use vortex::array::ExecutionStep; +use vortex::array::ExecutionResult; use vortex::array::Precision; use vortex::array::buffer::BufferHandle; use vortex::array::serde::ArrayChildren; @@ -33,7 +34,7 @@ mod validity; vtable!(NormVector); -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct NormVector; impl VTable for NormVector { @@ -165,7 +166,10 @@ impl VTable for NormVector { Ok(()) } - fn execute(array: &NormVectorArray, ctx: &mut ExecutionCtx) -> VortexResult { - Ok(ExecutionStep::Done(array.decompress(ctx)?)) + fn execute( + array: Arc, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + Ok(ExecutionResult::done(array.decompress(ctx)?)) } } From 99fc4fcf25147d7b44b60cb827068b08bcdade15 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Mon, 23 Mar 2026 21:46:03 -0600 Subject: [PATCH 6/6] fix problems Signed-off-by: Connor Tsui --- vortex-tensor/public-api.lock | 22 ++-- vortex-tensor/src/encodings/mod.rs | 2 +- vortex-tensor/src/encodings/norm/array.rs | 116 ++++++++++-------- vortex-tensor/src/encodings/norm/mod.rs | 2 +- vortex-tensor/src/encodings/norm/tests.rs | 76 +++++------- .../src/encodings/norm/vtable/mod.rs | 8 +- .../src/scalar_fns/cosine_similarity.rs | 6 +- vortex-tensor/src/scalar_fns/l2_norm.rs | 4 +- vortex-tensor/src/utils.rs | 40 +++++- 9 files changed, 153 insertions(+), 123 deletions(-) diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 83818d74c3c..bde7c86e4fc 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -22,7 +22,7 @@ pub type vortex_tensor::encodings::norm::NormVector::Metadata = vortex_array::me pub type vortex_tensor::encodings::norm::NormVector::OperationsVTable = vortex_tensor::encodings::norm::NormVector -pub type vortex_tensor::encodings::norm::NormVector::ValidityVTable = vortex_array::vtable::validity::ValidityVTableFromChild +pub type vortex_tensor::encodings::norm::NormVector::ValidityVTable = vortex_array::vtable::validity::ValidityVTableFromValidityHelper pub fn vortex_tensor::encodings::norm::NormVector::array_eq(array: &vortex_tensor::encodings::norm::NormVectorArray, other: &vortex_tensor::encodings::norm::NormVectorArray, precision: vortex_array::hash::Precision) -> bool @@ -52,7 +52,7 @@ pub fn vortex_tensor::encodings::norm::NormVector::metadata(_array: &vortex_tens pub fn vortex_tensor::encodings::norm::NormVector::nbuffers(_array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize -pub fn vortex_tensor::encodings::norm::NormVector::nchildren(_array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize +pub fn vortex_tensor::encodings::norm::NormVector::nchildren(array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize pub fn vortex_tensor::encodings::norm::NormVector::serialize(_metadata: Self::Metadata) -> vortex_error::VortexResult>> @@ -66,21 +66,17 @@ impl vortex_array::vtable::operations::OperationsVTable vortex_error::VortexResult -impl vortex_array::vtable::validity::ValidityChild for vortex_tensor::encodings::norm::NormVector - -pub fn vortex_tensor::encodings::norm::NormVector::validity_child(array: &vortex_tensor::encodings::norm::NormVectorArray) -> &vortex_array::array::ArrayRef - pub struct vortex_tensor::encodings::norm::NormVectorArray impl vortex_tensor::encodings::norm::NormVectorArray -pub fn vortex_tensor::encodings::norm::NormVectorArray::compress(vector_array: vortex_array::array::ArrayRef) -> vortex_error::VortexResult +pub fn vortex_tensor::encodings::norm::NormVectorArray::compress(vector_array: vortex_array::array::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub fn vortex_tensor::encodings::norm::NormVectorArray::decompress(&self, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::encodings::norm::NormVectorArray::decompress(&self, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_tensor::encodings::norm::NormVectorArray::norms(&self) -> &vortex_array::array::ArrayRef -pub fn vortex_tensor::encodings::norm::NormVectorArray::try_new(vector_array: vortex_array::array::ArrayRef, norms: vortex_array::array::ArrayRef) -> vortex_error::VortexResult +pub fn vortex_tensor::encodings::norm::NormVectorArray::try_new(vector_array: vortex_array::array::ArrayRef, norms: vortex_array::array::ArrayRef, validity: vortex_array::validity::Validity) -> vortex_error::VortexResult pub fn vortex_tensor::encodings::norm::NormVectorArray::vector_array(&self) -> &vortex_array::array::ArrayRef @@ -114,6 +110,10 @@ impl vortex_array::array::IntoArray for vortex_tensor::encodings::norm::NormVect pub fn vortex_tensor::encodings::norm::NormVectorArray::into_array(self) -> vortex_array::array::ArrayRef +impl vortex_array::vtable::validity::ValidityHelper for vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::validity(&self) -> &vortex_array::validity::Validity + pub mod vortex_tensor::fixed_shape pub struct vortex_tensor::fixed_shape::FixedShapeTensor @@ -250,7 +250,7 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&se pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result @@ -280,7 +280,7 @@ pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::arity(&self, _options: &Self: pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName -pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result diff --git a/vortex-tensor/src/encodings/mod.rs b/vortex-tensor/src/encodings/mod.rs index 0d9feeb17ce..9cc2b35b972 100644 --- a/vortex-tensor/src/encodings/mod.rs +++ b/vortex-tensor/src/encodings/mod.rs @@ -2,4 +2,4 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors pub mod norm; -// mod spherical; +// TODO: Spherical coordinate encoding. diff --git a/vortex-tensor/src/encodings/norm/array.rs b/vortex-tensor/src/encodings/norm/array.rs index 55b9efd7ac2..1ea2b7339b5 100644 --- a/vortex-tensor/src/encodings/norm/array.rs +++ b/vortex-tensor/src/encodings/norm/array.rs @@ -5,20 +5,22 @@ use num_traits::Float; use vortex::array::ArrayRef; use vortex::array::ExecutionCtx; use vortex::array::IntoArray; -use vortex::array::ToCanonical; use vortex::array::arrays::ExtensionArray; use vortex::array::arrays::FixedSizeListArray; use vortex::array::arrays::PrimitiveArray; -use vortex::array::arrays::ScalarFnArray; use vortex::array::match_each_float_ptype; +use vortex::array::stats::ArrayStats; use vortex::array::validity::Validity; use vortex::dtype::DType; use vortex::dtype::Nullability; use vortex::dtype::extension::ExtDType; +use vortex::dtype::extension::ExtDTypeRef; use vortex::error::VortexResult; use vortex::error::vortex_ensure; use vortex::error::vortex_ensure_eq; use vortex::error::vortex_err; +use vortex::expr::Expression; +use vortex::expr::root; use vortex::extension::EmptyMetadata; use vortex::scalar_fn::EmptyOptions; use vortex::scalar_fn::ScalarFn; @@ -34,41 +36,39 @@ use crate::vector::Vector; /// /// Each vector in the array is divided by its L2 norm, producing a unit-normalized vector. The /// original norms are stored separately so that the original vectors can be reconstructed. +/// +/// The `vector_array` child carries its own validity and nullability, so a nullable input vector +/// array produces a nullable `NormVectorArray`. #[derive(Debug, Clone)] pub struct NormVectorArray { /// The backing vector array that has been unit normalized. /// - /// The underlying elements of the vector array must be floating-point. + /// The underlying elements of the vector array must be floating-point. This child may be + /// nullable; its validity determines the validity of the `NormVectorArray`. pub(crate) vector_array: ArrayRef, - /// The L2 (Frobenius) norms of each vector. + /// The L2 norms of each vector. /// /// This must have the same dtype as the elements of the vector array. pub(crate) norms: ArrayRef, + + /// Stats set owned by this array. + pub(crate) stats_set: ArrayStats, } impl NormVectorArray { /// Creates a new [`NormVectorArray`] from a unit-normalized vector array and its L2 norms. /// /// The `vector_array` must be a [`Vector`] extension array with floating-point elements, and - /// `norms` must be a primitive array of the same float type with the same length. + /// `norms` must be a primitive array of the same float type with the same length. The + /// `vector_array` may be nullable. pub fn try_new(vector_array: ArrayRef, norms: ArrayRef) -> VortexResult { - let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| { - vortex_err!( - "vector_array dtype must be an extension type, got {}", - vector_array.dtype() - ) - })?; - - vortex_ensure!( - ext.is::(), - "vector_array must have the Vector extension type, got {}", - vector_array.dtype() - ); + let ext = Self::validate(&vector_array)?; - let element_ptype = extension_element_ptype(ext)?; + let element_ptype = extension_element_ptype(&ext)?; - let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); + let nullability = Nullability::from(vector_array.dtype().is_nullable()); + let expected_norms_dtype = DType::Primitive(element_ptype, nullability); vortex_ensure_eq!( *norms.dtype(), expected_norms_dtype, @@ -84,14 +84,13 @@ impl NormVectorArray { Ok(Self { vector_array, norms, + stats_set: ArrayStats::default(), }) } - /// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and - /// dividing each vector by its norm. - /// - /// The input must be a [`Vector`] extension array with floating-point elements. - pub fn compress(vector_array: ArrayRef) -> VortexResult { + /// Validates that the given array has the [`Vector`] extension type and returns the extension + /// dtype. + fn validate(vector_array: &ArrayRef) -> VortexResult { let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| { vortex_err!( "vector_array dtype must be an extension type, got {}", @@ -105,19 +104,32 @@ impl NormVectorArray { vector_array.dtype() ); - let list_size = extension_list_size(ext)?; - let row_count = vector_array.len(); + Ok(ext.clone()) + } - // Compute L2 norms using the scalar function. - let l2_norm_fn = ScalarFn::new(L2Norm, EmptyOptions).erased(); - let norms = ScalarFnArray::try_new(l2_norm_fn, vec![vector_array.clone()], row_count)? - .to_primitive() - .into_array(); + /// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and + /// dividing each vector by its norm. + /// + /// The input must be a [`Vector`] extension array with floating-point elements. Nullable inputs + /// are supported; the validity mask is preserved and the normalized data for null rows is + /// unspecified. + pub fn compress(vector_array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let ext = Self::validate(&vector_array)?; + + let list_size = extension_list_size(&ext)?; + let row_count = vector_array.len(); + let nullability = Nullability::from(vector_array.dtype().is_nullable()); - // Divide each vector element by its corresponding norm. + // Compute L2 norms using the scalar function. If the input is nullable, the norms will + // also be nullable (null vectors produce null norms). let storage = extension_storage(&vector_array)?; - let flat = extract_flat_elements(&storage, list_size)?; - let norms_prim = norms.to_canonical()?.into_primitive(); + let l2_norm_expr = + Expression::try_new(ScalarFn::new(L2Norm, EmptyOptions).erased(), [root()])?; + let norms_prim: PrimitiveArray = vector_array.apply(&l2_norm_expr)?.execute(ctx)?; + let norms_array = norms_prim.clone().into_array(); + + // Extract flat elements from the (always non-nullable) storage for normalization. + let flat = extract_flat_elements(&storage, list_size, ctx)?; match_each_float_ptype!(flat.ptype(), |T| { let norms_slice = norms_prim.as_slice::(); @@ -129,10 +141,12 @@ impl NormVectorArray { }) .collect(); + // Reconstruct the vector array with the same nullability as the input. + let validity = Validity::from(nullability); let fsl = FixedSizeListArray::new( normalized_elems.into_array(), u32::try_from(list_size)?, - Validity::NonNullable, + validity, row_count, ); @@ -140,7 +154,7 @@ impl NormVectorArray { ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); let normalized_vector = ExtensionArray::new(ext_dtype, fsl.into_array()).into_array(); - Self::try_new(normalized_vector, norms) + Self::try_new(normalized_vector, norms_array) }) } @@ -149,31 +163,26 @@ impl NormVectorArray { &self.vector_array } - /// Returns a reference to the L2 (Frobenius) norms of each vector. + /// Returns a reference to the L2 norms of each vector. pub fn norms(&self) -> &ArrayRef { &self.norms } /// Reconstructs the original vectors by multiplying each unit-normalized vector by its L2 norm. - pub fn decompress(&self, _ctx: &mut ExecutionCtx) -> VortexResult { - let ext_dtype = self - .vector_array - .dtype() - .as_extension_opt() - .ok_or_else(|| { - vortex_err!( - "expected Vector extension dtype, got {}", - self.vector_array.dtype() - ) - })?; - - let list_size = extension_list_size(ext_dtype)?; + /// + /// The returned array has the same dtype (including nullability) as the original + /// `vector_array` child. + pub fn decompress(&self, ctx: &mut ExecutionCtx) -> VortexResult { + let ext = Self::validate(&self.vector_array)?; + let nullability = Nullability::from(self.vector_array.dtype().is_nullable()); + + let list_size = extension_list_size(&ext)?; let row_count = self.vector_array.len(); let storage = extension_storage(&self.vector_array)?; - let flat = extract_flat_elements(&storage, list_size)?; + let flat = extract_flat_elements(&storage, list_size, ctx)?; - let norms_prim = self.norms.to_canonical()?.into_primitive(); + let norms_prim: PrimitiveArray = self.norms.clone().execute(ctx)?; match_each_float_ptype!(flat.ptype(), |T| { let norms_slice = norms_prim.as_slice::(); @@ -185,10 +194,11 @@ impl NormVectorArray { }) .collect(); + let validity = Validity::from(nullability); let fsl = FixedSizeListArray::new( result_elems.into_array(), u32::try_from(list_size)?, - Validity::NonNullable, + validity, row_count, ); diff --git a/vortex-tensor/src/encodings/norm/mod.rs b/vortex-tensor/src/encodings/norm/mod.rs index 9cd20e5cdac..4a060567bbc 100644 --- a/vortex-tensor/src/encodings/norm/mod.rs +++ b/vortex-tensor/src/encodings/norm/mod.rs @@ -4,7 +4,7 @@ mod array; pub use array::NormVectorArray; -// pub(crate) mod compute; +// TODO: Compute operations for NormVector. mod vtable; pub use vtable::NormVector; diff --git a/vortex-tensor/src/encodings/norm/tests.rs b/vortex-tensor/src/encodings/norm/tests.rs index ef87e18d912..66b3cdb14dc 100644 --- a/vortex-tensor/src/encodings/norm/tests.rs +++ b/vortex-tensor/src/encodings/norm/tests.rs @@ -2,15 +2,15 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex::array::IntoArray; +use vortex::array::LEGACY_SESSION; use vortex::array::VortexSessionExecute; use vortex::array::arrays::Extension; +use vortex::array::arrays::PrimitiveArray; use vortex::error::VortexResult; use crate::encodings::norm::NormVectorArray; -use crate::utils::extension_list_size; -use crate::utils::extension_storage; -use crate::utils::extract_flat_elements; use crate::utils::test_helpers::assert_close; +use crate::utils::test_helpers::extract_vector_rows; use crate::utils::test_helpers::vector_array; #[test] @@ -24,17 +24,14 @@ fn encode_unit_vectors() -> VortexResult<()> { ], )?; - let norm = NormVectorArray::compress(arr)?; - let norms = norm.norms().to_canonical()?.into_primitive(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let norm = NormVectorArray::compress(arr, &mut ctx)?; + let norms: PrimitiveArray = norm.norms().clone().execute(&mut ctx)?; assert_close(norms.as_slice::(), &[1.0, 1.0]); - let vectors = norm.vector_array(); - let ext = vectors.dtype().as_extension_opt().unwrap(); - let list_size = extension_list_size(ext)?; - let storage = extension_storage(vectors)?; - let flat = extract_flat_elements(&storage, list_size)?; - assert_close(flat.row::(0), &[1.0, 0.0, 0.0]); - assert_close(flat.row::(1), &[0.0, 1.0, 0.0]); + let rows = extract_vector_rows(norm.vector_array(), &mut ctx)?; + assert_close(&rows[0], &[1.0, 0.0, 0.0]); + assert_close(&rows[1], &[0.0, 1.0, 0.0]); Ok(()) } @@ -49,44 +46,40 @@ fn encode_non_unit_vectors() -> VortexResult<()> { ], )?; - let norm = NormVectorArray::compress(arr)?; - let norms = norm.norms().to_canonical()?.into_primitive(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let norm = NormVectorArray::compress(arr, &mut ctx)?; + let norms: PrimitiveArray = norm.norms().clone().execute(&mut ctx)?; assert_close(norms.as_slice::(), &[5.0, 0.0]); - let vectors = norm.vector_array(); - let ext = vectors.dtype().as_extension_opt().unwrap(); - let list_size = extension_list_size(ext)?; - let storage = extension_storage(vectors)?; - let flat = extract_flat_elements(&storage, list_size)?; - assert_close(flat.row::(0), &[3.0 / 5.0, 4.0 / 5.0]); - assert_close(flat.row::(1), &[0.0, 0.0]); + let rows = extract_vector_rows(norm.vector_array(), &mut ctx)?; + assert_close(&rows[0], &[3.0 / 5.0, 4.0 / 5.0]); + assert_close(&rows[1], &[0.0, 0.0]); Ok(()) } #[test] fn execute_round_trip() -> VortexResult<()> { - let original_elements = &[ - 3.0, 4.0, // norm = 5.0 - 6.0, 8.0, // norm = 10.0 - ]; - let arr = vector_array(2, original_elements)?; + let arr = vector_array( + 2, + &[ + 3.0, 4.0, // norm = 5.0 + 6.0, 8.0, // norm = 10.0 + ], + )?; - let norm = NormVectorArray::compress(arr)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let norm = NormVectorArray::compress(arr, &mut ctx)?; // Execute to reconstruct the original vectors. - let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx(); let reconstructed = norm.decompress(&mut ctx)?; // The reconstructed array should be a Vector extension array. assert!(reconstructed.as_opt::().is_some()); - let ext = reconstructed.dtype().as_extension_opt().unwrap(); - let list_size = extension_list_size(ext)?; - let storage = extension_storage(&reconstructed)?; - let flat = extract_flat_elements(&storage, list_size)?; - assert_close(flat.row::(0), &[3.0, 4.0]); - assert_close(flat.row::(1), &[6.0, 8.0]); + let rows = extract_vector_rows(&reconstructed, &mut ctx)?; + assert_close(&rows[0], &[3.0, 4.0]); + assert_close(&rows[1], &[6.0, 8.0]); Ok(()) } @@ -95,17 +88,14 @@ fn execute_round_trip() -> VortexResult<()> { fn execute_round_trip_zero_vector() -> VortexResult<()> { let arr = vector_array(2, &[0.0, 0.0])?; - let norm = NormVectorArray::compress(arr)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let norm = NormVectorArray::compress(arr, &mut ctx)?; - let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx(); let reconstructed = norm.decompress(&mut ctx)?; - let ext = reconstructed.dtype().as_extension_opt().unwrap(); - let list_size = extension_list_size(ext)?; - let storage = extension_storage(&reconstructed)?; - let flat = extract_flat_elements(&storage, list_size)?; + let rows = extract_vector_rows(&reconstructed, &mut ctx)?; // Zero vector should remain zero after round-trip. - assert_close(flat.row::(0), &[0.0, 0.0]); + assert_close(&rows[0], &[0.0, 0.0]); Ok(()) } @@ -120,10 +110,10 @@ fn scalar_at_returns_original_vector() -> VortexResult<()> { ], )?; - let encoded = NormVectorArray::compress(arr)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let encoded = NormVectorArray::compress(arr, &mut ctx)?; // `scalar_at` on the NormVectorArray should match `scalar_at` on the decompressed result. - let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx(); let decompressed = encoded.decompress(&mut ctx)?; let norm_array = encoded.into_array(); diff --git a/vortex-tensor/src/encodings/norm/vtable/mod.rs b/vortex-tensor/src/encodings/norm/vtable/mod.rs index 14d16caba24..51a7dd4d47f 100644 --- a/vortex-tensor/src/encodings/norm/vtable/mod.rs +++ b/vortex-tensor/src/encodings/norm/vtable/mod.rs @@ -60,7 +60,7 @@ impl VTable for NormVector { } fn stats(array: &NormVectorArray) -> StatsSetRef<'_> { - array.vector_array().statistics() + array.stats_set.to_ref(array.as_ref()) } fn array_hash(array: &NormVectorArray, state: &mut H, precision: Precision) { @@ -144,7 +144,8 @@ impl VTable for NormVector { vortex_err!("NormVectorArray dtype must be an extension type, got {dtype}") })?; let element_ptype = extension_element_ptype(ext)?; - let norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); + let nullability = Nullability::from(dtype.is_nullable()); + let norms_dtype = DType::Primitive(element_ptype, nullability); let norms = children.get(1, &norms_dtype, len)?; NormVectorArray::try_new(vector_array, norms) @@ -161,8 +162,7 @@ impl VTable for NormVector { .try_into() .map_err(|_| vortex_err!("NormVectorArray requires exactly 2 children"))?; - array.vector_array = vector_array; - array.norms = norms; + *array = NormVectorArray::try_new(vector_array, norms)?; Ok(()) } diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 2b922649307..d2f9b539e6d 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -115,7 +115,7 @@ impl ScalarFnVTable for CosineSimilarity { &self, _options: &Self::Options, args: &dyn ExecutionArgs, - _ctx: &mut ExecutionCtx, + ctx: &mut ExecutionCtx, ) -> VortexResult { let lhs = args.get(0)?; let rhs = args.get(1)?; @@ -135,8 +135,8 @@ impl ScalarFnVTable for CosineSimilarity { let lhs_storage = extension_storage(&lhs)?; let rhs_storage = extension_storage(&rhs)?; - let lhs_flat = extract_flat_elements(&lhs_storage, list_size)?; - let rhs_flat = extract_flat_elements(&rhs_storage, list_size)?; + let lhs_flat = extract_flat_elements(&lhs_storage, list_size, ctx)?; + let rhs_flat = extract_flat_elements(&rhs_storage, list_size, ctx)?; match_each_float_ptype!(lhs_flat.ptype(), |T| { let result: PrimitiveArray = (0..row_count) diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 1535c7b7463..b605b4a844d 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -98,7 +98,7 @@ impl ScalarFnVTable for L2Norm { &self, _options: &Self::Options, args: &dyn ExecutionArgs, - _ctx: &mut ExecutionCtx, + ctx: &mut ExecutionCtx, ) -> VortexResult { let input = args.get(0)?; let row_count = args.row_count(); @@ -113,7 +113,7 @@ impl ScalarFnVTable for L2Norm { let list_size = extension_list_size(ext)?; let storage = extension_storage(&input)?; - let flat = extract_flat_elements(&storage, list_size)?; + let flat = extract_flat_elements(&storage, list_size, ctx)?; match_each_float_ptype!(flat.ptype(), |T| { let result: PrimitiveArray = (0..row_count) diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 0eb3e423ea0..2a99cf06b3e 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -2,10 +2,12 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex::array::ArrayRef; +use vortex::array::ExecutionCtx; use vortex::array::IntoArray; use vortex::array::arrays::Constant; use vortex::array::arrays::ConstantArray; use vortex::array::arrays::Extension; +use vortex::array::arrays::FixedSizeListArray; use vortex::array::arrays::PrimitiveArray; use vortex::dtype::DType; use vortex::dtype::NativePType; @@ -91,13 +93,17 @@ impl FlatElements { /// /// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is /// materialized to avoid expanding it to the full column length. -pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResult { +pub fn extract_flat_elements( + storage: &ArrayRef, + list_size: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult { if let Some(constant) = storage.as_opt::() { // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a huge // amount of data. let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); - let fsl = single.to_canonical()?.into_fixed_size_list(); - let elems = fsl.elements().to_canonical()?.into_primitive(); + let fsl: FixedSizeListArray = single.execute(ctx)?; + let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?; return Ok(FlatElements { elems, stride: 0, @@ -106,8 +112,8 @@ pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResu } // Otherwise we have to fully expand all of the data. - let fsl = storage.to_canonical()?.into_fixed_size_list(); - let elems = fsl.elements().to_canonical()?.into_primitive(); + let fsl: FixedSizeListArray = storage.clone().execute(ctx)?; + let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?; Ok(FlatElements { elems, stride: list_size, @@ -118,6 +124,7 @@ pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResu #[cfg(test)] pub mod test_helpers { use vortex::array::ArrayRef; + use vortex::array::ExecutionCtx; use vortex::array::IntoArray; use vortex::array::arrays::ConstantArray; use vortex::array::arrays::ExtensionArray; @@ -128,9 +135,13 @@ pub mod test_helpers { use vortex::dtype::Nullability; use vortex::dtype::extension::ExtDType; use vortex::error::VortexResult; + use vortex::error::vortex_err; use vortex::extension::EmptyMetadata; use vortex::scalar::Scalar; + use super::extension_list_size; + use super::extension_storage; + use super::extract_flat_elements; use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; use crate::vector::Vector; @@ -210,6 +221,25 @@ pub mod test_helpers { Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } + /// Extracts the f64 rows from a [`Vector`] extension array. + /// + /// Returns a `Vec>` where each inner vec is one vector's elements. + pub fn extract_vector_rows( + array: &ArrayRef, + ctx: &mut ExecutionCtx, + ) -> VortexResult>> { + let ext = array + .dtype() + .as_extension_opt() + .ok_or_else(|| vortex_err!("expected Vector extension dtype, got {}", array.dtype()))?; + let list_size = extension_list_size(ext)?; + let storage = extension_storage(array)?; + let flat = extract_flat_elements(&storage, list_size, ctx)?; + Ok((0..array.len()) + .map(|i| flat.row::(i).to_vec()) + .collect()) + } + /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected` /// value, with support for NaN (NaN == NaN is considered equal). #[track_caller]