What will we do ?

  • We will show an example of creating a multi label classification model using fastai2.

  • We will compare that to it's single label counterpart.

    You can try the bear classifier app here:Binder

Implementing multi label classification

In this post, we will train a bear classifier that tells us what kinds of bears it sees in an image, or whether it does not see any. In our case it's either black bear, grizzly or teddy !

1. Get the data

To get the data (images) we'll use flickr and use an open source tool to access the api.

You will have to restart your runtime after installing the requirements.

# Install the tool as recommeded by it's documentation
!git clone https://github.com/ultralytics/flickr_scraper
%cd flickr_scraper
!pip install -U -r requirements.txt

Now we will download the images for each class, here we chose 170 images for each which is just an arbitrary number you can can download as many as the flickr api allows.

  • - -search is the keywords that you want to search.

  • - -n is the number of images (here 170 just to make the training a bit faster).

  • - -download is to, well download the images

The images will be saved in images/whatever_keyword_you_gave

# Downloading grizzly images
!python3 /content/flickr_scraper/flickr_scraper.py --search 'grizzly' --n 170 --download

# Downloading black bear images 
!python3 /content/flickr_scraper/flickr_scraper.py --search 'black bear' --n 170 --download

# Downloading teddy images
!python3 /content/flickr_scraper/flickr_scraper.py --search 'teddy' --n 170 --download

Verify that all our images can be opened (delete them otherwise), so we don't encounter errors further down the road.

# Get the paths of the images
fns = get_image_files(path)

# Verify that there is no corrupted ones
fail = verify_images(fns)
print('Number of corrupted images: ', len(fail))

# If any corrupted image found delete them
if len(fail) > 0 : fail.map(Path.unlink)
Number of corrupted images:  0

Prepare our data for training (i.e. label it, split it, do standard transformations ... )

If you never used fastai2 before and have no idea of what Datablock or Dataloader mean, i would recommed my other post Quick prototyping with fastai2 (Car classifier example), that gives a more gentle introduction.

# Create the datablock
bears_multi = DataBlock(
    
    # Set the type of input and output (images, multiple categories respectively)
    blocks=(ImageBlock, MultiCategoryBlock), 

    # How to get the images
    get_items=get_image_files,

    # Split to train and valid 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),

    # How to get the labels
    get_y=parent_label_multi,

    # Do the transformations
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms())

# Create the Dataloader
dls_multi = bears_multi.dataloaders(path)

Check that our data looks OK.

# Show a batch of images from the dataset
dls_multi.show_batch()

2. Train the model

# Create Model
learn_multi_firt = cnn_learner(dls_multi, resnet18, metrics=accuracy_multi)

# Train Model
learn_multi_firt.fine_tune(4)
epoch train_loss valid_loss accuracy_multi time
0 0.781898 0.245826 0.900000 00:32
epoch train_loss valid_loss accuracy_multi time
0 0.194182 0.106430 0.970000 00:32
1 0.136712 0.059210 0.980000 00:33
2 0.095133 0.032058 0.986667 00:33
3 0.072287 0.026714 0.990000 00:32

WAW 99% accuracy, let's check that ourselves and give an image that he had never seen before.

We will give it an image of a kid's toy, and he should not recognize it as any of the classes he trained on.

Image.open('/content/test/test_data_bear_classifier_multi/toy2.jpg').to_thumb(240)
learn_multi_firt.predict('/content/test/test_data_bear_classifier_multi/toy2.jpg')[0]
(#1) ['teddys']

Well that's not a Teddy ! Let's try a picture of a person.

Image.open('/content/person2.jpg').to_thumb(240)
learn_multi_firt.predict('/content/person2.jpg')[0]
(#1) ['teddys']

Well that's not a teddy either !

This can be explained beacause perhaps our model learned features that differ from one class to another in the training data but do not apply outside that context (e.g. it learned that if it sees some colorful or white background image then it's most likely a teddy)

3. Get more data ?

Now to to make our classifier learn some better (usefull) features, we will add a generic class (that we will call other), where we put some general images and images that the previous model got the wrong guess for (it will not contain any black bear, grizzly, teddy images).

So let's do that !

# Download the images described above
!python3 /content/flickr_scraper/flickr_scraper.py --search 'toys' --n 170 --download
! python3 /content/flickr_scraper/flickr_scraper.py --search 'portrait white background' --n 170 --download
!python3 /content/flickr_scraper/flickr_scraper.py --search 'images' --n 170 --download

# Create the directory for the 'other' class
!mkdir images/other

# Move the images downloaded to that folder
!mv images/toys/* images/other
!mv images/images/* images/other
!mv images/portrait_white_background/* images/other

# Remove the folders because they are empty now
!rm -r images/toys
!rm -r images/images
!rm -r images/portrait_white_background

And load our new data and retrain a new model with it.

# Same steps as we did before
bears_multi = DataBlock(
    blocks=(ImageBlock, MultiCategoryBlock), 
    get_items=get_image_files,
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label_multi,
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms())

dls_multi = bears_multi.dataloaders(path)

learn_multi = cnn_learner(dls_multi, resnet18, metrics=accuracy_multi)
learn_multi.fine_tune(4)
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth

epoch train_loss valid_loss accuracy_multi time
0 0.747898 0.337043 0.852723 01:10
epoch train_loss valid_loss accuracy_multi time
0 0.305860 0.182119 0.943069 01:10
1 0.229317 0.118199 0.967822 01:10
2 0.169207 0.076744 0.978960 01:09
3 0.132857 0.063394 0.982673 01:08

98 % accuracy it's a bit lower but it's still pretty accurate. Now let's pass the real test (we'll be using the same images as before).

print('Actual: Toy (other)\nPrediction: ', learn_multi.predict('/content/test/test_data_bear_classifier_multi/toy2.jpg')[0])
Actual: Toy (other)
Prediction:  ['other']
print('Actual: Person (other)\nPrediction: ', learn_multi.predict('/content/person2.jpg')[0])
Actual: Person (other)
Prediction:  ['other']

Well that's more like it !

Now let's try it for what it was originally trained for (i.e. detecting bears).

Image.open('/content/test/test_data_bear_classifier_multi/teddy.webp').to_thumb(240)
learn_multi.predict('/content/test/test_data_bear_classifier_multi/teddy.webp')[0]
(#1) ['teddys']
Image.open('/content/test/test_data_bear_classifier_multi/both.jpg').to_thumb(240)
learn_multi.predict('/content/test/test_data_bear_classifier_multi/both.jpg')[0]
(#2) ['black_bear','grizzly']

Compare it with a single label model

Now let's use the same dataset but to train a single label classification model, and compare the results.

Single label means it outputs one single result (e.g. if it's an image containing a grizzly and black bear, it will choose only one of them, the one it is more confident about)

# Same steps as before
bears = DataBlock(
    blocks=(ImageBlock, CategoryBlock), 
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms())

dls = bears.dataloaders(path)

learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.fine_tune(4)
epoch train_loss valid_loss accuracy time
0 1.388999 0.194680 0.940594 01:10
epoch train_loss valid_loss accuracy time
0 0.113717 0.079140 0.970297 01:09
1 0.081960 0.087613 0.970297 01:09
2 0.072686 0.062682 0.990099 01:09
3 0.054821 0.057809 0.990099 01:09

99 % accuracy compared to 98 % that's better then our multi label model, well let's see how it does on the same test data.

print('Actual: Toy(other)\nPrediction: ', learn.predict('/content/test/test_data_bear_classifier_multi/toy2.jpg')[0])
Actual: Toy(other)
Prediction:  other
print('Actual: Person(other)\nPrediction: ', learn.predict('/content/person2.jpg')[0])
Actual: Person(other)
Prediction:  other
print('Actual: Teddy\nPrediction: ', learn.predict('/content/test/test_data_bear_classifier_multi/teddy.webp')[0])
Actual: Teddy
Prediction:  teddys
print('Actual: Black + grizzly bear\nPrediction: ', learn.predict('/content/test/test_data_bear_classifier_multi/both.jpg')[0])
Actual: Black + grizzly bear
Prediction:  grizzly

Well it passes all the tests except for the last one, that's because as we introduced it earlier, it's a single label model (i.e. outputs one single result). Here it is more confident of it being a grizzly.

As you can see despite being more accurate, it has the downside that it might not give us all the information we need.