Add Automicrobatching for Non-Powers-of-2 + Fixes to FSDP deadlocks using Adaptive Sync Hooks#3503
Add Automicrobatching for Non-Powers-of-2 + Fixes to FSDP deadlocks using Adaptive Sync Hooks#3503JackZ-db wants to merge 23 commits intomosaicml:mainfrom
Conversation
mvpatel2000
left a comment
There was a problem hiding this comment.
Rerequest once test passes!
mvpatel2000
left a comment
There was a problem hiding this comment.
first pass, design looks right but code needs some cleanup
| @no_type_check | ||
| def unshard(self): | ||
| """ | ||
| Run the unshard logic. | ||
| This is an unpatched method from pytorch, meant to be reverted to | ||
| whenever automicrobatching turns off its hooks for increased throughput. | ||
| This includes all-gathering the flat parameter | ||
| and switching to using the unsharded flat parameter. If the handle does | ||
| not need unsharding, then this only switches to using the unsharded | ||
| flat parameter. For ``NO_SHARD``, this is a no-op. | ||
| If FSDP is in :meth:`summon_full_params` and the handle uses parameter |
There was a problem hiding this comment.
This should probably be in the if torch 2.3.1 section
|
|
||
| if auto_microbatching: |
There was a problem hiding this comment.
can you add a comment on what this is doing?
| def _double_device_train_microbatch_size(state: State): | ||
| """Double device_train_microbatch_size when automicrobatching searches upward for a higher non-OOM microbatch size. |
There was a problem hiding this comment.
should this go into automcirobatching utils folder?
| num_consecutive_thrashes = 0 | ||
| return num_consecutive_thrashes | ||
|
|
||
| def _handle_downward_search_in_automicrobatching(state: State, lowest_oom_microbatch_size: int, highest_non_oom_microbatch_size: int, lower_bound_microbatch_size: int, num_search_steps: int, max_search_steps: int): |
There was a problem hiding this comment.
same comment on moving to utils?
| if parallelism_config is not None: | ||
| # Patch PyTorch to fix distributed bugs | ||
| patch_pytorch() | ||
| patch_unshard_for_automicrobatching(self.auto_microbatch_size_found) |
There was a problem hiding this comment.
this should be just part of patch_pytorch to simplify interface
There was a problem hiding this comment.
we need to pass in a boolean variable telling it how to patch this one specific method though - i feel like it would be less readable if we passed self.auto_microbatch_size_found directly into patch_pytorch
| # Sync for OOMs | ||
| found_cuda_oom = _found_ooms_across_ranks(self.state, found_cuda_oom) |
There was a problem hiding this comment.
this block is really complicated. lets move to a helper fn
|
|
||
| with torch.no_grad(), model_eval_mode(self.state.model): | ||
| if self.state.fsdp_enabled and self.first_batch_complete: | ||
| print("readd hooks for eval") |
| extract_hparams, | ||
| ) | ||
| from composer.utils.automicrobatching import ( | ||
| # _create_sync_hook, |
| 'validate_credentials', | ||
| 'build_remote_backend', | ||
| 'RemoteFilesExistingCheckStatus', | ||
| # '_create_sync_hook', |
No description provided.