Skip to content

Latest commit

 

History

History
70 lines (52 loc) · 3.17 KB

File metadata and controls

70 lines (52 loc) · 3.17 KB

Image-Classification-Resnet34

Introduction:

By using model Resnet-18 which was built by myself, I used Pytorch to classify image of 10 animal types.

Here is my pytorch implementation of the model described in the RESNET paper.

Note: Because this is the model I built, instead of this is trained by COCO dataset, I trained by dataset of 10 animal classes, which I present later.

Dataset:

Statistics of datasets I used for experiments. These datasets could be download from link

Classes Train samples Test samples
butterfly 1902 210
cat 1508 160
chicken 2790 308
cow 1684 182
dog 4373 490
elephant 1306 140
horse 2357 266
sheep 1638 182
spider 4345 476
squirre 1680 182

Settings:

For optimizer and learning rate, I use:

  • SGD optimizer with different learning rates (0.01 in most cases).

Additionally, in the my model, I will set up 100 epochs (using early stopping if after 5 epochs, if there is not greater score, it will stop train proccess) ,which is seen as a loop over batch_size: 16

Training:

If you want to train a model with default parameters, you could run:

python train_animal.py If you want to adjust your preference parameters, here is some option you can choose:

Parameters Abbreviation Default Description
--batch-size -b 16 Select suitable batch size
--data-path -p '../../' directory contains dataset
--lr 1e-2 modify learning rate
--epochs -e 100 modify epoch number
--log-path -l tensorboard directory contains metrics visualization
--checkpoint -sc tensorboard/animals/epoch_last.pt directory which saves the train model

For example: python train.py -p dataset_location --log-patch directory-name

How to view tensorboard:

 tensorboard --logdir directory/contain/tensorboard/

Evaluating:

You could preview my evaluating process throughout Tenserboard. After 22 epochs, I reached:

  • The highest accuracy score: 0.7 in epoch 19 Accuracy score

  • Confusion matrix: Confusion matrix

  • Cross entropy loss in 2 processes: Loss plot

Testing:

After buiding models, I started to implement the testing process by file test_cnn.py

Regarding to the above confusion matrix, this model had good prediction in 4 classes: spider, chicken, horse and butterfly.

In contrast, it scored the worst prediction in class: Cat.