Skip to content

Commit

Permalink
feat: Add SQL support for RIGHT JOIN, fix an issue with wildcard al…
Browse files Browse the repository at this point in the history
…iasing (pola-rs#19626)
  • Loading branch information
alexander-beedie authored and tylerriccio33 committed Nov 8, 2024
1 parent 5b8c689 commit 5d1a5df
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 29 deletions.
29 changes: 15 additions & 14 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ impl SQLContext {
lf = match &join.join_operator {
op @ (JoinOperator::FullOuter(constraint)
| JoinOperator::LeftOuter(constraint)
| JoinOperator::RightOuter(constraint)
| JoinOperator::Inner(constraint)
| JoinOperator::LeftAnti(constraint)
| JoinOperator::LeftSemi(constraint)
Expand All @@ -585,6 +586,7 @@ impl SQLContext {
match op {
JoinOperator::FullOuter(_) => JoinType::Full,
JoinOperator::LeftOuter(_) => JoinType::Left,
JoinOperator::RightOuter(_) => JoinType::Right,
JoinOperator::Inner(_) => JoinType::Inner,
#[cfg(feature = "semi_anti_join")]
JoinOperator::LeftAnti(_) | JoinOperator::RightAnti(_) => JoinType::Anti,
Expand Down Expand Up @@ -1414,14 +1416,14 @@ fn collect_compound_identifiers(
right_name: &str,
) -> PolarsResult<(Vec<Expr>, Vec<Expr>)> {
if left.len() == 2 && right.len() == 2 {
let (tbl_a, col_a) = (left[0].value.as_str(), left[1].value.as_str());
let (tbl_b, col_b) = (right[0].value.as_str(), right[1].value.as_str());
let (tbl_a, col_name_a) = (left[0].value.as_str(), left[1].value.as_str());
let (tbl_b, col_name_b) = (right[0].value.as_str(), right[1].value.as_str());

// switch left/right operands if the caller has them in reverse
if left_name == tbl_b || right_name == tbl_a {
Ok((vec![col(col_b)], vec![col(col_a)]))
Ok((vec![col(col_name_b)], vec![col(col_name_a)]))
} else {
Ok((vec![col(col_a)], vec![col(col_b)]))
Ok((vec![col(col_name_a)], vec![col(col_name_b)]))
}
} else {
polars_bail!(SQLInterface: "collect_compound_identifiers: Expected left.len() == 2 && right.len() == 2, but found left.len() == {:?}, right.len() == {:?}", left.len(), right.len());
Expand Down Expand Up @@ -1461,14 +1463,13 @@ fn process_join_on(
) -> PolarsResult<(Vec<Expr>, Vec<Expr>)> {
if let SQLExpr::BinaryOp { left, op, right } = expression {
match *op {
BinaryOperator::Eq => {
if let (SQLExpr::CompoundIdentifier(left), SQLExpr::CompoundIdentifier(right)) =
(left.as_ref(), right.as_ref())
{
BinaryOperator::Eq => match (left.as_ref(), right.as_ref()) {
(SQLExpr::CompoundIdentifier(left), SQLExpr::CompoundIdentifier(right)) => {
collect_compound_identifiers(left, right, &tbl_left.name, &tbl_right.name)
} else {
polars_bail!(SQLInterface: "JOIN clauses support '=' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right);
}
},
_ => {
polars_bail!(SQLInterface: "only equi-join constraints (on identifiers) are currently supported; found lhs={:?}, rhs={:?}", left, right);
},
},
BinaryOperator::And => {
let (mut left_i, mut right_i) = process_join_on(left, tbl_left, tbl_right)?;
Expand All @@ -1479,13 +1480,13 @@ fn process_join_on(
Ok((left_i, right_i))
},
_ => {
polars_bail!(SQLInterface: "JOIN clauses support '=' constraints combined with 'AND'; found op = '{:?}'", op);
polars_bail!(SQLInterface: "only equi-join constraints (combined with 'AND') are currently supported; found op = '{:?}'", op);
},
}
} else if let SQLExpr::Nested(expr) = expression {
process_join_on(expr, tbl_left, tbl_right)
} else {
polars_bail!(SQLInterface: "JOIN clauses support '=' constraints combined with 'AND'; found expression = {:?}", expression);
polars_bail!(SQLInterface: "only equi-join constraints (combined with 'AND') are currently supported; found expression = {:?}", expression);
}
}

Expand All @@ -1504,7 +1505,7 @@ fn process_join_constraint(
}
if op != &BinaryOperator::Eq {
polars_bail!(SQLInterface:
"only equi-join constraints are supported; found '{:?}' op in\n{:?}", op, constraint)
"only equi-join constraints are currently supported; found '{:?}' op in\n{:?}", op, constraint)
}
match (left.as_ref(), right.as_ref()) {
(SQLExpr::CompoundIdentifier(left), SQLExpr::CompoundIdentifier(right)) => {
Expand Down
33 changes: 21 additions & 12 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,24 @@ pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr {
}
}

fn resolve_column<'a>(
ctx: &'a mut SQLContext,
ident_root: &'a Ident,
name: &'a str,
dtype: &'a DataType,
) -> PolarsResult<(Expr, Option<&'a DataType>)> {
let resolved = ctx.resolve_name(&ident_root.value, name);
let resolved = resolved.as_str();
Ok((
if name != resolved {
col(resolved).alias(name)
} else {
col(name)
},
Some(dtype),
))
}

pub(crate) fn resolve_compound_identifier(
ctx: &mut SQLContext,
idents: &[Ident],
Expand All @@ -1182,20 +1200,11 @@ pub(crate) fn resolve_compound_identifier(
let name = &remaining_idents.next().unwrap().value;
if lf.is_some() && name == "*" {
return Ok(schema
.iter_names()
.map(|name| col(name.clone()))
.iter_names_and_dtypes()
.map(|(name, dtype)| resolve_column(ctx, ident_root, name, dtype).unwrap().0)
.collect::<Vec<_>>());
} else if let Some((_, name, dtype)) = schema.get_full(name) {
let resolved = ctx.resolve_name(&ident_root.value, name);
let resolved = resolved.as_str();
Ok((
if name != resolved {
col(resolved).alias(name.clone())
} else {
col(name.clone())
},
Some(dtype),
))
resolve_column(ctx, ident_root, name, dtype)
} else if lf.is_none() {
remaining_idents = idents.iter().skip(1);
Ok((
Expand Down
118 changes: 115 additions & 3 deletions py-polars/tests/unit/sql/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from io import BytesIO
from pathlib import Path
from typing import Any

import pytest

Expand Down Expand Up @@ -295,10 +296,11 @@ def test_join_misc_16255() -> None:
)
def test_non_equi_joins(constraint: str) -> None:
# no support (yet) for non equi-joins in polars joins
# TODO: integrate awareness of new IEJoin
with (
pytest.raises(
SQLInterfaceError,
match=r"only equi-join constraints are supported",
match=r"only equi-join constraints are currently supported",
),
pl.SQLContext({"tbl": pl.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2]})}) as ctx,
):
Expand Down Expand Up @@ -335,6 +337,109 @@ def test_implicit_joins() -> None:
)


@pytest.mark.parametrize(
("query", "expected"),
[
# INNER joins
(
"SELECT df1.* FROM df1 INNER JOIN df2 USING (a)",
{"a": [1, 3], "b": ["x", "z"], "c": [100, 300]},
),
(
"SELECT df2.* FROM df1 INNER JOIN df2 USING (a)",
{"a": [1, 3], "b": ["qq", "pp"], "c": [400, 500]},
),
(
"SELECT df1.* FROM df2 INNER JOIN df1 USING (a)",
{"a": [1, 3], "b": ["x", "z"], "c": [100, 300]},
),
(
"SELECT df2.* FROM df2 INNER JOIN df1 USING (a)",
{"a": [1, 3], "b": ["qq", "pp"], "c": [400, 500]},
),
# LEFT joins
(
"SELECT df1.* FROM df1 LEFT JOIN df2 USING (a)",
{"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [100, 200, 300]},
),
(
"SELECT df2.* FROM df1 LEFT JOIN df2 USING (a)",
{"a": [1, 3, None], "b": ["qq", "pp", None], "c": [400, 500, None]},
),
(
"SELECT df1.* FROM df2 LEFT JOIN df1 USING (a)",
{"a": [1, 3, None], "b": ["x", "z", None], "c": [100, 300, None]},
),
(
"SELECT df2.* FROM df2 LEFT JOIN df1 USING (a)",
{"a": [1, 3, 4], "b": ["qq", "pp", "oo"], "c": [400, 500, 600]},
),
# RIGHT joins
(
"SELECT df1.* FROM df1 RIGHT JOIN df2 USING (a)",
{"a": [1, 3, None], "b": ["x", "z", None], "c": [100, 300, None]},
),
(
"SELECT df2.* FROM df1 RIGHT JOIN df2 USING (a)",
{"a": [1, 3, 4], "b": ["qq", "pp", "oo"], "c": [400, 500, 600]},
),
(
"SELECT df1.* FROM df2 RIGHT JOIN df1 USING (a)",
{"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [100, 200, 300]},
),
(
"SELECT df2.* FROM df2 RIGHT JOIN df1 USING (a)",
{"a": [1, 3, None], "b": ["qq", "pp", None], "c": [400, 500, None]},
),
# FULL joins
(
"SELECT df1.* FROM df1 FULL JOIN df2 USING (a)",
{
"a": [1, 2, 3, None],
"b": ["x", "y", "z", None],
"c": [100, 200, 300, None],
},
),
(
"SELECT df2.* FROM df1 FULL JOIN df2 USING (a)",
{
"a": [1, 3, 4, None],
"b": ["qq", "pp", "oo", None],
"c": [400, 500, 600, None],
},
),
(
"SELECT df1.* FROM df2 FULL JOIN df1 USING (a)",
{
"a": [1, 2, 3, None],
"b": ["x", "y", "z", None],
"c": [100, 200, 300, None],
},
),
(
"SELECT df2.* FROM df2 FULL JOIN df1 USING (a)",
{
"a": [1, 3, 4, None],
"b": ["qq", "pp", "oo", None],
"c": [400, 500, 600, None],
},
),
],
)
def test_wildcard_resolution_and_join_order(
query: str, expected: dict[str, Any]
) -> None:
df1 = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [100, 200, 300]}) # noqa: F841
df2 = pl.DataFrame({"a": [1, 3, 4], "b": ["qq", "pp", "oo"], "c": [400, 500, 600]}) # noqa: F841

res = pl.sql(query).collect()
assert_frame_equal(
res,
pl.DataFrame(expected),
check_row_order=False,
)


def test_natural_joins_01() -> None:
df1 = pl.DataFrame(
{
Expand Down Expand Up @@ -481,8 +586,15 @@ def test_natural_joins_02(cols_constraint: str, expect_data: list[tuple[int]]) -
@pytest.mark.parametrize(
"join_clause",
[
"df2 INNER JOIN df3 ON df2.CharacterID=df3.CharacterID",
"df2 INNER JOIN (df3 INNDER JOIN df4 ON df3.CharacterID=df4.CharacterID) ON df2.CharacterID=df3.CharacterID",
"""
df2 JOIN df3 ON
df2.CharacterID = df3.CharacterID
""",
"""
df2 INNER JOIN (
df3 JOIN df4 ON df3.CharacterID = df4.CharacterID
) ON df2.CharacterID = df3.CharacterID
""",
],
)
def test_nested_join(join_clause: str) -> None:
Expand Down

0 comments on commit 5d1a5df

Please sign in to comment.