@@ -181,6 +181,34 @@ def validation_step(self, batch: Any, batch_idx: int): # type: ignore[override]
181181
182182 return {'val_loss' : total_loss }
183183
184+ def test_step (self , batch : Any , batch_idx : int ): # type: ignore[override]
185+ result = self ._forward_pass_with_teacher (batch )
186+
187+ total_loss = result ['loss' ]
188+ recon_loss = result ['recon_loss' ]
189+ mmd_loss = result ['mmd_loss' ]
190+ sw_loss = result ['sw_loss' ]
191+ cls_loss = result ['cls_loss' ]
192+ cls_acc = result ['cls_acc' ]
193+
194+ self .log ('test_loss' , total_loss , on_epoch = True , prog_bar = True , sync_dist = True )
195+ self .log ('test/recon_loss' , recon_loss , on_epoch = True , sync_dist = True )
196+ self .log ('test/mmd_loss' , mmd_loss , on_epoch = True , sync_dist = True )
197+ self .log ('test/sw_loss' , sw_loss , on_epoch = True , sync_dist = True )
198+ self .log ('test/cls_loss' , cls_loss , on_epoch = True , sync_dist = True )
199+ self .log ('test/cls_acc' , cls_acc , on_epoch = True , sync_dist = True )
200+
201+ if 'masked_mae' in result :
202+ self .log ('test/masked_mae' , result ['masked_mae' ], on_epoch = True , sync_dist = True )
203+ if 'masked_corr' in result :
204+ self .log ('test/masked_corr' , result ['masked_corr' ], on_epoch = True , sync_dist = True )
205+ if 'sw_predict' in result :
206+ self .log ('test/sw_predict' , result ['sw_predict' ], on_epoch = True , sync_dist = True )
207+ if 'mask_rate' in result :
208+ self .log ('test/mask_rate' , result ['mask_rate' ], on_epoch = True , sync_dist = True )
209+
210+ return {'test_loss' : total_loss }
211+
184212 def configure_optimizers (self ): # type: ignore[override]
185213 optimizer = torch .optim .AdamW (self .model .parameters (), lr = self .learning_rate , weight_decay = self .weight_decay )
186214 config : Dict [str , Any ] = {'optimizer' : optimizer }
0 commit comments