Skip to content

Commit a99d6e0

Browse files
maomaodevpan3793
authored andcommitted
[KYUUBI #7422] [KSHC] Fix FileWriterFactory using the same TaskAttemptId for different task attempts
### Why are the changes needed? Port SPARK-48484 to KSHC. Fix #7421. In the KSHC, `FileWriterFactory` is forked from Spark's `org.apache.spark.sql.execution.datasources.v2.FileWriterFactory`. However, it still contains a bug later fixed on the Spark side by apache/spark#46811. This PR ports that upstream fix to KSHC. ### How was this patch tested? UT ### Was this patch authored or co-authored using generative AI tooling? NO Closes #7422 from maomaodev/kyuubi-7421. Closes #7422 1272f87 [lifumao] [KSHC] Fix FileWriterFactory using the same TaskAttemptId for different task attempts Authored-by: lifumao <lifumao@tencent.com> Signed-off-by: Cheng Pan <chengpan@apache.org>
1 parent 100446c commit a99d6e0

3 files changed

Lines changed: 59 additions & 3 deletions

File tree

extensions/spark/kyuubi-spark-connector-hive/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@
6969
<scope>test</scope>
7070
</dependency>
7171

72+
<dependency>
73+
<groupId>org.scalatestplus</groupId>
74+
<artifactId>mockito-4-11_${scala.binary.version}</artifactId>
75+
<scope>test</scope>
76+
</dependency>
77+
7278
<dependency>
7379
<groupId>org.apache.spark</groupId>
7480
<artifactId>spark-core_${scala.binary.version}</artifactId>

extensions/spark/kyuubi-spark-connector-hive/src/main/scala/org/apache/kyuubi/spark/connector/hive/write/FileWriterFactory.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ case class FileWriterFactory(
4343
@transient private lazy val jobId = createJobID(jobTrackerID, 0)
4444

4545
override def createWriter(partitionId: Int, realTaskId: Long): DataWriter[InternalRow] = {
46-
val taskAttemptContext = createTaskAttemptContext(partitionId)
46+
val taskAttemptContext = createTaskAttemptContext(partitionId, realTaskId.toInt & Int.MaxValue)
4747
committer.setupTask(taskAttemptContext)
4848
if (description.partitionColumns.isEmpty) {
4949
new SingleDirectoryDataWriter(description, taskAttemptContext, committer)
@@ -52,9 +52,11 @@ case class FileWriterFactory(
5252
}
5353
}
5454

55-
private def createTaskAttemptContext(partitionId: Int): TaskAttemptContextImpl = {
55+
private def createTaskAttemptContext(
56+
partitionId: Int,
57+
realTaskId: Int): TaskAttemptContextImpl = {
5658
val taskId = new TaskID(jobId, TaskType.MAP, partitionId)
57-
val taskAttemptId = new TaskAttemptID(taskId, 0)
59+
val taskAttemptId = new TaskAttemptID(taskId, realTaskId)
5860
// Set up the configuration object
5961
val hadoopConf = description.serializableHadoopConf.value
6062
hadoopConf.set("mapreduce.job.id", jobId.toString)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.kyuubi.spark.connector.hive.write
19+
20+
import org.apache.hadoop.conf.Configuration
21+
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
22+
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.internal.io.FileCommitProtocol
24+
import org.apache.spark.sql.execution.datasources.WriteJobDescription
25+
import org.apache.spark.util.SerializableConfiguration
26+
import org.mockito.Mockito._
27+
import org.scalatest.PrivateMethodTester
28+
29+
class FileWriterFactorySuite extends SparkFunSuite with PrivateMethodTester {
30+
31+
test("V2Write uses different TaskAttemptIds for different task attempts") {
32+
val jobDescription = mock(classOf[WriteJobDescription])
33+
when(jobDescription.serializableHadoopConf).thenReturn(
34+
new SerializableConfiguration(new Configuration(false)))
35+
val committer = mock(classOf[FileCommitProtocol])
36+
37+
val writerFactory = FileWriterFactory(jobDescription, committer)
38+
val createTaskAttemptContext =
39+
PrivateMethod[TaskAttemptContextImpl](Symbol("createTaskAttemptContext"))
40+
41+
val attemptContext =
42+
writerFactory.invokePrivate(createTaskAttemptContext(0, 1))
43+
val attemptContext1 =
44+
writerFactory.invokePrivate(createTaskAttemptContext(0, 2))
45+
assert(attemptContext.getTaskAttemptID.getTaskID == attemptContext1.getTaskAttemptID.getTaskID)
46+
assert(attemptContext.getTaskAttemptID.getId != attemptContext1.getTaskAttemptID.getId)
47+
}
48+
}

0 commit comments

Comments
 (0)