Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,16 +814,9 @@ def pad_input_tensors(tensor, batch_size, num_processes, dim=0):
"""

def _pad_input_tensors(tensor, batch_size, num_processes, dim=0):
remainder = batch_size // num_processes
last_inputs = batch_size - (remainder * num_processes)
if batch_size // num_processes == 0:
to_pad = num_processes - batch_size
else:
to_pad = num_processes - (batch_size // num_processes)
# In the rare case that `to_pad` is negative,
# we need to pad the last inputs - the found `to_pad`
if last_inputs > to_pad & to_pad < 1:
to_pad = last_inputs - to_pad
# Pad dim-0 up to the next multiple of num_processes so the batch can be split
# evenly across processes (0 padding when batch_size is already divisible).
to_pad = (num_processes - batch_size % num_processes) % num_processes
old_size = tensor.shape
new_size = list(old_size)
new_size[0] = batch_size + to_pad
Expand Down
13 changes: 13 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,19 @@ def test_slice_and_concatenate(self):
# We should expect there to be 66 items now
assert result.shape == torch.Size([66, 4, 4])

def test_pad_input_tensors_divisibility(self):
# pad_input_tensors must pad dim-0 up to the next multiple of num_processes
# (0 padding when already divisible), so the batch can be split evenly.
# Regression: the old code padded 8->10 and 6->9 (not divisible by 4).
assert pad_input_tensors(torch.rand(8, 4), 8, 4).shape == torch.Size([8, 4])
assert pad_input_tensors(torch.rand(6, 4), 6, 4).shape == torch.Size([8, 4])
# Invariant across sizes: result >= batch_size and divisible by num_processes.
for num_processes in range(1, 9):
for batch_size in range(1, 30):
padded = pad_input_tensors(torch.zeros(batch_size, 2), batch_size, num_processes)
assert padded.shape[0] >= batch_size
assert padded.shape[0] % num_processes == 0

def test_send_to_device_compiles(self):
compiled_send_to_device = torch.compile(send_to_device, fullgraph=True)
compiled_send_to_device(torch.zeros([1], dtype=torch.bfloat16), "cpu")
Expand Down