Skip to content

Functorch support#1184

Open
roycho96 wants to merge 30 commits intolinkedin:mainfrom
roycho96:functorch-support
Open

Functorch support#1184
roycho96 wants to merge 30 commits intolinkedin:mainfrom
roycho96:functorch-support

Conversation

@roycho96
Copy link
Copy Markdown
Contributor

@roycho96 roycho96 commented Apr 3, 2026

Summary

Adds setup_context to torch.autograd.Function subclasses that use the legacy forward(ctx, ...) pattern, enabling compatibility with torch.func transforms (torch.func.grad, torch.func.grad_and_value, etc.).

Currently, any code that uses torch.func.grad_and_value() with Liger-Kernel enabled crashes:

RuntimeError: In order to use an autograd.Function with functorch transforms
(vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod.

This affects using torch.func.grad_and_value() internally for chunked gradient computation.

Details

For each autograd.Function:

  1. Removed ctx from forward() signature
  2. Added setup_context(ctx, inputs, output) to handle context setup
  3. Updated backward signature to match new forward output count
  4. Updated all callers to handle extra return values

Functions with @amp_custom_fwd / @amp_custom_bwd decorators were skipped — these need separate handling due to the decorator assuming args[0] is ctx.

Testing Done

  • Hardware Type: RTX 5060 Ti and H100
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence
  • Added functorch compatibility tests for each modified Function

roycho96 added 30 commits April 4, 2026 02:51
Remove the ctx parameter assumption from ensure_contiguous wrapper so it
works with both legacy forward(ctx, ...) and new forward(...) signatures
needed for setup_context pattern.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant