Refactored auto-microbatching hook handles for FSDP#3843
Conversation
b126818 to
2a1cfb4
Compare
fa82a6b to
647ad56
Compare
|
can we run an e2e test to verify it works? |
It seems hard to curate an e2e test to catch this failure that would be more informative than our current unit tests. We have these two unit tests: I guess in theory, we could create a larger example where a certain MPT module raises a CUDA OOM error at a certain epoch given a specific batch size or we could use a MPT module with a massive hidden layer in the FFN that will run into OOM for one batch size but not 1/2 of it... |
I meant just a general e2e test, does not have to trigger OOM |
|
Tested here: mpt-7b-fsdp2-p39FPR and compared it to base (mpt-7b-fsdp2-AKLNwv) and the numbers look good based on the tolerations mentioned in the regression testing PR @bowenyang008 (note that it defaults to 8 microbatch size when auto is set) |
fixed test issues formatted gated non-wrapped to FSDP1 updated for FSDP2 propagated changes to trainer added minor test fix formatted formatted once more addressed comments formatted minor fix
bowenyang008
left a comment
There was a problem hiding this comment.
left a few minor comments, LGTM
Refactored auto-microbatching hook handles for FSDP1 with additional documentation.
This PR was originally designed to support FSDP2 auto microbatching, but since there are additional issues with FSDP2 state there, we moved that to a draft PR: #3866