Skip to content

timurci/supcon-autoencoder

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

56 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SupCon Autoencoder

A PyTorch library that combines Supervised Contrastive Learning with Autoencoder architectures. This hybrid approach trains autoencoders that not only reconstruct input data but also organize the latent space so that samples from the same class cluster together.

Python Version from PEP 621 TOML GitHub License

Overview

SupCon Autoencoder integrates two complementary objectives:

  1. Supervised Contrastive Loss — Pulls embeddings from the same class closer while pushing different classes apart in latent space
  2. Reconstruction Loss — Ensures the autoencoder can faithfully reconstruct its input

Hybrid Loss Formula:

$$\mathcal{L} = \lambda \cdot \mathcal{L}_{\text{SupCon}} + (1 - \lambda) \cdot \mathcal{L}_{\text{reconstruction}}$$

Using the Loss Function Independently

You can use the loss functions without the built-in trainer. Just match the simple interface:

# SupConLoss: takes embeddings and labels
supcon_loss = SupConLoss(temperature=0.5)
loss = supcon_loss(embeddings, labels)

# HybridLoss: takes embeddings, labels, original, reconstructed
hybrid_loss = HybridLoss(supcon_loss, nn.MSELoss(), lambda_=0.5)
loss = hybrid_loss(embeddings, labels, original, reconstructed)

Built-in Trainer Requirements (Optional)

If you use the built-in Trainer, your model and data must follow these protocols:

Model — Must expose encoder and decoder properties:

class MyAutoencoder(nn.Module):
    @property
    def encoder(self) -> nn.Module: ...

    @property
    def decoder(self) -> nn.Module: ...

Data — Must return a dictionary with features and labels:

sample = {
    "features": torch.Tensor,  # Input data
    "labels": torch.Tensor,    # Class labels
}

Quick Start

from supcon_autoencoder.core.loss import HybridLoss, SupConLoss
from supcon_autoencoder.core.training import Trainer

loss_fn = HybridLoss(
    sup_con_loss=SupConLoss(temperature=0.5),
    reconstruction_loss=nn.MSELoss(),
    lambda_=0.5
)

trainer = Trainer(model=model, optimizer=optimizer, loss_fn=loss_fn)
history = trainer.train(train_loader=train_loader, device=device, epochs=50)

Installation

# To add this package to your project
uv add git+https://github.com/timurci/supcon-autoencoder.git
# To run examples
uv sync

Examples

  • Fashion-MNIST: examples/fashion_mnist/
  • Gene Expression: examples/gene_expression/

References

This implementation is based on:

License

MIT License

About

Supervised Contrastive Learning (SupCon) in Autoencoder architecture | PyTorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages