In the previous two posts, we learned how to use pre-trained models and how to extract features from them for training a model for a different task. In this tutorial, we will learn how to fine-tune a pre-trained model for a different task than it was originally trained for.
We will try to improve on the problem of classifying pumpkin, watermelon, and tomato discussed in the previous post. We will be using the same data for this tutorial.
What is Fine-tuning of a network
We have already explained the importance of using pre-trained networks in our previous article. Just to recap, when we train a network from scratch, we encounter the following two limitations :
- Huge data required – Since the network has millions of parameters, to get an optimal set of parameters, we need to have a lot of data.
- Huge computing power required – Even if we have a lot of data, training generally requires multiple iterations and it takes a toll on the computing resources.
The task of fine-tuning a network is to tweak the parameters of an already trained network so that it adapts to the new task at hand. As explained here, the initial layers learn very general features and as we go higher up the network, the layers tend to learn patterns more specific to the task it is being trained on. Thus, for fine-tuning, we want to keep the initial layers intact ( or freeze them ) and retrain the later layers for our task.
Thus, fine-tuning avoids both the limitations discussed above.
- The amount of data required for training is not much because of two reasons. First, we are not training the entire network. Second, the part that is being trained is not trained from scratch.
- Since the parameters that need to be updated is less, the amount of time needed will also be less.
Fine-tuning in Keras
Let us directly dive into the code without much ado. We will be using the same data which we used in the previous post. You can choose to use a larger dataset if you have a GPU as the training will take much longer if you do it on a CPU for a large dataset. We will use the VGG model for fine-tuning.
To easily follow along this tutorial, please download code by clicking on the button below. It’s FREE!
Load the pre-trained model
First, we will load a VGG model without the top layer ( which consists of fully connected layers ).
from tensorflow.keras.applications import vgg16 # Init the VGG model vgg_conv = vgg16.VGG16(weights='imagenet', include_top=False, input_shape=(image_size, image_size, 3))
Freeze the required layers
In Keras, each layer has a parameter called “trainable”. For freezing the weights of a particular layer, we should set this parameter to False, indicating that this layer should not be trained. That’s it! We go over each layer and select which layers we want to train.
# Freeze all the layers for layer in vgg_conv.layers[:]: layer.trainable = False # Check the trainable status of the individual layers for layer in vgg_conv.layers: print(layer, layer.trainable)

Create a new model
Now that we have set the trainable parameters of our base network, we would like to add a classifier on top of the convolutional base. We will simply add a fully connected layer followed by a softmax layer with 3 outputs. This is done as given below.
# Create the model model = Sequential() # Add the vgg convolutional base model model.add(vgg_conv) # Add new layers model.add(Flatten()) model.add(Dense(1024, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(3, activation='softmax')) # Show a summary of the model. Check the number of trainable parameters model.summary()
Setup the data generators
We have already separated the data into train and validation and kept it in the “train” and “validation” folders. We can use ImageDataGenerator available in Keras to read images in batches directly from these folders and optionally perform data augmentation. We will use two different data generators for train and validation folders.
# Load the normalized images train_datagen = ImageDataGenerator(rescale=1./255) validation_datagen = ImageDataGenerator(rescale=1./255) # Change the batchsize according to your system RAM train_batchsize = 100 val_batchsize = 10 # Data generator for training data train_generator = train_datagen.flow_from_directory( train_dir, target_size=(image_size, image_size), batch_size=train_batchsize, class_mode='categorical') # Data generator for validation data validation_generator = validation_datagen.flow_from_directory( validation_dir, target_size=(image_size, image_size), batch_size=val_batchsize, class_mode='categorical', shuffle=False)
Train the model
Till now, we have created the model and set up the data for training. So, we should proceed with the training and check out the performance. We will have to specify the optimizer and the learning rate and start training using the `model.fit()
` function. After the training is over, we will save the model.
# Configure the model for training model.compile(loss='categorical_crossentropy', optimizer=optimizers.RMSprop(lr=1e-4), metrics=['acc']) # Train the model history = model.fit( train_generator, steps_per_epoch= train_generator.samples/train_generator.batch_size, epochs=20, validation_data=validation_generator, validation_steps= validation_generator.samples/validation_generator.batch_size, verbose=1)
Check Performance
We obtained an accuracy of 90% with the transfer learning approach discussed in our previous blog. Here we are getting a much better accuracy of 98%.
Let us see the loss and accuracy curves using visualize_results(history)
function:
# Utility function for plotting of the model results def visualize_results(history): # Plot the accuracy and loss curves acc = history.history['acc'] val_acc = history.history['val_acc'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(len(acc)) plt.plot(epochs, acc, 'b', label='Training acc') plt.plot(epochs, val_acc, 'r', label='Validation acc') plt.title('Training and validation accuracy') plt.legend() plt.figure() plt.plot(epochs, loss, 'b', label='Training loss') plt.plot(epochs, val_loss, 'r', label='Validation loss') plt.title('Training and validation loss') plt.legend() plt.show() # Run the function to illustrate accuracy and loss visualize_results(history)
The output is:

Also, let us visually see the errors that we got. In order to avoid extra code duplications in different experiments we will introduce several utility functions: obtain_errors
and show_errors
:
# Utility function for obtaining of the errors def obtain_errors(val_generator, predictions): # Get the filenames from the generator fnames = validation_generator.filenames # Get the ground truth from generator ground_truth = validation_generator.classes # Get the dictionary of classes label2index = validation_generator.class_indices # Obtain the list of the classes idx2label = list(label2index.keys()) print("The list of classes: ", idx2label) # Get the class index predicted_classes = np.argmax(predictions, axis=1) errors = np.where(predicted_classes != ground_truth)[0] print("Number of errors = {}/{}".format(len(errors),validation_generator.samples)) return idx2label, errors, fnames # Utility function for visualization of the errors def show_errors(idx2label, errors, predictions, fnames): # Show the errors for i in range(len(errors)): pred_class = np.argmax(predictions[errors[i]]) pred_label = idx2label[pred_class] title = 'Original label:{}, Prediction :{}, confidence : {:.3f}'.format( fnames[errors[i]].split('/')[0], pred_label, predictions[errors[i]][pred_class]) original = load_img('{}/{}'.format(validation_dir,fnames[errors[i]])) plt.figure(figsize=[7,7]) plt.axis('off') plt.title(title) plt.imshow(original) plt.show()
Thus, the pipeline of error analysis is the following:
# Get the predictions from the model using the generator predictions = model.predict(validation_generator, steps=validation_generator.samples/validation_generator.batch_size,verbose=1) # Run the function to get the list of classes and errors idx2label, errors, fnames = obtain_errors(validation_generator, predictions) # Run the function to illustrate the error cases show_errors(idx2label, errors, predictions, fnames)
Let’s explore some output examples:
Experiments
We have done 3 experiments to see the effect of fine-tuning and data augmentation. We kept the validation set same as the previous post i.e. 50 images per class.
- Freezing all layers and learning a classifier on top of it – similar to transfer learning. The number of errors was 15 out of 150 images which is similar to what we got in the previous post.
- Training the last 3 convolutional layers – We got 9 errors out of 150.
- Training the last 3 convolutional layers with data augmentation – The number of errors reduced to 3 out of 150.
I hope you find this useful. Try doing your own experiments and post your findings in the comments section.
References
Keras Blog
Deep Learning with Python Github Repository
Subscribe & Download Code
If you liked this article and would like to download code and example images used in this post, please subscribe to our newsletter. You will also receive a free Computer Vision Resource Guide. In our newsletter, we share OpenCV tutorials and examples written in C++/Python, and Computer Vision and Machine Learning algorithms and news.