Skip to content

Commit c5aa610

Browse files
committed
feat: add classifier freezing
1 parent 3c09887 commit c5aa610

2 files changed

Lines changed: 23 additions & 11 deletions

File tree

model2vec/train/base.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616

1717

1818
class FinetunableStaticModel(nn.Module):
19-
def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0) -> None:
19+
def __init__(
20+
self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0, freeze: bool = False
21+
) -> None:
2022
"""
2123
Initialize a trainable StaticModel from a StaticModel.
2224
2325
:param vectors: The embeddings of the staticmodel.
2426
:param tokenizer: The tokenizer.
2527
:param out_dim: The output dimension of the head.
2628
:param pad_id: The padding id. This is set to 0 in almost all model2vec models
29+
:param freeze: Whether to freeze the embeddings. This should be set to False in most cases.
2730
"""
2831
super().__init__()
2932
self.pad_id = pad_id
@@ -37,8 +40,8 @@ def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int
3740
f"Your vectors are {dtype} precision, converting to to torch.float32 to avoid compatibility issues."
3841
)
3942
self.vectors = vectors.float()
40-
41-
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
43+
self.freeze = freeze
44+
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=self.freeze, padding_idx=pad_id)
4245
self.head = self.construct_head()
4346
self.w = self.construct_weights()
4447
self.tokenizer = tokenizer
@@ -157,7 +160,7 @@ def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor,
157160
"""Collate function."""
158161
texts, targets = zip(*batch)
159162

160-
tensors = [torch.LongTensor(x) for x in texts]
163+
tensors: list[torch.Tensor] = [torch.LongTensor(x) for x in texts]
161164
padded = pad_sequence(tensors, batch_first=True, padding_value=0)
162165

163166
return padded, torch.stack(targets)

model2vec/train/classifier.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
hidden_dim: int = 512,
3939
out_dim: int = 2,
4040
pad_id: int = 0,
41+
freeze: bool = False,
4142
) -> None:
4243
"""Initialize a standard classifier model."""
4344
self.n_layers = n_layers
@@ -46,7 +47,7 @@ def __init__(
4647
self.classes_: list[str] = [str(x) for x in range(out_dim)]
4748
# multilabel flag will be set based on the type of `y` passed to fit.
4849
self.multilabel: bool = False
49-
super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer)
50+
super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer, freeze=freeze)
5051

5152
@property
5253
def classes(self) -> np.ndarray:
@@ -124,7 +125,7 @@ def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_siz
124125
pred.append(torch.softmax(logits, dim=1).cpu().numpy())
125126
return np.concatenate(pred, axis=0)
126127

127-
def fit(
128+
def fit( # noqa: C901 # Complexity is bad.
128129
self,
129130
X: list[str],
130131
y: LabelType,
@@ -165,7 +166,7 @@ def fit(
165166
:param device: The device to train on. If this is "auto", the device is chosen automatically.
166167
:param X_val: The texts to be used for validation.
167168
:param y_val: The labels to be used for validation.
168-
:param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
169+
:param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
169170
have the same length as the number of classes.
170171
:return: The fitted model.
171172
:raises ValueError: If either X_val or y_val are provided, but not both.
@@ -201,7 +202,7 @@ def fit(
201202
base_number = int(min(max(1, (len(train_texts) / 30) // 32), 16))
202203
batch_size = int(base_number * 32)
203204
logger.info("Batch size automatically set to %d.", batch_size)
204-
205+
205206
if class_weight is not None:
206207
if len(class_weight) != len(self.classes_):
207208
raise ValueError("class_weight must have the same length as the number of classes.")
@@ -300,7 +301,9 @@ def _initialize(self, y: LabelType) -> None:
300301
self.classes_ = classes
301302
self.out_dim = len(self.classes_) # Update output dimension
302303
self.head = self.construct_head()
303-
self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=self.pad_id)
304+
self.embeddings = nn.Embedding.from_pretrained(
305+
self.vectors.clone(), freeze=self.freeze, padding_idx=self.pad_id
306+
)
304307
self.w = self.construct_weights()
305308
self.train()
306309

@@ -383,12 +386,18 @@ def to_pipeline(self) -> StaticModelPipeline:
383386

384387

385388
class _ClassifierLightningModule(pl.LightningModule):
386-
def __init__(self, model: StaticModelForClassification, learning_rate: float, class_weight: torch.Tensor | None = None) -> None:
389+
def __init__(
390+
self, model: StaticModelForClassification, learning_rate: float, class_weight: torch.Tensor | None = None
391+
) -> None:
387392
"""Initialize the LightningModule."""
388393
super().__init__()
389394
self.model = model
390395
self.learning_rate = learning_rate
391-
self.loss_function = nn.CrossEntropyLoss(weight=class_weight) if not model.multilabel else nn.BCEWithLogitsLoss(pos_weight=class_weight)
396+
self.loss_function = (
397+
nn.CrossEntropyLoss(weight=class_weight)
398+
if not model.multilabel
399+
else nn.BCEWithLogitsLoss(pos_weight=class_weight)
400+
)
392401

393402
def forward(self, x: torch.Tensor) -> torch.Tensor:
394403
"""Simple forward pass."""

0 commit comments

Comments
 (0)