Skip to content

Commit a391191

Browse files
committed
simplify
1 parent fae4b11 commit a391191

1 file changed

Lines changed: 9 additions & 16 deletions

File tree

model2vec/train/base.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,25 +101,20 @@ def construct_weights(self) -> nn.Parameter:
101101
return nn.Parameter(w, requires_grad=not self.freeze_weights)
102102

103103
def construct_head(self) -> nn.Sequential:
104-
"""Constructs a simple classifier head."""
105-
return self.construct_mlp(self.n_layers, self.embed_dim, self.hidden_dim, self.out_dim)
106-
107-
@staticmethod
108-
def construct_mlp(n_layers: int, embed_dim: int, hidden_dim: int, out_dim: int) -> nn.Sequential:
109104
"""Constructs a simple classifier head."""
110105
modules: list[nn.Module] = []
111-
if n_layers == 0:
112-
modules.append(nn.Linear(embed_dim, out_dim))
106+
if self.n_layers == 0:
107+
modules.append(nn.Linear(self.embed_dim, self.out_dim))
113108
else:
114109
# If we have a hidden layer, we should first project to hidden_dim
115110
modules = [
116-
nn.Linear(embed_dim, hidden_dim),
111+
nn.Linear(self.embed_dim, self.hidden_dim),
117112
nn.ReLU(),
118113
]
119-
for _ in range(n_layers - 1):
120-
modules.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
114+
for _ in range(self.n_layers - 1):
115+
modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()])
121116
# We always have a layer mapping from hidden to out.
122-
modules.append(nn.Linear(hidden_dim, out_dim))
117+
modules.append(nn.Linear(self.hidden_dim, self.out_dim))
123118

124119
linear_modules = [module for module in modules if isinstance(module, nn.Linear)]
125120
if linear_modules:
@@ -254,11 +249,9 @@ def to_static_model(self) -> StaticModel:
254249
"""Convert the model to a static model."""
255250
with torch.no_grad():
256251
emb = self.embeddings.weight
257-
emb = emb.detach().cpu().numpy()
258-
if self.w is not None:
259-
w = torch.sigmoid(self.w).detach().cpu().numpy()
260-
else:
261-
w = np.ones(len(emb))
252+
emb = emb.cpu().numpy()
253+
w = torch.sigmoid(self.w).cpu().numpy()
254+
262255
# If the weights and emb are the same length, the model was not quantized before training.
263256
if len(w) == len(emb):
264257
emb = emb * w[:, None]

0 commit comments

Comments
 (0)