Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions datafusion/common/src/join_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ pub enum JoinType {
/// [1]. This join type is used to decorrelate EXISTS subqueries used inside disjunctive
/// predicates.
///
/// Note: This we currently do not implement the full null semantics for the mark join described
/// in [1] which will be needed if we and ANY subqueries. In our version the mark column will
/// only be true for had a match and false when no match was found, never null.
/// For scalar `NOT IN`, DataFusion can plan a null-aware hash mark join where the
/// mark column is nullable: TRUE for a match, NULL for SQL UNKNOWN, and FALSE
/// otherwise. Row-valued multi-column `NOT IN` and non-hash residual predicate
/// null-aware mark semantics are not implemented.
///
/// [1]: http://btw2017.informatik.uni-stuttgart.de/slidesandpapers/F1-10-37/paper_web.pdf
LeftMark,
Expand Down
95 changes: 73 additions & 22 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1626,6 +1626,12 @@ impl DefaultPhysicalPlanner {

let prefer_hash_join =
session_state.config_options().optimizer.prefer_hash_join;
// Null-aware joins are pinned to CollectLeft hash joins (see
// `HashJoinExec::null_aware`): never repartition them, and
// never route them to the sort-merge path below.
let can_repartition_join = session_state.config().target_partitions() > 1
&& session_state.config().repartition_joins()
&& !*null_aware;

// TODO: Allow PWMJ to deal with residual equijoin conditions
let join: Arc<dyn ExecutionPlan> = if join_on.is_empty() {
Expand Down Expand Up @@ -1754,10 +1760,7 @@ impl DefaultPhysicalPlanner {
None,
)?)
}
} else if session_state.config().target_partitions() > 1
&& session_state.config().repartition_joins()
&& !prefer_hash_join
{
} else if can_repartition_join && !prefer_hash_join {
// Use SortMergeJoin if hash join is not preferred
let join_on_len = join_on.len();
Arc::new(SortMergeJoinExec::try_new(
Expand All @@ -1769,32 +1772,23 @@ impl DefaultPhysicalPlanner {
vec![SortOptions::default(); join_on_len],
*null_equality,
)?)
} else if session_state.config().target_partitions() > 1
&& session_state.config().repartition_joins()
&& prefer_hash_join
&& !*null_aware
// Null-aware joins must use CollectLeft
{
Arc::new(HashJoinExec::try_new(
physical_left,
physical_right,
join_on,
join_filter,
join_type,
None,
PartitionMode::Auto,
*null_equality,
*null_aware,
)?)
} else {
// Null-aware joins need global probe-side state, so keep
// them in CollectLeft mode.
let partition_mode = if can_repartition_join {
PartitionMode::Auto
} else {
PartitionMode::CollectLeft
};

Arc::new(HashJoinExec::try_new(
physical_left,
physical_right,
join_on,
join_filter,
join_type,
None,
PartitionMode::CollectLeft,
partition_mode,
*null_equality,
*null_aware,
)?)
Expand Down Expand Up @@ -3605,6 +3599,63 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn correlated_not_in_is_null_uses_null_aware_hash_mark_join() -> Result<()> {
let query = "
SELECT value
FROM (
VALUES
(1, 1, 'a'),
(3, 1, 'b'),
(1, 2, 'c'),
(NULL, 1, 'd'),
(5, 3, 'e'),
(2, 1, 'f'),
(NULL, 2, 'g')
) AS outer_corr_table(id, grp, value)
WHERE (id NOT IN (
SELECT id
FROM (
VALUES
(2, 1),
(NULL, 1),
(1, 2)
) AS inner_corr_table(id, grp)
WHERE inner_corr_table.grp = outer_corr_table.grp
)) IS NULL
ORDER BY value";

let config = SessionConfig::new()
.with_target_partitions(4)
.set_bool("datafusion.optimizer.prefer_hash_join", false);
let ctx = SessionContext::new_with_config(config);

let plan = ctx.sql(query).await?.create_physical_plan().await?;
let formatted = displayable(plan.as_ref()).indent(true).to_string();
assert_contains!(
&formatted,
"HashJoinExec: mode=CollectLeft, join_type=LeftMark"
);
assert!(!formatted.contains("SortMergeJoinExec"), "{formatted}");

let batches = ctx.sql(query).await?.collect().await?;
assert_batches_eq!(
&[
"+-------+",
"| value |",
"+-------+",
"| a |",
"| b |",
"| d |",
"| g |",
"+-------+",
],
&batches
);

Ok(())
}

#[tokio::test]
async fn scalar_subquery_in_projection_and_filter_plans() -> Result<()> {
let plan = plan_sql(
Expand Down
27 changes: 18 additions & 9 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1154,8 +1154,12 @@ impl LogicalPlanBuilder {
.zip(right_keys)
.map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
.collect();
let join_schema =
build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
let join_schema = build_join_schema(
self.plan.schema(),
right.schema(),
&join_type,
null_aware,
)?;

// Inner type without join condition is cross join
if join_type != JoinType::Inner && on.is_empty() && filter.is_none() {
Expand Down Expand Up @@ -1652,7 +1656,7 @@ pub fn unique_field_aliases(fields: &Fields) -> Vec<Option<String>> {
.collect()
}

fn mark_field(schema: &DFSchema) -> (Option<TableReference>, Arc<Field>) {
fn mark_field(schema: &DFSchema, nullable: bool) -> (Option<TableReference>, Arc<Field>) {
let mut table_references = schema
.iter()
.filter_map(|(qualifier, _)| qualifier)
Expand All @@ -1666,16 +1670,21 @@ fn mark_field(schema: &DFSchema) -> (Option<TableReference>, Arc<Field>) {

(
table_reference,
Arc::new(Field::new("mark", DataType::Boolean, false)),
Arc::new(Field::new("mark", DataType::Boolean, nullable)),
)
}

/// Creates a schema for a join operation.
/// The fields from the left side are first
/// The fields from the left side are first.
///
/// When `null_aware` is set, the `LeftMark`/`RightMark` `mark` column is made
/// nullable so it can represent SQL UNKNOWN for null-aware `NOT IN` semantics.
/// `null_aware` has no effect on non-mark join types.
pub fn build_join_schema(
left: &DFSchema,
right: &DFSchema,
join_type: &JoinType,
null_aware: bool,
) -> Result<DFSchema> {
fn nullify_fields<'a>(
fields: impl Iterator<Item = (Option<&'a TableReference>, &'a Arc<Field>)>,
Expand Down Expand Up @@ -1738,7 +1747,7 @@ pub fn build_join_schema(
}
JoinType::LeftMark => left_fields
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.chain(once(mark_field(right)))
.chain(once(mark_field(right, null_aware)))
.collect(),
JoinType::RightSemi | JoinType::RightAnti => {
// Only use the right side for the schema
Expand All @@ -1748,7 +1757,7 @@ pub fn build_join_schema(
}
JoinType::RightMark => right_fields
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.chain(once(mark_field(left)))
.chain(once(mark_field(left, null_aware)))
.collect(),
};
let func_dependencies = left.functional_dependencies().join(
Expand Down Expand Up @@ -2912,13 +2921,13 @@ mod tests {
)?;

let join_schema =
build_join_schema(&left_schema, &right_schema, &JoinType::Left)?;
build_join_schema(&left_schema, &right_schema, &JoinType::Left, false)?;
assert_eq!(
join_schema.metadata(),
&HashMap::from([("key".to_string(), "left".to_string())])
);
let join_schema =
build_join_schema(&left_schema, &right_schema, &JoinType::Right)?;
build_join_schema(&left_schema, &right_schema, &JoinType::Right, false)?;
assert_eq!(
join_schema.metadata(),
&HashMap::from([("key".to_string(), "right".to_string())])
Expand Down
31 changes: 21 additions & 10 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,12 @@ impl LogicalPlan {
null_equality,
null_aware,
}) => {
let schema =
build_join_schema(left.schema(), right.schema(), &join_type)?;
let schema = build_join_schema(
left.schema(),
right.schema(),
&join_type,
null_aware,
)?;

let new_on: Vec<_> = on
.into_iter()
Expand Down Expand Up @@ -944,7 +948,12 @@ impl LogicalPlan {
..
}) => {
let (left, right) = self.only_two_inputs(inputs)?;
let schema = build_join_schema(left.schema(), right.schema(), join_type)?;
let schema = build_join_schema(
left.schema(),
right.schema(),
join_type,
*null_aware,
)?;

let equi_expr_count = on.len() * 2;
assert!(expr.len() >= equi_expr_count);
Expand Down Expand Up @@ -4228,13 +4237,13 @@ pub struct Join {
pub schema: DFSchemaRef,
/// Defines the null equality for the join.
pub null_equality: NullEquality,
/// Whether this is a null-aware anti join (for NOT IN semantics).
/// Whether this join needs null-aware NOT IN semantics.
///
/// Only applies to LeftAnti joins. When true, implements SQL NOT IN semantics where:
/// - If the right side (subquery) contains any NULL in join keys, no rows are output
/// - Left side rows with NULL in join keys are not output
/// For `LeftAnti`, if the right side contains any NULL in join keys, no rows are output and
/// left rows with NULL join keys are also excluded.
///
/// This is required for correct NOT IN subquery behavior with three-valued logic.
/// For `LeftMark`, the generated `mark` column becomes nullable so unmatched rows can produce
/// `NULL` rather than `false` when SQL three-valued logic requires it.
pub null_aware: bool,
}

Expand All @@ -4253,7 +4262,7 @@ impl Join {
/// * `join_type` - Type of join (Inner, Left, Right, etc.)
/// * `join_constraint` - Join constraint (On, Using)
/// * `null_equality` - How to handle nulls in join comparisons
/// * `null_aware` - Whether this is a null-aware anti join (for NOT IN semantics)
/// * `null_aware` - Whether this join needs null-aware NOT IN semantics
///
/// # Returns
///
Expand All @@ -4269,7 +4278,8 @@ impl Join {
null_equality: NullEquality,
null_aware: bool,
) -> Result<Self> {
let join_schema = build_join_schema(left.schema(), right.schema(), &join_type)?;
let join_schema =
build_join_schema(left.schema(), right.schema(), &join_type, null_aware)?;

Ok(Join {
left,
Expand Down Expand Up @@ -4324,6 +4334,7 @@ impl Join {
left_sch.schema(),
right_sch.schema(),
&original_join.join_type,
original_join.null_aware,
)?;

Ok((
Expand Down
Loading
Loading