Skip to content
164 changes: 157 additions & 7 deletions unsloth_zoo/gradient_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
INITIAL_CPU_BUFFER_SIZE = 128 * 1024 # Initial size per CPU buffer
INITIAL_GPU_BUFFER_SIZE = 2 * 256 * 2048 # Initial size per GPU buffer
INITIAL_CPU_BUFFER_COUNT = 200 # Initial number of CPU buffers
DOUBLE_BUFFER_HEADROOM = 256 * 1024 * 1024 # 256MB minimum free CUDA memory to enable double buffering

torch_version = torch.__version__
if Version(torch_version) < Version("2.4.0"):
Expand Down Expand Up @@ -300,6 +301,8 @@ def set_device_states(devices, states, *, device_type=None) -> None:
global CPU_BUFFERS
global CPU_INDEX
global GPU_BUFFERS
global GPU_BUFFERS_B
global USE_DOUBLE_BUFFER
global BACKWARD_PASS
global EXTRA_STREAMS
global MAIN_STREAMS
Expand All @@ -308,6 +311,9 @@ def set_device_states(devices, states, *, device_type=None) -> None:
global LAST_GC_INDEX
global FIRST_PASS
global CURRENT_GC_INDEX
global BUFFER_EVENTS_A
global BUFFER_EVENTS_B
global NEXT_BUFFER_SLOT

if DEVICE_TYPE in ("cuda", "hip"):
torch_gpu_stream = torch.cuda.stream
Expand All @@ -322,6 +328,8 @@ def initialize_unsloth_gradient_checkpointing(dtype = None):
global CPU_BUFFERS
global CPU_INDEX
global GPU_BUFFERS
global GPU_BUFFERS_B
global USE_DOUBLE_BUFFER
global BACKWARD_PASS
global EXTRA_STREAMS
global MAIN_STREAMS
Expand All @@ -330,6 +338,9 @@ def initialize_unsloth_gradient_checkpointing(dtype = None):
global LAST_GC_INDEX
global FIRST_PASS
global CURRENT_GC_INDEX
global BUFFER_EVENTS_A
global BUFFER_EVENTS_B
global NEXT_BUFFER_SLOT
CPU_BUFFERS = []
CPU_INDEX = 0

Expand All @@ -351,8 +362,36 @@ def initialize_unsloth_gradient_checkpointing(dtype = None):

# Allocate buffers to how many GPUs
n_gpus = torch.cuda.device_count() if DEVICE_TYPE in ("cuda", "hip") else torch.xpu.device_count()
NEXT_BUFFER_SLOT = [0] * n_gpus
try:
GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f"{DEVICE_TYPE_TORCH}:{i}") for i in range(n_gpus)])
GPU_BUFFERS = tuple([torch.empty(INITIAL_GPU_BUFFER_SIZE, dtype = dtype, device = f"{DEVICE_TYPE_TORCH}:{i}") for i in range(n_gpus)])
# Double buffering: try to allocate buffer B (can be disabled via env var)
if os.environ.get("UNSLOTH_DISABLE_DOUBLE_BUFFER", "0") == "1":
GPU_BUFFERS_B = None
USE_DOUBLE_BUFFER = False
BUFFER_EVENTS_A = None
BUFFER_EVENTS_B = None
else:
try:
GPU_BUFFERS_B = tuple([torch.empty(INITIAL_GPU_BUFFER_SIZE, dtype = dtype, device = f"{DEVICE_TYPE_TORCH}:{i}") for i in range(n_gpus)])
USE_DOUBLE_BUFFER = False # set false first, enabled after first pass if CUDA free memory > DOUBLE_BUFFER_HEADROOM
# Per-buffer events to prevent race conditions in double buffering.
# Each event tracks when compute on that buffer finishes
if DEVICE_TYPE in ("cuda", "hip"):
event_ctor = torch.cuda.Event
elif DEVICE_TYPE == "xpu":
event_ctor = torch.xpu.Event
else:
raise RuntimeError(f"Double buffering unsupported on {DEVICE_TYPE}")
BUFFER_EVENTS_A = tuple([event_ctor() for _ in range(n_gpus)])
BUFFER_EVENTS_B = tuple([event_ctor() for _ in range(n_gpus)])
except RuntimeError as e:
if "out of memory" not in str(e).lower():
raise
GPU_BUFFERS_B = None
USE_DOUBLE_BUFFER = False
BUFFER_EVENTS_A = None
BUFFER_EVENTS_B = None
except Exception as e:
print("="*10 + "\n")
print("Unsloth: Your setup does not support `PYTORCH_CUDA_ALLOC_CONF`\n")
Expand Down Expand Up @@ -439,6 +478,8 @@ def forward(ctx, run_function, preserve_rng_state, *args):
use_gpu_buffer = True
global CPU_BUFFERS
global GPU_BUFFERS
global GPU_BUFFERS_B
global USE_DOUBLE_BUFFER
global BACKWARD_PASS
global EXTRA_STREAMS
global MAIN_STREAMS
Expand All @@ -452,6 +493,15 @@ def forward(ctx, run_function, preserve_rng_state, *args):
if BACKWARD_PASS:
BACKWARD_PASS = False
CPU_INDEX = 0
if not FIRST_PASS and not USE_DOUBLE_BUFFER and GPU_BUFFERS_B is not None:
free_mem, _ = torch.cuda.mem_get_info(device_index)
if free_mem > DOUBLE_BUFFER_HEADROOM:
USE_DOUBLE_BUFFER = True
print(f"Unsloth: Double buffering enabled (parallel H2D + compute) for backward pass.")
else:
for j in range(len(GPU_BUFFERS_B)):
GPU_BUFFERS_B[j].resize_(0)
GPU_BUFFERS_B = None
pass

# Extend buffer size
Expand All @@ -463,15 +513,50 @@ def forward(ctx, run_function, preserve_rng_state, *args):
x = CPU_BUFFERS[CPU_INDEX]
shape = arg.shape
if new_size > x.numel(): x.resize_(new_size)
if new_size > GPU_BUFFER.numel(): GPU_BUFFER.resize_(new_size)
if new_size > GPU_BUFFER.numel():
try:
GPU_BUFFER.resize_(new_size)
except RuntimeError as e:
if "out of memory" not in str(e).lower():
raise
# clear Buffer B and try to resize Single Buffer
if USE_DOUBLE_BUFFER:
USE_DOUBLE_BUFFER = False
for j in range(len(GPU_BUFFERS_B)):
GPU_BUFFERS_B[j].resize_(0)
GPU_BUFFERS_B = None
print("Unsloth: Disabled double buffering due to insufficient VRAM.")
GPU_BUFFER.resize_(new_size)
else:
raise
# resize buffer B as needed if double buffering is enabled, disable and free Buffer B if OOM
if USE_DOUBLE_BUFFER:
GPU_BUFFER_B = GPU_BUFFERS_B[device_index]
if new_size > GPU_BUFFER_B.numel():
try:
GPU_BUFFER_B.resize_(new_size)
except RuntimeError as e:
if "out of memory" not in str(e).lower():
raise
# OOM - disable double buffering
USE_DOUBLE_BUFFER = False
# Reclaim buffer B
for j in range(len(GPU_BUFFERS_B)):
GPU_BUFFERS_B[j].resize_(0)
GPU_BUFFERS_B = None
print("Unsloth: Disabled double buffering due to insufficient VRAM.")

x = x[:new_size].view(shape)

# See https://pytorch.org/docs/stable/notes/cuda.html#cuda-streams
EXTRA_STREAM.wait_stream(MAIN_STREAM)
with torch_gpu_stream(EXTRA_STREAM):
x.copy_(arg, non_blocking = True)

ctx._saved_metadata = (new_size, shape, CPU_INDEX, device_index, MAIN_STREAM, EXTRA_STREAM,)
global NEXT_BUFFER_SLOT
buffer_slot = NEXT_BUFFER_SLOT[device_index]
NEXT_BUFFER_SLOT[device_index] ^= 1
ctx._saved_metadata = (new_size, shape, CPU_INDEX, device_index, MAIN_STREAM, EXTRA_STREAM, buffer_slot,)
CPU_INDEX += 1
tensor_inputs.append(None)

Expand All @@ -480,7 +565,7 @@ def forward(ctx, run_function, preserve_rng_state, *args):
print("Unsloth: Will smartly offload gradients to save VRAM!")
USE_UNSLOTH_GC = False
else:
ctx._saved_metadata = (None, None, None, None, None, None,)
ctx._saved_metadata = (None, None, None, None, None, None, None,)
tensor_inputs.append(arg)
pass
else:
Expand Down Expand Up @@ -520,14 +605,30 @@ def backward(ctx, *args):
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors

new_size, shape, CPU_INDEX, device_index, MAIN_STREAM, EXTRA_STREAM = ctx._saved_metadata
new_size, shape, CPU_INDEX, device_index, MAIN_STREAM, EXTRA_STREAM, buffer_slot = ctx._saved_metadata
if CPU_INDEX is not None:
global GPU_BUFFER
buffer = GPU_BUFFERS[device_index][:new_size].view(shape)
global USE_DOUBLE_BUFFER
global GPU_BUFFERS_B
global BUFFER_EVENTS_A
global BUFFER_EVENTS_B
# Select which buffer to use based on per-device buffer_slot
if USE_DOUBLE_BUFFER and buffer_slot == 1:
buffer = GPU_BUFFERS_B[device_index][:new_size].view(shape)
else:
buffer = GPU_BUFFERS[device_index][:new_size].view(shape)

x = CPU_BUFFERS[CPU_INDEX][:new_size].view(shape)

# See https://pytorch.org/docs/stable/notes/cuda.html#cuda-streams
EXTRA_STREAM.wait_stream(MAIN_STREAM)
if USE_DOUBLE_BUFFER:
# Wait for the last compute on THIS SPECIFIC buffer to finish
event_buffer = BUFFER_EVENTS_B if buffer_slot == 1 else BUFFER_EVENTS_A
EXTRA_STREAM.wait_event(event_buffer[device_index])
else:
# Single buffer mode: Must wait for MAIN_STREAM to finish
EXTRA_STREAM.wait_stream(MAIN_STREAM)

with torch_gpu_stream(EXTRA_STREAM):
buffer.copy_(x, non_blocking = True)
else:
Expand Down Expand Up @@ -612,6 +713,11 @@ def backward(ctx, *args):
torch.autograd.backward(outputs_with_grad, args_with_grad)
pass

# Record event after compute finishes so the copy stream knows
if CPU_INDEX is not None and USE_DOUBLE_BUFFER:
event_buffer = BUFFER_EVENTS_B if buffer_slot == 1 else BUFFER_EVENTS_A
event_buffer[device_index].record(MAIN_STREAM)

grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs
Expand Down Expand Up @@ -819,14 +925,27 @@ def unpatch_unsloth_smart_gradient_checkpointing():
torch.utils.checkpoint.CheckpointFunction = torch.utils.checkpoint._old_CheckpointFunction
global CPU_BUFFERS
global GPU_BUFFERS
global GPU_BUFFERS_B
global USE_DOUBLE_BUFFER
global BUFFER_EVENTS_A
global BUFFER_EVENTS_B
global NEXT_BUFFER_SLOT
for i in range(len(CPU_BUFFERS)):
if hasattr(CPU_BUFFERS[i], "resize_"): CPU_BUFFERS[i].resize_(0)
if type(CPU_BUFFERS) is list: CPU_BUFFERS[i] = None
for i in range(len(GPU_BUFFERS)):
if hasattr(GPU_BUFFERS[i], "resize_"): GPU_BUFFERS[i].resize_(0)
if type(GPU_BUFFERS) is list: GPU_BUFFERS[i] = None
if GPU_BUFFERS_B is not None:
for i in range(len(GPU_BUFFERS_B)):
if hasattr(GPU_BUFFERS_B[i], "resize_"): GPU_BUFFERS_B[i].resize_(0)
GPU_BUFFERS_B = None
USE_DOUBLE_BUFFER = False
CPU_BUFFERS = None
GPU_BUFFERS = None
BUFFER_EVENTS_A = None
BUFFER_EVENTS_B = None
NEXT_BUFFER_SLOT = None
torch.cuda.empty_cache()
gc.collect()

Expand Down Expand Up @@ -875,6 +994,11 @@ def reset_unsloth_gradient_checkpointing_buffers():
global FIRST_PASS
global CURRENT_GC_INDEX
global USE_UNSLOTH_GC
global NEXT_BUFFER_SLOT
global GPU_BUFFERS_B
global USE_DOUBLE_BUFFER
global BUFFER_EVENTS_A
global BUFFER_EVENTS_B

# Check if buffers exist
if CPU_BUFFERS is None or GPU_BUFFERS is None:
Expand Down Expand Up @@ -913,6 +1037,32 @@ def reset_unsloth_gradient_checkpointing_buffers():
FIRST_PASS = True
CURRENT_GC_INDEX = 0
USE_UNSLOTH_GC = True # Re-enable the "Will smartly offload" message
if NEXT_BUFFER_SLOT is not None:
for i in range(len(NEXT_BUFFER_SLOT)):
NEXT_BUFFER_SLOT[i] = 0

# Reset double buffering if buffer B still exists, or try to re-allocate
if GPU_BUFFERS_B is not None:
for i in range(len(GPU_BUFFERS_B)):
if GPU_BUFFERS_B[i] is not None and hasattr(GPU_BUFFERS_B[i], "resize_"):
GPU_BUFFERS_B[i].resize_(INITIAL_GPU_BUFFER_SIZE)
USE_DOUBLE_BUFFER = False
else:
try:
n_gpus = len(GPU_BUFFERS)
dtype = GPU_BUFFERS[0].dtype
GPU_BUFFERS_B = tuple([torch.empty(INITIAL_GPU_BUFFER_SIZE, dtype=dtype, device=f"{DEVICE_TYPE_TORCH}:{i}") for i in range(n_gpus)])
if DEVICE_TYPE in ("cuda", "hip"):
event_ctor = torch.cuda.Event
elif DEVICE_TYPE == "xpu":
event_ctor = torch.xpu.Event
else:
raise RuntimeError(f"Double buffering unsupported on {DEVICE_TYPE}")
BUFFER_EVENTS_A = tuple([event_ctor() for _ in range(n_gpus)])
BUFFER_EVENTS_B = tuple([event_ctor() for _ in range(n_gpus)])
USE_DOUBLE_BUFFER = False
except RuntimeError:
pass

# Clean up freed memory
torch.cuda.empty_cache()
Expand Down