Tutorial: Training a Binary ResNet Classifier¶
Note
You can download an executable version (Jupyter Notebook) of this tutorial and the data required to follow along
here
.
North Atlantic Right Whale detector-part 1¶
This is the first of a two parts tutorial illustrating how to build a deep learning acoustic detector with ketos.
We'll use the database built in the Creating a training database tutorial, in which we converted raw audio files to spectrograms of the North Atlantic Right Whale's stereotypical upcall. If you didn't follow that tutorial, you can find the resulting database in the .zip file linked at the top of this page. There you will also find an executable version of this jupyter notebook, in case you want to follow along.
Our final goal is to have a detector that can take a long .wav file (e.g.: 30 min) and tell us where within that file are the right whales upcalls.
The core part of such detector will be a binary classifer that takes 3-s long spectrograms and classifies them into "contains an upcall" or "does not contain an upcall". We will treat these two classes as "1" and "0". This is what we'll cover in this tutorial.
The second part will take this binary classifier and turn it into a detector.
Contents:¶
1. Importing the packages
2. Creating the data feed
3. Creating and training the Neural Network
The lines below define the random seeds used in the tutorial. This is necessary to ensure that you get the precisely the same results every time you run the code.
import numpy as np
np.random.seed(1000)
import tensorflow as tf
tf.random.set_seed(2000)
1. Importing the packages¶
We start by importing the ketos modules and classes we will use throughout the tutorial.
import ketos.data_handling.database_interface as dbi
from ketos.neural_networks.resnet import ResNetInterface
from ketos.data_handling.data_feeding import BatchGenerator
2. Creating the data feed¶
The database created in the Creating a training database tutorial organizes the data into "train" and "validation". Ketos' BatchGenerator
provides an interface that makes it easy to use this database during the training process. It selects batches of data from the database and feeds it to the neural network. We are dealing with small amounts of data for the purposes of this tutorial, but this is very helpful when dealing with larger databases, which is often the case in deep learning.
First, we open a connection to our database:
db = dbi.open_file("database.h5", 'r')
Then, we need to open the tables containing the spectrograms and annotations. All we are doing here is creating a handle to indicate where the BatchGenerator can find the spectrograms and annotations, but no data is actually loaded into memory at this point.
train_data = dbi.open_table(db, "/train/data")
val_data = dbi.open_table(db, "/val/data")
With the handles ready, we can create the two batch generators, one to load training data and another for the validation data.
There are a few options we need to configure:
batch_size indicates how many data samples (spectrograms) will be loaded into memory at a time
data_table indicates the table handle we just created
output_transform_func indicates a function that transforms the data as it is loaded into memory. This can be any python function. To make the job easier, the neural network architectures availablable in ketos all have an interface that includes a transformation function to put the data into the right format for that type of neural network
shuffle indicates whether we want to shuffle the data before creating the batches. That's a good idea for our case because in the database our spectrograms are sorted by labels (all the 'upcalls' followed by all 'backgrounds'), but we want each batch to contain a mix
refresh_on_epoch_end When we train the neural network, we will show it the whole training dataset several times (each time is called an 'epoch'). Setting this option to True
makes the batch generator reshuffle the data at the end of each epoch, so that the batches contain different examples each time.
# Below is an example of a simple data transformation function
# (However, in this tutorial we will use the ResNetInterface.transform_batch function provided by Ketos)
def transform_batch(X, Y):
x = X.reshape(X.shape[0],X.shape[1],X.shape[2],1)
y = tf.one_hot(Y['label'], depth=2, axis=1).numpy()
return x, y
train_generator = BatchGenerator(batch_size=128, data_table=train_data,
output_transform_func=ResNetInterface.transform_batch,
shuffle=True, refresh_on_epoch_end=True)
For the validation generator, we'll just change the table handles. We'll also set set refresh_on_epoch_end to False
, so that the validation set is shuffled once before creating the batches for the first epoch but not in susequent epochs. This way, everytime we validate the models (i.e.: at the end of each training epoch) it will use the same order for the validation samples.
val_generator = BatchGenerator(batch_size=128, data_table=val_data,
output_transform_func=ResNetInterface.transform_batch,
shuffle=True, refresh_on_epoch_end=False)
Creating and training the neural network¶
For this exercise we will use a ResNet-like architecture, which is a popular architecture for image recognition and has also shown good results for audio recognition using spectral inputs.
Ketos' Neural Network interfaces can use recipes to create a network. The recipe files are an easy way to let others reproduce the architecture you used. You can find a recipe.json file within the .zip file.
resnet = ResNetInterface.build_from_recipe_file("recipe.json")
Notice that this creates a brand new network. That's what we want since we are training from scratch, but once the model is trained, we can also save it for later use and share it with others. That saved model will not only contain the recipe for recreating the architecture, but also the weights optimized (or learned) during the training process and can, therefore, be used without the need for training again (or access to the training data).
Before we start training, we just need to connect the batch generators we created to the network interface, so it can access the data as it needs.
resnet.train_generator = train_generator
resnet.val_generator = val_generator
We also need to set where we want to save our model's checkpoints. By default, ketos will save the model progress every 5 epochs (this can be adjusted by the checkpoint_freq parameter in the train_loop method, but we'll use the default). If the folder does not yet exist, Ketos will create it. Later when we save the model, Ketos will take the latest checkpoint and include it in the model file.
resnet.checkpoint_dir = "checkpoints"
We train our upcall/background classifier by calling the train_loop method in our resnet object. In the example below, we specify the number of epochs, which indicates how many times the network will go through the training dataset in order to learn usefull features for classification. We also set the verbose parameter to True
, which will print some summary metrics during the training.
Given the simple task/database we are using for this tutorial, 30 epochs should give us a reasonably good classifier to build a detector in part 2. If you are following along, please notice that it might take a while, depending on your computer (about 60 min on an average laptop).
resnet.train_loop(n_epochs=30, verbose=True)
After training is done, we can close the database
db.close()
And finally save the model
resnet.save_model('narw.kt',audio_repr_file='spec_config.json')
The command above will create the narw.kt file in the same directory where you are running this notebook (or the working directory for you session if you are running the python interpreter elsewhere). You can also specify a different folder and name, like resnet.save_model('trained_classifiers/my_classifier.kt')
.
The audio_repr_file
argument can be used to add an audio specification to the model file. This is useful when reusing the model because it makes the settings used to process the audio available within the model file. If this argument is omitted or set to None
, the settings will not be added to the model file.
This classifier works with 3 seconds long spectrograms as inputs.
We won't actually use it directly, as our goal is to build a detector that will scan a longer .wav file (e.g.: 30min) and output time stamps for indicating when right whales are present. That's the topic of our next step, in which we will use the classifier we just trained to built our right whale detector.