Skip to content

[Bug] Incorrect args for gradient_checkpoint_enable #4886

@jonahsamost

Description

@jonahsamost

Note: Please do not remove the questions. Answer beside them.

  1. Did you update? pip install --upgrade unsloth unsloth_zoo. yes
  2. Colab or Kaggle or local / cloud. cloud
  3. Number GPUs used, use nvidia-smi. 1
  4. Which notebook? Please link!
  5. Which Unsloth version, TRL version, transformers version, PyTorch version?
unsloth                                  2026.4.4
unsloth-zoo                              2026.4.3
trl                                      1.0.0
transformers                             5.5.0
torch                                    2.10.0
  1. Which trainer? SFTTrainer, GRPOTrainer etc. GRPOTrainer
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/root/foo-llm-finetune/src/foo_rl/train/src/phase1/unsloth_train.py", line 152, in <module>
    main()
  File "/root/foo-llm-finetune/src/foo_rl/train/src/phase1/unsloth_train.py", line 146, in main
    trainer.train()
  File "/root/foo-llm-finetune/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 84, in wrapper
    output = f(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/foo-llm-finetune/.venv/lib/python3.12/site-packages/unsloth/models/rl.py", line 142, in _unsloth_train_with_resume_guard
    return original_train(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/foo-llm-finetune/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 1424, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 81, in _fast_inner_training_loop
  File "/root/foo-llm-finetune/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 1734, in _run_epoch
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/foo-llm-finetune/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 3039, in training_step
    output = super().training_step(model, inputs, num_items_in_batch)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 34, in _unsloth_training_step
  File "/root/foo-llm-finetune/.venv/lib/python3.12/site-packages/trl/extras/profiling.py", line 202, in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/foo-llm-finetune/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 3068, in _prepare_inputs
    generation_batch = self._generate_and_score_completions(generation_batch)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/foo-llm-finetune/.venv/lib/python3.12/site-packages/unsloth/models/rl.py", line 524, in wrapped
    return original(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/foo-llm-finetune/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 3844, in _generate_and_score_completions
    with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs):
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/share/uv/python/cpython-3.12.13-linux-x86_64-gnu/lib/python3.12/contextlib.py", line 144, in __exit__
    next(self.gen)
  File "/root/foo-llm-finetune/.venv/lib/python3.12/site-packages/trl/models/utils.py", line 382, in disable_gradient_checkpointing
    model.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
TypeError: FastBaseModel.post_patch_model.<locals>._gc_enable_reentrant() takes 0 positional arguments but 1 was given

With use_gradient_checkpointing="unsloth", in my peft_kwargs. when you replace gradient_checkpoint_enable with _gc_enable_reentrant, there is an arg mismatch.

i added a workaround like

def _patch_gradient_checkpointing_enable_for_trl(m: Any) -> None:
    """TRL calls gradient_checkpointing_enable(kwargs); Unsloth's replacement takes no args."""
    if not hasattr(m, "gradient_checkpointing_enable"):
        return
    _orig = m.gradient_checkpointing_enable

    def _enable(self, gradient_checkpointing_kwargs=None, **kwargs):  # type: ignore[no-untyped-def]
        return _orig()

    import types

    m.gradient_checkpointing_enable = types.MethodType(_enable, m)

🦥 You can also ask via our Reddit page: https://reddit.com/r/unsloth/

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions