Multi label classfication
1. Get the data
To get the data (images) we'll use flickr and use an open source tool to access the api.
- Tool used : Flickr image-scraping
You will have to restart your runtime after installing the requirements.
# Install the tool as recommeded by it's documentation
!git clone https://github.com/ultralytics/flickr_scraper
%cd flickr_scraper
!pip install -U -r requirements.txt
Now we will download the images for each class, here we chose 170 images for each which is just an arbitrary number you can can download as many as the flickr api allows.
-
- -search is the keywords that you want to search.
-
- -n is the number of images (here 170 just to make the training a bit faster).
-
- -download is to, well download the images
The images will be saved in images/whatever_keyword_you_gave
# Downloading grizzly images
!python3 /content/flickr_scraper/flickr_scraper.py --search 'grizzly' --n 170 --download
# Downloading black bear images
!python3 /content/flickr_scraper/flickr_scraper.py --search 'black bear' --n 170 --download
# Downloading teddy images
!python3 /content/flickr_scraper/flickr_scraper.py --search 'teddy' --n 170 --download
Verify that all our images can be opened (delete them otherwise), so we don't encounter errors further down the road.
# Get the paths of the images
fns = get_image_files(path)
# Verify that there is no corrupted ones
fail = verify_images(fns)
print('Number of corrupted images: ', len(fail))
# If any corrupted image found delete them
if len(fail) > 0 : fail.map(Path.unlink)
Prepare our data for training (i.e. label it, split it, do standard transformations ... )
If you never used fastai2 before and have no idea of what Datablock or Dataloader mean, i would recommed my other post Quick prototyping with fastai2 (Car classifier example), that gives a more gentle introduction.
# Create the datablock
bears_multi = DataBlock(
# Set the type of input and output (images, multiple categories respectively)
blocks=(ImageBlock, MultiCategoryBlock),
# How to get the images
get_items=get_image_files,
# Split to train and valid
splitter=RandomSplitter(valid_pct=0.2, seed=42),
# How to get the labels
get_y=parent_label_multi,
# Do the transformations
item_tfms=RandomResizedCrop(224, min_scale=0.5),
batch_tfms=aug_transforms())
# Create the Dataloader
dls_multi = bears_multi.dataloaders(path)
Check that our data looks OK.
# Show a batch of images from the dataset
dls_multi.show_batch()
# Create Model
learn_multi_firt = cnn_learner(dls_multi, resnet18, metrics=accuracy_multi)
# Train Model
learn_multi_firt.fine_tune(4)
WAW 99% accuracy, let's check that ourselves and give an image that he had never seen before.
We will give it an image of a kid's toy, and he should not recognize it as any of the classes he trained on.
Image.open('/content/test/test_data_bear_classifier_multi/toy2.jpg').to_thumb(240)
learn_multi_firt.predict('/content/test/test_data_bear_classifier_multi/toy2.jpg')[0]
Well that's not a Teddy ! Let's try a picture of a person.
Image.open('/content/person2.jpg').to_thumb(240)
learn_multi_firt.predict('/content/person2.jpg')[0]
Well that's not a teddy either !
This can be explained beacause perhaps our model learned features that differ from one class to another in the training data but do not apply outside that context (e.g. it learned that if it sees some colorful or white background image then it's most likely a teddy)
3. Get more data ?
Now to to make our classifier learn some better (usefull) features, we will add a generic class (that we will call other), where we put some general images and images that the previous model got the wrong guess for (it will not contain any black bear, grizzly, teddy images).
So let's do that !
# Download the images described above
!python3 /content/flickr_scraper/flickr_scraper.py --search 'toys' --n 170 --download
! python3 /content/flickr_scraper/flickr_scraper.py --search 'portrait white background' --n 170 --download
!python3 /content/flickr_scraper/flickr_scraper.py --search 'images' --n 170 --download
# Create the directory for the 'other' class
!mkdir images/other
# Move the images downloaded to that folder
!mv images/toys/* images/other
!mv images/images/* images/other
!mv images/portrait_white_background/* images/other
# Remove the folders because they are empty now
!rm -r images/toys
!rm -r images/images
!rm -r images/portrait_white_background
And load our new data and retrain a new model with it.
# Same steps as we did before
bears_multi = DataBlock(
blocks=(ImageBlock, MultiCategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2, seed=42),
get_y=parent_label_multi,
item_tfms=RandomResizedCrop(224, min_scale=0.5),
batch_tfms=aug_transforms())
dls_multi = bears_multi.dataloaders(path)
learn_multi = cnn_learner(dls_multi, resnet18, metrics=accuracy_multi)
learn_multi.fine_tune(4)
98 % accuracy it's a bit lower but it's still pretty accurate. Now let's pass the real test (we'll be using the same images as before).
print('Actual: Toy (other)\nPrediction: ', learn_multi.predict('/content/test/test_data_bear_classifier_multi/toy2.jpg')[0])
print('Actual: Person (other)\nPrediction: ', learn_multi.predict('/content/person2.jpg')[0])
Well that's more like it !
Now let's try it for what it was originally trained for (i.e. detecting bears).
Image.open('/content/test/test_data_bear_classifier_multi/teddy.webp').to_thumb(240)
learn_multi.predict('/content/test/test_data_bear_classifier_multi/teddy.webp')[0]
Image.open('/content/test/test_data_bear_classifier_multi/both.jpg').to_thumb(240)
learn_multi.predict('/content/test/test_data_bear_classifier_multi/both.jpg')[0]
Compare it with a single label model
Now let's use the same dataset but to train a single label classification model, and compare the results.
Single label means it outputs one single result (e.g. if it's an image containing a grizzly and black bear, it will choose only one of them, the one it is more confident about)
# Same steps as before
bears = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2, seed=42),
get_y=parent_label,
item_tfms=RandomResizedCrop(224, min_scale=0.5),
batch_tfms=aug_transforms())
dls = bears.dataloaders(path)
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.fine_tune(4)
99 % accuracy compared to 98 % that's better then our multi label model, well let's see how it does on the same test data.
print('Actual: Toy(other)\nPrediction: ', learn.predict('/content/test/test_data_bear_classifier_multi/toy2.jpg')[0])
print('Actual: Person(other)\nPrediction: ', learn.predict('/content/person2.jpg')[0])
print('Actual: Teddy\nPrediction: ', learn.predict('/content/test/test_data_bear_classifier_multi/teddy.webp')[0])
print('Actual: Black + grizzly bear\nPrediction: ', learn.predict('/content/test/test_data_bear_classifier_multi/both.jpg')[0])
Well it passes all the tests except for the last one, that's because as we introduced it earlier, it's a single label model (i.e. outputs one single result). Here it is more confident of it being a grizzly.
As you can see despite being more accurate, it has the downside that it might not give us all the information we need.