|
16 | 16 | */ |
17 | 17 | package org.apache.gluten.utils |
18 | 18 |
|
| 19 | +import org.apache.gluten.config.GlutenConfig |
19 | 20 | import org.apache.gluten.sql.shims.SparkShimLoader |
| 21 | +import org.apache.gluten.utils.PartitionsUtil.regeneratePartition |
20 | 22 |
|
21 | 23 | import org.apache.spark.Partition |
22 | 24 | import org.apache.spark.internal.Logging |
23 | 25 | import org.apache.spark.sql.catalyst.expressions.Attribute |
24 | | -import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, HadoopFsRelation, PartitionDirectory} |
| 26 | +import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile} |
25 | 27 | import org.apache.spark.sql.types.StructType |
26 | 28 | import org.apache.spark.util.collection.BitSet |
27 | 29 |
|
28 | 30 | import org.apache.hadoop.fs.Path |
29 | 31 |
|
| 32 | +import scala.collection.mutable |
| 33 | + |
30 | 34 | case class PartitionsUtil( |
31 | 35 | relation: HadoopFsRelation, |
32 | 36 | requiredSchema: StructType, |
@@ -96,7 +100,10 @@ case class PartitionsUtil( |
96 | 100 | } |
97 | 101 | .sortBy(_.length)(implicitly[Ordering[Long]].reverse) |
98 | 102 |
|
99 | | - FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes) |
| 103 | + val inputPartitions = |
| 104 | + FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes) |
| 105 | + |
| 106 | + regeneratePartition(inputPartitions, GlutenConfig.get.smallFileThreshold) |
100 | 107 | } |
101 | 108 |
|
102 | 109 | private def genBucketedPartitionSeq(): Seq[Partition] = { |
@@ -140,3 +147,90 @@ case class PartitionsUtil( |
140 | 147 | output.find(_.name == colName) |
141 | 148 | } |
142 | 149 | } |
| 150 | + |
| 151 | +object PartitionsUtil { |
| 152 | + |
| 153 | + /** |
| 154 | + * Regenerate the partitions by balancing the number of files per partition and total size per |
| 155 | + * partition. |
| 156 | + */ |
| 157 | + def regeneratePartition( |
| 158 | + inputPartitions: Seq[FilePartition], |
| 159 | + smallFileThreshold: Double): Seq[FilePartition] = { |
| 160 | + |
| 161 | + // Flatten and sort descending by file size. |
| 162 | + val filesSorted: Seq[(PartitionedFile, Long)] = |
| 163 | + inputPartitions |
| 164 | + .flatMap(_.files) |
| 165 | + .map(f => (f, f.length)) |
| 166 | + .sortBy(_._2)(Ordering.Long.reverse) |
| 167 | + |
| 168 | + val partitions = Array.fill(inputPartitions.size)(mutable.ArrayBuffer.empty[PartitionedFile]) |
| 169 | + |
| 170 | + def addToBucket( |
| 171 | + heap: mutable.PriorityQueue[(Long, Int, Int)], |
| 172 | + file: PartitionedFile, |
| 173 | + sz: Long): Unit = { |
| 174 | + val (load, numFiles, idx) = heap.dequeue() |
| 175 | + partitions(idx) += file |
| 176 | + heap.enqueue((load + sz, numFiles + 1, idx)) |
| 177 | + } |
| 178 | + |
| 179 | + // First by load, then by numFiles. |
| 180 | + val heapByFileSize = |
| 181 | + mutable.PriorityQueue.empty[(Long, Int, Int)]( |
| 182 | + Ordering |
| 183 | + .by[(Long, Int, Int), (Long, Int)] { |
| 184 | + case (load, numFiles, _) => |
| 185 | + (load, numFiles) |
| 186 | + } |
| 187 | + .reverse |
| 188 | + ) |
| 189 | + |
| 190 | + if (smallFileThreshold > 0) { |
| 191 | + val smallFileTotalSize = filesSorted.map(_._2).sum * smallFileThreshold |
| 192 | + // First by numFiles, then by load. |
| 193 | + val heapByFileNum = |
| 194 | + mutable.PriorityQueue.empty[(Long, Int, Int)]( |
| 195 | + Ordering |
| 196 | + .by[(Long, Int, Int), (Int, Long)] { |
| 197 | + case (load, numFiles, _) => |
| 198 | + (numFiles, load) |
| 199 | + } |
| 200 | + .reverse |
| 201 | + ) |
| 202 | + |
| 203 | + inputPartitions.indices.foreach(i => heapByFileNum.enqueue((0L, 0, i))) |
| 204 | + |
| 205 | + var numSmallFiles = 0 |
| 206 | + var smallFileSize = 0L |
| 207 | + // Enqueue small files to the least number of files and the least load. |
| 208 | + filesSorted.reverse.takeWhile(f => f._2 + smallFileSize <= smallFileTotalSize).foreach { |
| 209 | + case (file, sz) => |
| 210 | + addToBucket(heapByFileNum, file, sz) |
| 211 | + numSmallFiles += 1 |
| 212 | + smallFileSize += sz |
| 213 | + } |
| 214 | + |
| 215 | + // Move buckets from heapByFileNum to heapByFileSize. |
| 216 | + while (heapByFileNum.nonEmpty) { |
| 217 | + heapByFileSize.enqueue(heapByFileNum.dequeue()) |
| 218 | + } |
| 219 | + |
| 220 | + // Finally, enqueue remaining files. |
| 221 | + filesSorted.take(filesSorted.size - numSmallFiles).foreach { |
| 222 | + case (file, sz) => |
| 223 | + addToBucket(heapByFileSize, file, sz) |
| 224 | + } |
| 225 | + } else { |
| 226 | + inputPartitions.indices.foreach(i => heapByFileSize.enqueue((0L, 0, i))) |
| 227 | + |
| 228 | + filesSorted.foreach { |
| 229 | + case (file, sz) => |
| 230 | + addToBucket(heapByFileSize, file, sz) |
| 231 | + } |
| 232 | + } |
| 233 | + |
| 234 | + partitions.zipWithIndex.map { case (p, idx) => FilePartition(idx, p.toArray) } |
| 235 | + } |
| 236 | +} |
0 commit comments