Skip to content

Commit e8d217a

Browse files
authored
perf: use DynComparator in sort-merge join (SMJ), microbenchmark queries up to 12% faster, TPC-H overall ~5% faster (#21484)
## Which issue does this PR close? Partially addresses #20910. ## Rationale for this change Sort merge join comparisons (`compare_join_arrays`, `is_join_arrays_equal`) do a `match DataType` + `downcast_ref` on every call, per column. These are called per-row in hot join loops across SMJ, semi/anti/mark SMJ, and piecewise merge join. `arrow_ord::ord::make_comparator` does the type dispatch once at construction and returns a `DynComparator` closure that goes straight to typed value comparison. Arrow's own `LexicographicalComparator` uses this pattern for sorting — we should use it for joins too. ## What changes are included in this PR? Adds `JoinKeyComparator` to `joins/utils.rs`: a thin wrapper around `Vec<DynComparator>` built once per batch pair. Null handling (`NullEqualsNothing` both-null -> `Less` override) is baked into the closures at construction time so `compare()` is a branchless loop. Integrated into all hot-path call sites: - `materializing_stream.rs`: `streamed_buffered_cmp` (streamed vs buffered) and `buffered_equality_cmp` (head vs tail equality) - `bitwise_stream.rs`: `outer_inner_cmp`, `outer_self_cmp`, `inner_self_cmp`; simplified `find_key_group_end` signature (takes `&JoinKeyComparator`, returns `usize` instead of `Result<usize>` since type errors are now caught at construction) - `piecewise_merge_join/classic_join.rs`: single comparator built per batch pair `compare_join_arrays` is kept for the one-off `keys_match` call (once per batch boundary). Deleted `is_join_arrays_equal` (75-line per-row type dispatch function), replaced by `JoinKeyComparator::is_equal`. ## Are these changes tested? - 4 unit tests for `JoinKeyComparator`: multi-column mixed types, `NullEqualsNull`, `NullEqualsNothing`, `nulls_first` ordering - Existing SMJ test suites pass - Existing sqllogictest join tests pass ## Are there any user-facing changes? No.
1 parent 44af0a1 commit e8d217a

File tree

5 files changed

+393
-181
lines changed

5 files changed

+393
-181
lines changed

datafusion/core/tests/memory_limit/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ async fn sort_merge_join_spill() {
213213
.with_config(config)
214214
.with_disk_manager_builder(DiskManagerBuilder::default())
215215
.with_scenario(Scenario::AccessLogStreaming)
216+
.with_expected_success()
216217
.run()
217218
.await
218219
}

datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use crate::handle_state;
3838
use crate::joins::piecewise_merge_join::exec::{BufferedSide, BufferedSideReadyState};
3939
use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final;
4040
use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult};
41-
use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap};
41+
use crate::joins::utils::{JoinKeyComparator, get_final_indices_from_shared_bitmap};
4242

4343
pub(super) enum PiecewiseMergeJoinStreamState {
4444
WaitBufferedSide,
@@ -460,6 +460,14 @@ fn resolve_classic_join(
460460
let buffered_len = buffered_side.buffered_data.values().len();
461461
let stream_values = stream_batch.compare_key_values();
462462

463+
// Build comparator once for the batch pair
464+
let cmp = JoinKeyComparator::new(
465+
&[Arc::clone(&stream_values[0])],
466+
&[Arc::clone(buffered_side.buffered_data.values())],
467+
&[sort_options],
468+
NullEquality::NullEqualsNothing,
469+
)?;
470+
463471
let mut buffer_idx = batch_process_state.start_buffer_idx;
464472
let mut stream_idx = batch_process_state.start_stream_idx;
465473

@@ -475,17 +483,7 @@ fn resolve_classic_join(
475483
// in the previous stream row.
476484
for row_idx in stream_idx..stream_batch.batch.num_rows() {
477485
while buffer_idx < buffered_len {
478-
let compare = {
479-
let buffered_values = buffered_side.buffered_data.values();
480-
compare_join_arrays(
481-
&[Arc::clone(&stream_values[0])],
482-
row_idx,
483-
&[Arc::clone(buffered_values)],
484-
buffer_idx,
485-
&[sort_options],
486-
NullEquality::NullEqualsNothing,
487-
)?
488-
};
486+
let compare = cmp.compare(row_idx, buffer_idx);
489487

490488
// If we find a match we append all indices and move to the next stream row index
491489
match operator {

datafusion/physical-plan/src/joins/sort_merge_join/bitwise_stream.rs

Lines changed: 94 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ use std::sync::Arc;
126126
use std::task::{Context, Poll};
127127

128128
use crate::RecordBatchStream;
129-
use crate::joins::utils::{JoinFilter, compare_join_arrays};
129+
use crate::joins::utils::{JoinFilter, JoinKeyComparator, compare_join_arrays};
130130
use crate::metrics::{
131131
BaselineMetrics, Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder,
132132
};
@@ -162,70 +162,40 @@ fn evaluate_join_keys(
162162
}
163163

164164
/// Find the first index in `key_arrays` starting from `from` where the key
165-
/// differs from the key at `from`. Uses `compare_join_arrays` for zero-alloc
166-
/// ordinal comparison.
165+
/// differs from the key at `from`. Uses a pre-built `JoinKeyComparator` for
166+
/// zero-alloc ordinal comparison without per-row type dispatch.
167167
///
168168
/// Optimized for join workloads: checks adjacent and boundary keys before
169169
/// falling back to binary search, since most key groups are small (often 1).
170-
fn find_key_group_end(
171-
key_arrays: &[ArrayRef],
172-
from: usize,
173-
len: usize,
174-
sort_options: &[SortOptions],
175-
null_equality: NullEquality,
176-
) -> Result<usize> {
170+
fn find_key_group_end(cmp: &JoinKeyComparator, from: usize, len: usize) -> usize {
177171
let next = from + 1;
178172
if next >= len {
179-
return Ok(len);
173+
return len;
180174
}
181175

182176
// Fast path: single-row group (common with unique keys).
183-
if compare_join_arrays(
184-
key_arrays,
185-
from,
186-
key_arrays,
187-
next,
188-
sort_options,
189-
null_equality,
190-
)? != Ordering::Equal
191-
{
192-
return Ok(next);
177+
if cmp.compare(from, next) != Ordering::Equal {
178+
return next;
193179
}
194180

195181
// Check if the entire remaining batch shares this key.
196182
let last = len - 1;
197-
if compare_join_arrays(
198-
key_arrays,
199-
from,
200-
key_arrays,
201-
last,
202-
sort_options,
203-
null_equality,
204-
)? == Ordering::Equal
205-
{
206-
return Ok(len);
183+
if cmp.compare(from, last) == Ordering::Equal {
184+
return len;
207185
}
208186

209187
// Binary search the interior: key at `next` matches, key at `last` doesn't.
210188
let mut lo = next + 1;
211189
let mut hi = last;
212190
while lo < hi {
213191
let mid = lo + (hi - lo) / 2;
214-
if compare_join_arrays(
215-
key_arrays,
216-
from,
217-
key_arrays,
218-
mid,
219-
sort_options,
220-
null_equality,
221-
)? == Ordering::Equal
222-
{
192+
if cmp.compare(from, mid) == Ordering::Equal {
223193
lo = mid + 1;
224194
} else {
225195
hi = mid;
226196
}
227197
}
228-
Ok(lo)
198+
lo
229199
}
230200

231201
/// When an outer key group spans a batch boundary, the boundary loop emits
@@ -328,6 +298,14 @@ pub(crate) struct BitwiseSortMergeJoinStream {
328298
runtime_env: Arc<datafusion_execution::runtime_env::RuntimeEnv>,
329299
inner_buffer_size: usize,
330300

301+
// Cached comparators — pre-built to avoid per-row type dispatch.
302+
/// Comparator for outer vs inner key comparison
303+
outer_inner_cmp: Option<JoinKeyComparator>,
304+
/// Comparator for outer self-comparison (find_key_group_end on outer)
305+
outer_self_cmp: Option<JoinKeyComparator>,
306+
/// Comparator for inner self-comparison (find_key_group_end on inner)
307+
inner_self_cmp: Option<JoinKeyComparator>,
308+
331309
// True once the current outer batch has been emitted. The Equal
332310
// branch's inner loops call emit then `ready!(poll_next_outer_batch)`.
333311
// If that poll returns Pending, poll_join re-enters from the top
@@ -413,6 +391,9 @@ impl BitwiseSortMergeJoinStream {
413391
spill_manager,
414392
runtime_env,
415393
inner_buffer_size: 0,
394+
outer_inner_cmp: None,
395+
outer_self_cmp: None,
396+
inner_self_cmp: None,
416397
batch_emitted: false,
417398
})
418399
}
@@ -425,6 +406,45 @@ impl BitwiseSortMergeJoinStream {
425406
Ok(())
426407
}
427408

409+
/// Get or build the outer vs inner key comparator.
410+
fn get_outer_inner_cmp(&mut self) -> Result<&JoinKeyComparator> {
411+
if self.outer_inner_cmp.is_none() {
412+
self.outer_inner_cmp = Some(JoinKeyComparator::new(
413+
&self.outer_key_arrays,
414+
&self.inner_key_arrays,
415+
&self.sort_options,
416+
self.null_equality,
417+
)?);
418+
}
419+
Ok(self.outer_inner_cmp.as_ref().unwrap())
420+
}
421+
422+
/// Get or build the outer self-comparison comparator.
423+
fn get_outer_self_cmp(&mut self) -> Result<&JoinKeyComparator> {
424+
if self.outer_self_cmp.is_none() {
425+
self.outer_self_cmp = Some(JoinKeyComparator::new(
426+
&self.outer_key_arrays,
427+
&self.outer_key_arrays,
428+
&self.sort_options,
429+
self.null_equality,
430+
)?);
431+
}
432+
Ok(self.outer_self_cmp.as_ref().unwrap())
433+
}
434+
435+
/// Get or build the inner self-comparison comparator.
436+
fn get_inner_self_cmp(&mut self) -> Result<&JoinKeyComparator> {
437+
if self.inner_self_cmp.is_none() {
438+
self.inner_self_cmp = Some(JoinKeyComparator::new(
439+
&self.inner_key_arrays,
440+
&self.inner_key_arrays,
441+
&self.sort_options,
442+
self.null_equality,
443+
)?);
444+
}
445+
Ok(self.inner_self_cmp.as_ref().unwrap())
446+
}
447+
428448
/// Spill the in-memory inner key buffer to disk and clear it.
429449
fn spill_inner_key_buffer(&mut self) -> Result<()> {
430450
let spill_file = self
@@ -468,6 +488,8 @@ impl BitwiseSortMergeJoinStream {
468488
self.outer_batch = Some(batch);
469489
self.outer_offset = 0;
470490
self.outer_key_arrays = keys;
491+
self.outer_inner_cmp = None;
492+
self.outer_self_cmp = None;
471493
self.batch_emitted = false;
472494
self.matched = BooleanBufferBuilder::new(batch_num_rows);
473495
self.matched.append_n(batch_num_rows, false);
@@ -494,6 +516,8 @@ impl BitwiseSortMergeJoinStream {
494516
self.inner_batch = Some(batch);
495517
self.inner_offset = 0;
496518
self.inner_key_arrays = keys;
519+
self.outer_inner_cmp = None;
520+
self.inner_self_cmp = None;
497521
return Poll::Ready(Ok(true));
498522
}
499523
}
@@ -555,13 +579,12 @@ impl BitwiseSortMergeJoinStream {
555579
let outer_batch = self.outer_batch.as_ref().unwrap();
556580
let num_outer = outer_batch.num_rows();
557581

582+
self.get_outer_self_cmp()?;
558583
let outer_group_end = find_key_group_end(
559-
&self.outer_key_arrays,
584+
self.outer_self_cmp.as_ref().unwrap(),
560585
self.outer_offset,
561586
num_outer,
562-
&self.sort_options,
563-
self.null_equality,
564-
)?;
587+
);
565588

566589
for i in self.outer_offset..outer_group_end {
567590
self.matched.set_bit(i, true);
@@ -584,13 +607,12 @@ impl BitwiseSortMergeJoinStream {
584607
};
585608
let num_inner = inner_batch.num_rows();
586609

610+
self.get_inner_self_cmp()?;
587611
let group_end = find_key_group_end(
588-
&self.inner_key_arrays,
612+
self.inner_self_cmp.as_ref().unwrap(),
589613
self.inner_offset,
590614
num_inner,
591-
&self.sort_options,
592-
self.null_equality,
593-
)?;
615+
);
594616

595617
if group_end < num_inner {
596618
self.inner_offset = group_end;
@@ -642,20 +664,19 @@ impl BitwiseSortMergeJoinStream {
642664
}
643665

644666
loop {
645-
let inner_batch = match &self.inner_batch {
646-
Some(b) => b,
647-
None => return Poll::Ready(Ok(true)),
648-
};
649-
let num_inner = inner_batch.num_rows();
667+
if self.inner_batch.is_none() {
668+
return Poll::Ready(Ok(true));
669+
}
670+
let num_inner = self.inner_batch.as_ref().unwrap().num_rows();
671+
self.get_inner_self_cmp()?;
650672
let group_end = find_key_group_end(
651-
&self.inner_key_arrays,
673+
self.inner_self_cmp.as_ref().unwrap(),
652674
self.inner_offset,
653675
num_inner,
654-
&self.sort_options,
655-
self.null_equality,
656-
)?;
676+
);
657677

658678
if !resume_from_poll {
679+
let inner_batch = self.inner_batch.as_ref().unwrap();
659680
let slice =
660681
inner_batch.slice(self.inner_offset, group_end - self.inner_offset);
661682
self.inner_buffer_size += slice.get_array_memory_size();
@@ -719,6 +740,7 @@ impl BitwiseSortMergeJoinStream {
719740
/// key group, evaluates the filter against the outer key group and ORs
720741
/// the results into the matched bitset using u64-chunked bitwise ops.
721742
fn process_key_match_with_filter(&mut self) -> Result<()> {
743+
self.get_outer_self_cmp()?;
722744
let filter = self.filter.as_ref().unwrap();
723745
let outer_batch = self.outer_batch.as_ref().unwrap();
724746
let num_outer = outer_batch.num_rows();
@@ -738,12 +760,10 @@ impl BitwiseSortMergeJoinStream {
738760
);
739761

740762
let outer_group_end = find_key_group_end(
741-
&self.outer_key_arrays,
763+
self.outer_self_cmp.as_ref().unwrap(),
742764
self.outer_offset,
743765
num_outer,
744-
&self.sort_options,
745-
self.null_equality,
746-
)?;
766+
);
747767
let outer_group_len = outer_group_end - self.outer_offset;
748768
let outer_slice = outer_batch.slice(self.outer_offset, outer_group_len);
749769

@@ -959,34 +979,30 @@ impl BitwiseSortMergeJoinStream {
959979
}
960980

961981
// 4. Compare keys at current positions
962-
let cmp = compare_join_arrays(
963-
&self.outer_key_arrays,
964-
self.outer_offset,
965-
&self.inner_key_arrays,
966-
self.inner_offset,
967-
&self.sort_options,
968-
self.null_equality,
969-
)?;
982+
self.get_outer_inner_cmp()?;
983+
let cmp = self
984+
.outer_inner_cmp
985+
.as_ref()
986+
.unwrap()
987+
.compare(self.outer_offset, self.inner_offset);
970988

971989
match cmp {
972990
Ordering::Less => {
991+
self.get_outer_self_cmp()?;
973992
let group_end = find_key_group_end(
974-
&self.outer_key_arrays,
993+
self.outer_self_cmp.as_ref().unwrap(),
975994
self.outer_offset,
976995
num_outer,
977-
&self.sort_options,
978-
self.null_equality,
979-
)?;
996+
);
980997
self.outer_offset = group_end;
981998
}
982999
Ordering::Greater => {
1000+
self.get_inner_self_cmp()?;
9831001
let group_end = find_key_group_end(
984-
&self.inner_key_arrays,
1002+
self.inner_self_cmp.as_ref().unwrap(),
9851003
self.inner_offset,
9861004
num_inner,
987-
&self.sort_options,
988-
self.null_equality,
989-
)?;
1005+
);
9901006
if group_end >= num_inner {
9911007
let saved_keys =
9921008
slice_keys(&self.inner_key_arrays, num_inner - 1);

0 commit comments

Comments
 (0)