Skip to content

Commit

Permalink
feat: merge_insert update subcolumns (lancedb#2639)
Browse files Browse the repository at this point in the history
Closes lancedb#2610

* Supports subschemas in `merge_insert` for updates only
  * Inserts and deletes left as TODO
* Field id `-2` is now reserved as a field "tombstone". These tombstones
are fields that are no longer in the schema, usually because those
fields are now in a different data file.
* Fixed a bug in `Merger` where statistics were reset on each batch.
  • Loading branch information
wjones127 authored Jul 30, 2024
1 parent 588c416 commit 6ebeaa0
Show file tree
Hide file tree
Showing 9 changed files with 852 additions and 101 deletions.
5 changes: 5 additions & 0 deletions protos/table.proto
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ message DataFile {
string path = 1;
// The ids of the fields/columns in this file.
//
// -1 is used for "unassigned" while in memory. It is not meant to be written
// to disk. -2 is used for "tombstoned", meaningful a field that is no longer
// in use. This is often because the original field id was reassigned to a
// different data file.
//
// In Lance v1 IDs are assigned based on position in the file, offset by the max
// existing field id in the table (if any already). So when a fragment is first
// created with one file of N columns, the field ids will be 1, 2, ..., N. If a
Expand Down
51 changes: 50 additions & 1 deletion python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,52 @@ def test_merge_insert(tmp_path: Path):
check_merge_stats(merge_dict, (None, None, None))


def test_merge_insert_subcols(tmp_path: Path):
initial_data = pa.table(
{
"a": range(10),
"b": range(10),
"c": range(10, 20),
}
)
# Split across two fragments
dataset = lance.write_dataset(
initial_data, tmp_path / "dataset", max_rows_per_file=5
)
original_fragments = dataset.get_fragments()

new_values = pa.table(
{
"a": range(3, 5),
"b": range(20, 22),
}
)
(dataset.merge_insert("a").when_matched_update_all().execute(new_values))

expected = pa.table(
{
"a": range(10),
"b": [0, 1, 2, 20, 21, 5, 6, 7, 8, 9],
"c": range(10, 20),
}
)
assert dataset.to_table().sort_by("a") == expected

# First fragment has new file
fragments = dataset.get_fragments()
assert fragments[0].fragment_id == original_fragments[0].fragment_id
assert fragments[1].fragment_id == original_fragments[1].fragment_id

assert len(fragments[0].data_files()) == 2
assert str(fragments[0].data_files()[0]) == str(
original_fragments[0].data_files()[0]
)
assert len(fragments[1].data_files()) == 1
assert str(fragments[1].data_files()[0]) == str(
original_fragments[1].data_files()[0]
)


def test_flat_vector_search_with_delete(tmp_path: Path):
table = pa.Table.from_pydict(
{
Expand Down Expand Up @@ -1312,7 +1358,10 @@ def test_merge_insert_incompatible_schema(tmp_path: Path):

with pytest.raises(OSError):
merge_dict = (
dataset.merge_insert("a").when_matched_update_all().execute(new_table)
dataset.merge_insert("a")
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(new_table)
)
check_merge_stats(merge_dict, (None, None, None))

Expand Down
29 changes: 28 additions & 1 deletion rust/lance-arrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use arrow_array::{
};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, Schema};
use arrow_select::take::take;
use arrow_select::{interleave::interleave, take::take};
use rand::prelude::*;

pub mod deepcopy;
Expand Down Expand Up @@ -608,6 +608,33 @@ fn get_sub_array<'a>(array: &'a ArrayRef, components: &[&str]) -> Option<&'a Arr
.and_then(|arr| get_sub_array(arr, &components[1..]))
}

/// Interleave multiple RecordBatches into a single RecordBatch.
///
/// Behaves like [`arrow::compute::interleave`], but for RecordBatches.
pub fn interleave_batches(
batches: &[RecordBatch],
indices: &[(usize, usize)],
) -> Result<RecordBatch> {
let first_batch = batches.first().ok_or_else(|| {
ArrowError::InvalidArgumentError("Cannot interleave zero RecordBatches".to_string())
})?;
let schema = first_batch.schema().clone();
let num_columns = first_batch.num_columns();
let mut columns = Vec::with_capacity(num_columns);
let mut chunks = Vec::with_capacity(batches.len());

for i in 0..num_columns {
for batch in batches {
chunks.push(batch.column(i).as_ref());
}
let new_column = interleave(&chunks, indices)?;
columns.push(new_column);
chunks.clear();
}

RecordBatch::try_new(schema, columns)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
4 changes: 2 additions & 2 deletions rust/lance-datafusion/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ impl BatchStreamGrouper {
}

/// Get the output schema of the stream.
pub fn schema(&self) -> Arc<Schema> {
self.input.schema()
pub fn schema(&self) -> &Arc<Schema> {
&self.schema
}

/// Given a record batch, find the distinct ranges of partition values.
Expand Down
26 changes: 15 additions & 11 deletions rust/lance-datafusion/src/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use datafusion::{
dataframe::DataFrame,
datasource::streaming::StreamingTable,
execution::{
context::{SessionConfig, SessionContext, SessionState},
context::{SessionConfig, SessionContext},
disk_manager::DiskManagerConfig,
memory_pool::FairSpillPool,
runtime_env::{RuntimeConfig, RuntimeEnv},
Expand Down Expand Up @@ -195,13 +195,7 @@ impl LanceExecutionOptions {
}
}

/// Executes a plan using default session & runtime configuration
///
/// Only executes a single partition. Panics if the plan has more than one partition.
pub fn execute_plan(
plan: Arc<dyn ExecutionPlan>,
options: LanceExecutionOptions,
) -> Result<SendableRecordBatchStream> {
pub fn new_session_context(options: LanceExecutionOptions) -> SessionContext {
let session_config = SessionConfig::new();
let mut runtime_config = RuntimeConfig::new();
if options.use_spilling() {
Expand All @@ -210,12 +204,22 @@ pub fn execute_plan(
options.mem_pool_size() as usize
)));
}
let runtime_env = Arc::new(RuntimeEnv::new(runtime_config)?);
let session_state = SessionState::new_with_config_rt(session_config, runtime_env);
let runtime_env = Arc::new(RuntimeEnv::new(runtime_config).unwrap());
SessionContext::new_with_config_rt(session_config, runtime_env)
}

/// Executes a plan using default session & runtime configuration
///
/// Only executes a single partition. Panics if the plan has more than one partition.
pub fn execute_plan(
plan: Arc<dyn ExecutionPlan>,
options: LanceExecutionOptions,
) -> Result<SendableRecordBatchStream> {
let session_ctx = new_session_context(options);
// NOTE: we are only executing the first partition here. Therefore, if
// the plan has more than one partition, we will be missing data.
assert_eq!(plan.properties().partitioning.partition_count(), 1);
Ok(plan.execute(0, session_state.task_ctx())?)
Ok(plan.execute(0, session_ctx.task_ctx())?)
}

pub trait SessionContextExt {
Expand Down
47 changes: 28 additions & 19 deletions rust/lance/src/datafusion/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
sync::{Arc, Mutex},
};

use arrow_schema::{DataType, Field, Schema, SchemaRef};
use arrow_schema::{Schema, SchemaRef};
use async_trait::async_trait;
use datafusion::{
dataframe::DataFrame,
Expand All @@ -19,35 +19,36 @@ use datafusion::{
logical_expr::{Expr, TableProviderFilterPushDown, TableType},
physical_plan::{streaming::PartitionStream, ExecutionPlan, SendableRecordBatchStream},
};
use lance_core::ROW_ID;
use lance_arrow::SchemaExt;
use lance_core::{ROW_ADDR_FIELD, ROW_ID_FIELD};

use crate::Dataset;

pub struct LanceTableProvider {
dataset: Arc<Dataset>,
full_schema: Arc<Schema>,
row_id_idx: Option<usize>,
row_addr_idx: Option<usize>,
}

impl LanceTableProvider {
fn new(dataset: Arc<Dataset>, with_row_id: bool) -> Self {
let full_schema = if with_row_id {
let mut full_schema = dataset.schema().clone();
full_schema
.extend(&[Field::new(ROW_ID, DataType::UInt64, false)])
.unwrap();
full_schema
} else {
dataset.schema().clone()
};
fn new(dataset: Arc<Dataset>, with_row_id: bool, with_row_addr: bool) -> Self {
let mut full_schema = Schema::from(dataset.schema());
let mut row_id_idx = None;
let mut row_addr_idx = None;
if with_row_id {
full_schema = full_schema.try_with_column(ROW_ID_FIELD.clone()).unwrap();
row_id_idx = Some(full_schema.fields.len() - 1);
}
if with_row_addr {
full_schema = full_schema.try_with_column(ROW_ADDR_FIELD.clone()).unwrap();
row_addr_idx = Some(full_schema.fields.len() - 1);
}
Self {
dataset,
full_schema: Arc::new(Schema::from(&full_schema)),
row_id_idx: if with_row_id {
Some(full_schema.fields.len() - 1)
} else {
None
},
full_schema: Arc::new(full_schema),
row_id_idx,
row_addr_idx,
}
}
}
Expand Down Expand Up @@ -79,6 +80,8 @@ impl TableProvider for LanceTableProvider {
for field_idx in projection {
if Some(*field_idx) == self.row_id_idx {
scan.with_row_id();
} else if Some(*field_idx) == self.row_addr_idx {
scan.with_row_address();
} else {
columns.push(self.full_schema.field(*field_idx).name());
}
Expand Down Expand Up @@ -126,6 +129,7 @@ pub trait SessionContextExt {
&self,
dataset: Arc<Dataset>,
with_row_id: bool,
with_row_addr: bool,
) -> datafusion::common::Result<DataFrame>;
/// Creates a DataFrame for reading a stream of data
///
Expand Down Expand Up @@ -169,8 +173,13 @@ impl SessionContextExt for SessionContext {
&self,
dataset: Arc<Dataset>,
with_row_id: bool,
with_row_addr: bool,
) -> datafusion::common::Result<DataFrame> {
self.read_table(Arc::new(LanceTableProvider::new(dataset, with_row_id)))
self.read_table(Arc::new(LanceTableProvider::new(
dataset,
with_row_id,
with_row_addr,
)))
}

fn read_one_shot(
Expand Down
15 changes: 13 additions & 2 deletions rust/lance/src/dataset/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1035,11 +1035,22 @@ impl FileFragment {
schemas: Option<(Schema, Schema)>,
) -> Result<Updater> {
let mut schema = self.dataset.schema().clone();

let mut with_row_addr = false;
if let Some(columns) = columns {
schema = schema.project(columns)?;
let mut projection = Vec::new();
for column in columns {
if column.as_ref() == ROW_ADDR {
with_row_addr = true;
} else {
projection.push(column.as_ref());
}
}
schema = schema.project(&projection)?;
}
// If there is no projection, we at least need to read the row addresses
let with_row_addr = schema.fields.is_empty();
with_row_addr |= schema.fields.is_empty();

let reader = self.open(&schema, false, with_row_addr, None);
let deletion_vector = read_deletion_file(
&self.dataset.base,
Expand Down
Loading

0 comments on commit 6ebeaa0

Please sign in to comment.