Skip to content

Commit 33c55a4

Browse files
committed
feat: hardend weight tying against misconfigurations
1 parent ec1ac4f commit 33c55a4

6 files changed

Lines changed: 102 additions & 1 deletion

File tree

src/modalities/config/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
PydanticTokenizerIFType,
3535
)
3636
from modalities.config.utils import parse_torch_device
37+
from modalities.models.weight_tying import has_tied_word_embeddings
3738
from modalities.running_env.env_utils import (
3839
FSDP2MixedPrecisionSettings,
3940
MixedPrecisionSettings,
@@ -342,6 +343,13 @@ def validate_tp_mesh_existence(self) -> "GPT2ModelTPConfig":
342343
raise ValueError("data_parallel_replicate_degree > 1 cannot be used with Tensor Parallelism.")
343344
return self
344345

346+
@model_validator(mode="after")
347+
def validate_untied_word_embeddings(self) -> "GPT2ModelTPConfig":
348+
models = self.model if isinstance(self.model, list) else [self.model]
349+
if any(has_tied_word_embeddings(model) for model in models):
350+
raise ValueError("Tied word embeddings are not supported with Tensor Parallelism.")
351+
return self
352+
345353

346354
class CompiledModelConfig(BaseModel):
347355
model: PydanticPytorchModuleOrListType

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,12 @@ def __init__(
938938
self.transformer.lm_head.weight
939939
) # https://paperswithcode.com/method/weight-tying
940940

941+
@property
942+
def has_tied_word_embeddings(self) -> bool:
943+
token_embedding_weight = getattr(self.transformer.wte, "weight", None)
944+
lm_head_weight = getattr(self.transformer.lm_head, "weight", None)
945+
return token_embedding_weight is not None and token_embedding_weight is lm_head_weight
946+
941947
@overload
942948
def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
943949
"""

src/modalities/models/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def weight_decay_groups(self) -> WeightDecayGroups:
4646
"""
4747
return self._weight_decay_groups
4848

49+
@property
50+
def has_tied_word_embeddings(self) -> bool:
51+
"""Whether the model currently uses tied token embedding and output weights."""
52+
return False
53+
4954
@abstractmethod
5055
def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
5156
"""

src/modalities/models/parallelism/pipeline_parallelism_configs.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Annotated
22

3-
from pydantic import BaseModel, Field
3+
from pydantic import BaseModel, Field, model_validator
44

55
from modalities.config.pydantic_if_types import (
66
PydanticDeviceMeshIFType,
@@ -11,6 +11,7 @@
1111
PydanticStagesGeneratorType,
1212
)
1313
from modalities.models.parallelism.pipeline_parallelism import PipelineSelectionTypes
14+
from modalities.models.weight_tying import has_tied_word_embeddings
1415
from modalities.utils.deprecated_alias import add_deprecated_alias
1516

1617

@@ -26,6 +27,12 @@ class StagedPipelineConfig(BaseModel):
2627
pp_schedule_name: str
2728
num_layers_per_stage: Annotated[int, Field(strict=True, ge=1)]
2829

30+
@model_validator(mode="after")
31+
def validate_untied_word_embeddings(self) -> "StagedPipelineConfig":
32+
if has_tied_word_embeddings(self.whole_model):
33+
raise ValueError("Tied word embeddings are not supported with Pipeline Parallelism.")
34+
return self
35+
2936

3037
class ScheduledPipelineConfig(BaseModel):
3138
loss_fn: PydanticLossIFType
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch.nn as nn
2+
3+
4+
def has_tied_word_embeddings(model: nn.Module) -> bool:
5+
model_has_tied_word_embeddings = getattr(model, "has_tied_word_embeddings", None)
6+
if model_has_tied_word_embeddings is None:
7+
raise TypeError(
8+
f"{type(model).__name__} must define 'has_tied_word_embeddings' to be used with tied-embedding validation."
9+
)
10+
11+
return bool(model_has_tied_word_embeddings)

tests/test_weight_tying.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import pytest
22
import torch.nn as nn
3+
from pydantic import ValidationError
4+
from torch.distributed.device_mesh import DeviceMesh
35

6+
from modalities.config.config import GPT2ModelTPConfig
47
from modalities.models.components.layer_norms import LayerNormConfig
58
from modalities.models.gpt2.gpt2_model import (
69
GPT2LLM,
@@ -11,6 +14,10 @@
1114
PositionTypes,
1215
)
1316
from modalities.models.model import ActivationType
17+
from modalities.models.parallelism.pipeline_parallelism_configs import StagedPipelineConfig
18+
from modalities.models.parallelism.stages_generator import GPT2LLMStagesGenerator
19+
from modalities.models.weight_tying import has_tied_word_embeddings
20+
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees
1421

1522
VOCAB_SIZE = 1000
1623
EMBEDDING_DIM = 64
@@ -79,9 +86,17 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM:
7986
)
8087

8188

89+
def create_device_mesh_stub(*mesh_dim_names: str) -> DeviceMesh:
90+
device_mesh = DeviceMesh.__new__(DeviceMesh)
91+
device_mesh.mesh_dim_names = mesh_dim_names
92+
return device_mesh
93+
94+
8295
@pytest.mark.parametrize("use_weight_tying", [True, False])
8396
def test_weight_tying_behavior(use_weight_tying):
8497
model = create_gpt2_model(use_weight_tying)
98+
assert model.has_tied_word_embeddings is use_weight_tying
99+
85100
if use_weight_tying:
86101
assert (
87102
model.transformer.wte.weight is model.transformer.lm_head.weight
@@ -118,3 +133,52 @@ def test_weight_tying_named_parameters(use_weight_tying):
118133
assert (
119134
"transformer.lm_head.weight" in named_params
120135
), "transformer.lm_head.weight should appear in named_parameters when weight tying is not used."
136+
137+
138+
def test_has_tied_word_embeddings_requires_model_capability():
139+
with pytest.raises(TypeError, match="must define 'has_tied_word_embeddings'"):
140+
has_tied_word_embeddings(nn.Linear(1, 1))
141+
142+
143+
def test_tp_config_rejects_tied_word_embeddings():
144+
model = create_gpt2_model(use_weight_tying=True)
145+
device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value)
146+
147+
with pytest.raises(ValidationError, match="Tied word embeddings are not supported with Tensor Parallelism"):
148+
GPT2ModelTPConfig(model=model, device_mesh=device_mesh)
149+
150+
151+
def test_tp_config_allows_untied_word_embeddings():
152+
model = create_gpt2_model(use_weight_tying=False)
153+
device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value)
154+
155+
GPT2ModelTPConfig(model=model, device_mesh=device_mesh)
156+
157+
158+
def test_pp_config_rejects_tied_word_embeddings():
159+
model = create_gpt2_model(use_weight_tying=True)
160+
device_mesh = create_device_mesh_stub(ParallelismDegrees.PP.value)
161+
162+
with pytest.raises(ValidationError, match="Tied word embeddings are not supported with Pipeline Parallelism"):
163+
StagedPipelineConfig(
164+
whole_model=model,
165+
stages_generator=GPT2LLMStagesGenerator(num_model_layers=model.n_layer),
166+
device_mesh=device_mesh,
167+
local_rank=0,
168+
pp_schedule_name="gpipe",
169+
num_layers_per_stage=1,
170+
)
171+
172+
173+
def test_pp_config_allows_untied_word_embeddings():
174+
model = create_gpt2_model(use_weight_tying=False)
175+
device_mesh = create_device_mesh_stub(ParallelismDegrees.PP.value)
176+
177+
StagedPipelineConfig(
178+
whole_model=model,
179+
stages_generator=GPT2LLMStagesGenerator(num_model_layers=model.n_layer),
180+
device_mesh=device_mesh,
181+
local_rank=0,
182+
pp_schedule_name="gpipe",
183+
num_layers_per_stage=1,
184+
)

0 commit comments

Comments
 (0)