1+ model :
2+ component_key : model
3+ variant_key : fsdp2_wrapped
4+ config :
5+ model :
6+ instance_key : gpt2_cp_model
7+ pass_type : BY_REFERENCE
8+ device_mesh :
9+ instance_key : device_mesh
10+ pass_type : BY_REFERENCE
11+ mixed_precision_settings :
12+ param_dtype : FP_32 # must be FP32 in the test to prevent rounding issues
13+ reduce_dtype : FP_32 # must be FP32 in the test to prevent rounding issues
14+ block_names : [GPT2Block]
15+
16+ gpt2_cp_model :
17+ component_key : model
18+ variant_key : gpt2_cp
19+ config :
20+ model :
21+ instance_key : initialized_model
22+ pass_type : BY_REFERENCE
23+ device_mesh :
24+ instance_key : device_mesh
25+ pass_type : BY_REFERENCE
26+
27+
28+ initialized_model :
29+ component_key : model
30+ variant_key : model_initialized
31+ config :
32+ model :
33+ instance_key : model_raw
34+ pass_type : BY_REFERENCE
35+ model_initializer :
36+ component_key : model_initialization
37+ variant_key : composed
38+ config :
39+ model_type : gpt2
40+ weight_init_type : scaled
41+ mean : 0.0
42+ std : 0.02
43+ num_layers : ${model_raw.config.n_layer}
44+
45+ model_raw :
46+ component_key : model
47+ variant_key : gpt2
48+ config :
49+ use_meta_device : false
50+ use_weight_tying : false
51+ sample_key : input_ids
52+ poe_type : NOPE
53+ sequence_length : 1024
54+ prediction_key : logits
55+ vocab_size : 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
56+ n_layer : 5
57+ n_head_q : 8
58+ n_head_kv : 4
59+ ffn_hidden : 256
60+ n_embd : 256
61+ dropout : 0.0
62+ bias : false # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
63+ attention_config :
64+ qkv_transforms :
65+ - type_hint : RotaryTransform
66+ config :
67+ n_embd : ${model_raw.config.n_embd}
68+ n_head : ${model_raw.config.n_head_q} # it has to be head_q here
69+ seq_length_dim : -2
70+ base_freq : 10000
71+ attention_implementation : pytorch_flash
72+ activation_type : swiglu
73+ attention_norm_config :
74+ norm_type : layer_norm
75+ config :
76+ normalized_shape : ${model_raw.config.n_embd}
77+ eps : 1e-5
78+ ffn_norm_config :
79+ norm_type : layer_norm
80+ config :
81+ normalized_shape : ${model_raw.config.n_embd}
82+ eps : 1e-5
83+ lm_head_norm_config :
84+ norm_type : layer_norm
85+ config :
86+ normalized_shape : ${model_raw.config.n_embd}
87+ eps : 1e-5
88+
89+ device_mesh :
90+ component_key : device_mesh
91+ variant_key : default
92+ config :
93+ device_type : cuda
94+ data_parallel_replicate_degree : 1
95+ data_parallel_shard_degree : -1
96+ context_parallel_degree : 2
97+ tensor_parallel_degree : 1
98+ world_size : 4
0 commit comments