@@ -38,6 +38,7 @@ def __init__(
3838 hidden_dim : int = 512 ,
3939 out_dim : int = 2 ,
4040 pad_id : int = 0 ,
41+ freeze : bool = False ,
4142 ) -> None :
4243 """Initialize a standard classifier model."""
4344 self .n_layers = n_layers
@@ -46,7 +47,7 @@ def __init__(
4647 self .classes_ : list [str ] = [str (x ) for x in range (out_dim )]
4748 # multilabel flag will be set based on the type of `y` passed to fit.
4849 self .multilabel : bool = False
49- super ().__init__ (vectors = vectors , out_dim = out_dim , pad_id = pad_id , tokenizer = tokenizer )
50+ super ().__init__ (vectors = vectors , out_dim = out_dim , pad_id = pad_id , tokenizer = tokenizer , freeze = freeze )
5051
5152 @property
5253 def classes (self ) -> np .ndarray :
@@ -124,7 +125,7 @@ def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_siz
124125 pred .append (torch .softmax (logits , dim = 1 ).cpu ().numpy ())
125126 return np .concatenate (pred , axis = 0 )
126127
127- def fit (
128+ def fit ( # noqa: C901 # Complexity is bad.
128129 self ,
129130 X : list [str ],
130131 y : LabelType ,
@@ -165,7 +166,7 @@ def fit(
165166 :param device: The device to train on. If this is "auto", the device is chosen automatically.
166167 :param X_val: The texts to be used for validation.
167168 :param y_val: The labels to be used for validation.
168- :param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
169+ :param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
169170 have the same length as the number of classes.
170171 :return: The fitted model.
171172 :raises ValueError: If either X_val or y_val are provided, but not both.
@@ -201,7 +202,7 @@ def fit(
201202 base_number = int (min (max (1 , (len (train_texts ) / 30 ) // 32 ), 16 ))
202203 batch_size = int (base_number * 32 )
203204 logger .info ("Batch size automatically set to %d." , batch_size )
204-
205+
205206 if class_weight is not None :
206207 if len (class_weight ) != len (self .classes_ ):
207208 raise ValueError ("class_weight must have the same length as the number of classes." )
@@ -300,7 +301,9 @@ def _initialize(self, y: LabelType) -> None:
300301 self .classes_ = classes
301302 self .out_dim = len (self .classes_ ) # Update output dimension
302303 self .head = self .construct_head ()
303- self .embeddings = nn .Embedding .from_pretrained (self .vectors .clone (), freeze = False , padding_idx = self .pad_id )
304+ self .embeddings = nn .Embedding .from_pretrained (
305+ self .vectors .clone (), freeze = self .freeze , padding_idx = self .pad_id
306+ )
304307 self .w = self .construct_weights ()
305308 self .train ()
306309
@@ -383,12 +386,18 @@ def to_pipeline(self) -> StaticModelPipeline:
383386
384387
385388class _ClassifierLightningModule (pl .LightningModule ):
386- def __init__ (self , model : StaticModelForClassification , learning_rate : float , class_weight : torch .Tensor | None = None ) -> None :
389+ def __init__ (
390+ self , model : StaticModelForClassification , learning_rate : float , class_weight : torch .Tensor | None = None
391+ ) -> None :
387392 """Initialize the LightningModule."""
388393 super ().__init__ ()
389394 self .model = model
390395 self .learning_rate = learning_rate
391- self .loss_function = nn .CrossEntropyLoss (weight = class_weight ) if not model .multilabel else nn .BCEWithLogitsLoss (pos_weight = class_weight )
396+ self .loss_function = (
397+ nn .CrossEntropyLoss (weight = class_weight )
398+ if not model .multilabel
399+ else nn .BCEWithLogitsLoss (pos_weight = class_weight )
400+ )
392401
393402 def forward (self , x : torch .Tensor ) -> torch .Tensor :
394403 """Simple forward pass."""
0 commit comments