@@ -101,25 +101,20 @@ def construct_weights(self) -> nn.Parameter:
101101 return nn .Parameter (w , requires_grad = not self .freeze_weights )
102102
103103 def construct_head (self ) -> nn .Sequential :
104- """Constructs a simple classifier head."""
105- return self .construct_mlp (self .n_layers , self .embed_dim , self .hidden_dim , self .out_dim )
106-
107- @staticmethod
108- def construct_mlp (n_layers : int , embed_dim : int , hidden_dim : int , out_dim : int ) -> nn .Sequential :
109104 """Constructs a simple classifier head."""
110105 modules : list [nn .Module ] = []
111- if n_layers == 0 :
112- modules .append (nn .Linear (embed_dim , out_dim ))
106+ if self . n_layers == 0 :
107+ modules .append (nn .Linear (self . embed_dim , self . out_dim ))
113108 else :
114109 # If we have a hidden layer, we should first project to hidden_dim
115110 modules = [
116- nn .Linear (embed_dim , hidden_dim ),
111+ nn .Linear (self . embed_dim , self . hidden_dim ),
117112 nn .ReLU (),
118113 ]
119- for _ in range (n_layers - 1 ):
120- modules .extend ([nn .Linear (hidden_dim , hidden_dim ), nn .ReLU ()])
114+ for _ in range (self . n_layers - 1 ):
115+ modules .extend ([nn .Linear (self . hidden_dim , self . hidden_dim ), nn .ReLU ()])
121116 # We always have a layer mapping from hidden to out.
122- modules .append (nn .Linear (hidden_dim , out_dim ))
117+ modules .append (nn .Linear (self . hidden_dim , self . out_dim ))
123118
124119 linear_modules = [module for module in modules if isinstance (module , nn .Linear )]
125120 if linear_modules :
@@ -254,11 +249,9 @@ def to_static_model(self) -> StaticModel:
254249 """Convert the model to a static model."""
255250 with torch .no_grad ():
256251 emb = self .embeddings .weight
257- emb = emb .detach ().cpu ().numpy ()
258- if self .w is not None :
259- w = torch .sigmoid (self .w ).detach ().cpu ().numpy ()
260- else :
261- w = np .ones (len (emb ))
252+ emb = emb .cpu ().numpy ()
253+ w = torch .sigmoid (self .w ).cpu ().numpy ()
254+
262255 # If the weights and emb are the same length, the model was not quantized before training.
263256 if len (w ) == len (emb ):
264257 emb = emb * w [:, None ]
0 commit comments