Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Improve histogram bin logic #18761

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 193 additions & 102 deletions crates/polars-ops/src/chunked_array/hist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,143 +3,234 @@ use std::fmt::Write;
use num_traits::ToPrimitive;
use polars_core::prelude::*;
use polars_core::with_match_physical_numeric_polars_type;
use polars_utils::total_ord::ToTotalOrd;

fn compute_hist<T>(
const DEFAULT_BIN_COUNT: usize = 10;

fn get_breaks<T>(
ca: &ChunkedArray<T>,
bin_count: Option<usize>,
bins: Option<&[f64]>,
include_category: bool,
include_breakpoint: bool,
) -> Series
) -> PolarsResult<(Vec<f64>, bool, bool)>
where
T: PolarsNumericType,
ChunkedArray<T>: ChunkAgg<T::Native>,
{
let mut lower_bound: f64;
let (breaks, count) = if let Some(bins) = bins {
let mut breaks = Vec::with_capacity(bins.len() + 1);
breaks.extend_from_slice(bins);
breaks.sort_unstable_by_key(|k| k.to_total_ord());
breaks.push(f64::INFINITY);

let sorted = ca.sort(false);

let mut count: Vec<IdxSize> = Vec::with_capacity(breaks.len());
let mut current_count: IdxSize = 0;
let mut breaks_iter = breaks.iter();

// We start with the lower garbage bin.
// (-inf, B0]
lower_bound = f64::NEG_INFINITY;
let mut upper_bound = *breaks_iter.next().unwrap();

for chunk in sorted.downcast_iter() {
for item in chunk.non_null_values_iter() {
let item = item.to_f64().unwrap();

// Not a member of current interval
if !(item <= upper_bound && item > lower_bound) {
loop {
// So we push the previous interval
count.push(current_count);
current_count = 0;
lower_bound = upper_bound;
upper_bound = *breaks_iter.next().unwrap();
if item <= upper_bound && item > lower_bound {
break;
}
let mut pad_lower = false;
let (bins, uniform) = match (bin_count, bins) {
(Some(_), Some(_)) => {
return Err(PolarsError::ComputeError(
"can only provide one of `bin_count` or `bins`".into(),
));
},
(None, Some(bins)) => {
// User-supplied bins. Note these are actually bin edges. Check for monotonicity.
// If we only have one edge, we have no bins.
let bin_len = bins.len();
// We also check for uniformity of bins. We declare uniformity if the difference
// between the largest and smallest bin is < 0.00001 the average bin size.
if bin_len > 1 {
let mut smallest = bins[1] - bins[0];
let mut largest = smallest;
let mut avg_bin_size = smallest;
for i in 1..bins.len() {
let d = bins[i] - bins[i - 1];
if d <= 0.0 {
return Err(PolarsError::ComputeError(
"bins must increase monotonically".into(),
));
}
if d > largest {
largest = d;
} else if d < smallest {
smallest = d;
}
avg_bin_size += d;
}
current_count += 1;
let uniform = (largest - smallest) / (avg_bin_size / bin_len as f64) < 0.00001;
(bins.to_vec(), uniform)
} else {
(Vec::<f64>::new(), false) // uniformity doesn't matter here
}
},
(bin_count, None) => {
// User-supplied bin count, or 10 by default. Compute edges from the data.
let bin_count = bin_count.unwrap_or(DEFAULT_BIN_COUNT);
let n = ca.len() - ca.null_count();
let (offset, width) = if n == 0 {
// No non-null items; supply unit interval.
(0.0, 1.0 / bin_count as f64)
} else if n == 1 {
// Unit interval around single point
let idx = ca.first_non_null().unwrap();
// SAFETY: idx is guaranteed to contain an element.
let center = unsafe { ca.get_unchecked(idx) }.unwrap().to_f64().unwrap();
(center - 0.5, 1.0 / bin_count as f64)
} else {
// Determine outer bin edges from the data itself
let min_value = ca.min().unwrap().to_f64().unwrap();
let max_value = ca.max().unwrap().to_f64().unwrap();
pad_lower = true;
(min_value, (max_value - min_value) / bin_count as f64)
};
let out = (0..bin_count + 1)
.map(|x| (x as f64 * width) + offset)
.collect::<Vec<f64>>();
(out, true)
},
};
Ok((bins, uniform, pad_lower))
}

// O(n) implementation when buckets are fixed-size.
// We deposit items directly into their buckets.
fn uniform_hist_count<T>(breaks: &[f64], ca: &ChunkedArray<T>, include_lower: bool) -> Vec<IdxSize>
where
T: PolarsNumericType,
ChunkedArray<T>: ChunkAgg<T::Native>,
{
let num_bins = breaks.len() - 1;
let mut count: Vec<IdxSize> = vec![0; num_bins];
let min_break: f64 = breaks[0];
let max_break: f64 = breaks[num_bins];
let width = breaks[1] - min_break; // guaranteed at least one bin
let is_integer = !T::get_dtype().is_float();

for chunk in ca.downcast_iter() {
for item in chunk.non_null_values_iter() {
let item = item.to_f64().unwrap();
if include_lower && item == min_break {
count[0] += 1;
} else if item > min_break && item <= max_break {
let idx = (item - min_break) / width;
// This is needed for numeric stability for integers.
// We can fall directly on a boundary with an integer.
let idx = if is_integer && (idx.round() - idx).abs() < 0.0000001 {
idx.round() - 1.0
} else {
idx.ceil() - 1.0
};
count[idx as usize] += 1;
}
}
// Add last value, this is the garbage bin. E.g. anything that doesn't fit in the bounds.
count.push(current_count);
// Add the remaining buckets
while count.len() < breaks.len() {
count.push(0)
}
// Push lower bound to infinity
lower_bound = f64::NEG_INFINITY;
(breaks, count)
} else if ca.null_count() == ca.len() {
lower_bound = f64::NEG_INFINITY;
let breaks: Vec<f64> = vec![f64::INFINITY];
let count: Vec<IdxSize> = vec![0];
(breaks, count)
} else {
let start = ChunkAgg::min(ca).unwrap().to_f64().unwrap();
let end = ChunkAgg::max(ca).unwrap().to_f64().unwrap();
}
count
}

// If bin_count is omitted, default to the difference between start and stop (unit bins)
let bin_count = if let Some(bin_count) = bin_count {
bin_count
} else {
(end - start).round() as usize
};
// Variable-width bucketing. We sort the items and then move linearly through buckets.
fn hist_count<T>(breaks: &[f64], ca: &ChunkedArray<T>, include_lower: bool) -> Vec<IdxSize>
where
T: PolarsNumericType,
ChunkedArray<T>: ChunkAgg<T::Native>,
{
let exclude_lower = !include_lower;
let num_bins = breaks.len() - 1;
let mut breaks_iter = breaks.iter().skip(1); // Skip the first lower bound
let (min_break, max_break) = (breaks[0], breaks[breaks.len() - 1]);
let mut upper_bound = *breaks_iter.next().unwrap();
let sorted = ca.sort(false).rechunk();
let mut current_count: IdxSize = 0;
let chunk = sorted.downcast_iter().next().unwrap();
let mut count: Vec<IdxSize> = Vec::with_capacity(num_bins);

// Calculate the breakpoints and make the array. The breakpoints form the RHS of the bins.
let interval = (end - start) / (bin_count as f64);
let breaks_iter = (1..(bin_count)).map(|b| start + (b as f64) * interval);
let mut breaks = Vec::with_capacity(breaks_iter.size_hint().0 + 1);
breaks.extend(breaks_iter);

// Extend the left-most edge by 0.1% of the total range to include the minimum value.
let margin = (end - start) * 0.001;
lower_bound = start - margin;
breaks.push(end);

let mut count: Vec<IdxSize> = vec![0; bin_count];
let max_bin = breaks.len() - 1;
for chunk in ca.downcast_iter() {
for item in chunk.non_null_values_iter() {
let item = item.to_f64().unwrap();
let bin = ((((item - start) / interval).ceil() - 1.0) as usize).min(max_bin);
count[bin] += 1;
'item: for item in chunk.non_null_values_iter() {
let item = item.to_f64().unwrap();

// Cycle through items until we hit the first bucket.
if item < min_break || (exclude_lower && item == min_break) {
continue;
}

while item > upper_bound {
if item > max_break {
// No more items will fit in any buckets
break 'item;
}

// Finished with prior bucket; push, reset, and move to next.
count.push(current_count);
current_count = 0;
upper_bound = *breaks_iter.next().unwrap();
}
(breaks, count)

// Item is in bound.
current_count += 1;
}
count.push(current_count);
count.resize(num_bins, 0); // If we left early, fill remainder with 0.
count
}

fn compute_hist<T>(
ca: &ChunkedArray<T>,
bin_count: Option<usize>,
bins: Option<&[f64]>,
include_category: bool,
include_breakpoint: bool,
) -> PolarsResult<Series>
where
T: PolarsNumericType,
ChunkedArray<T>: ChunkAgg<T::Native>,
{
let (breaks, uniform, pad_lower) = get_breaks(ca, bin_count, bins)?;
let num_bins = std::cmp::max(breaks.len(), 1) - 1;
let count = if num_bins > 0 && ca.len() > ca.null_count() {
if uniform {
uniform_hist_count(&breaks, ca, pad_lower)
} else {
hist_count(&breaks, ca, pad_lower)
}
} else {
vec![0; num_bins]
};

// Generate output: breakpoint (optional), breaks (optional), count
let mut fields = Vec::with_capacity(3);

if include_breakpoint {
let breakpoints = if num_bins > 0 {
Series::new(PlSmallStr::from_static("breakpoint"), &breaks[1..])
} else {
let empty: &[f64; 0] = &[];
Series::new(PlSmallStr::from_static("breakpoint"), empty)
};
fields.push(breakpoints)
}

if include_category {
// Use AnyValue for formatting.
let mut lower = AnyValue::Float64(lower_bound);
let mut categories =
StringChunkedBuilder::new(PlSmallStr::from_static("category"), breaks.len());

let mut buf = String::new();
for br in &breaks {
let br = AnyValue::Float64(*br);
buf.clear();
write!(buf, "({lower}, {br}]").unwrap();
categories.append_value(buf.as_str());
lower = br;
if num_bins > 0 {
let mut lower = AnyValue::Float64(if pad_lower {
breaks[0] - (breaks[num_bins] - breaks[0]) * 0.001
} else {
breaks[0]
});
let mut buf = String::new();
for br in &breaks[1..] {
let br = AnyValue::Float64(*br);
buf.clear();
write!(buf, "({lower}, {br}]").unwrap();
categories.append_value(buf.as_str());
lower = br;
}
}
let categories = categories
.finish()
.cast(&DataType::Categorical(None, Default::default()))
.unwrap();
fields.push(categories);
};
if include_breakpoint {
fields.insert(
0,
Series::new(PlSmallStr::from_static("breakpoint"), breaks),
)
}

let count = Series::new(PlSmallStr::from_static("count"), count);
fields.push(count);

if fields.len() == 1 {
let out = fields.pop().unwrap();
out.with_name(ca.name().clone())
Ok(if fields.len() == 1 {
fields.pop().unwrap().with_name(ca.name().clone())
} else {
StructChunked::from_series(ca.name().clone(), fields[0].len(), fields.iter())
.unwrap()
.into_series()
}
})
}

pub fn hist_series(
Expand All @@ -165,7 +256,7 @@ pub fn hist_series(

let out = with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
compute_hist(ca, bin_count, bins_arg, include_category, include_breakpoint)
compute_hist(ca, bin_count, bins_arg, include_category, include_breakpoint)?
});
Ok(out)
}
24 changes: 10 additions & 14 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10078,33 +10078,29 @@ def hist(
--------
>>> df = pl.DataFrame({"a": [1, 3, 8, 8, 2, 1, 3]})
>>> df.select(pl.col("a").hist(bins=[1, 2, 3]))
shape: (4, 1)
shape: (2, 1)
┌─────┐
│ a │
│ --- │
│ u32 │
╞═════╡
│ 2 │
│ 1 │
│ 2 │
│ 2 │
└─────┘
>>> df.select(
... pl.col("a").hist(
... bins=[1, 2, 3], include_breakpoint=True, include_category=True
... )
... )
shape: (4, 1)
┌───────────────────────┐
│ a │
│ --- │
│ struct[3] │
╞═══════════════════════╡
│ {1.0,"(-inf, 1.0]",2} │
│ {2.0,"(1.0, 2.0]",1} │
│ {3.0,"(2.0, 3.0]",2} │
│ {inf,"(3.0, inf]",2} │
└───────────────────────┘
shape: (2, 1)
┌──────────────────────┐
│ a │
│ --- │
│ struct[3] │
╞══════════════════════╡
│ {2.0,"(1.0, 2.0]",1} │
│ {3.0,"(2.0, 3.0]",2} │
└──────────────────────┘
"""
if bins is not None:
if isinstance(bins, list):
Expand Down
Loading