Skip to content

Commit 3224e0c

Browse files
compheadgruuya
andauthored
[branch-53] fix: use spill writer's schema instead of the first batch schema for … (#21451)
…spill files (cherry picked from commit e133dd3) ## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> Co-authored-by: Marko Grujic <markoog@gmail.com>
1 parent d24faa0 commit 3224e0c

File tree

3 files changed

+242
-1
lines changed

3 files changed

+242
-1
lines changed

datafusion/core/tests/memory_limit/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use std::sync::{Arc, LazyLock};
2424
#[cfg(feature = "extended_tests")]
2525
mod memory_limit_validation;
2626
mod repartition_mem_limit;
27+
mod union_nullable_spill;
2728
use arrow::array::{ArrayRef, DictionaryArray, Int32Array, RecordBatch, StringViewArray};
2829
use arrow::compute::SortOptions;
2930
use arrow::datatypes::{Int32Type, SchemaRef};
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::array::{Array, Int64Array, RecordBatch};
21+
use arrow::compute::SortOptions;
22+
use arrow::datatypes::{DataType, Field, Schema};
23+
use datafusion::datasource::memory::MemorySourceConfig;
24+
use datafusion_execution::config::SessionConfig;
25+
use datafusion_execution::memory_pool::FairSpillPool;
26+
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
27+
use datafusion_physical_expr::expressions::col;
28+
use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr};
29+
use datafusion_physical_plan::repartition::RepartitionExec;
30+
use datafusion_physical_plan::sorts::sort::sort_batch;
31+
use datafusion_physical_plan::union::UnionExec;
32+
use datafusion_physical_plan::{ExecutionPlan, Partitioning};
33+
use futures::StreamExt;
34+
35+
const NUM_BATCHES: usize = 200;
36+
const ROWS_PER_BATCH: usize = 10;
37+
38+
fn non_nullable_schema() -> Arc<Schema> {
39+
Arc::new(Schema::new(vec![
40+
Field::new("key", DataType::Int64, false),
41+
Field::new("val", DataType::Int64, false),
42+
]))
43+
}
44+
45+
fn nullable_schema() -> Arc<Schema> {
46+
Arc::new(Schema::new(vec![
47+
Field::new("key", DataType::Int64, false),
48+
Field::new("val", DataType::Int64, true),
49+
]))
50+
}
51+
52+
fn non_nullable_batches() -> Vec<RecordBatch> {
53+
(0..NUM_BATCHES)
54+
.map(|i| {
55+
let start = (i * ROWS_PER_BATCH) as i64;
56+
let keys: Vec<i64> = (start..start + ROWS_PER_BATCH as i64).collect();
57+
RecordBatch::try_new(
58+
non_nullable_schema(),
59+
vec![
60+
Arc::new(Int64Array::from(keys)),
61+
Arc::new(Int64Array::from(vec![0i64; ROWS_PER_BATCH])),
62+
],
63+
)
64+
.unwrap()
65+
})
66+
.collect()
67+
}
68+
69+
fn nullable_batches() -> Vec<RecordBatch> {
70+
(0..NUM_BATCHES)
71+
.map(|i| {
72+
let start = (i * ROWS_PER_BATCH) as i64;
73+
let keys: Vec<i64> = (start..start + ROWS_PER_BATCH as i64).collect();
74+
let vals: Vec<Option<i64>> = (0..ROWS_PER_BATCH)
75+
.map(|j| if j % 3 == 1 { None } else { Some(j as i64) })
76+
.collect();
77+
RecordBatch::try_new(
78+
nullable_schema(),
79+
vec![
80+
Arc::new(Int64Array::from(keys)),
81+
Arc::new(Int64Array::from(vals)),
82+
],
83+
)
84+
.unwrap()
85+
})
86+
.collect()
87+
}
88+
89+
fn build_task_ctx(pool_size: usize) -> Arc<datafusion_execution::TaskContext> {
90+
let session_config = SessionConfig::new().with_batch_size(2);
91+
let runtime = RuntimeEnvBuilder::new()
92+
.with_memory_pool(Arc::new(FairSpillPool::new(pool_size)))
93+
.build_arc()
94+
.unwrap();
95+
Arc::new(
96+
datafusion_execution::TaskContext::default()
97+
.with_session_config(session_config)
98+
.with_runtime(runtime),
99+
)
100+
}
101+
102+
/// Exercises spilling through UnionExec -> RepartitionExec where union children
103+
/// have mismatched nullability (one child's `val` is non-nullable, the other's
104+
/// is nullable with NULLs). A tiny FairSpillPool forces all batches to spill.
105+
///
106+
/// UnionExec returns child streams without schema coercion, so batches from
107+
/// different children carry different per-field nullability into the shared
108+
/// SpillPool. The IPC writer must use the SpillManager's canonical (nullable)
109+
/// schema — not the first batch's schema — so readback batches are valid.
110+
///
111+
/// Otherwise, sort_batch will panic with
112+
/// `Column 'val' is declared as non-nullable but contains null values`
113+
#[tokio::test]
114+
async fn test_sort_union_repartition_spill_mixed_nullability() {
115+
let non_nullable_exec = MemorySourceConfig::try_new_exec(
116+
&[non_nullable_batches()],
117+
non_nullable_schema(),
118+
None,
119+
)
120+
.unwrap();
121+
122+
let nullable_exec =
123+
MemorySourceConfig::try_new_exec(&[nullable_batches()], nullable_schema(), None)
124+
.unwrap();
125+
126+
let union_exec = UnionExec::try_new(vec![non_nullable_exec, nullable_exec]).unwrap();
127+
assert!(union_exec.schema().field(1).is_nullable());
128+
129+
let repartition = Arc::new(
130+
RepartitionExec::try_new(union_exec, Partitioning::RoundRobinBatch(1)).unwrap(),
131+
);
132+
133+
let task_ctx = build_task_ctx(200);
134+
let mut stream = repartition.execute(0, task_ctx).unwrap();
135+
136+
let sort_expr = LexOrdering::new(vec![PhysicalSortExpr {
137+
expr: col("key", &nullable_schema()).unwrap(),
138+
options: SortOptions::default(),
139+
}])
140+
.unwrap();
141+
142+
let mut total_rows = 0usize;
143+
let mut total_nulls = 0usize;
144+
while let Some(result) = stream.next().await {
145+
let batch = result.unwrap();
146+
147+
let batch = sort_batch(&batch, &sort_expr, None).unwrap();
148+
149+
total_rows += batch.num_rows();
150+
total_nulls += batch.column(1).null_count();
151+
}
152+
153+
assert_eq!(
154+
total_rows,
155+
NUM_BATCHES * ROWS_PER_BATCH * 2,
156+
"All rows from both UNION branches should be present"
157+
);
158+
assert!(
159+
total_nulls > 0,
160+
"Expected some null values in output (i.e. nullable batches were processed)"
161+
);
162+
}

datafusion/physical-plan/src/spill/in_progress_spill_file.rs

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ impl InProgressSpillFile {
6262
));
6363
}
6464
if self.writer.is_none() {
65-
let schema = batch.schema();
65+
// Use the SpillManager's declared schema rather than the batch's schema.
66+
// Individual batches may have different schemas (e.g., different nullability)
67+
// when they come from different branches of a UnionExec. The SpillManager's
68+
// schema represents the canonical schema that all batches should conform to.
69+
let schema = self.spill_writer.schema();
6670
if let Some(in_progress_file) = &mut self.in_progress_file {
6771
self.writer = Some(IPCStreamWriter::new(
6872
in_progress_file.path(),
@@ -138,3 +142,77 @@ impl InProgressSpillFile {
138142
Ok(self.in_progress_file.take())
139143
}
140144
}
145+
146+
#[cfg(test)]
147+
mod tests {
148+
use super::*;
149+
use arrow::array::Int64Array;
150+
use arrow_schema::{DataType, Field, Schema};
151+
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
152+
use datafusion_physical_expr_common::metrics::{
153+
ExecutionPlanMetricsSet, SpillMetrics,
154+
};
155+
use futures::TryStreamExt;
156+
157+
#[tokio::test]
158+
async fn test_spill_file_uses_spill_manager_schema() -> Result<()> {
159+
let nullable_schema = Arc::new(Schema::new(vec![
160+
Field::new("key", DataType::Int64, false),
161+
Field::new("val", DataType::Int64, true),
162+
]));
163+
let non_nullable_schema = Arc::new(Schema::new(vec![
164+
Field::new("key", DataType::Int64, false),
165+
Field::new("val", DataType::Int64, false),
166+
]));
167+
168+
let runtime = Arc::new(RuntimeEnvBuilder::new().build()?);
169+
let metrics_set = ExecutionPlanMetricsSet::new();
170+
let spill_metrics = SpillMetrics::new(&metrics_set, 0);
171+
let spill_manager = Arc::new(SpillManager::new(
172+
runtime,
173+
spill_metrics,
174+
Arc::clone(&nullable_schema),
175+
));
176+
177+
let mut in_progress = spill_manager.create_in_progress_file("test")?;
178+
179+
// First batch: non-nullable val (simulates literal-0 UNION branch)
180+
let non_nullable_batch = RecordBatch::try_new(
181+
Arc::clone(&non_nullable_schema),
182+
vec![
183+
Arc::new(Int64Array::from(vec![1, 2, 3])),
184+
Arc::new(Int64Array::from(vec![0, 0, 0])),
185+
],
186+
)?;
187+
in_progress.append_batch(&non_nullable_batch)?;
188+
189+
// Second batch: nullable val with NULLs (simulates table UNION branch)
190+
let nullable_batch = RecordBatch::try_new(
191+
Arc::clone(&nullable_schema),
192+
vec![
193+
Arc::new(Int64Array::from(vec![4, 5, 6])),
194+
Arc::new(Int64Array::from(vec![Some(10), None, Some(30)])),
195+
],
196+
)?;
197+
in_progress.append_batch(&nullable_batch)?;
198+
199+
let spill_file = in_progress.finish()?.unwrap();
200+
201+
let stream = spill_manager.read_spill_as_stream(spill_file, None)?;
202+
203+
// Stream schema should be nullable
204+
assert_eq!(stream.schema(), nullable_schema);
205+
206+
let batches = stream.try_collect::<Vec<_>>().await?;
207+
assert_eq!(batches.len(), 2);
208+
209+
// Both batches must have the SpillManager's nullable schema
210+
assert_eq!(
211+
batches[0],
212+
non_nullable_batch.with_schema(Arc::clone(&nullable_schema))?
213+
);
214+
assert_eq!(batches[1], nullable_batch);
215+
216+
Ok(())
217+
}
218+
}

0 commit comments

Comments
 (0)