Skip to content

Commit e12db1a

Browse files
committed
chore: Place private methods below the public interface
1 parent b91762a commit e12db1a

1 file changed

Lines changed: 79 additions & 79 deletions

File tree

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 79 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,85 @@ def __init__(
211211

212212
self.reset_parameters()
213213

214+
def reset_parameters(self):
215+
# If previously initialized on or moved to a device, reuse that device.
216+
# Otherwise, use the default device of the current environment.
217+
device = self.inv_freq.device if hasattr(self, "inv_freq") and isinstance(self.inv_freq, torch.Tensor) else None
218+
219+
rope_type = self.rope_scaling.rope_type if self.rope_scaling is not None else "default"
220+
221+
if rope_type == "yarn":
222+
inv_freq, self.attention_scaling = self._compute_yarn_parameters(device=device)
223+
else:
224+
inv_freq = 1.0 / (
225+
self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model)
226+
)
227+
self.attention_scaling = 1.0
228+
229+
self.register_buffer("inv_freq", inv_freq)
230+
231+
self._seq_len_cached = None
232+
self._cos_cached = None
233+
self._sin_cached = None
234+
235+
def rotate_half(self, x: torch.Tensor):
236+
"""
237+
Rearrange tensor elements.
238+
239+
Args:
240+
x (torch.Tensor): The input tensor.
241+
242+
Returns:
243+
torch.Tensor: The output tensor.
244+
245+
"""
246+
x1, x2 = x.chunk(2, dim=-1)
247+
return torch.cat((-x2, x1), dim=-1)
248+
249+
def apply_rotary_pos_emb(self, x, cos, sin):
250+
"""
251+
Applies rotary positional embedding to the input tensor.
252+
253+
Args:
254+
x (torch.Tensor): Input tensor.
255+
cos (torch.Tensor): Cosine values for rotary positional embedding.
256+
sin (torch.Tensor): Sine values for rotary positional embedding.
257+
258+
Returns:
259+
torch.Tensor: Tensor after applying rotary positional embedding.
260+
"""
261+
# NOTE: This could probably be moved to Triton
262+
263+
# Handle a possible sequence length mismatch in between q and k
264+
cos = cos[:, :, : x.shape[self.seq_length_dim], :]
265+
sin = sin[:, :, : x.shape[self.seq_length_dim], :]
266+
267+
# the rotation is not really a rotation in higher dimensions,
268+
# It merely swaps and negates certain dimensions to make
269+
# the rotation below work
270+
return (x * cos) + (self.rotate_half(x) * sin)
271+
272+
def forward(
273+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
274+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
275+
"""
276+
Forward pass of the RotaryTransform module.
277+
278+
Args:
279+
q (torch.Tensor): Query tensor.
280+
k (torch.Tensor): Key tensor.
281+
v (torch.Tensor): Value tensor.
282+
283+
Returns:
284+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
285+
Tuple containing the modified query tensor, key tensor, and value tensor.
286+
"""
287+
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k)
288+
q = self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached)
289+
k = self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached)
290+
291+
return q, k, v
292+
214293
def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.Tensor, float]:
215294
"""Compute YaRN inverse frequencies and the attention scaling factor."""
216295
if not isinstance(self.rope_scaling, YarnRopeScalingConfig):
@@ -299,41 +378,6 @@ def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Te
299378

300379
return inv_freq, float(attention_factor)
301380

302-
def reset_parameters(self):
303-
# If previously initialized on or moved to a device, reuse that device.
304-
# Otherwise, use the default device of the current environment.
305-
device = self.inv_freq.device if hasattr(self, "inv_freq") and isinstance(self.inv_freq, torch.Tensor) else None
306-
307-
rope_type = self.rope_scaling.rope_type if self.rope_scaling is not None else "default"
308-
309-
if rope_type == "yarn":
310-
inv_freq, self.attention_scaling = self._compute_yarn_parameters(device=device)
311-
else:
312-
inv_freq = 1.0 / (
313-
self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model)
314-
)
315-
self.attention_scaling = 1.0
316-
317-
self.register_buffer("inv_freq", inv_freq)
318-
319-
self._seq_len_cached = None
320-
self._cos_cached = None
321-
self._sin_cached = None
322-
323-
def rotate_half(self, x: torch.Tensor):
324-
"""
325-
Rearrange tensor elements.
326-
327-
Args:
328-
x (torch.Tensor): The input tensor.
329-
330-
Returns:
331-
torch.Tensor: The output tensor.
332-
333-
"""
334-
x1, x2 = x.chunk(2, dim=-1)
335-
return torch.cat((-x2, x1), dim=-1)
336-
337381
def _update_cos_sin_tables(self, x):
338382
# Update the cosine and sine tables.
339383
seq_len = x.shape[self.seq_length_dim]
@@ -358,50 +402,6 @@ def _update_cos_sin_tables(self, x):
358402

359403
return self._cos_cached, self._sin_cached
360404

361-
def apply_rotary_pos_emb(self, x, cos, sin):
362-
"""
363-
Applies rotary positional embedding to the input tensor.
364-
365-
Args:
366-
x (torch.Tensor): Input tensor.
367-
cos (torch.Tensor): Cosine values for rotary positional embedding.
368-
sin (torch.Tensor): Sine values for rotary positional embedding.
369-
370-
Returns:
371-
torch.Tensor: Tensor after applying rotary positional embedding.
372-
"""
373-
# NOTE: This could probably be moved to Triton
374-
375-
# Handle a possible sequence length mismatch in between q and k
376-
cos = cos[:, :, : x.shape[self.seq_length_dim], :]
377-
sin = sin[:, :, : x.shape[self.seq_length_dim], :]
378-
379-
# the rotation is not really a rotation in higher dimensions,
380-
# It merely swaps and negates certain dimensions to make
381-
# the rotation below work
382-
return (x * cos) + (self.rotate_half(x) * sin)
383-
384-
def forward(
385-
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
386-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
387-
"""
388-
Forward pass of the RotaryTransform module.
389-
390-
Args:
391-
q (torch.Tensor): Query tensor.
392-
k (torch.Tensor): Key tensor.
393-
v (torch.Tensor): Value tensor.
394-
395-
Returns:
396-
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
397-
Tuple containing the modified query tensor, key tensor, and value tensor.
398-
"""
399-
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k)
400-
q = self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached)
401-
k = self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached)
402-
403-
return q, k, v
404-
405405

406406
class QueryKeyValueTransformType(Enum):
407407
"""

0 commit comments

Comments
 (0)