Skip to content

Latest commit

Β 

History

History
88 lines (61 loc) Β· 2.72 KB

File metadata and controls

88 lines (61 loc) Β· 2.72 KB

CIFAR-10 CNN Classifier with PyTorch

A custom convolutional neural network (MyNet) built from scratch using PyTorch to classify images from the CIFAR-10 dataset. This project includes a complete training pipeline, data augmentation, model checkpointing, visualization of results, and a confusion matrix for detailed performance analysis. The model achieved 87% accuracy on the test dataset.


πŸ“Œ Features

  • Custom CNN (MyNet) with 6 convolutional layers, BatchNorm, and Dropout
  • Stratified train/validation split using StratifiedShuffleSplit
  • Advanced data augmentation: AutoAugment, RandomCrop, and ColorJitter
  • Training loop with real-time loss and accuracy tracking
  • Learning rate scheduling with StepLR
  • Final evaluation on the test dataset
  • Visualizations: training vs. validation loss, accuracy, and confusion matrix

πŸ—οΈ Model Architecture

  • Convolutional Layers: 6 layers with increasing depth (8 β†’ 256 channels)
  • Pooling: MaxPooling applied after specific layers
  • Batch Normalization: Applied to the first 5 convolutional layers
  • Activation Function: ReLU used throughout the network
  • Classifier: Global Average Pooling β†’ Fully Connected (512) β†’ Fully Connected (10)
  • Dropout: 50% dropout before the final fully connected layer

πŸ§ͺ Data Pipeline

  • Dataset: CIFAR-10, automatically downloaded using torchvision
  • Splits:
    • Training: 80%
    • Validation: 20% (stratified)
    • Test: Official CIFAR-10 test set
  • Augmentation:
    • AutoAugment policy for CIFAR-10
    • Random cropping and flipping
    • Color jitter for brightness, contrast, and saturation

πŸš€ How to Run

  1. Install dependencies:

    pip install torch torchvision matplotlib numpy scikit-learn tqdm
  2. Train the model:

    python3 MyNet.py

πŸ“Š Visual Output

πŸ”» Loss Plot

Loss Plot

πŸ”Ί Accuracy Plot

Accuracy Plot

πŸ“‰ Confusion Matrix

Confusion Matrix


πŸ“ Project Structure

β”œβ”€β”€ MyNet.py                     # Main script containing model, training, etc.
β”œβ”€β”€ data/                         # CIFAR-10 dataset (auto-downloaded)
β”œβ”€β”€ figures/                      # Directory for saved plots
β”‚   β”œβ”€β”€ loss_fig.png              # Loss plot
β”‚   β”œβ”€β”€ accuracy_fig.png          # Accuracy plot
β”‚   └── confusion_matrix.png      # Confusion matrix
β”œβ”€β”€ cifar_mynet_final.pt          # Final trained model
β”œβ”€β”€ cifar_mynet_epoch_*.pt        # Model checkpoints
└── README.md                     # Project documentation