Predict fruit image types by retraining an existing image classifier

In today’s short tutorial we are going to try and classify two types of fruits viz apples and bananas retraining an existing image classifier model viz Resnet model. Let’s cover the steps briefly:

  1. We are going to use ImageAI library to train our model. So first step is to setup ImageAI library locally on your laptop if it contains a decent Nvidia GPU or on google colab if you don’t have an integrated GPU on your laptop. If you are running the classifier locally then you need to install and setup essential libraries and dependencies required for running ImageAI as specified in the official documentation here.
    In the case of google colab, all the required libraries come preinstalled in the colab machine instance.
  2. We need to download jupyter notebook uploaded on GitHub here: Github
  3. For running the notebook on google colab first create an account here: Colab.
     Next, create a new python3 notebook in the options provided here

4. Now go to ‘File’ menu at the top and select upload notebook option. Go to FruitImagesClassifier repo folder you downloaded in step 3 and select the ImageClassifierShared.ipynb file to open the notebook in google colab.
Alternatively, if you would like to run the notebook locally then follow the steps provided here: Jupyter Notebook Official Docs

Now we need to follow the steps in colab notebook to set up our model to classify images of fruits which are explained in this article in detail further.

The steps we follow in the jupyter notebook are:

First, we download and install imageAI library using pip install command:

pip3 install

Next, we create the required directory structure in our machine by running mkdir command:

mkdir -p data/fruits/{train/{apple,banana},test/{apple,banana},valid}

The -p flag allows us to pass a JSON object to mkdir command and generate the required folder structure in one go.

Next, we install tree and subversion modules which are required for subsequent steps in the project

apt-get install tree
tree data/
apt-get install subversion

We need a good amount of fruit images of apples and bananas to train our model with sufficient accuracy which is hard to find. Luckily google dataset search came to my rescue. I am using fruits images dataset whose details are provided in research paper here: ResearchGate and it is available on Github here.

Since the repository contains many kinds of fruits and I needed only images of two types of fruits to keep things simple, I used subversion module to selectively download images from specific folders in the repository using svn export command.

cd /content/data/fruits/train/apple/
svn export
mv 'Apple Red 1'/* ./
rm -rf 'Apple Red 1'/

I moved the images from temporary folder ‘Apple Red 1’ to apple folder in the next two steps. After training the model I realized that we would need some more real-world images for proper image classification. Fortunately, pixabay apis came in handy since it allows us to batch download images from their website which has tons of royalty free images shared by the generous community. The code snippet to download images is as follows:

import requests
import urllib.request
for x in range(1, 3):
response = requests.get("{pageno}".format(pageno=x))
imagesArray = response.json()["hits"]
for imageObject in imagesArray:
print("getting file"+str(imageObject["id"])+".jpg")
urllib.request.urlretrieve(imageObject["previewURL"], str(imageObject["id"])+".jpg")

We are using requests library from python to make API request to pixabay with query term ‘apple fruit’ as explained in their documentation here. We are getting images from the first two search result pages due which there is a for loop running at the beginning. The response from each request is as follows:

totalHits: 500,
hits: [
largeImageURL: "",
pageURL: "",
webformatURL: "",
previewURL: "",


Next, we run an iteration through each of these image objects returned from api response and download each image individually from the link in ‘previewURL’ key using urllib.request module from python which allows us to download the images locally.

Next, we plot the image using the matplotlib library to see one of the downloaded images

import os
files = os.listdir(f'./')[:7]
import matplotlib.pyplot as plt
img = plt.imread(f'./{files[6]}')

Which gives us an output as:-

We move one of the images to the validation folder to test the accuracy of our model later.

We follow similar steps to download test dataset images for apples and both training and testing images for bananas.

Next, we import the custom training module from ImageAI and set base model type as Resnet.

from imageai.Prediction.Custom import ModelTraining
model_trainer = ModelTraining()

We specify data directory for our model to get input dataset images from and run the trainModel function to retrain our model using new images.

model_trainer.trainModel(num_objects=2, num_experiments=10, enhance_data=False, batch_size=32, show_network_summary=True)

The parameters specified are as followed:-

  • number_objects: This refers to the number of different types of fruits in our dataset.
    • num_experiments: This is the number of times the model trainer will study all the images in the dataset in order to achieve maximum accuracy.
    • Enhance_data (Optional): This is to tell the model trainer to create modified copies of the images in the dataset to ensure maximum accuracy is achieved.
    • batch_size: This refers to the number of images the set that the model trainer will study at once until it has studied all the images in the dataset.
    • Show_network_summary (Optional): This is to show the structure of the model type you are using to train the model.

The newly trained models are available in models folder out of which we take up the one with maximum accuracy for making predictions on real-world data.

from imageai.Prediction.Custom import CustomImagePrediction
prediction = CustomImagePrediction()

In order to check that our model is working fine, we first try to make predictions on images we moved previously to validation folder.

predictions, probabilities = prediction.predictImage("./data/fruits/valid/banana.jpg", result_count=2)
for eachPrediction, eachProbability in zip(predictions, probabilities):
print(eachPrediction , " : " , eachProbability)
predictions, probabilities = prediction.predictImage("./data/fruits/valid/apple.jpg", result_count=2)
for eachPrediction, eachProbability in zip(predictions, probabilities):
print(eachPrediction , " : " , eachProbability)

This will print out probabilities for each fruit class made by our classifier. In the test run, the probabilities are:-

For the image of apple:

apple  :  99.95651841163635
banana : 0.04348267102614045

For the image of the banana:

banana  :  99.95608925819397
apple : 0.043913364061154425

Next, let us dowload a new image from pexels and test our model

wget "" -O /content/data/fruits/valid/applereal.jpg
predictions, probabilities = prediction.predictImage("./data/fruits/valid/applereal.jpg", result_count=2)
for eachPrediction, eachProbability in zip(predictions, probabilities):
print(eachPrediction , " : " , eachProbability)

The result in the test run for the image of the apple is:

apple  :  94.83857154846191
banana : 5.161432176828384

This validates the accuracy of our model with a decent score.

That concludes our today’s short tutorial on retraining an existing image classifier to predict fruit images. Happy Coding 😀