|
1 | 1 | import pytest |
2 | 2 | import torch.nn as nn |
| 3 | +from pydantic import ValidationError |
| 4 | +from torch.distributed.device_mesh import DeviceMesh |
3 | 5 |
|
| 6 | +from modalities.config.config import GPT2ModelTPConfig |
4 | 7 | from modalities.models.components.layer_norms import LayerNormConfig |
5 | 8 | from modalities.models.gpt2.gpt2_model import ( |
6 | 9 | GPT2LLM, |
|
11 | 14 | PositionTypes, |
12 | 15 | ) |
13 | 16 | from modalities.models.model import ActivationType |
| 17 | +from modalities.models.parallelism.pipeline_parallelism_configs import StagedPipelineConfig |
| 18 | +from modalities.models.parallelism.stages_generator import GPT2LLMStagesGenerator |
| 19 | +from modalities.models.weight_tying import has_tied_word_embeddings |
| 20 | +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees |
14 | 21 |
|
15 | 22 | VOCAB_SIZE = 1000 |
16 | 23 | EMBEDDING_DIM = 64 |
@@ -79,9 +86,17 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM: |
79 | 86 | ) |
80 | 87 |
|
81 | 88 |
|
| 89 | +def create_device_mesh_stub(*mesh_dim_names: str) -> DeviceMesh: |
| 90 | + device_mesh = DeviceMesh.__new__(DeviceMesh) |
| 91 | + device_mesh.mesh_dim_names = mesh_dim_names |
| 92 | + return device_mesh |
| 93 | + |
| 94 | + |
82 | 95 | @pytest.mark.parametrize("use_weight_tying", [True, False]) |
83 | 96 | def test_weight_tying_behavior(use_weight_tying): |
84 | 97 | model = create_gpt2_model(use_weight_tying) |
| 98 | + assert model.has_tied_word_embeddings is use_weight_tying |
| 99 | + |
85 | 100 | if use_weight_tying: |
86 | 101 | assert ( |
87 | 102 | model.transformer.wte.weight is model.transformer.lm_head.weight |
@@ -118,3 +133,52 @@ def test_weight_tying_named_parameters(use_weight_tying): |
118 | 133 | assert ( |
119 | 134 | "transformer.lm_head.weight" in named_params |
120 | 135 | ), "transformer.lm_head.weight should appear in named_parameters when weight tying is not used." |
| 136 | + |
| 137 | + |
| 138 | +def test_has_tied_word_embeddings_requires_model_capability(): |
| 139 | + with pytest.raises(TypeError, match="must define 'has_tied_word_embeddings'"): |
| 140 | + has_tied_word_embeddings(nn.Linear(1, 1)) |
| 141 | + |
| 142 | + |
| 143 | +def test_tp_config_rejects_tied_word_embeddings(): |
| 144 | + model = create_gpt2_model(use_weight_tying=True) |
| 145 | + device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value) |
| 146 | + |
| 147 | + with pytest.raises(ValidationError, match="Tied word embeddings are not supported with Tensor Parallelism"): |
| 148 | + GPT2ModelTPConfig(model=model, device_mesh=device_mesh) |
| 149 | + |
| 150 | + |
| 151 | +def test_tp_config_allows_untied_word_embeddings(): |
| 152 | + model = create_gpt2_model(use_weight_tying=False) |
| 153 | + device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value) |
| 154 | + |
| 155 | + GPT2ModelTPConfig(model=model, device_mesh=device_mesh) |
| 156 | + |
| 157 | + |
| 158 | +def test_pp_config_rejects_tied_word_embeddings(): |
| 159 | + model = create_gpt2_model(use_weight_tying=True) |
| 160 | + device_mesh = create_device_mesh_stub(ParallelismDegrees.PP.value) |
| 161 | + |
| 162 | + with pytest.raises(ValidationError, match="Tied word embeddings are not supported with Pipeline Parallelism"): |
| 163 | + StagedPipelineConfig( |
| 164 | + whole_model=model, |
| 165 | + stages_generator=GPT2LLMStagesGenerator(num_model_layers=model.n_layer), |
| 166 | + device_mesh=device_mesh, |
| 167 | + local_rank=0, |
| 168 | + pp_schedule_name="gpipe", |
| 169 | + num_layers_per_stage=1, |
| 170 | + ) |
| 171 | + |
| 172 | + |
| 173 | +def test_pp_config_allows_untied_word_embeddings(): |
| 174 | + model = create_gpt2_model(use_weight_tying=False) |
| 175 | + device_mesh = create_device_mesh_stub(ParallelismDegrees.PP.value) |
| 176 | + |
| 177 | + StagedPipelineConfig( |
| 178 | + whole_model=model, |
| 179 | + stages_generator=GPT2LLMStagesGenerator(num_model_layers=model.n_layer), |
| 180 | + device_mesh=device_mesh, |
| 181 | + local_rank=0, |
| 182 | + pp_schedule_name="gpipe", |
| 183 | + num_layers_per_stage=1, |
| 184 | + ) |
0 commit comments