A deep learning project that implements an image colorization system using Generative Adversarial Networks (GANs). This project can transform grayscale images into realistic colored images using a U-Net generator and PatchGAN discriminator architecture.
This project implements an image-to-image translation system that learns to colorize grayscale images. It uses a conditional GAN architecture with:
- Generator: U-Net architecture with skip connections and FiLM conditioning
- Discriminator: PatchGAN architecture for high-quality local detail preservation
- Training: WGAN-GP with additional L1 and mode-seeking losses
- U-Net Structure: Encoder-decoder architecture with skip connections
- FiLM Conditioning: Feature-wise linear modulation for conditional generation
- Upsampling: Multiple interpolation modes (bilinear, nearest, bicubic, transposed convolution)
- Residual Connections: Within upsampling blocks for better gradient flow
- ResNet34 Encoder: Pre-trained backbone for feature extraction
- PatchGAN70: 70x70 receptive field for local detail discrimination
- Spectral Normalization: Optional for training stability
- Group Normalization: Alternative to batch normalization
image_colorizer/
βββ loaders/ # Data loading utilities
β βββ data_loader.py # Tiny ImageNet dataset loader
β βββ __init__.py
βββ models/ # Neural network models
β βββ generator_model.py # U-Net generator implementation
β βββ discriminator.py # PatchGAN discriminator
β βββ __init__.py
βββ train/ # Training infrastructure
β βββ runner.py # Main training script
β βββ trainer.py # Training loop and logic
β βββ __init__.py
βββ utils/ # Utility functions
β βββ config.py # Configuration classes
β βββ kaggle_utils.py # Dataset download utilities
β βββ logging.py # TensorBoard logging
β βββ metrics.py # Evaluation metrics (PSNR, SSIM)
βββ playground.ipynb # Interactive notebook for experimentation
βββ requirements.txt # Python dependencies
βββ README.md # This file
- Python 3.8+
- PyTorch 2.0+
- CUDA-compatible GPU (recommended)
-
Clone the repository
git clone <repository-url> cd image_colorizer
-
Install dependencies
pip install -r requirements.txt
-
Download the dataset
python utils/kaggle_utils.py
Note: You'll need to configure Kaggle API credentials first:
- Install Kaggle CLI:
pip install kaggle - Download your API token from Kaggle Settings
- Place
kaggle.jsonin~/.kaggle/
- Install Kaggle CLI:
-
Configure training parameters in
train/runner.py:config = TrainingArgs( epochs=12, batch_size=16, image_size=128, device="cuda", # or "cpu" data_path="data_kaggle/data", # ... other parameters )
-
Start training:
python train/runner.py
-
Monitor training with TensorBoard:
tensorboard --logdir logs_big_images_final_v5_tmp
| Parameter | Description | Default |
|---|---|---|
epochs |
Number of training epochs | 12 |
batch_size |
Training batch size | 16 |
image_size |
Input image resolution | 128 |
device |
Training device | "cpu" |
generator_lr |
Generator learning rate | 0.0001 |
discriminator_lr |
Discriminator learning rate | 0.0001 |
lambda_gp |
Gradient penalty weight | 10 |
lambda_l1 |
L1 loss weight | 100 |
lambda_mode_seeking |
Mode-seeking loss weight | 1 |
critic_n |
Critic updates per generator update | 5 |
- Generator: U-Net with ResNet34 encoder, FiLM conditioning
- Discriminator: PatchGAN variants (70x70 or 34x34 receptive field)
- Upsampling: Multiple interpolation modes supported
- Normalization: BatchNorm, GroupNorm, or Spectral Normalization
Use playground.ipynb for experimentation and visualization:
# Load pre-trained models
generator = UNetGenerator(...)
generator.load_state_dict(torch.load('checkpoints_good/generator_epoch_11.pth'))
# Colorize a grayscale image
with torch.no_grad():
colored = generator(grayscale_image)- Adversarial Loss: WGAN-GP for stable training
- L1 Loss: Pixel-wise reconstruction loss
- Mode-Seeking Loss: Prevents mode collapse
- Gradient Penalty: Enforces Lipschitz constraint
- Critic Updates: 5 discriminator updates per generator update
- Learning Rate Scheduling: Cosine annealing for both networks
- Gradient Accumulation: 4 steps for effective larger batch sizes
- Validation: Regular evaluation with PSNR metric
- TensorBoard Logging: Loss curves, generated samples, metrics
- Checkpointing: Model saves every epoch
- Validation Metrics: PSNR
Extend TinyImageNetDataset in loaders/data_loader.py:
class CustomDataset(Dataset):
def __init__(self, root_dir, transform=None):
# Implement your dataset logic
pass
def __getitem__(self, idx):
# Return {'image': grayscale, 'target': colored}
pass- Generator: Adjust U-Net depth, add/remove FiLM layers
- Discriminator: Change receptive field size, normalization type
- Loss Functions: Modify loss weights or add custom losses
- Generator and discriminator checkpoints for epochs 0-11
- Trained on 128x128 images with Tiny ImageNet dataset
- PSNR: Peak Signal-to-Noise Ratio for image quality
Everything is logged via tensorboard!
- Fork the repository
- Create a feature branch
- Make your changes
- Add tests if applicable
- Submit a pull request
- U-Net: U-Net: Convolutional Networks for Biomedical Image Segmentation
- PatchGAN: Image-to-Image Translation with Conditional Adversarial Networks
- WGAN-GP: Improved Training of Wasserstein GANs
- FiLM: FiLM: Visual Reasoning with a General Conditioning Layer
- MSGAN: Mode Seeking Generative Adversarial Networks for Diverse Image Synthesis
This project is licensed under the MIT License - see the LICENSE file for details.
Note: This project is for research and educational purposes. For production use, consider additional optimizations and robust evaluation protocols.