Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,14 @@ Model Arguments



- **qk_layernorm_over_heads**: bool

Default = False

Apply QK normalization over [\*, N, H] instead of [\*, H].



- **layernorm_epsilon**: float

Default = 1e-05
Expand Down
9 changes: 6 additions & 3 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,18 @@ def __init__(
self.pos_emb = neox_args.pos_emb

self.use_qk_layernorm = neox_args.use_qk_layernorm
self.qk_layernorm_over_heads = neox_args.qk_layernorm_over_heads
if self.use_qk_layernorm:
norm, eps = get_norm(neox_args)
self.qk_layernorm = norm(
norm_dims = (
[
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
],
eps=eps,
]
if self.qk_layernorm_over_heads
else [self.hidden_size_per_attention_head]
)
self.qk_layernorm = norm(norm_dims, eps=eps)

self.sliding_window_width = neox_args.sliding_window_width

Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ class NeoXArgsModel(NeoXArgsTemplate):
Use QK Normalization
"""

qk_layernorm_over_heads: bool = False
"""
Apply QK normalization over [*, N, H] instead of [*, H].
"""

layernorm_epsilon: float = 1.0e-5
"""
Layer norm epsilon.
Expand Down
Loading