|
19 | 19 | RMSLayerNormConfig, |
20 | 20 | ) |
21 | 21 | from modalities.models.model import ActivationType, NNModel, SwiGLU |
22 | | -from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method |
| 22 | +from modalities.running_env.fsdp.device_mesh import ( |
| 23 | + ParallelismDegrees, |
| 24 | + get_parallel_degree, |
| 25 | + get_parallel_rank, |
| 26 | + has_parallelism_method, |
| 27 | +) |
23 | 28 | from modalities.util import parse_enum_by_name |
24 | 29 |
|
25 | 30 | try: |
@@ -874,11 +879,9 @@ def __init__( |
874 | 879 | "embedding": [".wte", ".wpe"], |
875 | 880 | "layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"], |
876 | 881 | } |
877 | | - # Set different random seed for each TP rank to ensure diversity |
878 | | - if seed is not None and has_parallelism_method( |
879 | | - device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP |
880 | | - ): |
881 | | - seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP) |
| 882 | + # Set different random seed for each TP and PP rank to ensure diversity |
| 883 | + if seed is not None and device_mesh is not None: |
| 884 | + seed = _offset_seed_by_parallel_ranks(seed=seed, device_mesh=device_mesh) |
882 | 885 | super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) |
883 | 886 | self.sample_key = sample_key |
884 | 887 | self.prediction_key = prediction_key |
@@ -1069,3 +1072,26 @@ def manual_scaled_dot_product_attention( |
1069 | 1072 | attn_weight = torch.softmax(attn_weight, dim=-1) |
1070 | 1073 | attn_weight = torch.dropout(attn_weight, dropout_p, train=True) |
1071 | 1074 | return attn_weight @ value |
| 1075 | + |
| 1076 | + |
| 1077 | +def _offset_seed_by_parallel_ranks(seed: int, device_mesh: DeviceMesh) -> int: |
| 1078 | + """ |
| 1079 | + Return a seed shifted by the TP/PP ranks so each TP/PP pair produces a distinct value. |
| 1080 | + """ |
| 1081 | + tp_rank = None |
| 1082 | + pp_rank = None |
| 1083 | + pp_degree = 1 |
| 1084 | + |
| 1085 | + if has_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP): |
| 1086 | + tp_rank = get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP) |
| 1087 | + if has_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP): |
| 1088 | + pp_rank = get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) |
| 1089 | + pp_degree = get_parallel_degree(device_mesh=device_mesh, parallelism_methods=[ParallelismDegrees.PP]) |
| 1090 | + |
| 1091 | + if tp_rank is not None and pp_rank is not None: |
| 1092 | + return seed + tp_rank * pp_degree + pp_rank |
| 1093 | + if tp_rank is not None: |
| 1094 | + return seed + tp_rank |
| 1095 | + if pp_rank is not None: |
| 1096 | + return seed + pp_rank |
| 1097 | + return seed |
0 commit comments