Tutorial: Training a Binary ResNet Classifier

Note

You can download an executable version (Jupyter Notebook) of this tutorial and the data required to follow along here.

train_a_narw_detector_part_1

North Atlantic Right Whale detector-part 1

This is the first of a two parts tutorial illustrating how to build a deep learning acoustic detector with ketos.

We'll use the database built in the Creating a training database tutorial, in which we converted raw audio files to spectrograms of the North Atlantic Right Whale's stereotypical upcall. If you didn't follow that tutorial, you can find the resulting database in the .zip file linked at the top of this page. There you will also find an executable version of this jupyter notebook, in case you want to follow along.

Our final goal is to have a detector that can take a long .wav file (e.g.: 30 min) and tell us where within that file are the right whales upcalls.

The core part of such detector will be a binary classifer that takes 3-s long spectrograms and classifies them into "contains an upcall" or "does not contain an upcall". We will treat these two classes as "1" and "0". This is what we'll cover in this tutorial.

The second part will take this binary classifier and turn it into a detector.

Contents:

1. Importing the packages
2. Creating the data feed
3. Creating and training the Neural Network

The lines below define the random seeds used in the tutorial. This is necessary to ensure that you get the precisely the same results every time you run the code.

In [1]:
import numpy as np
np.random.seed(1000)

import tensorflow as tf
tf.random.set_seed(2000)

1. Importing the packages

We start by importing the ketos modules and classes we will use throughout the tutorial.

In [2]:
import ketos.data_handling.database_interface as dbi
from ketos.neural_networks.resnet import ResNetInterface
from ketos.data_handling.data_feeding import BatchGenerator

2. Creating the data feed

The database created in the Creating a training database tutorial organizes the data into "train" and "validation". Ketos' BatchGenerator provides an interface that makes it easy to use this database during the training process. It selects batches of data from the database and feeds it to the neural network. We are dealing with small amounts of data for the purposes of this tutorial, but this is very helpful when dealing with larger databases, which is often the case in deep learning.

First, we open a connection to our database:

In [3]:
db = dbi.open_file("database.h5", 'r')

Then, we need to open the tables containing the spectrograms and annotations. All we are doing here is creating a handle to indicate where the BatchGenerator can find the spectrograms and annotations, but no data is actually loaded into memory at this point.

In [4]:
train_data = dbi.open_table(db, "/train/data")
val_data = dbi.open_table(db, "/val/data")

With the handles ready, we can create the two batch generators, one to load training data and another for the validation data.

There are a few options we need to configure:
batch_size indicates how many data samples (spectrograms) will be loaded into memory at a time

data_table indicates the table handle we just created

output_transform_func indicates a function that transforms the data as it is loaded into memory. This can be any python function. To make the job easier, the neural network architectures availablable in ketos all have an interface that includes a transformation function to put the data into the right format for that type of neural network

shuffle indicates whether we want to shuffle the data before creating the batches. That's a good idea for our case because in the database our spectrograms are sorted by labels (all the 'upcalls' followed by all 'backgrounds'), but we want each batch to contain a mix

refresh_on_epoch_end When we train the neural network, we will show it the whole training dataset several times (each time is called an 'epoch'). Setting this option to True makes the batch generator reshuffle the data at the end of each epoch, so that the batches contain different examples each time.

In [5]:
# Below is an example of a simple data transformation function
# (However, in this tutorial we will use the ResNetInterface.transform_batch function provided by Ketos)
def transform_batch(X, Y):
  x = X.reshape(X.shape[0],X.shape[1],X.shape[2],1)
  y = tf.one_hot(Y['label'], depth=2, axis=1).numpy()
  return x, y
In [6]:
train_generator = BatchGenerator(batch_size=128, data_table=train_data, 
                                  output_transform_func=ResNetInterface.transform_batch,
                                  shuffle=True, refresh_on_epoch_end=True)

For the validation generator, we'll just change the table handles. We'll also set set refresh_on_epoch_end to False, so that the validation set is shuffled once before creating the batches for the first epoch but not in susequent epochs. This way, everytime we validate the models (i.e.: at the end of each training epoch) it will use the same order for the validation samples.

In [7]:
val_generator = BatchGenerator(batch_size=128, data_table=val_data,
                                 output_transform_func=ResNetInterface.transform_batch,
                                 shuffle=True, refresh_on_epoch_end=False)

Creating and training the neural network

For this exercise we will use a ResNet-like architecture, which is a popular architecture for image recognition and has also shown good results for audio recognition using spectral inputs.

Ketos' Neural Network interfaces can use recipes to create a network. The recipe files are an easy way to let others reproduce the architecture you used. You can find a recipe.json file within the .zip file.

In [8]:
resnet = ResNetInterface.build_from_recipe_file("recipe.json")

Notice that this creates a brand new network. That's what we want since we are training from scratch, but once the model is trained, we can also save it for later use and share it with others. That saved model will not only contain the recipe for recreating the architecture, but also the weights optimized (or learned) during the training process and can, therefore, be used without the need for training again (or access to the training data).

Before we start training, we just need to connect the batch generators we created to the network interface, so it can access the data as it needs.

In [9]:
resnet.train_generator = train_generator
resnet.val_generator = val_generator

We also need to set where we want to save our model's checkpoints. By default, ketos will save the model progress every 5 epochs (this can be adjusted by the checkpoint_freq parameter in the train_loop method, but we'll use the default). If the folder does not yet exist, Ketos will create it. Later when we save the model, Ketos will take the latest checkpoint and include it in the model file.

In [10]:
resnet.checkpoint_dir = "checkpoints"

We train our upcall/background classifier by calling the train_loop method in our resnet object. In the example below, we specify the number of epochs, which indicates how many times the network will go through the training dataset in order to learn usefull features for classification. We also set the verbose parameter to True, which will print some summary metrics during the training.

Given the simple task/database we are using for this tutorial, 30 epochs should give us a reasonably good classifier to build a detector in part 2. If you are following along, please notice that it might take a while, depending on your computer (about 60 min on an average laptop).

In [11]:
resnet.train_loop(n_epochs=30, verbose=True)
====================================================================================
Epoch: 1 
train_loss: 0.4338880777359009
train_CategoricalAccuracy: 0.595 train_Precision: 0.580 train_Recall: 0.684 
val_loss: 0.5085288882255554
val_CategoricalAccuracy: 0.500 val_Precision: 0.000 val_Recall: 0.000 

====================================================================================

====================================================================================
Epoch: 2 
train_loss: 0.34146538376808167
train_CategoricalAccuracy: 0.702 train_Precision: 0.655 train_Recall: 0.857 
val_loss: 0.49516749382019043
val_CategoricalAccuracy: 0.530 val_Precision: 0.515 val_Recall: 1.000 

====================================================================================

====================================================================================
Epoch: 3 
train_loss: 0.31047946214675903
train_CategoricalAccuracy: 0.725 train_Precision: 0.669 train_Recall: 0.887 
val_loss: 0.483070433139801
val_CategoricalAccuracy: 0.530 val_Precision: 0.515 val_Recall: 1.000 

====================================================================================

====================================================================================
Epoch: 4 
train_loss: 0.28095531463623047
train_CategoricalAccuracy: 0.752 train_Precision: 0.686 train_Recall: 0.931 
val_loss: 0.4828239679336548
val_CategoricalAccuracy: 0.515 val_Precision: 0.508 val_Recall: 1.000 

====================================================================================

====================================================================================
Epoch: 5 
train_loss: 0.2622312307357788
train_CategoricalAccuracy: 0.766 train_Precision: 0.706 train_Recall: 0.913 
val_loss: 0.4832149147987366
val_CategoricalAccuracy: 0.515 val_Precision: 0.508 val_Recall: 1.000 

====================================================================================

====================================================================================
Epoch: 6 
train_loss: 0.26083073019981384
train_CategoricalAccuracy: 0.761 train_Precision: 0.695 train_Recall: 0.927 
val_loss: 0.49218130111694336
val_CategoricalAccuracy: 0.505 val_Precision: 0.503 val_Recall: 1.000 

====================================================================================

====================================================================================
Epoch: 7 
train_loss: 0.21898069977760315
train_CategoricalAccuracy: 0.817 train_Precision: 0.768 train_Recall: 0.910 
val_loss: 0.483643114566803
val_CategoricalAccuracy: 0.515 val_Precision: 0.508 val_Recall: 1.000 

====================================================================================

====================================================================================
Epoch: 8 
train_loss: 0.1959889680147171
train_CategoricalAccuracy: 0.844 train_Precision: 0.796 train_Recall: 0.925 
val_loss: 0.4783859848976135
val_CategoricalAccuracy: 0.520 val_Precision: 0.510 val_Recall: 1.000 

====================================================================================

====================================================================================
Epoch: 9 
train_loss: 0.19118483364582062
train_CategoricalAccuracy: 0.839 train_Precision: 0.804 train_Recall: 0.896 
val_loss: 0.47925305366516113
val_CategoricalAccuracy: 0.525 val_Precision: 0.513 val_Recall: 1.000 

====================================================================================

====================================================================================
Epoch: 10 
train_loss: 0.17318378388881683
train_CategoricalAccuracy: 0.860 train_Precision: 0.820 train_Recall: 0.923 
val_loss: 0.48716098070144653
val_CategoricalAccuracy: 0.505 val_Precision: 0.503 val_Recall: 1.000 

====================================================================================

====================================================================================
Epoch: 11 
train_loss: 0.15976482629776
train_CategoricalAccuracy: 0.885 train_Precision: 0.856 train_Recall: 0.924 
val_loss: 0.4935968518257141
val_CategoricalAccuracy: 0.510 val_Precision: 0.505 val_Recall: 1.000 

====================================================================================

====================================================================================
Epoch: 12 
train_loss: 0.15248368680477142
train_CategoricalAccuracy: 0.891 train_Precision: 0.867 train_Recall: 0.922 
val_loss: 0.31768137216567993
val_CategoricalAccuracy: 0.670 val_Precision: 0.605 val_Recall: 0.980 

====================================================================================

====================================================================================
Epoch: 13 
train_loss: 0.13877111673355103
train_CategoricalAccuracy: 0.898 train_Precision: 0.873 train_Recall: 0.932 
val_loss: 0.40557682514190674
val_CategoricalAccuracy: 0.570 val_Precision: 0.538 val_Recall: 1.000 

====================================================================================

====================================================================================
Epoch: 14 
train_loss: 0.12895086407661438
train_CategoricalAccuracy: 0.904 train_Precision: 0.888 train_Recall: 0.924 
val_loss: 0.3949286937713623
val_CategoricalAccuracy: 0.540 val_Precision: 0.521 val_Recall: 1.000 

====================================================================================

====================================================================================
Epoch: 15 
train_loss: 0.1229528933763504
train_CategoricalAccuracy: 0.912 train_Precision: 0.903 train_Recall: 0.923 
val_loss: 0.3685142993927002
val_CategoricalAccuracy: 0.615 val_Precision: 0.565 val_Recall: 1.000 

====================================================================================

====================================================================================
Epoch: 16 
train_loss: 0.11468056589365005
train_CategoricalAccuracy: 0.915 train_Precision: 0.909 train_Recall: 0.924 
val_loss: 0.3127107620239258
val_CategoricalAccuracy: 0.665 val_Precision: 0.599 val_Recall: 1.000 

====================================================================================

====================================================================================
Epoch: 17 
train_loss: 0.11124808341264725
train_CategoricalAccuracy: 0.922 train_Precision: 0.909 train_Recall: 0.938 
val_loss: 0.26711761951446533
val_CategoricalAccuracy: 0.745 val_Precision: 0.669 val_Recall: 0.970 

====================================================================================

====================================================================================
Epoch: 18 
train_loss: 0.09989237785339355
train_CategoricalAccuracy: 0.934 train_Precision: 0.931 train_Recall: 0.938 
val_loss: 0.2679952383041382
val_CategoricalAccuracy: 0.765 val_Precision: 0.693 val_Recall: 0.950 

====================================================================================

====================================================================================
Epoch: 19 
train_loss: 0.09425995498895645
train_CategoricalAccuracy: 0.939 train_Precision: 0.930 train_Recall: 0.950 
val_loss: 0.23850911855697632
val_CategoricalAccuracy: 0.775 val_Precision: 0.704 val_Recall: 0.950 

====================================================================================

====================================================================================
Epoch: 20 
train_loss: 0.08219350129365921
train_CategoricalAccuracy: 0.949 train_Precision: 0.946 train_Recall: 0.952 
val_loss: 0.255126953125
val_CategoricalAccuracy: 0.775 val_Precision: 0.704 val_Recall: 0.950 

====================================================================================

====================================================================================
Epoch: 21 
train_loss: 0.07940079271793365
train_CategoricalAccuracy: 0.951 train_Precision: 0.948 train_Recall: 0.955 
val_loss: 0.24770861864089966
val_CategoricalAccuracy: 0.755 val_Precision: 0.683 val_Recall: 0.950 

====================================================================================

====================================================================================
Epoch: 22 
train_loss: 0.07545625418424606
train_CategoricalAccuracy: 0.954 train_Precision: 0.947 train_Recall: 0.962 
val_loss: 0.22864598035812378
val_CategoricalAccuracy: 0.785 val_Precision: 0.718 val_Recall: 0.940 

====================================================================================

====================================================================================
Epoch: 23 
train_loss: 0.058807797729969025
train_CategoricalAccuracy: 0.968 train_Precision: 0.964 train_Recall: 0.972 
val_loss: 0.28637802600860596
val_CategoricalAccuracy: 0.840 val_Precision: 0.798 val_Recall: 0.910 

====================================================================================

====================================================================================
Epoch: 24 
train_loss: 0.05741170421242714
train_CategoricalAccuracy: 0.970 train_Precision: 0.967 train_Recall: 0.974 
val_loss: 0.32357555627822876
val_CategoricalAccuracy: 0.680 val_Precision: 0.611 val_Recall: 0.990 

====================================================================================

====================================================================================
Epoch: 25 
train_loss: 0.05459459871053696
train_CategoricalAccuracy: 0.971 train_Precision: 0.969 train_Recall: 0.974 
val_loss: 0.23912644386291504
val_CategoricalAccuracy: 0.800 val_Precision: 0.734 val_Recall: 0.940 

====================================================================================

====================================================================================
Epoch: 26 
train_loss: 0.045924343168735504
train_CategoricalAccuracy: 0.974 train_Precision: 0.973 train_Recall: 0.974 
val_loss: 0.22424530982971191
val_CategoricalAccuracy: 0.850 val_Precision: 0.818 val_Recall: 0.900 

====================================================================================

====================================================================================
Epoch: 27 
train_loss: 0.03830939903855324
train_CategoricalAccuracy: 0.980 train_Precision: 0.983 train_Recall: 0.978 
val_loss: 0.20659667253494263
val_CategoricalAccuracy: 0.820 val_Precision: 0.762 val_Recall: 0.930 

====================================================================================

====================================================================================
Epoch: 28 
train_loss: 0.03324630483984947
train_CategoricalAccuracy: 0.982 train_Precision: 0.983 train_Recall: 0.980 
val_loss: 0.20578312873840332
val_CategoricalAccuracy: 0.875 val_Precision: 0.864 val_Recall: 0.890 

====================================================================================

====================================================================================
Epoch: 29 
train_loss: 0.029962901026010513
train_CategoricalAccuracy: 0.984 train_Precision: 0.983 train_Recall: 0.984 
val_loss: 0.18220269680023193
val_CategoricalAccuracy: 0.875 val_Precision: 0.871 val_Recall: 0.880 

====================================================================================

====================================================================================
Epoch: 30 
train_loss: 0.029690532013773918
train_CategoricalAccuracy: 0.983 train_Precision: 0.984 train_Recall: 0.982 
val_loss: 0.19279241561889648
val_CategoricalAccuracy: 0.890 val_Precision: 0.906 val_Recall: 0.870 

====================================================================================

After training is done, we can close the database

In [12]:
db.close()

And finally save the model

In [ ]:
resnet.save_model('narw.kt',audio_repr_file='spec_config.json')

The command above will create the narw.kt file in the same directory where you are running this notebook (or the working directory for you session if you are running the python interpreter elsewhere). You can also specify a different folder and name, like resnet.save_model('trained_classifiers/my_classifier.kt'). The audio_repr_file argument can be used to add an audio specification to the model file. This is useful when reusing the model because it makes the settings used to process the audio available within the model file. If this argument is omitted or set to None, the settings will not be added to the model file.

This classifier works with 3 seconds long spectrograms as inputs.

We won't actually use it directly, as our goal is to build a detector that will scan a longer .wav file (e.g.: 30min) and output time stamps for indicating when right whales are present. That's the topic of our next step, in which we will use the classifier we just trained to built our right whale detector.