-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel.py
More file actions
186 lines (151 loc) · 6.28 KB
/
Copy pathmodel.py
File metadata and controls
186 lines (151 loc) · 6.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""Embedding fingerprint model and tokenizer.
A small transformer classifier that identifies which embedding model
produced a given vector, using digit-level tokenization of float values.
"""
import io
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class NumericTokenizer:
"""Tokenize float vectors into digit-level token sequences.
Converts each float to "%.4f" format, then maps characters to tokens:
0-9 -> digits, 10 -> minus, 11 -> dot, 12 -> SEP, 13 -> CLS, 14 -> PAD
"""
PAD = 14
CLS = 13
SEP = 12
VOCAB_SIZE = 15
def __init__(self, precision=4):
self.precision = precision
self.lut = np.full(128, 255, dtype=np.uint8)
for d in range(10):
self.lut[ord("0") + d] = d
self.lut[ord("-")] = 10
self.lut[ord(".")] = 11
self.lut[ord(" ")] = self.SEP
def encode_batch(self, embeddings):
"""Encode (N, D) float32 array -> list of np.uint8 token arrays."""
N = embeddings.shape[0]
fmt = f"%.{self.precision}f"
buf = io.BytesIO()
np.savetxt(buf, embeddings, fmt=fmt, delimiter=" ", newline="\n")
buf.seek(0)
lines = buf.read().split(b"\n")
results = []
for i in range(N):
byte_arr = np.frombuffer(lines[i], dtype=np.uint8)
tokens = self.lut[byte_arr]
tokens = tokens[tokens != 255]
result = np.empty(len(tokens) + 1, dtype=np.uint8)
result[0] = self.CLS
result[1:] = tokens
results.append(result)
return results
def encode(self, vec):
"""Encode a single vector -> list of int tokens."""
arr = np.array(vec, dtype=np.float32).reshape(1, -1)
return self.encode_batch(arr)[0].tolist()
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
class RotaryEmbedding(nn.Module):
def __init__(self, dim, base=10000.0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x, seq_len):
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
return q * cos + rotate_half(q) * sin, k * cos + rotate_half(k) * sin
class SwiGLU(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class Attention(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.o_proj = nn.Linear(dim, dim, bias=False)
self.rotary = RotaryEmbedding(self.head_dim)
def forward(self, x, attention_mask=None):
B, L, D = x.shape
q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary(x, L)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
if attention_mask is not None:
sdpa_mask = ~(attention_mask == 0).unsqueeze(1).unsqueeze(2).expand(-1, -1, L, -1)
else:
sdpa_mask = None
out = F.scaled_dot_product_attention(q, k, v, attn_mask=sdpa_mask)
out = out.transpose(1, 2).contiguous().view(B, L, D)
return self.o_proj(out)
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, ff_hidden):
super().__init__()
self.attn_norm = RMSNorm(dim)
self.attn = Attention(dim, num_heads)
self.ff_norm = RMSNorm(dim)
self.ff = SwiGLU(dim, ff_hidden)
def forward(self, x, attention_mask=None):
x = x + self.attn(self.attn_norm(x), attention_mask)
x = x + self.ff(self.ff_norm(x))
return x
class EmbeddingFingerprinter(nn.Module):
"""Transformer classifier for embedding model identification.
Args:
num_classes: number of embedding model classes
vocab_size: tokenizer vocabulary size (default 15)
dim: model dimension (default 128)
num_layers: number of transformer layers (default 4)
num_heads: number of attention heads (default 4)
ff_multiplier: feedforward hidden dim multiplier (default 4)
"""
def __init__(self, num_classes, vocab_size=15, dim=128, num_layers=4,
num_heads=4, ff_multiplier=4):
super().__init__()
self.dim = dim
self.num_layers = num_layers
self.num_heads = num_heads
self.token_emb = nn.Embedding(vocab_size, dim)
ff_hidden = int(dim * ff_multiplier * 2 / 3)
ff_hidden = ((ff_hidden + 7) // 8) * 8
self.layers = nn.ModuleList(
[TransformerBlock(dim, num_heads, ff_hidden) for _ in range(num_layers)]
)
self.final_norm = RMSNorm(dim)
self.classifier = nn.Linear(dim, num_classes)
self._init_weights()
def _init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, input_ids, attention_mask=None):
x = self.token_emb(input_ids)
for layer in self.layers:
x = layer(x, attention_mask)
x = self.final_norm(x)
return self.classifier(x[:, 0, :])