Tutorial: Training a Binary ResNet Classifier
Note
You can download an executable version (Jupyter Notebook) of this tutorial and the data required to follow along
here
.
North Atlantic Right Whale detector-part 1¶
This is the first of a two parts tutorial illustrating how to build a deep learning acoustic detector with ketos.
We'll use the database built in the Creating a training database tutorial, in which we converted raw audio files to spectrograms of the North Atlantic Right Whale's stereotypical upcall. If you didn't follow that tutorial, you can find the resulting database in the .zip file linked at the top of this page. There you will also find an executable version of this jupyter notebook, in case you want to follow along.
Our final goal is to have a detector that can take a long .wav file (e.g.: 30 min) and tell us where within that file are the right whales upcalls.
The core part of such detector will be a binary classifer that takes 3-s long spectrograms and classifies them into "contains an upcall" or "does not contain an upcall". We will treat these two classes as "1" and "0". This is what we'll cover in this tutorial.
The second part will take this binary classifier and turn it into a detector.
Contents:¶
1. Importing the packages
2. Creating the data feed
3. Creating and training the Neural Network
The lines below define the random seeds used in the tutorial. This is necessary to ensure that you get the precisely the same results every time you run the code.
import numpy as np
np.random.seed(1000)
import tensorflow as tf
tf.random.set_seed(2000)
1. Importing the packages¶
We start by importing the ketos modules and classes we will use throughout the tutorial.
import ketos.data_handling.database_interface as dbi
from ketos.neural_networks.resnet import ResNetInterface
from ketos.data_handling.data_feeding import BatchGenerator
2023-07-07 14:23:34.658393: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-07-07 14:23:34.663517: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-07-07 14:23:34.663680: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-07-07 14:23:34.664206: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2023-07-07 14:23:34.664576: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-07-07 14:23:34.664730: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-07-07 14:23:34.664910: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-07-07 14:23:35.109018: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-07-07 14:23:35.109207: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-07-07 14:23:35.109353: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-07-07 14:23:35.109470: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 6702 MB memory: -> device: 0, name: NVIDIA GeForce GTX 1070 Ti, pci bus id: 0000:01:00.0, compute capability: 6.1 /home/bruno/.pyenv/versions/3.10.8/envs/ketos-tutorial/lib/python3.10/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead. super(Adam, self).__init__(name, **kwargs)
2. Creating the data feed¶
The database created in the Creating a training database tutorial organizes the data into "train" and "validation". Ketos' BatchGenerator
provides an interface that makes it easy to use this database during the training process. It selects batches of data from the database and feeds it to the neural network. We are dealing with small amounts of data for the purposes of this tutorial, but this is very helpful when dealing with larger databases, which is often the case in deep learning.
First, we open a connection to our database:
db = dbi.open_file("database.h5", 'r')
Then, we need to open the tables containing the spectrograms and annotations. All we are doing here is creating a handle to indicate where the BatchGenerator can find the spectrograms and annotations, but no data is actually loaded into memory at this point.
train_data = dbi.open_table(db, "/train/data")
val_data = dbi.open_table(db, "/val/data")
With the handles ready, we can create the two batch generators, one to load training data and another for the validation data.
There are a few options we need to configure:
batch_size indicates how many data samples (spectrograms) will be loaded into memory at a time
data_table indicates the table handle we just created
output_transform_func indicates a function that transforms the data as it is loaded into memory. This can be any python function. To make the job easier, the neural network architectures availablable in ketos all have an interface that includes a transformation function to put the data into the right format for that type of neural network
shuffle indicates whether we want to shuffle the data before creating the batches. That's a good idea for our case because in the database our spectrograms are sorted by labels (all the 'upcalls' followed by all 'backgrounds'), but we want each batch to contain a mix
refresh_on_epoch_end When we train the neural network, we will show it the whole training dataset several times (each time is called an 'epoch'). Setting this option to True
makes the batch generator reshuffle the data at the end of each epoch, so that the batches contain different examples each time.
# Below is an example of a simple data transformation function
# (However, in this tutorial we will use the ResNetInterface.transform_batch function provided by Ketos)
def transform_batch(X, Y):
x = X.reshape(X.shape[0],X.shape[1],X.shape[2],1)
y = tf.one_hot(Y['label'], depth=2, axis=1).numpy()
return x, y
train_generator = BatchGenerator(batch_size=128, data_table=train_data,
output_transform_func=ResNetInterface.transform_batch,
shuffle=True, refresh_on_epoch_end=True)
For the validation generator, we'll just change the table handles. We'll also set set refresh_on_epoch_end to False
, so that the validation set is shuffled once before creating the batches for the first epoch but not in susequent epochs. This way, everytime we validate the models (i.e.: at the end of each training epoch) it will use the same order for the validation samples.
val_generator = BatchGenerator(batch_size=128, data_table=val_data,
output_transform_func=ResNetInterface.transform_batch,
shuffle=True, refresh_on_epoch_end=False)
Creating and training the neural network¶
For this exercise we will use a ResNet-like architecture, which is a popular architecture for image recognition and has also shown good results for audio recognition using spectral inputs.
Ketos' Neural Network interfaces can use recipes to create a network. The recipe files are an easy way to let others reproduce the architecture you used. You can find a recipe.json file within the .zip file.
resnet = ResNetInterface.build("recipe.json")
Notice that this creates a brand new network. That's what we want since we are training from scratch, but once the model is trained, we can also save it for later use and share it with others. That saved model will not only contain the recipe for recreating the architecture, but also the weights optimized (or learned) during the training process and can, therefore, be used without the need for training again (or access to the training data).
Before we start training, we just need to connect the batch generators we created to the network interface, so it can access the data as it needs.
resnet.train_generator = train_generator
resnet.val_generator = val_generator
We also need to set where we want to save our model's checkpoints. By default, ketos will save the model progress every 5 epochs (this can be adjusted by the checkpoint_freq parameter in the train_loop method, but we'll use the default). If the folder does not yet exist, Ketos will create it. Later when we save the model, Ketos will take the latest checkpoint and include it in the model file.
resnet.checkpoint_dir = "checkpoints"
We train our upcall/background classifier by calling the train_loop method in our resnet object. In the example below, we specify the number of epochs, which indicates how many times the network will go through the training dataset in order to learn usefull features for classification. We also set the verbose parameter to True
, which will print some summary metrics during the training.
Given the simple task/database we are using for this tutorial, 30 epochs should give us a reasonably good classifier to build a detector in part 2. If you are following along, please notice that it might take a while, depending on your computer (about 60 min on an average laptop).
resnet.train_loop(n_epochs=30, verbose=True)
2023-07-07 14:23:38.326197: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8101
==================================================================================== Epoch: 1 train_loss: 0.43364542722702026 train_CategoricalAccuracy: 0.594 train_Precision: 0.580 train_Recall: 0.683 val_loss: 0.5087904334068298 val_CategoricalAccuracy: 0.500 val_Precision: 0.000 val_Recall: 0.000 ==================================================================================== ==================================================================================== Epoch: 2 train_loss: 0.3483597934246063 train_CategoricalAccuracy: 0.689 train_Precision: 0.650 train_Recall: 0.823 val_loss: 0.49771904945373535 val_CategoricalAccuracy: 0.520 val_Precision: 0.510 val_Recall: 0.980 ==================================================================================== ==================================================================================== Epoch: 3 train_loss: 0.30880007147789 train_CategoricalAccuracy: 0.726 train_Precision: 0.662 train_Recall: 0.923 val_loss: 0.4809614419937134 val_CategoricalAccuracy: 0.530 val_Precision: 0.515 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 4 train_loss: 0.2840263545513153 train_CategoricalAccuracy: 0.743 train_Precision: 0.678 train_Recall: 0.924 val_loss: 0.4890333414077759 val_CategoricalAccuracy: 0.515 val_Precision: 0.508 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 5 train_loss: 0.2659791111946106 train_CategoricalAccuracy: 0.757 train_Precision: 0.692 train_Recall: 0.928 val_loss: 0.47688257694244385 val_CategoricalAccuracy: 0.520 val_Precision: 0.510 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 6 train_loss: 0.23930667340755463 train_CategoricalAccuracy: 0.790 train_Precision: 0.726 train_Recall: 0.933 val_loss: 0.4993901252746582 val_CategoricalAccuracy: 0.500 val_Precision: 0.500 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 7 train_loss: 0.21126334369182587 train_CategoricalAccuracy: 0.827 train_Precision: 0.780 train_Recall: 0.913 val_loss: 0.48389697074890137 val_CategoricalAccuracy: 0.515 val_Precision: 0.508 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 8 train_loss: 0.18436050415039062 train_CategoricalAccuracy: 0.854 train_Precision: 0.814 train_Recall: 0.919 val_loss: 0.4842665195465088 val_CategoricalAccuracy: 0.515 val_Precision: 0.508 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 9 train_loss: 0.17526043951511383 train_CategoricalAccuracy: 0.873 train_Precision: 0.835 train_Recall: 0.929 val_loss: 0.4611823558807373 val_CategoricalAccuracy: 0.515 val_Precision: 0.508 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 10 train_loss: 0.1606837660074234 train_CategoricalAccuracy: 0.873 train_Precision: 0.838 train_Recall: 0.924 val_loss: 0.49293410778045654 val_CategoricalAccuracy: 0.505 val_Precision: 0.503 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 11 train_loss: 0.14703543484210968 train_CategoricalAccuracy: 0.894 train_Precision: 0.872 train_Recall: 0.924 val_loss: 0.48350274562835693 val_CategoricalAccuracy: 0.500 val_Precision: 0.500 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 12 train_loss: 0.14117835462093353 train_CategoricalAccuracy: 0.898 train_Precision: 0.880 train_Recall: 0.921 val_loss: 0.4712313413619995 val_CategoricalAccuracy: 0.510 val_Precision: 0.505 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 13 train_loss: 0.131413534283638 train_CategoricalAccuracy: 0.903 train_Precision: 0.891 train_Recall: 0.919 val_loss: 0.4178962707519531 val_CategoricalAccuracy: 0.545 val_Precision: 0.524 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 14 train_loss: 0.12461615353822708 train_CategoricalAccuracy: 0.912 train_Precision: 0.911 train_Recall: 0.913 val_loss: 0.49280881881713867 val_CategoricalAccuracy: 0.500 val_Precision: 0.500 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 15 train_loss: 0.12055444717407227 train_CategoricalAccuracy: 0.920 train_Precision: 0.906 train_Recall: 0.937 val_loss: 0.43877112865448 val_CategoricalAccuracy: 0.540 val_Precision: 0.521 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 16 train_loss: 0.1107536032795906 train_CategoricalAccuracy: 0.924 train_Precision: 0.912 train_Recall: 0.939 val_loss: 0.30691373348236084 val_CategoricalAccuracy: 0.670 val_Precision: 0.602 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 17 train_loss: 0.09773175418376923 train_CategoricalAccuracy: 0.933 train_Precision: 0.932 train_Recall: 0.933 val_loss: 0.2971503734588623 val_CategoricalAccuracy: 0.795 val_Precision: 0.792 val_Recall: 0.800 ==================================================================================== ==================================================================================== Epoch: 18 train_loss: 0.08981834352016449 train_CategoricalAccuracy: 0.945 train_Precision: 0.940 train_Recall: 0.950 val_loss: 0.3420528173446655 val_CategoricalAccuracy: 0.685 val_Precision: 0.785 val_Recall: 0.510 ==================================================================================== ==================================================================================== Epoch: 19 train_loss: 0.08626841008663177 train_CategoricalAccuracy: 0.945 train_Precision: 0.944 train_Recall: 0.946 val_loss: 0.2664533257484436 val_CategoricalAccuracy: 0.810 val_Precision: 0.746 val_Recall: 0.940 ==================================================================================== ==================================================================================== Epoch: 20 train_loss: 0.07750551402568817 train_CategoricalAccuracy: 0.950 train_Precision: 0.945 train_Recall: 0.956 val_loss: 0.441051721572876 val_CategoricalAccuracy: 0.510 val_Precision: 0.505 val_Recall: 1.000 ==================================================================================== ==================================================================================== Epoch: 21 train_loss: 0.0627274438738823 train_CategoricalAccuracy: 0.967 train_Precision: 0.967 train_Recall: 0.967 val_loss: 0.21134936809539795 val_CategoricalAccuracy: 0.860 val_Precision: 0.810 val_Recall: 0.940 ==================================================================================== ==================================================================================== Epoch: 22 train_loss: 0.0596284493803978 train_CategoricalAccuracy: 0.970 train_Precision: 0.970 train_Recall: 0.969 val_loss: 0.2229396104812622 val_CategoricalAccuracy: 0.880 val_Precision: 0.873 val_Recall: 0.890 ==================================================================================== ==================================================================================== Epoch: 23 train_loss: 0.048517871648073196 train_CategoricalAccuracy: 0.975 train_Precision: 0.976 train_Recall: 0.973 val_loss: 0.21462082862854004 val_CategoricalAccuracy: 0.875 val_Precision: 0.975 val_Recall: 0.770 ==================================================================================== ==================================================================================== Epoch: 24 train_loss: 0.04241454228758812 train_CategoricalAccuracy: 0.978 train_Precision: 0.980 train_Recall: 0.976 val_loss: 0.19882500171661377 val_CategoricalAccuracy: 0.875 val_Precision: 0.963 val_Recall: 0.780 ==================================================================================== ==================================================================================== Epoch: 25 train_loss: 0.036662518978118896 train_CategoricalAccuracy: 0.979 train_Precision: 0.982 train_Recall: 0.977 val_loss: 0.2546226978302002 val_CategoricalAccuracy: 0.785 val_Precision: 0.983 val_Recall: 0.580 ==================================================================================== ==================================================================================== Epoch: 26 train_loss: 0.030914973467588425 train_CategoricalAccuracy: 0.983 train_Precision: 0.985 train_Recall: 0.980 val_loss: 0.36067163944244385 val_CategoricalAccuracy: 0.640 val_Precision: 0.938 val_Recall: 0.300 ==================================================================================== ==================================================================================== Epoch: 27 train_loss: 0.029800748452544212 train_CategoricalAccuracy: 0.984 train_Precision: 0.987 train_Recall: 0.982 val_loss: 0.22378051280975342 val_CategoricalAccuracy: 0.810 val_Precision: 0.984 val_Recall: 0.630 ==================================================================================== ==================================================================================== Epoch: 28 train_loss: 0.028306158259510994 train_CategoricalAccuracy: 0.985 train_Precision: 0.989 train_Recall: 0.982 val_loss: 0.17236924171447754 val_CategoricalAccuracy: 0.890 val_Precision: 0.882 val_Recall: 0.900 ==================================================================================== ==================================================================================== Epoch: 29 train_loss: 0.02544277161359787 train_CategoricalAccuracy: 0.987 train_Precision: 0.991 train_Recall: 0.982 val_loss: 0.18860876560211182 val_CategoricalAccuracy: 0.855 val_Precision: 0.973 val_Recall: 0.730 ==================================================================================== ==================================================================================== Epoch: 30 train_loss: 0.023734688758850098 train_CategoricalAccuracy: 0.987 train_Precision: 0.991 train_Recall: 0.982 val_loss: 0.1677076816558838 val_CategoricalAccuracy: 0.880 val_Precision: 0.922 val_Recall: 0.830 ====================================================================================
After training is done, we can close the database
db.close()
And finally save the model
resnet.save('narw.kt',audio_repr_file='spec_config.json')
The command above will create the narw.kt file in the same directory where you are running this notebook (or the working directory for you session if you are running the python interpreter elsewhere). You can also specify a different folder and name, like resnet.save_model('trained_classifiers/my_classifier.kt')
.
The audio_repr_file
argument can be used to add an audio specification to the model file. This is useful when reusing the model because it makes the settings used to process the audio available within the model file. If this argument is omitted or set to None
, the settings will not be added to the model file.
This classifier works with 3 seconds long spectrograms as inputs.
We won't actually use it directly, as our goal is to build a detector that will scan a longer .wav file (e.g.: 30min) and output time stamps for indicating when right whales are present. That's the topic of our next step, in which we will use the classifier we just trained to built our right whale detector.