From fd0992b4c2bf84105492571ab583c40d6a947fcd Mon Sep 17 00:00:00 2001 From: ritchie Date: Sun, 20 Oct 2024 11:09:29 +0200 Subject: [PATCH 1/5] wip: gather --- .../src/array/fixed_size_list/mod.rs | 1 + .../src/compute/take/fixed_size_list.rs | 216 +++++++++++++++++- crates/polars-arrow/src/datatypes/mod.rs | 18 ++ .../src/datatypes/physical_type.rs | 4 + crates/polars-arrow/src/util/macros.rs | 1 + .../arithmetic/test_list_arithmetic.py | 4 + 6 files changed, 240 insertions(+), 4 deletions(-) diff --git a/crates/polars-arrow/src/array/fixed_size_list/mod.rs b/crates/polars-arrow/src/array/fixed_size_list/mod.rs index b8340825d0c7..bff41334dcb5 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/mod.rs @@ -144,6 +144,7 @@ impl FixedSizeListArray { /// # Safety /// The caller must ensure that `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + debug_assert!(offset + length <= self.len()); self.validity = self .validity .take() diff --git a/crates/polars-arrow/src/compute/take/fixed_size_list.rs b/crates/polars-arrow/src/compute/take/fixed_size_list.rs index 84f15bd44791..a0d18e8cd3e0 100644 --- a/crates/polars-arrow/src/compute/take/fixed_size_list.rs +++ b/crates/polars-arrow/src/compute/take/fixed_size_list.rs @@ -17,10 +17,13 @@ use super::Index; use crate::array::growable::{Growable, GrowableFixedSizeList}; -use crate::array::{FixedSizeListArray, PrimitiveArray}; +use crate::array::{Array, ArrayRef, FixedSizeListArray, PrimitiveArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::{ArrowDataType, PhysicalType}; +use crate::legacy::prelude::FromData; +use crate::{with_match_primitive_type}; -/// `take` implementation for FixedSizeListArrays -pub(super) unsafe fn take_unchecked( +pub(super) unsafe fn take_unchecked_slow( values: &FixedSizeListArray, indices: &PrimitiveArray, ) -> FixedSizeListArray { @@ -31,7 +34,7 @@ pub(super) unsafe fn take_unchecked( .iter() .map(|index| { let index = index.to_usize(); - let slice = values.clone().sliced(index, take_len); + let slice = values.clone().sliced_unchecked(index, take_len); capacity += slice.len(); slice }) @@ -62,3 +65,208 @@ pub(super) unsafe fn take_unchecked( growable.into() } } + +fn get_stride_and_leaf_type(dtype: &ArrowDataType, size: usize) -> (usize, &ArrowDataType) { + if let ArrowDataType::FixedSizeList(inner, size_inner) = dtype { + get_stride_and_leaf_type(inner.dtype(), *size_inner * size) + } else { + (size, dtype) + } +} + +fn get_leaves(array: &FixedSizeListArray) -> &dyn Array { + if let Some(array) = array.values().as_any().downcast_ref::() { + get_leaves(array) + } else { + &**array.values() + } +} + +fn get_buffer_and_size(array: &dyn Array) -> (&[u8], usize) { + match array.dtype().to_physical_type() { + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + + let arr = array.as_any().downcast_ref::>().unwrap(); + let values = arr.values(); + (bytemuck::cast_slice(values), size_of::<$T>()) + + }), + _ => { + unimplemented!() + }, + } +} + +unsafe fn from_buffer(mut buf: Vec, dtype: &ArrowDataType) -> ArrayRef { + assert_eq!(buf.as_ptr().align_offset(256), 0); + + match dtype.to_physical_type() { + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + + let ptr = buf.as_mut_ptr(); + let len_units = buf.len(); + let cap_units = buf.capacity(); + + std::mem::forget(buf); + + let buf = Vec::from_raw_parts( + ptr as *mut $T, + len_units / size_of::<$T>(), + cap_units / size_of::<$T>(), + ); + + PrimitiveArray::from_data_default(buf.into(), None).boxed() + + }), + _ => { + unimplemented!() + }, + } +} + + +// Use an alignedvec so the alignment always fits the actual type +// That way we can operate on bytes and reduce monomorphization. +#[repr(C, align(256))] +struct Align256([u8; 256]); + +unsafe fn aligned_vec(n_bytes: usize) -> Vec { + // Lazy math to ensure we always have enough. + let n_units = (n_bytes / size_of::()) + 1; + + let mut aligned: Vec = Vec::with_capacity(n_units); + + let ptr = aligned.as_mut_ptr(); + let len_units = aligned.len(); + let cap_units = aligned.capacity(); + + std::mem::forget(aligned); + + Vec::from_raw_parts( + ptr as *mut u8, + len_units * size_of::(), + cap_units * size_of::(), + ) +} + +fn replace_leaves(arr: &FixedSizeListArray, leaves: ArrayRef) -> FixedSizeListArray { + if let Some(arr) = arr.values().as_any().downcast_ref::() { + replace_leaves(arr, leaves) + } else { + FixedSizeListArray::new(arr.dtype().clone(), if arr.size() == 0 { 0 } else { leaves.len() / arr.size() }, leaves, None) + } +} + +fn no_inner_validities(values: &ArrayRef) -> bool { + if let Some(arr) = values.as_any().downcast_ref::() { + arr.validity().is_none() && no_inner_validities(arr.values()) + } else { + values.validity().is_none() + } +} + +/// `take` implementation for FixedSizeListArrays +pub(super) unsafe fn take_unchecked( + values: &FixedSizeListArray, + indices: &PrimitiveArray, +) -> FixedSizeListArray { + + let (stride, leaf_type) = get_stride_and_leaf_type(values.dtype(), 1); + if leaf_type.to_physical_type().is_primitive() && no_inner_validities(values.values()) { + let leaves = get_leaves(values); + + let (leaves_buf, leave_size) = get_buffer_and_size(leaves); + let bytes_per_element = leave_size * stride; + + let n_idx = indices.len(); + let total_bytes = bytes_per_element * n_idx; + + let mut buf = aligned_vec(total_bytes); + let dst = buf.spare_capacity_mut(); + + let mut count = 0; + let validity = if indices.null_count() == 0 { + dbg!("no-null"); + for i in indices.values().iter() { + let i = i.to_usize(); + + std::ptr::copy_nonoverlapping(leaves_buf.as_ptr().add(i * bytes_per_element), dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, bytes_per_element); + count += 1; + } + None + } else { + dbg!("null"); + let mut new_validity = MutableBitmap::with_capacity(indices.len()); + let validity = indices.validity().unwrap(); + for i in indices.values().iter() { + let i = i.to_usize(); + + if validity.get_bit_unchecked(i) { + new_validity.push_unchecked(true); + std::ptr::copy_nonoverlapping(leaves_buf.as_ptr().add(i * bytes_per_element), dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, bytes_per_element); + } else { + new_validity.push_unchecked(false); + std::ptr::write_bytes(dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, 0, bytes_per_element); + } + + count += 1; + } + Some(new_validity.freeze()) + }; + assert_eq!(count * bytes_per_element, total_bytes); + + buf.set_len(total_bytes); + + + let leaves = from_buffer(buf, leaves.dtype()); + replace_leaves(&values, leaves).with_validity(validity) + + } else { + dbg!("slow"); + take_unchecked_slow(values, indices) + } + + + + + +} + + +#[cfg(test)] +mod test { + use polars_utils::pl_str::PlSmallStr; + use crate::datatypes::Field; + use super::*; + + #[test] + fn test_gather_fixed_size_list() { + + let s = PlSmallStr::EMPTY; + let f = Field::new(s, ArrowDataType::Int16, true); + let dt = ArrowDataType::FixedSizeList(Box::new(f), 2); + + let values = PrimitiveArray::from_data_default(vec![0i16, 1, 2, 3, 4, 5, 6, 7].into(), None); + let arr = FixedSizeListArray::new(dt.clone(), 4, values.boxed(), None); + + + let idx = PrimitiveArray::from_data_default(vec![2u32, 1, 0, 0, 1, 2].into(), None); + + unsafe { + dbg!(take_unchecked(&arr, &idx)); + } + + let f = Field::new(PlSmallStr::EMPTY, dt, true); + let dt = ArrowDataType::FixedSizeList(Box::new(f), 2); + let arr = FixedSizeListArray::new(dt, 2, arr.boxed(), None); + + dbg!(&arr); + let idx = PrimitiveArray::from_data_default(vec![0u32, 1, 0].into(), None); + + unsafe { + dbg!(take_unchecked(&arr, &idx)); + } + + + } +} \ No newline at end of file diff --git a/crates/polars-arrow/src/datatypes/mod.rs b/crates/polars-arrow/src/datatypes/mod.rs index c609ffbe432f..02b423891667 100644 --- a/crates/polars-arrow/src/datatypes/mod.rs +++ b/crates/polars-arrow/src/datatypes/mod.rs @@ -365,6 +365,24 @@ impl ArrowDataType { matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView) } + pub fn is_numeric(&self) -> bool { + use ArrowDataType as D; + matches!(self, + D::Int8 + | D::Int16 + | D::Int32 + | D::Int64 + | D::UInt8 + | D::UInt16 + | D::UInt32 + | D::UInt64 + | D::Float32 + | D::Float64 + | D::Decimal(_, _) + | D::Decimal256(_, _) + ) + } + pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType { ArrowDataType::FixedSizeList( Box::new(Field::new( diff --git a/crates/polars-arrow/src/datatypes/physical_type.rs b/crates/polars-arrow/src/datatypes/physical_type.rs index 174c0401ca3f..732a129055a6 100644 --- a/crates/polars-arrow/src/datatypes/physical_type.rs +++ b/crates/polars-arrow/src/datatypes/physical_type.rs @@ -57,6 +57,10 @@ impl PhysicalType { false } } + + pub fn is_primitive(&self) -> bool { + matches!(self, Self::Primitive(_)) + } } /// the set of valid indices types of a dictionary-encoded Array. diff --git a/crates/polars-arrow/src/util/macros.rs b/crates/polars-arrow/src/util/macros.rs index b09a9d5d5473..fb5bd61ebba0 100644 --- a/crates/polars-arrow/src/util/macros.rs +++ b/crates/polars-arrow/src/util/macros.rs @@ -13,6 +13,7 @@ macro_rules! with_match_primitive_type {( UInt16 => __with_ty__! { u16 }, UInt32 => __with_ty__! { u32 }, UInt64 => __with_ty__! { u64 }, + Int128 => __with_ty__! { i128 }, Float32 => __with_ty__! { f32 }, Float64 => __with_ty__! { f64 }, _ => panic!("operator does not support primitive `{:?}`", diff --git a/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py index 64ad0e533d8d..0e7af02d8290 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py @@ -101,6 +101,7 @@ def func( BROADCAST_SERIES_COMBINATIONS, ) @pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.slow def test_list_arithmetic_values( list_side: str, broadcast_series: Callable[ @@ -380,6 +381,7 @@ def test_list_add_supertype( "broadcast_series", BROADCAST_SERIES_COMBINATIONS, ) +@pytest.mark.slow def test_list_numeric_op_validity_combination( broadcast_series: Callable[ [pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series] @@ -451,6 +453,7 @@ def test_list_add_alignment() -> None: @pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.slow def test_list_add_empty_lists( exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], ) -> None: @@ -516,6 +519,7 @@ def test_list_add_height_mismatch( ], ) @pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.slow def test_list_date_to_numeric_arithmetic_raises_error( op: Callable[[Any], Any], exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series] ) -> None: From f79b8570bab8881f0fea9a375bf72b46c91f9586 Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 21 Oct 2024 10:18:17 +0200 Subject: [PATCH 2/5] wrap-up --- .../src/array/fixed_size_list/mod.rs | 107 ++++++++++++++++- .../polars-arrow/src/array/growable/list.rs | 2 +- .../src/compute/take/fixed_size_list.rs | 109 ++++++------------ crates/polars-arrow/src/compute/take/mod.rs | 2 +- crates/polars-arrow/src/datatypes/mod.rs | 26 +++-- .../src/datatypes/reshape.rs | 0 .../src/chunked_array/ops/append.rs | 2 +- .../src/chunked_array/ops/gather.rs | 48 +++++++- crates/polars-core/src/datatypes/mod.rs | 5 +- crates/polars-core/src/series/ops/reshape.rs | 2 +- .../tests/unit/operations/test_gather.py | 12 ++ 11 files changed, 220 insertions(+), 95 deletions(-) rename crates/{polars-core => polars-arrow}/src/datatypes/reshape.rs (100%) diff --git a/crates/polars-arrow/src/array/fixed_size_list/mod.rs b/crates/polars-arrow/src/array/fixed_size_list/mod.rs index bff41334dcb5..32267cc5a4b7 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/mod.rs @@ -1,4 +1,4 @@ -use super::{new_empty_array, new_null_array, Array, Splitable}; +use super::{new_empty_array, new_null_array, Array, ArrayRef, Splitable}; use crate::bitmap::Bitmap; use crate::datatypes::{ArrowDataType, Field}; @@ -9,8 +9,11 @@ mod iterator; mod mutable; pub use mutable::*; use polars_error::{polars_bail, polars_ensure, PolarsResult}; +use polars_utils::format_tuple; use polars_utils::pl_str::PlSmallStr; +use crate::datatypes::reshape::{Dimension, ReshapeDimension}; + /// The Arrow's equivalent to an immutable `Vec>` where `T` is an Arrow type. /// Cloning and slicing this struct is `O(1)`. #[derive(Clone)] @@ -120,6 +123,108 @@ impl FixedSizeListArray { let values = new_null_array(field.dtype().clone(), length * size); Self::new(dtype, length, values, Some(Bitmap::new_zeroed(length))) } + + pub fn from_shape( + leaf_array: ArrayRef, + dimensions: &[ReshapeDimension], + ) -> PolarsResult { + polars_ensure!( + !dimensions.is_empty(), + InvalidOperation: "at least one dimension must be specified" + ); + let size = leaf_array.len(); + + let mut total_dim_size = 1; + let mut num_infers = 0; + for &dim in dimensions { + match dim { + ReshapeDimension::Infer => num_infers += 1, + ReshapeDimension::Specified(dim) => total_dim_size *= dim.get() as usize, + } + } + + polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension"); + + if size == 0 { + polars_ensure!( + num_infers > 0 || total_dim_size == 0, + InvalidOperation: "cannot reshape empty array into shape without zero dimension: {}", + format_tuple!(dimensions), + ); + + let mut prev_arrow_dtype = leaf_array.dtype().clone(); + let mut prev_array = leaf_array; + + // @NOTE: We need to collect the iterator here because it is lazily processed. + let mut current_length = dimensions[0].get_or_infer(0); + let len_iter = dimensions[1..] + .iter() + .map(|d| { + let length = current_length as usize; + current_length *= d.get_or_infer(0); + length + }) + .collect::>(); + + // We pop the outer dimension as that is the height of the series. + for (dim, length) in dimensions[1..].iter().zip(len_iter).rev() { + // Infer dimension if needed + let dim = dim.get_or_infer(0); + prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true); + + prev_array = + FixedSizeListArray::new(prev_arrow_dtype.clone(), length, prev_array, None) + .boxed(); + } + + return Ok(prev_array); + } + + polars_ensure!( + total_dim_size > 0, + InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", + format_tuple!(dimensions) + ); + + polars_ensure!( + size % total_dim_size == 0, + InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions) + ); + + let mut prev_arrow_dtype = leaf_array.dtype().clone(); + let mut prev_array = leaf_array; + + // We pop the outer dimension as that is the height of the series. + for dim in dimensions[1..].iter().rev() { + // Infer dimension if needed + let dim = dim.get_or_infer((size / total_dim_size) as u64); + prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true); + + prev_array = FixedSizeListArray::new( + prev_arrow_dtype.clone(), + prev_array.len() / dim as usize, + prev_array, + None, + ) + .boxed(); + } + Ok(prev_array) + } + + pub fn get_dims(&self) -> Vec { + let mut dims = vec![ + Dimension::new(self.length as _), + Dimension::new(self.size as _), + ]; + + let mut prev_array = &self.values; + + while let Some(a) = prev_array.as_any().downcast_ref::() { + dims.push(Dimension::new(a.size as _)); + prev_array = &a.values; + } + dims + } } // must use diff --git a/crates/polars-arrow/src/array/growable/list.rs b/crates/polars-arrow/src/array/growable/list.rs index 90e4f15020a6..095f39522da4 100644 --- a/crates/polars-arrow/src/array/growable/list.rs +++ b/crates/polars-arrow/src/array/growable/list.rs @@ -14,7 +14,7 @@ unsafe fn extend_offset_values( start: usize, len: usize, ) { - let array = growable.arrays[index]; + let array = growable.arrays.get_unchecked_release(index); let offsets = array.offsets(); growable diff --git a/crates/polars-arrow/src/compute/take/fixed_size_list.rs b/crates/polars-arrow/src/compute/take/fixed_size_list.rs index a0d18e8cd3e0..2d1e2b082dc3 100644 --- a/crates/polars-arrow/src/compute/take/fixed_size_list.rs +++ b/crates/polars-arrow/src/compute/take/fixed_size_list.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. +use polars_utils::itertools::Itertools; + use super::Index; use crate::array::growable::{Growable, GrowableFixedSizeList}; use crate::array::{Array, ArrayRef, FixedSizeListArray, PrimitiveArray}; use crate::bitmap::MutableBitmap; +use crate::datatypes::reshape::{Dimension, ReshapeDimension}; use crate::datatypes::{ArrowDataType, PhysicalType}; use crate::legacy::prelude::FromData; -use crate::{with_match_primitive_type}; +use crate::with_match_primitive_type; pub(super) unsafe fn take_unchecked_slow( values: &FixedSizeListArray, @@ -124,7 +127,6 @@ unsafe fn from_buffer(mut buf: Vec, dtype: &ArrowDataType) -> ArrayRef { } } - // Use an alignedvec so the alignment always fits the actual type // That way we can operate on bytes and reduce monomorphization. #[repr(C, align(256))] @@ -149,14 +151,6 @@ unsafe fn aligned_vec(n_bytes: usize) -> Vec { ) } -fn replace_leaves(arr: &FixedSizeListArray, leaves: ArrayRef) -> FixedSizeListArray { - if let Some(arr) = arr.values().as_any().downcast_ref::() { - replace_leaves(arr, leaves) - } else { - FixedSizeListArray::new(arr.dtype().clone(), if arr.size() == 0 { 0 } else { leaves.len() / arr.size() }, leaves, None) - } -} - fn no_inner_validities(values: &ArrayRef) -> bool { if let Some(arr) = values.as_any().downcast_ref::() { arr.validity().is_none() && no_inner_validities(arr.values()) @@ -169,8 +163,7 @@ fn no_inner_validities(values: &ArrayRef) -> bool { pub(super) unsafe fn take_unchecked( values: &FixedSizeListArray, indices: &PrimitiveArray, -) -> FixedSizeListArray { - +) -> ArrayRef { let (stride, leaf_type) = get_stride_and_leaf_type(values.dtype(), 1); if leaf_type.to_physical_type().is_primitive() && no_inner_validities(values.values()) { let leaves = get_leaves(values); @@ -186,27 +179,35 @@ pub(super) unsafe fn take_unchecked( let mut count = 0; let validity = if indices.null_count() == 0 { - dbg!("no-null"); for i in indices.values().iter() { let i = i.to_usize(); - std::ptr::copy_nonoverlapping(leaves_buf.as_ptr().add(i * bytes_per_element), dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, bytes_per_element); + std::ptr::copy_nonoverlapping( + leaves_buf.as_ptr().add(i * bytes_per_element), + dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, + bytes_per_element, + ); count += 1; } None } else { - dbg!("null"); let mut new_validity = MutableBitmap::with_capacity(indices.len()); - let validity = indices.validity().unwrap(); - for i in indices.values().iter() { - let i = i.to_usize(); - - if validity.get_bit_unchecked(i) { - new_validity.push_unchecked(true); - std::ptr::copy_nonoverlapping(leaves_buf.as_ptr().add(i * bytes_per_element), dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, bytes_per_element); + new_validity.extend_constant(indices.len(), true); + for i in indices.iter() { + if let Some(i) = i { + let i = i.to_usize(); + std::ptr::copy_nonoverlapping( + leaves_buf.as_ptr().add(i * bytes_per_element), + dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, + bytes_per_element, + ); } else { - new_validity.push_unchecked(false); - std::ptr::write_bytes(dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, 0, bytes_per_element); + new_validity.set_unchecked(count, false); + std::ptr::write_bytes( + dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, + 0, + bytes_per_element, + ); } count += 1; @@ -217,56 +218,18 @@ pub(super) unsafe fn take_unchecked( buf.set_len(total_bytes); - let leaves = from_buffer(buf, leaves.dtype()); - replace_leaves(&values, leaves).with_validity(validity) - + let mut shape = values.get_dims(); + shape[0] = Dimension::new(indices.len() as _); + let shape = shape + .into_iter() + .map(ReshapeDimension::Specified) + .collect_vec(); + + FixedSizeListArray::from_shape(leaves.clone(), &shape) + .unwrap() + .with_validity(validity) } else { - dbg!("slow"); - take_unchecked_slow(values, indices) + take_unchecked_slow(values, indices).boxed() } - - - - - } - - -#[cfg(test)] -mod test { - use polars_utils::pl_str::PlSmallStr; - use crate::datatypes::Field; - use super::*; - - #[test] - fn test_gather_fixed_size_list() { - - let s = PlSmallStr::EMPTY; - let f = Field::new(s, ArrowDataType::Int16, true); - let dt = ArrowDataType::FixedSizeList(Box::new(f), 2); - - let values = PrimitiveArray::from_data_default(vec![0i16, 1, 2, 3, 4, 5, 6, 7].into(), None); - let arr = FixedSizeListArray::new(dt.clone(), 4, values.boxed(), None); - - - let idx = PrimitiveArray::from_data_default(vec![2u32, 1, 0, 0, 1, 2].into(), None); - - unsafe { - dbg!(take_unchecked(&arr, &idx)); - } - - let f = Field::new(PlSmallStr::EMPTY, dt, true); - let dt = ArrowDataType::FixedSizeList(Box::new(f), 2); - let arr = FixedSizeListArray::new(dt, 2, arr.boxed(), None); - - dbg!(&arr); - let idx = PrimitiveArray::from_data_default(vec![0u32, 1, 0].into(), None); - - unsafe { - dbg!(take_unchecked(&arr, &idx)); - } - - - } -} \ No newline at end of file diff --git a/crates/polars-arrow/src/compute/take/mod.rs b/crates/polars-arrow/src/compute/take/mod.rs index aed14823af1e..bdd782a1d609 100644 --- a/crates/polars-arrow/src/compute/take/mod.rs +++ b/crates/polars-arrow/src/compute/take/mod.rs @@ -68,7 +68,7 @@ pub unsafe fn take_unchecked(values: &dyn Array, indices: &IdxArr) -> Box { let array = values.as_any().downcast_ref().unwrap(); - Box::new(fixed_size_list::take_unchecked(array, indices)) + fixed_size_list::take_unchecked(array, indices) }, BinaryView => { take_binview_unchecked(values.as_any().downcast_ref().unwrap(), indices).boxed() diff --git a/crates/polars-arrow/src/datatypes/mod.rs b/crates/polars-arrow/src/datatypes/mod.rs index 02b423891667..0c0b7024bc71 100644 --- a/crates/polars-arrow/src/datatypes/mod.rs +++ b/crates/polars-arrow/src/datatypes/mod.rs @@ -2,6 +2,7 @@ mod field; mod physical_type; +pub mod reshape; mod schema; use std::collections::BTreeMap; @@ -367,19 +368,20 @@ impl ArrowDataType { pub fn is_numeric(&self) -> bool { use ArrowDataType as D; - matches!(self, + matches!( + self, D::Int8 - | D::Int16 - | D::Int32 - | D::Int64 - | D::UInt8 - | D::UInt16 - | D::UInt32 - | D::UInt64 - | D::Float32 - | D::Float64 - | D::Decimal(_, _) - | D::Decimal256(_, _) + | D::Int16 + | D::Int32 + | D::Int64 + | D::UInt8 + | D::UInt16 + | D::UInt32 + | D::UInt64 + | D::Float32 + | D::Float64 + | D::Decimal(_, _) + | D::Decimal256(_, _) ) } diff --git a/crates/polars-core/src/datatypes/reshape.rs b/crates/polars-arrow/src/datatypes/reshape.rs similarity index 100% rename from crates/polars-core/src/datatypes/reshape.rs rename to crates/polars-arrow/src/datatypes/reshape.rs diff --git a/crates/polars-core/src/chunked_array/ops/append.rs b/crates/polars-core/src/chunked_array/ops/append.rs index 383c76d63600..1cbf0da390e7 100644 --- a/crates/polars-core/src/chunked_array/ops/append.rs +++ b/crates/polars-core/src/chunked_array/ops/append.rs @@ -132,7 +132,7 @@ where impl ChunkedArray where - T: PolarsDataType, + T: PolarsDataType, for<'a> T::Physical<'a>: TotalOrd, { /// Append in place. This is done by adding the chunks of `other` to this [`ChunkedArray`]. diff --git a/crates/polars-core/src/chunked_array/ops/gather.rs b/crates/polars-core/src/chunked_array/ops/gather.rs index cb24305f75f6..fc162626bc27 100644 --- a/crates/polars-core/src/chunked_array/ops/gather.rs +++ b/crates/polars-core/src/chunked_array/ops/gather.rs @@ -143,7 +143,7 @@ unsafe fn gather_idx_array_unchecked( impl + ?Sized> ChunkTakeUnchecked for ChunkedArray where - T: PolarsDataType, + T: PolarsDataType, { /// Gather values from ChunkedArray by index. unsafe fn take_unchecked(&self, indices: &I) -> Self { @@ -178,7 +178,7 @@ pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) -> impl ChunkTakeUnchecked for ChunkedArray where - T: PolarsDataType, + T: PolarsDataType, { /// Gather values from ChunkedArray by index. unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { @@ -312,3 +312,47 @@ impl IdxCa { f(&ca) } } + +#[cfg(feature = "dtype-array")] +impl ChunkTakeUnchecked for ArrayChunked { + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { + let a = self.rechunk(); + let index = indices.rechunk(); + + let chunks = a + .downcast_iter() + .zip(index.downcast_iter()) + .map(|(arr, idx)| take_unchecked(arr, idx)) + .collect::>(); + self.copy_with_chunks(chunks) + } +} + +#[cfg(feature = "dtype-array")] +impl + ?Sized> ChunkTakeUnchecked for ArrayChunked { + unsafe fn take_unchecked(&self, indices: &I) -> Self { + let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref()); + self.take_unchecked(&idx) + } +} + +impl ChunkTakeUnchecked for ListChunked { + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { + let a = self.rechunk(); + let index = indices.rechunk(); + + let chunks = a + .downcast_iter() + .zip(index.downcast_iter()) + .map(|(arr, idx)| take_unchecked(arr, idx)) + .collect::>(); + self.copy_with_chunks(chunks) + } +} + +impl + ?Sized> ChunkTakeUnchecked for ListChunked { + unsafe fn take_unchecked(&self, indices: &I) -> Self { + let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref()); + self.take_unchecked(&idx) + } +} diff --git a/crates/polars-core/src/datatypes/mod.rs b/crates/polars-core/src/datatypes/mod.rs index 712466482ce2..8d84d47be978 100644 --- a/crates/polars-core/src/datatypes/mod.rs +++ b/crates/polars-core/src/datatypes/mod.rs @@ -13,7 +13,6 @@ mod any_value; mod dtype; mod field; mod into_scalar; -mod reshape; #[cfg(feature = "object")] mod static_array_collect; mod time_unit; @@ -26,6 +25,7 @@ use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub, SubAssign}; pub use aliases::*; pub use any_value::*; pub use arrow::array::{ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype, StaticArray}; +pub use arrow::datatypes::reshape::*; #[cfg(feature = "dtype-categorical")] use arrow::datatypes::IntegerType; pub use arrow::datatypes::{ArrowDataType, TimeUnit as ArrowTimeUnit}; @@ -42,7 +42,6 @@ use polars_utils::abs_diff::AbsDiff; use polars_utils::float::IsFloat; use polars_utils::min_max::MinMax; use polars_utils::nulls::IsNull; -pub use reshape::*; #[cfg(feature = "serde")] use serde::de::{EnumAccess, Error, Unexpected, VariantAccess, Visitor}; #[cfg(any(feature = "serde", feature = "serde-lazy"))] @@ -300,7 +299,7 @@ unsafe impl PolarsDataType for ObjectType { type OwnedPhysical = T; type ZeroablePhysical<'a> = Option<&'a T>; type Array = ObjectArray; - type IsNested = TrueT; + type IsNested = FalseT; type HasViews = FalseT; type IsStruct = FalseT; type IsObject = TrueT; diff --git a/crates/polars-core/src/series/ops/reshape.rs b/crates/polars-core/src/series/ops/reshape.rs index 89773abf794a..85998aa54de3 100644 --- a/crates/polars-core/src/series/ops/reshape.rs +++ b/crates/polars-core/src/series/ops/reshape.rs @@ -116,7 +116,7 @@ impl Series { InvalidOperation: "at least one dimension must be specified" ); - let leaf_array = self.get_leaf_array(); + let leaf_array = self.get_leaf_array().rechunk(); let size = leaf_array.len(); let mut total_dim_size = 1; diff --git a/py-polars/tests/unit/operations/test_gather.py b/py-polars/tests/unit/operations/test_gather.py index ddc891df04f1..c9811b3fd657 100644 --- a/py-polars/tests/unit/operations/test_gather.py +++ b/py-polars/tests/unit/operations/test_gather.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import polars as pl @@ -176,3 +177,14 @@ def test_gather_array_list_null_19302() -> None: assert data.select(pl.col("data").list.get(0)).to_dict(as_series=False) == { "data": [None] } + + +def test_gather_array() -> None: + a = np.arange(16).reshape(-1, 2, 2) + s = pl.Series(a) + + for idx in [[1, 2], [0, 0], [1, 0], [1, 1, 1, 1, 1, 1, 1, 1]]: + assert (s.gather(idx).to_numpy() == a[idx]).all() + + v = s[[0, 1, None, 3]] + assert v[2] is None From 6f0f4207dfa330d6bfe217d15bbf67847d7b7d92 Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 21 Oct 2024 10:29:13 +0200 Subject: [PATCH 3/5] mypy --- py-polars/polars/_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-polars/polars/_typing.py b/py-polars/polars/_typing.py index 894ab9e79346..82eacc12c2c9 100644 --- a/py-polars/polars/_typing.py +++ b/py-polars/polars/_typing.py @@ -275,7 +275,7 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any: MultiIndexSelector: TypeAlias = Union[ slice, range, - Sequence[int], + Sequence[int | None], "Series", "np.ndarray[Any, Any]", ] From 3e40c37d1cc0ff0848a2975c081b21ef9aaeb1cc Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 21 Oct 2024 10:32:25 +0200 Subject: [PATCH 4/5] from future --- .github/scripts/test_bytecode_parser.py | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/scripts/test_bytecode_parser.py b/.github/scripts/test_bytecode_parser.py index 2abe9bf5de13..327bddfb5933 100644 --- a/.github/scripts/test_bytecode_parser.py +++ b/.github/scripts/test_bytecode_parser.py @@ -13,6 +13,7 @@ Running it without `PYTHONPATH` set will result in the test failing. """ +from __future__ import annotations import datetime as dt # noqa: F401 import subprocess from datetime import datetime # noqa: F401 From 5111280b4d3519d453c79aac210e1c3644cfe59d Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 21 Oct 2024 10:51:49 +0200 Subject: [PATCH 5/5] ignore --- .github/scripts/test_bytecode_parser.py | 1 - py-polars/polars/_typing.py | 2 +- py-polars/tests/unit/operations/test_gather.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/scripts/test_bytecode_parser.py b/.github/scripts/test_bytecode_parser.py index 327bddfb5933..2abe9bf5de13 100644 --- a/.github/scripts/test_bytecode_parser.py +++ b/.github/scripts/test_bytecode_parser.py @@ -13,7 +13,6 @@ Running it without `PYTHONPATH` set will result in the test failing. """ -from __future__ import annotations import datetime as dt # noqa: F401 import subprocess from datetime import datetime # noqa: F401 diff --git a/py-polars/polars/_typing.py b/py-polars/polars/_typing.py index 82eacc12c2c9..894ab9e79346 100644 --- a/py-polars/polars/_typing.py +++ b/py-polars/polars/_typing.py @@ -275,7 +275,7 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any: MultiIndexSelector: TypeAlias = Union[ slice, range, - Sequence[int | None], + Sequence[int], "Series", "np.ndarray[Any, Any]", ] diff --git a/py-polars/tests/unit/operations/test_gather.py b/py-polars/tests/unit/operations/test_gather.py index c9811b3fd657..595b2bfee246 100644 --- a/py-polars/tests/unit/operations/test_gather.py +++ b/py-polars/tests/unit/operations/test_gather.py @@ -186,5 +186,5 @@ def test_gather_array() -> None: for idx in [[1, 2], [0, 0], [1, 0], [1, 1, 1, 1, 1, 1, 1, 1]]: assert (s.gather(idx).to_numpy() == a[idx]).all() - v = s[[0, 1, None, 3]] + v = s[[0, 1, None, 3]] # type: ignore[list-item] assert v[2] is None