Skip to content

Commit 6fc92ad

Browse files
authored
Merge pull request #13 from ArcInstitute/dhruv/step-fix
Add test step for consistency with model test
2 parents 755bb32 + fb66610 commit 6fc92ad

3 files changed

Lines changed: 41 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "arc-stack"
7-
version = "0.1.1"
7+
version = "0.1.2"
88
description = "Stack is a single-cell foundation model that enables in-context learning at inference time."
99
readme = "README.md"
1010
license = { file = "LICENSE" }

src/stack/finetune/lightning.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,34 @@ def validation_step(self, batch: Any, batch_idx: int): # type: ignore[override]
181181

182182
return {'val_loss': total_loss}
183183

184+
def test_step(self, batch: Any, batch_idx: int): # type: ignore[override]
185+
result = self._forward_pass_with_teacher(batch)
186+
187+
total_loss = result['loss']
188+
recon_loss = result['recon_loss']
189+
mmd_loss = result['mmd_loss']
190+
sw_loss = result['sw_loss']
191+
cls_loss = result['cls_loss']
192+
cls_acc = result['cls_acc']
193+
194+
self.log('test_loss', total_loss, on_epoch=True, prog_bar=True, sync_dist=True)
195+
self.log('test/recon_loss', recon_loss, on_epoch=True, sync_dist=True)
196+
self.log('test/mmd_loss', mmd_loss, on_epoch=True, sync_dist=True)
197+
self.log('test/sw_loss', sw_loss, on_epoch=True, sync_dist=True)
198+
self.log('test/cls_loss', cls_loss, on_epoch=True, sync_dist=True)
199+
self.log('test/cls_acc', cls_acc, on_epoch=True, sync_dist=True)
200+
201+
if 'masked_mae' in result:
202+
self.log('test/masked_mae', result['masked_mae'], on_epoch=True, sync_dist=True)
203+
if 'masked_corr' in result:
204+
self.log('test/masked_corr', result['masked_corr'], on_epoch=True, sync_dist=True)
205+
if 'sw_predict' in result:
206+
self.log('test/sw_predict', result['sw_predict'], on_epoch=True, sync_dist=True)
207+
if 'mask_rate' in result:
208+
self.log('test/mask_rate', result['mask_rate'], on_epoch=True, sync_dist=True)
209+
210+
return {'test_loss': total_loss}
211+
184212
def configure_optimizers(self): # type: ignore[override]
185213
optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
186214
config: Dict[str, Any] = {'optimizer': optimizer}

src/stack/training/lightning.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ def validation_step(self, batch, batch_idx): # type: ignore[override]
5555
self.log("val/mask_rate", result["mask_rate"], on_epoch=True, sync_dist=True)
5656
return {"val_loss": total_loss}
5757

58+
def test_step(self, batch, batch_idx): # type: ignore[override]
59+
features, _ = batch
60+
result = self.model(features, return_loss=True)
61+
total_loss = result["loss"]
62+
self.log("test_loss", total_loss, on_epoch=True, prog_bar=True, sync_dist=True)
63+
self.log("test/recon_loss", result["recon_loss"], on_epoch=True, sync_dist=True)
64+
self.log("test/sw_loss", result["sw_loss"], on_epoch=True, sync_dist=True)
65+
self.log("test/masked_mae", result["masked_mae"], on_epoch=True, sync_dist=True)
66+
self.log("test/masked_corr", result["masked_corr"], on_epoch=True, sync_dist=True)
67+
self.log("test/mask_rate", result["mask_rate"], on_epoch=True, sync_dist=True)
68+
return {"test_loss": total_loss}
69+
5870
def configure_optimizers(self): # type: ignore[override]
5971
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
6072
scheduler_dict = configure_scheduler(optimizer, self.scheduler_config)

0 commit comments

Comments
 (0)