Skip to content

Commit 8c8c5ab

Browse files
committed
feat: Consider pp rank for model seed
1 parent 5f9f50e commit 8c8c5ab

1 file changed

Lines changed: 32 additions & 6 deletions

File tree

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919
RMSLayerNormConfig,
2020
)
2121
from modalities.models.model import ActivationType, NNModel, SwiGLU
22-
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method
22+
from modalities.running_env.fsdp.device_mesh import (
23+
ParallelismDegrees,
24+
get_parallel_degree,
25+
get_parallel_rank,
26+
has_parallelism_method,
27+
)
2328
from modalities.util import parse_enum_by_name
2429

2530
try:
@@ -874,11 +879,9 @@ def __init__(
874879
"embedding": [".wte", ".wpe"],
875880
"layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"],
876881
}
877-
# Set different random seed for each TP rank to ensure diversity
878-
if seed is not None and has_parallelism_method(
879-
device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP
880-
):
881-
seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP)
882+
# Set different random seed for each TP and PP rank to ensure diversity
883+
if seed is not None and device_mesh is not None:
884+
seed = _offset_seed_by_parallel_ranks(seed=seed, device_mesh=device_mesh)
882885
super().__init__(weight_decay_groups=weight_decay_groups, seed=seed)
883886
self.sample_key = sample_key
884887
self.prediction_key = prediction_key
@@ -1069,3 +1072,26 @@ def manual_scaled_dot_product_attention(
10691072
attn_weight = torch.softmax(attn_weight, dim=-1)
10701073
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
10711074
return attn_weight @ value
1075+
1076+
1077+
def _offset_seed_by_parallel_ranks(seed: int, device_mesh: DeviceMesh) -> int:
1078+
"""
1079+
Return a seed shifted by the TP/PP ranks so each TP/PP pair produces a distinct value.
1080+
"""
1081+
tp_rank = None
1082+
pp_rank = None
1083+
pp_degree = 1
1084+
1085+
if has_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP):
1086+
tp_rank = get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP)
1087+
if has_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP):
1088+
pp_rank = get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP)
1089+
pp_degree = get_parallel_degree(device_mesh=device_mesh, parallelism_methods=[ParallelismDegrees.PP])
1090+
1091+
if tp_rank is not None and pp_rank is not None:
1092+
return seed + tp_rank * pp_degree + pp_rank
1093+
if tp_rank is not None:
1094+
return seed + tp_rank
1095+
if pp_rank is not None:
1096+
return seed + pp_rank
1097+
return seed

0 commit comments

Comments
 (0)