BatchGenerator
- class ketos.data_handling.data_feeding.BatchGenerator(batch_size, data_table=None, annot_in_data_table=True, annot_table=None, x=None, y=None, select_indices=None, output_transform_func=None, x_field='data', y_field='label', shuffle=False, refresh_on_epoch_end=False, return_batch_ids=False, filter=None, n_extend=0)[source]
Creates batches to be fed to a model
Instances of this class are python generators. They will load one batch at a time from a HDF5 database, which is particularly useful when working with larger than memory datasets.
It is also possible to load the entire data set into memory and provide it to the BatchGenerator via the arguments x and y. This can be convenient when working with smaller data sets.
Yields: (X,Y) or (ids,X,Y) if ‘return_batch_ids’ is True.
X is a batch of data in the form of an np.array of shape (batch_size,mx,nx) where mx,nx are the shape of one instance of X in the database. The number of dimensions in addition to ‘batch_size’ will not necessarily be 2, but correspond to the instance shape (1 for 1d instances, 3 for 3d, etc).
It is also possible to load multiple data objects per instance by specifying multiple x_field values, e.g., ‘x_field=[‘spectrogram’, ‘waveform’]’. In such cases, the return argument X is a np.array with shape (batch_size,) and each element is a np.void array with length equal to the number of x fields. Each element in this array is a np.array and can be accessed either through use of integer indices or the x_field names, e.g., the first spectrogram in the batch can be accessed as X[0][0] or X[0][‘spectrogram’].
Similarly, Y is an np.array of shape(batch_size) with the corresponding labels. Each item in the array is a named array of shape=(n_fields), where n_field is the number of fields specified in the ‘y_field’ argument. For instance, if ‘y_field’=[‘label’, ‘start’, ‘end’], you can access the first label with Y[0][‘label’]. Notice that even if y_field==[‘label’], you would still use the Y[0][‘label’] syntax.
Important note: The above remarks regarding the shapes of X and Y assume that the output transform function output_transform_func only modifies the contents and not the shapes of X and Y, which may not always be the case.
- Args:
- batch_size: int
The number of instances in each batch. The last batch of an epoch might have fewer examples, depending on the number of instances in the hdf5_table. If the batch size is greater than the number of instances available, batch_size will be set to the number of instances. and a warning will be issued
- data_table: pytables table (instance of table.Table())
The HDF5 table containing the data
- annot_in_data_table: bool
Whether or not the annotation fields (e.g.: ‘label’) is in the data_table (True, default) or in a separate annot_table (False).
- annot_table: pytables table (instance of table.Table())
A separate table for the annotations(labels), in case they are not included as fields in the data_table. This table must have a ‘data_index’ field, which corresponds to the the index (row number) of the data instance in the data_tables. Usually, a separete table will be used when the data is strongly annotated (i.e.: possibily more than one annotation per data instance). When there is only one annotation for each data instance, it’s recommended that annotations are included in the data_table for performance gains.
- x: numpy array
Array containing the data images.
- y: numpy array
Array containing the data labels. This array is expected to have a one-to-one correspondence to the x array (i.e.: y[0] is expected to have the label for x[0], y[1] for x[1], etc). If there are multiple labels for each data instance in x, use a data_table and an annot_table instead.
- select_indices: list of ints
Indices of those instances that will retrieved from the HDF5 table by the BatchGenerator. By default all instances are retrieved.
- output_transform_func: function
A function to be applied to the batch, transforming the instances. Must accept ‘X’ and ‘Y’ and, after processing, also return ‘X’ and ‘Y’ in a tuple.
- x_field: str
The name of the column containing the X data in the hdf5_table
- y_field: str
The name of the column containing the Y labels in the hdf5_table
- shuffle: bool
If True, instances are selected randomly (without replacement). If False, instances are selected in the order the appear in the database
- refresh_on_epoch_end: bool
If True, and shuffle is also True, resampling is performed at the end of each epoch resulting in different batches for every epoch. If False, the same batches are used in all epochs. Has no effect if shuffle is False.
- return_batch_ids: bool
If False, each batch will consist of X and Y. If True, the instance indices (as they are in the hdf5_table) will be included ((ids, X, Y)).
- filter: str
A valid PyTables query. If provided, the Batch Generator will query the hdf5 database before defining the batches and only the matching records will be used. Only relevant when data is passed through the hdf5_table argument. If both ‘filter’ and ‘indices’ are passed, ‘indices’ is ignored.
- n_extend: int
Extend every batch by including the last n_extend samples from the previous batch and the first n_extend samples from the following batch. The first batch is only extended at the end, while the last batch is only extended at the beginning. The default value is zero, i.e., no extension.
- Attr:
- data: pytables table (instance of table.Table())
The HDF5 table containing the data
- n_instances: int
The number of intances (rows) in the hdf5_table
- n_batches: int
The number of batches of size ‘batch_size’ for each epoch
- entry_indices:list of ints
A list of all intance indices, in the order used to generate batches for this epoch
- batch_indices: list of tuples (int,int)
A list of (start,end) indices for each batch. These indices refer to the ‘entry_indices’ attribute.
- batch_count: int
The current batch within the epoch. This will be the batch yielded on the next call to ‘next()’.
- from_memory: bool
True if the data are loaded from memory rather than an HDF5 table.
- Examples:
>>> from tables import open_file >>> from ketos.data_handling.database_interface import open_table >>> h5 = open_file("ketos/tests/assets/11x_same_spec.h5", 'r') # create the database handle >>> data_table = open_table(h5, "/group_1/table_data") >>> annot_table = open_table(h5, "/group_1/table_annot") >>> #Create a BatchGenerator from a data_table and separate annotations in a anot_table >>> train_generator = BatchGenerator(data_table=data_table, annot_in_data_table=False, annot_table=annot_table, batch_size=3, x_field='data', return_batch_ids=True) #create a batch generator >>> #Run 2 epochs. >>> n_epochs = 2 >>> for e in range(n_epochs): ... for batch_num in range(train_generator.n_batches): ... ids, batch_X, batch_Y = next(train_generator) ... print("epoch:{0}, batch {1} | instance ids:{2}, X batch shape: {3} labels for instance {4}: {5}".format(e, batch_num, ids, batch_X.shape, ids[0], batch_Y[0])) epoch:0, batch 0 | instance ids:[0, 1, 2], X batch shape: (3, 12, 12) labels for instance 0: [2, 3] epoch:0, batch 1 | instance ids:[3, 4, 5], X batch shape: (3, 12, 12) labels for instance 3: [2, 3] epoch:0, batch 2 | instance ids:[6, 7, 8, 9, 10], X batch shape: (5, 12, 12) labels for instance 6: [2, 3] epoch:1, batch 0 | instance ids:[0, 1, 2], X batch shape: (3, 12, 12) labels for instance 0: [2, 3] epoch:1, batch 1 | instance ids:[3, 4, 5], X batch shape: (3, 12, 12) labels for instance 3: [2, 3] epoch:1, batch 2 | instance ids:[6, 7, 8, 9, 10], X batch shape: (5, 12, 12) labels for instance 6: [2, 3] >>> h5.close() #close the database handle. >>> # Creating a Batch Generator from a data tables that includes annotations >>> h5 = open_file("ketos/tests/assets/mini_narw.h5", 'r') # create the database handle >>> data_table = open_table(h5, "/train/data") >>> #Applying a custom function to the batch >>> #Takes the mean of each instance in X; leaves Y untouched >>> def apply_to_batch(X,Y): ... X = np.mean(X, axis=(1,2)) #since X is a 3d array ... return (X,Y) >>> train_generator = BatchGenerator(data_table=data_table, batch_size=3, annot_in_data_table=True, return_batch_ids=False, output_transform_func=apply_to_batch) >>> X,Y = next(train_generator) >>> #Now each X instance is one single number, instead of a 2d array >>> #A batch of size 3 is an array of the 3 means >>> X.shape (3,) >>> #Here is how one X instance looks like >>> X[0] -37.247124 >>> #Y is the same as before >>> Y.shape (3,) >>> h5.close()
Methods
Get the indice sequence used for sampling the data table
get_samples
(indices[, annot_indices])Get data samples for specified indices
reset
([indices])Reset the batch generator.
Change the behaviour of the generator between returning only X,Y or id,X,Y
set_shuffle
(v)Change the behaviour of the generator between shuffling or not shuffling the indices.
- get_indices()[source]
Get the indice sequence used for sampling the data table
- Returns:
- : array
Indices
- get_samples(indices, annot_indices=None)[source]
Get data samples for specified indices
- Args:
- indices: list of ints
Row indices of the samples in the data table
- annot_indices: list of ints
Row indices of the matching samples in the annotation table, if applicable.
- Returns:
- : tuple
A batch of instances (X,Y)
- reset(indices=None)[source]
Reset the batch generator.
Resets the batch index counter and reshuffles the sample indices if shuffle was set to True.
- Args:
- indices: array
Manually specify the sequence of indices that should be used after reset.