Skip to content

Commit 3cc6a72

Browse files
committed
test: Add test configs for context parallelism
1 parent de01588 commit 3cc6a72

2 files changed

Lines changed: 185 additions & 0 deletions

File tree

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
model:
2+
component_key: model
3+
variant_key: fsdp2_wrapped
4+
config:
5+
model:
6+
instance_key: gpt2_cp_model
7+
pass_type: BY_REFERENCE
8+
device_mesh:
9+
instance_key: device_mesh
10+
pass_type: BY_REFERENCE
11+
mixed_precision_settings:
12+
param_dtype: FP_32 # must be FP32 in the test to prevent rounding issues
13+
reduce_dtype: FP_32 # must be FP32 in the test to prevent rounding issues
14+
block_names: [GPT2Block]
15+
16+
gpt2_cp_model:
17+
component_key: model
18+
variant_key: gpt2_cp
19+
config:
20+
model:
21+
instance_key: initialized_model
22+
pass_type: BY_REFERENCE
23+
device_mesh:
24+
instance_key: device_mesh
25+
pass_type: BY_REFERENCE
26+
27+
28+
initialized_model:
29+
component_key: model
30+
variant_key: model_initialized
31+
config:
32+
model:
33+
instance_key: model_raw
34+
pass_type: BY_REFERENCE
35+
model_initializer:
36+
component_key: model_initialization
37+
variant_key: composed
38+
config:
39+
model_type: gpt2
40+
weight_init_type: scaled
41+
mean: 0.0
42+
std: 0.02
43+
num_layers: ${model_raw.config.n_layer}
44+
45+
model_raw:
46+
component_key: model
47+
variant_key: gpt2
48+
config:
49+
use_meta_device: false
50+
use_weight_tying: false
51+
sample_key: input_ids
52+
poe_type: NOPE
53+
sequence_length: 1024
54+
prediction_key: logits
55+
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
56+
n_layer: 5
57+
n_head_q: 8
58+
n_head_kv: 4
59+
ffn_hidden: 256
60+
n_embd: 256
61+
dropout: 0.0
62+
bias: false # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
63+
attention_config:
64+
qkv_transforms:
65+
- type_hint: RotaryTransform
66+
config:
67+
n_embd: ${model_raw.config.n_embd}
68+
n_head: ${model_raw.config.n_head_q} #it has to be head_q here
69+
seq_length_dim: -2
70+
base_freq: 10000
71+
attention_implementation: pytorch_flash
72+
activation_type: swiglu
73+
attention_norm_config:
74+
norm_type: layer_norm
75+
config:
76+
normalized_shape: ${model_raw.config.n_embd}
77+
eps: 1e-5
78+
ffn_norm_config:
79+
norm_type: layer_norm
80+
config:
81+
normalized_shape: ${model_raw.config.n_embd}
82+
eps: 1e-5
83+
lm_head_norm_config:
84+
norm_type: layer_norm
85+
config:
86+
normalized_shape: ${model_raw.config.n_embd}
87+
eps: 1e-5
88+
89+
device_mesh:
90+
component_key: device_mesh
91+
variant_key: default
92+
config:
93+
device_type: cuda
94+
data_parallel_replicate_degree: 1
95+
data_parallel_shard_degree: -1
96+
context_parallel_degree: 2
97+
tensor_parallel_degree: 1
98+
world_size: 4
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
model:
2+
component_key: model
3+
variant_key: fsdp2_wrapped
4+
config:
5+
model:
6+
instance_key: initialized_model
7+
pass_type: BY_REFERENCE
8+
device_mesh:
9+
instance_key: device_mesh
10+
pass_type: BY_REFERENCE
11+
mixed_precision_settings:
12+
param_dtype: FP_32 # must be FP32 in the test to prevent rounding issues
13+
reduce_dtype: FP_32 # must be FP32 in the test to prevent rounding issues
14+
block_names: [GPT2Block]
15+
16+
17+
initialized_model:
18+
component_key: model
19+
variant_key: model_initialized
20+
config:
21+
model:
22+
instance_key: model_raw
23+
pass_type: BY_REFERENCE
24+
model_initializer:
25+
component_key: model_initialization
26+
variant_key: composed
27+
config:
28+
model_type: gpt2
29+
weight_init_type: scaled
30+
mean: 0.0
31+
std: 0.02
32+
num_layers: ${model_raw.config.n_layer}
33+
34+
model_raw:
35+
component_key: model
36+
variant_key: gpt2
37+
config:
38+
use_meta_device: false
39+
use_weight_tying: false
40+
sample_key: input_ids
41+
poe_type: NOPE
42+
sequence_length: 1024
43+
prediction_key: logits
44+
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
45+
n_layer: 5
46+
n_head_q: 8
47+
n_head_kv: 4
48+
ffn_hidden: 256
49+
n_embd: 256
50+
dropout: 0.0
51+
bias: false # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
52+
attention_config:
53+
qkv_transforms:
54+
- type_hint: RotaryTransform
55+
config:
56+
n_embd: ${model_raw.config.n_embd}
57+
n_head: ${model_raw.config.n_head_q} #it has to be head_q here
58+
seq_length_dim: -2
59+
base_freq: 10000
60+
attention_implementation: pytorch_flash
61+
activation_type: swiglu
62+
attention_norm_config:
63+
norm_type: layer_norm
64+
config:
65+
normalized_shape: ${model_raw.config.n_embd}
66+
eps: 1e-5
67+
ffn_norm_config:
68+
norm_type: layer_norm
69+
config:
70+
normalized_shape: ${model_raw.config.n_embd}
71+
eps: 1e-5
72+
lm_head_norm_config:
73+
norm_type: layer_norm
74+
config:
75+
normalized_shape: ${model_raw.config.n_embd}
76+
eps: 1e-5
77+
78+
device_mesh:
79+
component_key: device_mesh
80+
variant_key: default
81+
config:
82+
device_type: cuda
83+
data_parallel_replicate_degree: 1
84+
data_parallel_shard_degree: -1
85+
context_parallel_degree: 1
86+
tensor_parallel_degree: 1
87+
world_size: 4

0 commit comments

Comments
 (0)