Skip to content

LukaDarsalia/image_colorizer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

8 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Image Colorization with GANs

A deep learning project that implements an image colorization system using Generative Adversarial Networks (GANs). This project can transform grayscale images into realistic colored images using a U-Net generator and PatchGAN discriminator architecture.

🎯 Project Overview

This project implements an image-to-image translation system that learns to colorize grayscale images. It uses a conditional GAN architecture with:

  • Generator: U-Net architecture with skip connections and FiLM conditioning
  • Discriminator: PatchGAN architecture for high-quality local detail preservation
  • Training: WGAN-GP with additional L1 and mode-seeking losses

πŸ—οΈ Architecture

Generator (UNetGenerator)

  • U-Net Structure: Encoder-decoder architecture with skip connections
  • FiLM Conditioning: Feature-wise linear modulation for conditional generation
  • Upsampling: Multiple interpolation modes (bilinear, nearest, bicubic, transposed convolution)
  • Residual Connections: Within upsampling blocks for better gradient flow
  • ResNet34 Encoder: Pre-trained backbone for feature extraction

Discriminator (PatchGAN)

  • PatchGAN70: 70x70 receptive field for local detail discrimination
  • Spectral Normalization: Optional for training stability
  • Group Normalization: Alternative to batch normalization

πŸ“ Project Structure

image_colorizer/
β”œβ”€β”€ loaders/                   # Data loading utilities
β”‚   β”œβ”€β”€ data_loader.py        # Tiny ImageNet dataset loader
β”‚   └── __init__.py
β”œβ”€β”€ models/                    # Neural network models
β”‚   β”œβ”€β”€ generator_model.py    # U-Net generator implementation
β”‚   β”œβ”€β”€ discriminator.py      # PatchGAN discriminator
β”‚   └── __init__.py
β”œβ”€β”€ train/                     # Training infrastructure
β”‚   β”œβ”€β”€ runner.py             # Main training script
β”‚   β”œβ”€β”€ trainer.py            # Training loop and logic
β”‚   └── __init__.py
β”œβ”€β”€ utils/                     # Utility functions
β”‚   β”œβ”€β”€ config.py             # Configuration classes
β”‚   β”œβ”€β”€ kaggle_utils.py       # Dataset download utilities
β”‚   β”œβ”€β”€ logging.py            # TensorBoard logging
β”‚   └── metrics.py            # Evaluation metrics (PSNR, SSIM)
β”œβ”€β”€ playground.ipynb          # Interactive notebook for experimentation
β”œβ”€β”€ requirements.txt           # Python dependencies
└── README.md                 # This file

πŸš€ Quick Start

Prerequisites

  • Python 3.8+
  • PyTorch 2.0+
  • CUDA-compatible GPU (recommended)

Installation

  1. Clone the repository

    git clone <repository-url>
    cd image_colorizer
  2. Install dependencies

    pip install -r requirements.txt
  3. Download the dataset

    python utils/kaggle_utils.py

    Note: You'll need to configure Kaggle API credentials first:

    • Install Kaggle CLI: pip install kaggle
    • Download your API token from Kaggle Settings
    • Place kaggle.json in ~/.kaggle/

Training

  1. Configure training parameters in train/runner.py:

    config = TrainingArgs(
        epochs=12,
        batch_size=16,
        image_size=128,
        device="cuda",  # or "cpu"
        data_path="data_kaggle/data",
        # ... other parameters
    )
  2. Start training:

    python train/runner.py
  3. Monitor training with TensorBoard:

    tensorboard --logdir logs_big_images_final_v5_tmp

βš™οΈ Configuration

Training Arguments

Parameter Description Default
epochs Number of training epochs 12
batch_size Training batch size 16
image_size Input image resolution 128
device Training device "cpu"
generator_lr Generator learning rate 0.0001
discriminator_lr Discriminator learning rate 0.0001
lambda_gp Gradient penalty weight 10
lambda_l1 L1 loss weight 100
lambda_mode_seeking Mode-seeking loss weight 1
critic_n Critic updates per generator update 5

Model Architecture Options

  • Generator: U-Net with ResNet34 encoder, FiLM conditioning
  • Discriminator: PatchGAN variants (70x70 or 34x34 receptive field)
  • Upsampling: Multiple interpolation modes supported
  • Normalization: BatchNorm, GroupNorm, or Spectral Normalization

🎨 Usage Examples

Interactive Notebook

Use playground.ipynb for experimentation and visualization:

# Load pre-trained models
generator = UNetGenerator(...)
generator.load_state_dict(torch.load('checkpoints_good/generator_epoch_11.pth'))

# Colorize a grayscale image
with torch.no_grad():
    colored = generator(grayscale_image)

πŸ“Š Training Details

Loss Functions

  • Adversarial Loss: WGAN-GP for stable training
  • L1 Loss: Pixel-wise reconstruction loss
  • Mode-Seeking Loss: Prevents mode collapse
  • Gradient Penalty: Enforces Lipschitz constraint

Training Strategy

  • Critic Updates: 5 discriminator updates per generator update
  • Learning Rate Scheduling: Cosine annealing for both networks
  • Gradient Accumulation: 4 steps for effective larger batch sizes
  • Validation: Regular evaluation with PSNR metric

Monitoring

  • TensorBoard Logging: Loss curves, generated samples, metrics
  • Checkpointing: Model saves every epoch
  • Validation Metrics: PSNR

πŸ”§ Customization

Adding New Datasets

Extend TinyImageNetDataset in loaders/data_loader.py:

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        # Implement your dataset logic
        pass
    
    def __getitem__(self, idx):
        # Return {'image': grayscale, 'target': colored}
        pass

Modifying Architecture

  • Generator: Adjust U-Net depth, add/remove FiLM layers
  • Discriminator: Change receptive field size, normalization type
  • Loss Functions: Modify loss weights or add custom losses

πŸ“ˆ Performance

Pre-trained Models

  • Generator and discriminator checkpoints for epochs 0-11
  • Trained on 128x128 images with Tiny ImageNet dataset

Evaluation Metrics

  • PSNR: Peak Signal-to-Noise Ratio for image quality

Debuging

Everything is logged via tensorboard!

🀝 Contributing

  1. Fork the repository
  2. Create a feature branch
  3. Make your changes
  4. Add tests if applicable
  5. Submit a pull request

πŸ“š References

πŸ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.


Note: This project is for research and educational purposes. For production use, consider additional optimizations and robust evaluation protocols.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors