In the previous two posts, we learned how to use pre-trained models and how to extract features from them for training a model for a different task. In this tutorial, we will learn how to fine-tune a pre-trained model for a different task than it was originally trained for.
We will try to improve on the problem of classifying pumpkin, watermelon, and tomato discussed in the previous post. We will be using the same data for this tutorial.
What is Fine-tuning of a network
We have already explained the importance of using pre-trained networks in our previous article. Just to recap, when we train a network from scratch, we encounter the following two limitations :
- Huge data required – Since the network has millions of parameters, to get an optimal set of parameters, we need to have a lot of data.
- Huge computing power required – Even if we have a lot of data, training generally requires multiple iterations and it takes a toll on the computing resources.
The task of fine-tuning a network is to tweak the parameters of an already trained network so that it adapts to the new task at hand. As explained here, the initial layers learn very general features and as we go higher up the network, the layers tend to learn patterns more specific to the task it is being trained on. Thus, for fine-tuning, we want to keep the initial layers intact ( or freeze them ) and retrain the later layers for our task.
Thus, fine-tuning avoids both the limitations discussed above.
- The amount of data required for training is not much because of two reasons. First, we are not training the entire network. Second, the part that is being trained is not trained from scratch.
- Since the parameters that need to be updated is less, the amount of time needed will also be less.
Fine-tuning in Keras
Let us directly dive into the code without much ado. We will be using the same data which we used in the previous post. You can choose to use a larger dataset if you have a GPU as the training will take much longer if you do it on a CPU for a large dataset. We will use the VGG model for fine-tuning.
Load the pre-trained model
First, we will load a VGG model without the top layer ( which consists of fully connected layers ).
from tensorflow.keras.applications import vgg16
# Init the VGG model
vgg_conv = vgg16.VGG16(weights='imagenet', include_top=False, input_shape=(image_size, image_size, 3))
Freeze the required layers
In Keras, each layer has a parameter called “trainable”. For freezing the weights of a particular layer, we should set this parameter to False, indicating that this layer should not be trained. That’s it! We go over each layer and select which layers we want to train.
# Freeze all the layers
for layer in vgg_conv.layers[:]:
layer.trainable = False
# Check the trainable status of the individual layers
for layer in vgg_conv.layers:
print(layer, layer.trainable)
Create a new model
Now that we have set the trainable parameters of our base network, we would like to add a classifier on top of the convolutional base. We will simply add a fully connected layer followed by a softmax layer with 3 outputs. This is done as given below.
# Create the model
model = Sequential()
# Add the vgg convolutional base model
model.add(vgg_conv)
# Add new layers
model.add(Flatten())
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(3, activation='softmax'))
# Show a summary of the model. Check the number of trainable parameters
model.summary()
Setup the data generators
We have already separated the data into train and validation and kept it in the “train” and “validation” folders. We can use ImageDataGenerator available in Keras to read images in batches directly from these folders and optionally perform data augmentation. We will use two different data generators for train and validation folders.
# Load the normalized images
train_datagen = ImageDataGenerator(rescale=1./255)
validation_datagen = ImageDataGenerator(rescale=1./255)
# Change the batchsize according to your system RAM
train_batchsize = 100
val_batchsize = 10
# Data generator for training data
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(image_size, image_size),
batch_size=train_batchsize,
class_mode='categorical')
# Data generator for validation data
validation_generator = validation_datagen.flow_from_directory(
validation_dir,
target_size=(image_size, image_size),
batch_size=val_batchsize,
class_mode='categorical',
shuffle=False)
Train the model
Till now, we have created the model and set up the data for training. So, we should proceed with the training and check out the performance. We will have to specify the optimizer and the learning rate and start training using the `model.fit()
` function. After the training is over, we will save the model.
# Configure the model for training
model.compile(loss='categorical_crossentropy',
optimizer=optimizers.RMSprop(lr=1e-4),
metrics=['acc'])
# Train the model
history = model.fit(
train_generator,
steps_per_epoch=
train_generator.samples/train_generator.batch_size,
epochs=20,
validation_data=validation_generator,
validation_steps=
validation_generator.samples/validation_generator.batch_size,
verbose=1)
Check Performance
We obtained an accuracy of 90% with the transfer learning approach discussed in our previous article. Here we are getting a much better accuracy of 98%.
Let us see the loss and accuracy curves using visualize_results(history)
function:
# Utility function for plotting of the model results
def visualize_results(history):
# Plot the accuracy and loss curves
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'b', label='Training acc')
plt.plot(epochs, val_acc, 'r', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
# Run the function to illustrate accuracy and loss
visualize_results(history)
The output is:
Also, let us visually see the errors that we got. In order to avoid extra code duplications in different experiments we will introduce several utility functions: obtain_errors
and show_errors
:
# Utility function for obtaining of the errors
def obtain_errors(val_generator, predictions):
# Get the filenames from the generator
fnames = validation_generator.filenames
# Get the ground truth from generator
ground_truth = validation_generator.classes
# Get the dictionary of classes
label2index = validation_generator.class_indices
# Obtain the list of the classes
idx2label = list(label2index.keys())
print("The list of classes: ", idx2label)
# Get the class index
predicted_classes = np.argmax(predictions, axis=1)
errors = np.where(predicted_classes != ground_truth)[0]
print("Number of errors = {}/{}".format(len(errors),validation_generator.samples))
return idx2label, errors, fnames
# Utility function for visualization of the errors
def show_errors(idx2label, errors, predictions, fnames):
# Show the errors
for i in range(len(errors)):
pred_class = np.argmax(predictions[errors[i]])
pred_label = idx2label[pred_class]
title = 'Original label:{}, Prediction :{}, confidence : {:.3f}'.format(
fnames[errors[i]].split('/')[0],
pred_label,
predictions[errors[i]][pred_class])
original = load_img('{}/{}'.format(validation_dir,fnames[errors[i]]))
plt.figure(figsize=[7,7])
plt.axis('off')
plt.title(title)
plt.imshow(original)
plt.show()
Thus, the pipeline of error analysis is the following:
# Get the predictions from the model using the generator
predictions = model.predict(validation_generator, steps=validation_generator.samples/validation_generator.batch_size,verbose=1)
# Run the function to get the list of classes and errors
idx2label, errors, fnames = obtain_errors(validation_generator, predictions)
# Run the function to illustrate the error cases
show_errors(idx2label, errors, predictions, fnames)
Let’s explore some output examples:
Experiments
We have done 3 experiments to see the effect of fine-tuning and data augmentation. We kept the validation set same as the previous post i.e. 50 images per class.
- Freezing all layers and learning a classifier on top of it – similar to transfer learning. The number of errors was 15 out of 150 images which is similar to what we got in the previous post.
- Training the last 3 convolutional layers – We got 9 errors out of 150.
- Training the last 3 convolutional layers with data augmentation – The number of errors reduced to 3 out of 150.
I hope you find this useful. Try doing your own experiments and post your findings in the comments section.