|
| 1 | +import logging |
| 2 | +import multiprocessing as py_mp |
| 3 | +import os |
| 4 | +import traceback |
| 5 | +from pathlib import Path |
| 6 | +from typing import Any |
| 7 | + |
| 8 | +import pytest |
| 9 | +import torch |
| 10 | +import torch.multiprocessing as mp |
| 11 | + |
| 12 | +from modalities.__main__ import Main, load_app_config_dict |
| 13 | +from modalities.batch import EvaluationResultBatch |
| 14 | +from modalities.config.config import ProcessGroupBackendType |
| 15 | +from modalities.config.instantiation_models import TrainingComponentsInstantiationModel |
| 16 | +from modalities.logging_broker.messages import Message |
| 17 | +from tests.end2end_tests.custom_components import ( |
| 18 | + MultiProcessingCudaEnv, |
| 19 | + SaveAllResultSubscriber, |
| 20 | + SaveAllResultSubscriberConfig, |
| 21 | +) |
| 22 | +from tests.utility import find_free_port, monitor_child_processes |
| 23 | + |
| 24 | + |
| 25 | +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="This E2E test requires 4 CUDA devices.") |
| 26 | +class TestMoEEPFSDP2E2E: |
| 27 | + @staticmethod |
| 28 | + def _patch_for_short_test_run(config_dict: dict[str, Any], checkpoint_root_path: Path) -> None: |
| 29 | + # Keep runtime short while preserving EP + FSDP2 wiring. |
| 30 | + config_dict["settings"]["intervals"]["training_log_interval_in_steps"] = 1 |
| 31 | + config_dict["settings"]["intervals"]["checkpointing_interval_in_steps"] = 1 |
| 32 | + config_dict["settings"]["intervals"]["evaluation_interval_in_steps"] = 1000 |
| 33 | + |
| 34 | + config_dict["settings"]["step_profile"]["sequence_length"] = 64 |
| 35 | + config_dict["settings"]["step_profile"]["local_train_micro_batch_size"] = 1 |
| 36 | + config_dict["settings"]["step_profile"]["gradient_accumulation_steps"] = 1 |
| 37 | + |
| 38 | + config_dict["settings"]["training_target"]["num_target_tokens"] = 512 |
| 39 | + config_dict["settings"]["training_target"]["num_target_steps"] = 2 |
| 40 | + config_dict["lr_scheduler"]["config"]["total_steps"] = 2 |
| 41 | + |
| 42 | + config_dict["train_dataset"]["config"]["sequence_length"] = 64 |
| 43 | + config_dict["test_dataset"]["config"]["sequence_length"] = 64 |
| 44 | + config_dict["train_dataloader"]["config"]["num_workers"] = 0 |
| 45 | + config_dict["test_dataloader"]["config"]["num_workers"] = 0 |
| 46 | + config_dict["train_dataloader"]["config"]["pin_memory"] = False |
| 47 | + config_dict["test_dataloader"]["config"]["pin_memory"] = False |
| 48 | + |
| 49 | + config_dict["settings"]["paths"]["checkpoint_saving_path"] = checkpoint_root_path |
| 50 | + config_dict["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"][ |
| 51 | + "checkpoint_path" |
| 52 | + ] = checkpoint_root_path |
| 53 | + |
| 54 | + @staticmethod |
| 55 | + def _worker_wrapper( |
| 56 | + process_id: int, |
| 57 | + world_size: int, |
| 58 | + rdvz_port: int, |
| 59 | + config_file_path: Path, |
| 60 | + tmp_path: Path, |
| 61 | + error_queue: Any, |
| 62 | + ) -> None: |
| 63 | + with MultiProcessingCudaEnv( |
| 64 | + process_group_backend=ProcessGroupBackendType.nccl, |
| 65 | + global_rank=process_id, |
| 66 | + local_rank=process_id, |
| 67 | + world_size=world_size, |
| 68 | + rdvz_port=rdvz_port, |
| 69 | + ): |
| 70 | + try: |
| 71 | + TestMoEEPFSDP2E2E._worker_impl( |
| 72 | + process_id=process_id, |
| 73 | + config_file_path=config_file_path, |
| 74 | + tmp_path=tmp_path, |
| 75 | + ) |
| 76 | + except Exception as exc: |
| 77 | + tb = traceback.format_exc() |
| 78 | + logging.error(f"Process {process_id} failed: {exc}\n{tb}") |
| 79 | + try: |
| 80 | + error_queue.put((process_id, tb)) |
| 81 | + except Exception: |
| 82 | + logging.error("Failed to write child exception to queue.") |
| 83 | + os._exit(1) |
| 84 | + |
| 85 | + @staticmethod |
| 86 | + def _worker_impl(process_id: int, config_file_path: Path, tmp_path: Path) -> None: |
| 87 | + experiment_id = "moe-ep-fsdp2-e2e" |
| 88 | + checkpoint_root_path = tmp_path / experiment_id / "checkpoints" |
| 89 | + cfg = load_app_config_dict( |
| 90 | + config_file_path=config_file_path, experiments_root_path=tmp_path, experiment_id=experiment_id |
| 91 | + ) |
| 92 | + TestMoEEPFSDP2E2E._patch_for_short_test_run(cfg, checkpoint_root_path) |
| 93 | + |
| 94 | + main_obj = Main(config_file_path, experiments_root_path=tmp_path, experiment_id=experiment_id) |
| 95 | + main_obj.config_dict = cfg |
| 96 | + main_obj.add_custom_component( |
| 97 | + component_key="results_subscriber", |
| 98 | + variant_key="save_all", |
| 99 | + custom_component=SaveAllResultSubscriber, |
| 100 | + custom_config=SaveAllResultSubscriberConfig, |
| 101 | + ) |
| 102 | + main_obj.config_dict["evaluation_subscriber"]["variant_key"] = "save_all" |
| 103 | + main_obj.config_dict["evaluation_subscriber"]["config"] = {} |
| 104 | + |
| 105 | + components: TrainingComponentsInstantiationModel = main_obj.build_components( |
| 106 | + components_model_type=TrainingComponentsInstantiationModel |
| 107 | + ) |
| 108 | + |
| 109 | + assert getattr(components.model_raw, "_ep_wrapped", False), "Expected EP wrapping marker on raw model." |
| 110 | + first_layer = next(iter(components.model_raw.layers.values())) |
| 111 | + assert getattr(first_layer.ffn.experts, "_ep_enabled", False), "Expected experts to be EP-enabled." |
| 112 | + |
| 113 | + main_obj.run(components) |
| 114 | + |
| 115 | + result_messages: list[Message[EvaluationResultBatch]] = components.evaluation_subscriber.message_list |
| 116 | + assert len(result_messages) > 0, "Expected training messages in evaluation subscriber." |
| 117 | + for message in result_messages: |
| 118 | + loss_value = message.payload.losses["train loss avg"].value |
| 119 | + assert torch.isfinite(loss_value), f"Found non-finite train loss: {loss_value}" |
| 120 | + |
| 121 | + if process_id == 0: |
| 122 | + checkpoint_info_file_path = checkpoint_root_path / "last_checkpoint_info.json" |
| 123 | + assert checkpoint_info_file_path.exists(), "Expected checkpoint info file from DCP save." |
| 124 | + |
| 125 | + @staticmethod |
| 126 | + def test_moe_ep_fsdp2_training_and_checkpointing(tmp_path: Path) -> None: |
| 127 | + repo_root = Path(__file__).resolve().parents[2] |
| 128 | + config_file_path = repo_root / "config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml" |
| 129 | + |
| 130 | + world_size = 4 |
| 131 | + rdvz_port = find_free_port() |
| 132 | + |
| 133 | + manager = py_mp.Manager() |
| 134 | + error_queue = manager.Queue() |
| 135 | + proc_ctx = mp.spawn( |
| 136 | + TestMoEEPFSDP2E2E._worker_wrapper, |
| 137 | + args=(world_size, rdvz_port, config_file_path, tmp_path, error_queue), |
| 138 | + nprocs=world_size, |
| 139 | + join=False, |
| 140 | + ) |
| 141 | + |
| 142 | + monitor_child_processes(manager, error_queue, proc_ctx) |
0 commit comments