Skip to content

Refactored auto-microbatching hook handles for FSDP#3843

Merged
rithwik-db merged 9 commits intomainfrom
hookhandles
Jun 4, 2025
Merged

Refactored auto-microbatching hook handles for FSDP#3843
rithwik-db merged 9 commits intomainfrom
hookhandles

Conversation

@rithwik-db
Copy link
Copy Markdown
Contributor

@rithwik-db rithwik-db commented May 2, 2025

Refactored auto-microbatching hook handles for FSDP1 with additional documentation.

This PR was originally designed to support FSDP2 auto microbatching, but since there are additional issues with FSDP2 state there, we moved that to a draft PR: #3866

@rithwik-db rithwik-db changed the title Added hook handles for FSDP2 to address automicrobatching [WIP] Added hook handles for FSDP2 to address automicrobatching May 2, 2025
@rithwik-db rithwik-db changed the title [WIP] Added hook handles for FSDP2 to address automicrobatching Added hook handles for FSDP2 to supported auto microbatching May 2, 2025
@rithwik-db rithwik-db changed the title Added hook handles for FSDP2 to supported auto microbatching Added hook handles for FSDP2 to support auto microbatching May 2, 2025
Comment thread composer/distributed/fsdp2.py Outdated
Comment thread tests/common/models.py Outdated
Comment thread tests/common/models.py Outdated
Comment thread tests/trainer/test_fsdp2.py Outdated
Comment thread composer/distributed/fsdp2.py Outdated
@rithwik-db rithwik-db force-pushed the hookhandles branch 2 times, most recently from b126818 to 2a1cfb4 Compare May 5, 2025 22:11
Comment thread composer/distributed/shared_utils.py
Comment thread composer/distributed/shared_utils.py Outdated
Comment thread composer/distributed/shared_utils.py Outdated
Comment thread composer/distributed/shared_utils.py Outdated
@rithwik-db rithwik-db force-pushed the hookhandles branch 2 times, most recently from fa82a6b to 647ad56 Compare May 21, 2025 00:32
@rithwik-db rithwik-db requested a review from bowenyang008 May 21, 2025 20:27
Comment thread composer/distributed/fsdp2.py Outdated
Comment thread composer/distributed/prepare_distributed.py Outdated
Comment thread composer/distributed/prepare_distributed.py Outdated
Comment thread composer/distributed/shared_utils.py Outdated
Comment thread composer/distributed/shared_utils.py
Comment thread composer/distributed/shared_utils.py Outdated
Comment thread composer/distributed/shared_utils.py Outdated
Comment thread composer/distributed/shared_utils.py Outdated
Comment thread tests/trainer/fsdp2_context.py Outdated
Comment thread tests/trainer/test_fsdp2.py Outdated
Comment thread tests/trainer/test_fsdp2.py Outdated
Comment thread tests/trainer/test_fsdp2.py Outdated
@bowenyang008
Copy link
Copy Markdown
Contributor

can we run an e2e test to verify it works?

@rithwik-db
Copy link
Copy Markdown
Contributor Author

can we run an e2e test to verify it works?

It seems hard to curate an e2e test to catch this failure that would be more informative than our current unit tests. We have these two unit tests:

  1. FSDP1
  2. FSDP2

I guess in theory, we could create a larger example where a certain MPT module raises a CUDA OOM error at a certain epoch given a specific batch size or we could use a MPT module with a massive hidden layer in the FFN that will run into OOM for one batch size but not 1/2 of it...

@rithwik-db rithwik-db requested a review from bowenyang008 May 22, 2025 06:50
@bowenyang008
Copy link
Copy Markdown
Contributor

can we run an e2e test to verify it works?

It seems hard to curate an e2e test to catch this failure that would be more informative than our current unit tests. We have these two unit tests:

  1. FSDP1
  2. FSDP2

I guess in theory, we could create a larger example where a certain MPT module raises a CUDA OOM error at a certain epoch given a specific batch size or we could use a MPT module with a massive hidden layer in the FFN that will run into OOM for one batch size but not 1/2 of it...

I meant just a general e2e test, does not have to trigger OOM

@rithwik-db
Copy link
Copy Markdown
Contributor Author

rithwik-db commented May 22, 2025

Tested here: mpt-7b-fsdp2-p39FPR and compared it to base (mpt-7b-fsdp2-AKLNwv) and the numbers look good based on the tolerations mentioned in the regression testing PR @bowenyang008 (note that it defaults to 8 microbatch size when auto is set)

Comment thread composer/distributed/fsdp2.py Outdated
@rithwik-db rithwik-db changed the title Added hook handles for FSDP2 to support auto microbatching Refactored auto-microbatching hook handles for FSDP May 29, 2025
@rithwik-db rithwik-db requested a review from bowenyang008 May 29, 2025 21:04
fixed test issues

formatted

gated non-wrapped to FSDP1

updated for FSDP2

propagated changes to trainer

added minor test fix

formatted

formatted once more

addressed comments

formatted

minor fix
Comment thread composer/trainer/trainer.py
Comment thread composer/trainer/trainer.py Outdated
Comment thread composer/trainer/trainer.py
Copy link
Copy Markdown
Contributor

@bowenyang008 bowenyang008 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a few minor comments, LGTM

@rithwik-db rithwik-db merged commit ce08ff0 into main Jun 4, 2025
13 checks passed
@rithwik-db rithwik-db deleted the hookhandles branch June 4, 2025 18:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants