BatchGenerator

class ketos.data_handling.data_feeding.BatchGenerator(batch_size, data_table=None, annot_in_data_table=True, annot_table=None, x=None, y=None, select_indices=None, output_transform_func=None, x_field='data', y_field='label', shuffle=False, refresh_on_epoch_end=False, return_batch_ids=False, filter=None, n_extend=0)[source]

Creates batches to be fed to a model

Instances of this class are python generators. They will load one batch at a time from a HDF5 database, which is particularly useful when working with larger than memory datasets.

It is also possible to load the entire data set into memory and provide it to the BatchGenerator via the arguments x and y. This can be convenient when working with smaller data sets.

Yields: (X,Y) or (ids,X,Y) if ‘return_batch_ids’ is True.

X is a batch of data in the form of an np.array of shape (batch_size,mx,nx) where mx,nx are the shape of one instance of X in the database. The number of dimensions in addition to ‘batch_size’ will not necessarily be 2, but correspond to the instance shape (1 for 1d instances, 3 for 3d, etc).

It is also possible to load multiple data objects per instance by specifying multiple x_field values, e.g., ‘x_field=[‘spectrogram’, ‘waveform’]’. In such cases, the return argument X is a np.array with shape (batch_size,) and each element is a np.void array with length equal to the number of x fields. Each element in this array is a np.array and can be accessed either through use of integer indices or the x_field names, e.g., the first spectrogram in the batch can be accessed as X[0][0] or X[0][‘spectrogram’].

Similarly, Y is an np.array of shape(batch_size) with the corresponding labels. Each item in the array is a named array of shape=(n_fields), where n_field is the number of fields specified in the ‘y_field’ argument. For instance, if ‘y_field’=[‘label’, ‘start’, ‘end’], you can access the first label with Y[0][‘label’]. Notice that even if y_field==[‘label’], you would still use the Y[0][‘label’] syntax.

Important note: The above remarks regarding the shapes of X and Y assume that the output transform function output_transform_func only modifies the contents and not the shapes of X and Y, which may not always be the case.

Args:
batch_size: int

The number of instances in each batch. The last batch of an epoch might have fewer examples, depending on the number of instances in the hdf5_table. If the batch size is greater than the number of instances available, batch_size will be set to the number of instances. and a warning will be issued

data_table: pytables table (instance of table.Table())

The HDF5 table containing the data

annot_in_data_table: bool

Whether or not the annotation fields (e.g.: ‘label’) is in the data_table (True, default) or in a separate annot_table (False).

annot_table: pytables table (instance of table.Table())

A separate table for the annotations(labels), in case they are not included as fields in the data_table. This table must have a ‘data_index’ field, which corresponds to the the index (row number) of the data instance in the data_tables. Usually, a separete table will be used when the data is strongly annotated (i.e.: possibily more than one annotation per data instance). When there is only one annotation for each data instance, it’s recommended that annotations are included in the data_table for performance gains.

x: numpy array

Array containing the data images.

y: numpy array

Array containing the data labels. This array is expected to have a one-to-one correspondence to the x array (i.e.: y[0] is expected to have the label for x[0], y[1] for x[1], etc). If there are multiple labels for each data instance in x, use a data_table and an annot_table instead.

select_indices: list of ints

Indices of those instances that will retrieved from the HDF5 table by the BatchGenerator. By default all instances are retrieved.

output_transform_func: function

A function to be applied to the batch, transforming the instances. Must accept ‘X’ and ‘Y’ and, after processing, also return ‘X’ and ‘Y’ in a tuple.

x_field: str

The name of the column containing the X data in the hdf5_table

y_field: str

The name of the column containing the Y labels in the hdf5_table

shuffle: bool

If True, instances are selected randomly (without replacement). If False, instances are selected in the order the appear in the database

refresh_on_epoch_end: bool

If True, and shuffle is also True, resampling is performed at the end of each epoch resulting in different batches for every epoch. If False, the same batches are used in all epochs. Has no effect if shuffle is False.

return_batch_ids: bool

If False, each batch will consist of X and Y. If True, the instance indices (as they are in the hdf5_table) will be included ((ids, X, Y)).

filter: str

A valid PyTables query. If provided, the Batch Generator will query the hdf5 database before defining the batches and only the matching records will be used. Only relevant when data is passed through the hdf5_table argument. If both ‘filter’ and ‘indices’ are passed, ‘indices’ is ignored.

n_extend: int

Extend every batch by including the last n_extend samples from the previous batch and the first n_extend samples from the following batch. The first batch is only extended at the end, while the last batch is only extended at the beginning. The default value is zero, i.e., no extension.

Attr:
data: pytables table (instance of table.Table())

The HDF5 table containing the data

n_instances: int

The number of intances (rows) in the hdf5_table

n_batches: int

The number of batches of size ‘batch_size’ for each epoch

entry_indices:list of ints

A list of all intance indices, in the order used to generate batches for this epoch

batch_indices: list of tuples (int,int)

A list of (start,end) indices for each batch. These indices refer to the ‘entry_indices’ attribute.

batch_count: int

The current batch within the epoch. This will be the batch yielded on the next call to ‘next()’.

from_memory: bool

True if the data are loaded from memory rather than an HDF5 table.

Examples:
>>> from tables import open_file
>>> from ketos.data_handling.database_interface import open_table
>>> h5 = open_file("ketos/tests/assets/11x_same_spec.h5", 'r') # create the database handle  
>>> data_table = open_table(h5, "/group_1/table_data")
>>> annot_table = open_table(h5, "/group_1/table_annot")
>>> #Create a BatchGenerator from a data_table and separate annotations in a anot_table
>>> train_generator = BatchGenerator(data_table=data_table, annot_in_data_table=False, annot_table=annot_table, batch_size=3, x_field='data', return_batch_ids=True) #create a batch generator 
>>> #Run 2 epochs. 
>>> n_epochs = 2    
>>> for e in range(n_epochs):
...    for batch_num in range(train_generator.n_batches):
...        ids, batch_X, batch_Y = next(train_generator)
...        print("epoch:{0}, batch {1} | instance ids:{2}, X batch shape: {3} labels for instance {4}: {5}".format(e, batch_num, ids, batch_X.shape, ids[0], batch_Y[0]))
epoch:0, batch 0 | instance ids:[0, 1, 2], X batch shape: (3, 12, 12) labels for instance 0: [2, 3]
epoch:0, batch 1 | instance ids:[3, 4, 5], X batch shape: (3, 12, 12) labels for instance 3: [2, 3]
epoch:0, batch 2 | instance ids:[6, 7, 8, 9, 10], X batch shape: (5, 12, 12) labels for instance 6: [2, 3]
epoch:1, batch 0 | instance ids:[0, 1, 2], X batch shape: (3, 12, 12) labels for instance 0: [2, 3]
epoch:1, batch 1 | instance ids:[3, 4, 5], X batch shape: (3, 12, 12) labels for instance 3: [2, 3]
epoch:1, batch 2 | instance ids:[6, 7, 8, 9, 10], X batch shape: (5, 12, 12) labels for instance 6: [2, 3]
>>> h5.close() #close the database handle.
>>> # Creating a Batch Generator from a data tables that includes annotations
>>> h5 = open_file("ketos/tests/assets/mini_narw.h5", 'r') # create the database handle  
>>> data_table = open_table(h5, "/train/data")
>>> #Applying a custom function to the batch
>>> #Takes the mean of each instance in X; leaves Y untouched
>>> def apply_to_batch(X,Y):
...    X = np.mean(X, axis=(1,2)) #since X is a 3d array
...    return (X,Y)
>>> train_generator = BatchGenerator(data_table=data_table, batch_size=3, annot_in_data_table=True, return_batch_ids=False, output_transform_func=apply_to_batch) 
>>> X,Y = next(train_generator)                
>>> #Now each X instance is one single number, instead of a 2d array
>>> #A batch of size 3 is an array of the 3 means
>>> X.shape
(3,)
>>> #Here is how one X instance looks like
>>> X[0]
-37.247124
>>> #Y is the same as before 
>>> Y.shape
(3,)
>>> h5.close()

Methods

get_indices()

Get the indice sequence used for sampling the data table

get_samples(indices[, annot_indices])

Get data samples for specified indices

reset([indices])

Reset the batch generator.

set_return_batch_ids(v)

Change the behaviour of the generator between returning only X,Y or id,X,Y

set_shuffle(v)

Change the behaviour of the generator between shuffling or not shuffling the indices.

get_indices()[source]

Get the indice sequence used for sampling the data table

Returns:
: array

Indices

get_samples(indices, annot_indices=None)[source]

Get data samples for specified indices

Args:
indices: list of ints

Row indices of the samples in the data table

annot_indices: list of ints

Row indices of the matching samples in the annotation table, if applicable.

Returns:
: tuple

A batch of instances (X,Y)

reset(indices=None)[source]

Reset the batch generator.

Resets the batch index counter and reshuffles the sample indices if shuffle was set to True.

Args:
indices: array

Manually specify the sequence of indices that should be used after reset.

set_return_batch_ids(v)[source]

Change the behaviour of the generator between returning only X,Y or id,X,Y

Args:
v: bool

Whether to return id in addition to X,Y

set_shuffle(v)[source]

Change the behaviour of the generator between shuffling or not shuffling the indices.

Args:
v: bool

Whether to return shuffle the indices