Skip to content

Commit 2a3b81a

Browse files
committed
test: Add e2e moe test
1 parent baf94e9 commit 2a3b81a

1 file changed

Lines changed: 142 additions & 0 deletions

File tree

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import logging
2+
import multiprocessing as py_mp
3+
import os
4+
import traceback
5+
from pathlib import Path
6+
from typing import Any
7+
8+
import pytest
9+
import torch
10+
import torch.multiprocessing as mp
11+
12+
from modalities.__main__ import Main, load_app_config_dict
13+
from modalities.batch import EvaluationResultBatch
14+
from modalities.config.config import ProcessGroupBackendType
15+
from modalities.config.instantiation_models import TrainingComponentsInstantiationModel
16+
from modalities.logging_broker.messages import Message
17+
from tests.end2end_tests.custom_components import (
18+
MultiProcessingCudaEnv,
19+
SaveAllResultSubscriber,
20+
SaveAllResultSubscriberConfig,
21+
)
22+
from tests.utility import find_free_port, monitor_child_processes
23+
24+
25+
@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="This E2E test requires 4 CUDA devices.")
26+
class TestMoEEPFSDP2E2E:
27+
@staticmethod
28+
def _patch_for_short_test_run(config_dict: dict[str, Any], checkpoint_root_path: Path) -> None:
29+
# Keep runtime short while preserving EP + FSDP2 wiring.
30+
config_dict["settings"]["intervals"]["training_log_interval_in_steps"] = 1
31+
config_dict["settings"]["intervals"]["checkpointing_interval_in_steps"] = 1
32+
config_dict["settings"]["intervals"]["evaluation_interval_in_steps"] = 1000
33+
34+
config_dict["settings"]["step_profile"]["sequence_length"] = 64
35+
config_dict["settings"]["step_profile"]["local_train_micro_batch_size"] = 1
36+
config_dict["settings"]["step_profile"]["gradient_accumulation_steps"] = 1
37+
38+
config_dict["settings"]["training_target"]["num_target_tokens"] = 512
39+
config_dict["settings"]["training_target"]["num_target_steps"] = 2
40+
config_dict["lr_scheduler"]["config"]["total_steps"] = 2
41+
42+
config_dict["train_dataset"]["config"]["sequence_length"] = 64
43+
config_dict["test_dataset"]["config"]["sequence_length"] = 64
44+
config_dict["train_dataloader"]["config"]["num_workers"] = 0
45+
config_dict["test_dataloader"]["config"]["num_workers"] = 0
46+
config_dict["train_dataloader"]["config"]["pin_memory"] = False
47+
config_dict["test_dataloader"]["config"]["pin_memory"] = False
48+
49+
config_dict["settings"]["paths"]["checkpoint_saving_path"] = checkpoint_root_path
50+
config_dict["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"][
51+
"checkpoint_path"
52+
] = checkpoint_root_path
53+
54+
@staticmethod
55+
def _worker_wrapper(
56+
process_id: int,
57+
world_size: int,
58+
rdvz_port: int,
59+
config_file_path: Path,
60+
tmp_path: Path,
61+
error_queue: Any,
62+
) -> None:
63+
with MultiProcessingCudaEnv(
64+
process_group_backend=ProcessGroupBackendType.nccl,
65+
global_rank=process_id,
66+
local_rank=process_id,
67+
world_size=world_size,
68+
rdvz_port=rdvz_port,
69+
):
70+
try:
71+
TestMoEEPFSDP2E2E._worker_impl(
72+
process_id=process_id,
73+
config_file_path=config_file_path,
74+
tmp_path=tmp_path,
75+
)
76+
except Exception as exc:
77+
tb = traceback.format_exc()
78+
logging.error(f"Process {process_id} failed: {exc}\n{tb}")
79+
try:
80+
error_queue.put((process_id, tb))
81+
except Exception:
82+
logging.error("Failed to write child exception to queue.")
83+
os._exit(1)
84+
85+
@staticmethod
86+
def _worker_impl(process_id: int, config_file_path: Path, tmp_path: Path) -> None:
87+
experiment_id = "moe-ep-fsdp2-e2e"
88+
checkpoint_root_path = tmp_path / experiment_id / "checkpoints"
89+
cfg = load_app_config_dict(
90+
config_file_path=config_file_path, experiments_root_path=tmp_path, experiment_id=experiment_id
91+
)
92+
TestMoEEPFSDP2E2E._patch_for_short_test_run(cfg, checkpoint_root_path)
93+
94+
main_obj = Main(config_file_path, experiments_root_path=tmp_path, experiment_id=experiment_id)
95+
main_obj.config_dict = cfg
96+
main_obj.add_custom_component(
97+
component_key="results_subscriber",
98+
variant_key="save_all",
99+
custom_component=SaveAllResultSubscriber,
100+
custom_config=SaveAllResultSubscriberConfig,
101+
)
102+
main_obj.config_dict["evaluation_subscriber"]["variant_key"] = "save_all"
103+
main_obj.config_dict["evaluation_subscriber"]["config"] = {}
104+
105+
components: TrainingComponentsInstantiationModel = main_obj.build_components(
106+
components_model_type=TrainingComponentsInstantiationModel
107+
)
108+
109+
assert getattr(components.model_raw, "_ep_wrapped", False), "Expected EP wrapping marker on raw model."
110+
first_layer = next(iter(components.model_raw.layers.values()))
111+
assert getattr(first_layer.ffn.experts, "_ep_enabled", False), "Expected experts to be EP-enabled."
112+
113+
main_obj.run(components)
114+
115+
result_messages: list[Message[EvaluationResultBatch]] = components.evaluation_subscriber.message_list
116+
assert len(result_messages) > 0, "Expected training messages in evaluation subscriber."
117+
for message in result_messages:
118+
loss_value = message.payload.losses["train loss avg"].value
119+
assert torch.isfinite(loss_value), f"Found non-finite train loss: {loss_value}"
120+
121+
if process_id == 0:
122+
checkpoint_info_file_path = checkpoint_root_path / "last_checkpoint_info.json"
123+
assert checkpoint_info_file_path.exists(), "Expected checkpoint info file from DCP save."
124+
125+
@staticmethod
126+
def test_moe_ep_fsdp2_training_and_checkpointing(tmp_path: Path) -> None:
127+
repo_root = Path(__file__).resolve().parents[2]
128+
config_file_path = repo_root / "config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml"
129+
130+
world_size = 4
131+
rdvz_port = find_free_port()
132+
133+
manager = py_mp.Manager()
134+
error_queue = manager.Queue()
135+
proc_ctx = mp.spawn(
136+
TestMoEEPFSDP2E2E._worker_wrapper,
137+
args=(world_size, rdvz_port, config_file_path, tmp_path, error_queue),
138+
nprocs=world_size,
139+
join=False,
140+
)
141+
142+
monitor_child_processes(manager, error_queue, proc_ctx)

0 commit comments

Comments
 (0)