Source code for ketos.data_handling.data_feeding

# ================================================================================ #
#   Authors: Fabio Frazao and Oliver Kirsebom                                      #
#   Contact: fsfrazao@dal.ca, oliver.kirsebom@dal.ca                               #
#   Organization: MERIDIAN (https://meridian.cs.dal.ca/)                           #
#   Team: Data Analytics                                                           #
#   Project: ketos                                                                 #
#   Project goal: The ketos library provides functionalities for handling          #
#   and processing acoustic data and applying deep neural networks to sound        #
#   detection and classification tasks.                                            #
#                                                                                  #
#   License: GNU GPLv3                                                             #
#                                                                                  #
#       This program is free software: you can redistribute it and/or modify       #
#       it under the terms of the GNU General Public License as published by       #
#       the Free Software Foundation, either version 3 of the License, or          #
#       (at your option) any later version.                                        #
#                                                                                  #
#       This program is distributed in the hope that it will be useful,            #
#       but WITHOUT ANY WARRANTY; without even the implied warranty of             #
#       MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              #
#       GNU General Public License for more details.                               # 
#                                                                                  #
#       You should have received a copy of the GNU General Public License          #
#       along with this program.  If not, see <https://www.gnu.org/licenses/>.     #
# ================================================================================ #

""" Data feeding module within the ketos library

    This module provides utilities to load data and feed it to models.

    Contents:
        BatchGenerator class
        
        TrainiDataProvider class
"""
import warnings
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
from ketos.data_handling.data_handling import check_data_sanity


[docs]class BatchGenerator(): """ Creates batches to be fed to a model Instances of this class are python generators. They will load one batch at a time from a HDF5 database, which is particularly useful when working with larger than memory datasets. It is also possible to load the entire data set into memory and provide it to the BatchGenerator via the arguments x and y. This can be convenient when working with smaller data sets. Yields: (X,Y) or (ids,X,Y) if 'return_batch_ids' is True. X is a batch of data in the form of an np.array of shape (batch_size,mx,nx) where mx,nx are the shape of one instance of X in the database. The number of dimensions in addition to 'batch_size' will not necessarily be 2, but correspond to the instance shape (1 for 1d instances, 3 for 3d, etc). It is also possible to load multiple data objects per instance by specifying multiple x_field values, e.g., 'x_field=['spectrogram', 'waveform']'. In such cases, the return argument X is a np.array with shape (batch_size,) and each element is a np.void array with length equal to the number of x fields. Each element in this array is a np.array and can be accessed either through use of integer indices or the x_field names, e.g., the first spectrogram in the batch can be accessed as X[0][0] or X[0]['spectrogram']. Similarly, Y is an np.array of shape(batch_size) with the corresponding labels. Each item in the array is a named array of shape=(n_fields), where n_field is the number of fields specified in the 'y_field' argument. For instance, if 'y_field'=['label', 'start', 'end'], you can access the first label with Y[0]['label']. Notice that even if y_field==['label'], you would still use the Y[0]['label'] syntax. Important note: The above remarks regarding the shapes of X and Y assume that the output transform function `output_transform_func` only modifies the contents and not the shapes of X and Y, which may not always be the case. Args: batch_size: int The number of instances in each batch. The last batch of an epoch might have fewer examples, depending on the number of instances in the hdf5_table. If the batch size is greater than the number of instances available, batch_size will be set to the number of instances. and a warning will be issued data_table: pytables table (instance of table.Table()) The HDF5 table containing the data annot_in_data_table: bool Whether or not the annotation fields (e.g.: 'label') is in the data_table (True, default) or in a separate annot_table (False). annot_table: pytables table (instance of table.Table()) A separate table for the annotations(labels), in case they are not included as fields in the data_table. This table must have a 'data_index' field, which corresponds to the the index (row number) of the data instance in the data_tables. Usually, a separete table will be used when the data is strongly annotated (i.e.: possibily more than one annotation per data instance). When there is only one annotation for each data instance, it's recommended that annotations are included in the data_table for performance gains. x: numpy array Array containing the data images. y: numpy array Array containing the data labels. This array is expected to have a one-to-one correspondence to the x array (i.e.: y[0] is expected to have the label for x[0], y[1] for x[1], etc). If there are multiple labels for each data instance in x, use a data_table and an annot_table instead. select_indices: list of ints Indices of those instances that will retrieved from the HDF5 table by the BatchGenerator. By default all instances are retrieved. output_transform_func: function A function to be applied to the batch, transforming the instances. Must accept 'X' and 'Y' and, after processing, also return 'X' and 'Y' in a tuple. x_field: str The name of the column containing the X data in the hdf5_table y_field: str The name of the column containing the Y labels in the hdf5_table shuffle: bool If True, instances are selected randomly (without replacement). If False, instances are selected in the order the appear in the database refresh_on_epoch_end: bool If True, and shuffle is also True, resampling is performed at the end of each epoch resulting in different batches for every epoch. If False, the same batches are used in all epochs. Has no effect if shuffle is False. return_batch_ids: bool If False, each batch will consist of X and Y. If True, the instance indices (as they are in the hdf5_table) will be included ((ids, X, Y)). filter: str A valid PyTables query. If provided, the Batch Generator will query the hdf5 database before defining the batches and only the matching records will be used. Only relevant when data is passed through the hdf5_table argument. If both 'filter' and 'indices' are passed, 'indices' is ignored. n_extend: int Extend every batch by including the last n_extend samples from the previous batch and the first n_extend samples from the following batch. The first batch is only extended at the end, while the last batch is only extended at the beginning. The default value is zero, i.e., no extension. Attr: data: pytables table (instance of table.Table()) The HDF5 table containing the data n_instances: int The number of intances (rows) in the hdf5_table n_batches: int The number of batches of size 'batch_size' for each epoch entry_indices:list of ints A list of all intance indices, in the order used to generate batches for this epoch batch_indices: list of tuples (int,int) A list of (start,end) indices for each batch. These indices refer to the 'entry_indices' attribute. batch_count: int The current batch within the epoch. This will be the batch yielded on the next call to 'next()'. from_memory: bool True if the data are loaded from memory rather than an HDF5 table. Examples: >>> from tables import open_file >>> from ketos.data_handling.database_interface import open_table >>> h5 = open_file("ketos/tests/assets/11x_same_spec.h5", 'r') # create the database handle >>> data_table = open_table(h5, "/group_1/table_data") >>> annot_table = open_table(h5, "/group_1/table_annot") >>> #Create a BatchGenerator from a data_table and separate annotations in a anot_table >>> train_generator = BatchGenerator(data_table=data_table, annot_in_data_table=False, annot_table=annot_table, batch_size=3, x_field='data', return_batch_ids=True) #create a batch generator >>> #Run 2 epochs. >>> n_epochs = 2 >>> for e in range(n_epochs): ... for batch_num in range(train_generator.n_batches): ... ids, batch_X, batch_Y = next(train_generator) ... print("epoch:{0}, batch {1} | instance ids:{2}, X batch shape: {3} labels for instance {4}: {5}".format(e, batch_num, ids, batch_X.shape, ids[0], batch_Y[0])) epoch:0, batch 0 | instance ids:[0, 1, 2], X batch shape: (3, 12, 12) labels for instance 0: [2, 3] epoch:0, batch 1 | instance ids:[3, 4, 5], X batch shape: (3, 12, 12) labels for instance 3: [2, 3] epoch:0, batch 2 | instance ids:[6, 7, 8, 9, 10], X batch shape: (5, 12, 12) labels for instance 6: [2, 3] epoch:1, batch 0 | instance ids:[0, 1, 2], X batch shape: (3, 12, 12) labels for instance 0: [2, 3] epoch:1, batch 1 | instance ids:[3, 4, 5], X batch shape: (3, 12, 12) labels for instance 3: [2, 3] epoch:1, batch 2 | instance ids:[6, 7, 8, 9, 10], X batch shape: (5, 12, 12) labels for instance 6: [2, 3] >>> h5.close() #close the database handle. >>> # Creating a Batch Generator from a data tables that includes annotations >>> h5 = open_file("ketos/tests/assets/mini_narw.h5", 'r') # create the database handle >>> data_table = open_table(h5, "/train/data") >>> #Applying a custom function to the batch >>> #Takes the mean of each instance in X; leaves Y untouched >>> def apply_to_batch(X,Y): ... X = np.mean(X, axis=(1,2)) #since X is a 3d array ... return (X,Y) >>> train_generator = BatchGenerator(data_table=data_table, batch_size=3, annot_in_data_table=True, return_batch_ids=False, output_transform_func=apply_to_batch) >>> X,Y = next(train_generator) >>> #Now each X instance is one single number, instead of a 2d array >>> #A batch of size 3 is an array of the 3 means >>> X.shape (3,) >>> #Here is how one X instance looks like >>> X[0] -37.247124 >>> #Y is the same as before >>> Y.shape (3,) >>> h5.close() """ def __init__(self, batch_size, data_table=None, annot_in_data_table=True, annot_table=None, x=None, y=None, select_indices=None, output_transform_func=None, x_field='data', y_field='label', shuffle=False, refresh_on_epoch_end=False, return_batch_ids=False, filter=None, n_extend=0): self.from_memory = x is not None and y is not None self.annot_in_data_table = annot_in_data_table self.filter = filter self.unique_labels = None if self.from_memory: #TODO: Reinstate 'check_data_sanity' once it is more more flexible # check data sanity currently has restrictive assumptions. # For example, that y is a nx1 array, which is usually true for labels, # but prevents the use of the batch generator for some other purposes, # such as simply return multiple of support data columns with the training data # for pre-processing purposes. #check_data_sanity(x, y) self.x = x self.y = y if select_indices is None: self.select_indices = np.arange(len(self.x), dtype=int) else: self.select_indices = select_indices self.n_instances = len(self.select_indices) else: assert (data_table is not None), 'data_table + annot_table or x + y must be specified' if self.annot_in_data_table == False: assert annot_table is not None,'if annotations are not present in the data_table \ (annot_in_data_table=False), an annotations table (annot_table) must be specified' self.data = data_table self.annot = annot_table self.x_field = x_field self.y_field = y_field if type(self.y_field) is not list: self.y_field = [self.y_field] self.label_field = self.y_field[0] if select_indices is None: self.n_instances = self.data.nrows self.select_indices = self.data.col('id') else: self.n_instances = len(select_indices) self.select_indices = select_indices if self.filter is not None: self.id_row_index = self.data.get_where_list(self.filter) self.select_indices = self.data[self.id_row_index]['id'] self.n_instances = len(self.select_indices) self.batch_size = batch_size if self.batch_size > self.n_instances: warnings.warn("The batch size is greater than the number of instances available. Setting batch_size to n_instances.") self.batch_size = self.n_instances self.shuffle = shuffle self.output_transform_func = output_transform_func self.batch_count = 0 self.refresh_on_epoch_end = refresh_on_epoch_end self.return_batch_ids = return_batch_ids self.n_batches = int(self.n_instances // self.batch_size) self.__update_indices__() self.__create_batches__(n_extend) def __update_indices__(self, indices=None): """ Updates the indices used to divide the instances into batches. A list of indices is kept in the self.data_indices attribute. The order of the indices determines which instances will be placed in each batch. If the self.shuffle is True, the indices are randomly reorganized, resulting in batches with randomly selected instances. """ if indices is None: self.data_indices = self.select_indices.copy() if self.shuffle: np.random.shuffle(self.data_indices) else: self.data_indices = indices def __create_batches__(self, n_ext=0): """ Prepare batches. Divides the indices into batches of self.batch_size, based on the list generated by `update_indices()`. Args: n_ext: int Extend every batch by including the last n_extend samples from the previous batch and the first n_extend samples from the following batch. The first batch is only extended at the end, while the last batch is only extended at the beginning. The default value is zero, i.e., no extension. Returns: list_of_indices: list of tuples A list of tuple, each containing two integer values: the start and end of the batch. These positions refer to the list stored in self.entry_indices. """ ids = self.data_indices # for brevity n_complete_batches = int( self.n_instances // self.batch_size) # number of batches that can accomodate self.batch_size intances extra_instances = self.n_instances % self.batch_size if n_complete_batches == 0: list_of_indices = [list(ids)] else: n = self.batch_size list_of_indices = [list(ids[max(0,i*n-n_ext):min(n*n_complete_batches,(i+1)*n+n_ext)]) for i in range(n_complete_batches)] if extra_instances > 0: extra_instance_ids = list(ids[-extra_instances:]) list_of_indices[-1] = list_of_indices[-1] + extra_instance_ids if self.from_memory: data_indices = list_of_indices annot_indices = list_of_indices else: data_indices = list_of_indices if self.annot_in_data_table == False: index = np.array([(row['data_index'], annot_idx) for annot_idx,row in \ enumerate(self.annot.iterrows()) if row['data_index'] in self.select_indices]) if len(index) > 0: annot_indices = [[index[index[:,0]==data_idx,1] for data_idx in batch] for batch in list_of_indices] annot_indices = [np.concatenate(batch) for batch in annot_indices] else: annot_indices = [None for _ in data_indices] else: annot_indices = None self.batch_indices_data = data_indices self.batch_indices_annot = annot_indices
[docs] def reset(self, indices=None): """ Reset the batch generator. Resets the batch index counter and reshuffles the sample indices if shuffle was set to True. Args: indices: array Manually specify the sequence of indices that should be used after reset. """ if (self.refresh_on_epoch_end and self.batch_count > 0) or indices is not None: self.__update_indices__(indices) self.__create_batches__() self.batch_count = 0
[docs] def get_indices(self): ''' Get the indice sequence used for sampling the data table Returns: : array Indices ''' return self.data_indices
[docs] def set_return_batch_ids(self, v): ''' Change the behaviour of the generator between returning only X,Y or id,X,Y Args: v: bool Whether to return id in addition to X,Y ''' self.return_batch_ids = v
[docs] def set_shuffle(self, v): ''' Change the behaviour of the generator between shuffling or not shuffling the indices. Args: v: bool Whether to return shuffle the indices ''' self.shuffle = v
[docs] def get_samples(self, indices, annot_indices=None): """ Get data samples for specified indices Args: indices: list of ints Row indices of the samples in the data table annot_indices: list of ints Row indices of the matching samples in the annotation table, if applicable. Returns: : tuple A batch of instances (X,Y) """ if self.from_memory: X = np.take(self.x, indices, axis=0) Y = np.take(self.y, indices, axis=0) else: X = self.data[indices][self.x_field] if self.annot_in_data_table == True: Y = self.data[indices][self.y_field] #['label'] # if there is only 1 y-field, convert the ndarray to a normal array if len(self.y_field) == 1: Y = Y[self.label_field] else: #assert annot_indices is not None, "row indices for annotation table must be specified" if annot_indices is None: Y = [None for _ in X] else: data_indices = self.annot.col('data_index')[annot_indices] Y = self.annot[annot_indices][self.y_field] # count how many times each data index occurs in the annotation table index_mul = [np.sum(data_indices==i) for i in indices] # group the labels according to thir data index sections = np.cumsum(index_mul)[:-1] Y = np.split(Y, sections) # if there is only 1 y-field, convert the ndarray to a normal array if len(self.y_field) == 1: Y = [list(y[self.label_field]) for y in Y] if self.output_transform_func is not None: X,Y = self.output_transform_func(X,Y) return (X, Y)
def __iter__(self): return self def __next__(self): """ Return: tuple A batch of instances (X,Y) or, if 'returns_batch_ids" is True, a batch of instances accompanied by their indices (ids, X, Y) """ data_row_index = self.batch_indices_data[self.batch_count] if not self.from_memory and not self.annot_in_data_table: annot_row_index = self.batch_indices_annot[self.batch_count] else: annot_row_index = None self.batch_count += 1 if self.batch_count > (self.n_batches - 1): self.reset() (X,Y) = self.get_samples(indices=data_row_index, annot_indices=annot_row_index) if self.return_batch_ids: return (data_row_index,X,Y) else: return (X, Y)
[docs]class JointBatchGen(): """ Join two or more batch generators. A joint batch generator is composed by multiple BatchGenerator objects. It offers a flexible way of composing custom batches for training neural networks. Each batch is composed by joining the batches of all generators in the 'batch_generators' list. In order to be able to combine batch generators in this manner, the batch generators must yield data batches (X,Y) with the same format. Furthermore, the first dimension must be the batch size. In the case of multimodal generators, the second dimension must be the number of modes. For example, if the generator is returning a waveform and a spectrogram, and the batch size was set to 32, the JointBatchGen expects X to have length 32 and every element in X to have length 2 (corresponding to the two modalities, waveform and spectrogram). An assertion is made at initialization to check that all batch generators yield data with consistent formats. If the assertion fails, an error is thrown. Args: batch_generators: list of BatchGenerator objects A list of 2 or more BatchGenerator instances. n_batches: str or int (default:'min') The number of batches for the joint generator. It can be an integer number, 'min', which will use the lowest n_batches among the batch generators, or 'max, which will use the highest value. shuffle_batch:bool (default:False) If True, shuffle the joint batch before returning it. Note that this only concerns the joint batches and is independent of wheter the joined generators shuffle or not. reset_generators:bool (default:False) If True, reset the current batch counter of each generator whenever the joint generator reaches the n_batches value. This evokes the end-of-epoch behaviour for each batch generator (i.e.: if a batch generator was created with 'duffle_on_epoch_end=True', then it will shuffle at this time, even if that generator's batch counter is not yet at the maximum). return_batch_ids: bool If False, each batch will consist of X and Y. If True, the generator index and the instance indices (as they are in the hdf5_table) will be included ((ids, X, Y)). Default is False. output_transform_func: function A function to be applied to the joint batch, transforming the instances. Must accept 'X' and 'Y' and, after processing, also return 'X' and 'Y' in a tuple. Example: >>> from tables import open_file >>> from ketos.data_handling.database_interface import open_table >>> h5 = open_file("ketos/tests/assets/multimodal.h5", 'r') # create the database handle >>> tbl_pos = open_table(h5, "/train/pos/data") #table with positive samples >>> tbl_neg = open_table(h5, "/train/neg/data") #table with negative samples >>> #Create batch generators for multi-modal data (waveform, spectrogram) >>> generator_pos = BatchGenerator(data_table=tbl_pos, batch_size=2, x_field=['waveform','spectrogram']) >>> generator_neg = BatchGenerator(data_table=tbl_neg, batch_size=3, x_field=['waveform','spectrogram']) >>> #Join the generators >>> generator = JointBatchGen([generator_pos, generator_neg]) >>> #Loading the first batch, we note that the joint generator has a batch size of 2+3=5 >>> #and the waveforms and spectrograms have shapes (3000,) and (129,94), respectively. >>> X, Y = next(generator) >>> print(len(X), len(X[0]), X[0][0].shape, X[0][1].shape) 5 2 (3000,) (94, 129) >>> h5.close() #close the database handle. """ def __init__(self, batch_generators, n_batches="min", shuffle_batch=False, reset_generators=False, return_batch_ids=False, output_transform_func=None): self.batch_generators = batch_generators self.reset_generators = reset_generators self.shuffle_batch = shuffle_batch self.return_batch_ids = return_batch_ids self.output_transform_func = output_transform_func assert n_batches in ("min", "max") or isinstance(n_batches, int), "n_batches must be 'min', 'max' or an integer" if n_batches == "min": self.n_batches = min([gen.n_batches for gen in self.batch_generators]) elif n_batches == "max": self.n_batches = max([gen.n_batches for gen in self.batch_generators]) else: self.n_batches = n_batches self.batch_count = 0 # overwrite return_batch_ids attribute of individual generators, # determine batch size, and check if any of the generators are # loading annotations from a separate annotation table. self.batch_size = 0 annot_in_data_table = True for generator in self.batch_generators: generator.set_return_batch_ids(True) self.batch_size += generator.batch_size if not generator.annot_in_data_table: annot_in_data_table = False # check that the batch generators return consistent data types x_sizes = [] y_sizes = [] for generator in self.batch_generators: i,x,y = next(generator) x_sizes.append(len(x[0]) if isinstance(x[0], (np.void, list, tuple)) else 0) y_sizes.append(len(y[0]) if isinstance(y[0], (np.void, list, tuple)) else 0) generator.reset() assert np.all(np.array(x_sizes)==x_sizes[0]), 'Attempt to join batch generators with different X '\ 'output formats. Only batch generators with the same X and Y format may be joined' if annot_in_data_table: assert np.all(np.array(y_sizes)==y_sizes[0]), 'Attempt to join batch generators with different Y '\ 'output formats. Only batch generators with the same X and Y format may be joined' self.xsiz = x_sizes[0] self.ysiz = y_sizes[0] def __iter__(self): return self def __next__(self): X = [] Y = [] ids = [] for gen_id, gen in enumerate(self.batch_generators): i,x,y = next(gen) i = np.column_stack((gen_id * np.ones(len(i)),i)) i = i.astype(int) ids.append(i) if self.xsiz == 0: X.append(x) else: X += [e for e in x] if self.ysiz == 0: Y.append(y) else: Y += [e for e in y] if self.xsiz == 0: X = np.vstack(X) if self.ysiz == 0: Y = np.concatenate(Y) ids = np.vstack(ids) siz = len(X) if self.shuffle_batch == True: indices = np.arange(siz) np.random.shuffle(indices) if self.xsiz == 0: X = X[indices] else: X = [X[i] for i in indices] if self.ysiz == 0: Y = Y[indices] else: Y = [Y[i] for i in indices] if self.return_batch_ids: ids = ids[indices] self.batch_count += 1 if self.batch_count > (self.n_batches - 1): self.batch_count = 0 if self.reset_generators == True: for gen in self.batch_generators: gen.reset() if self.output_transform_func is not None: X,Y = self.output_transform_func(X,Y) if self.return_batch_ids: return (ids,X,Y) else: return (X,Y)
[docs] def reset(self, indices=None): """ Resets the individual batch generators. Args: indices: array Manually specify the sequence of indices that should be used after reset. """ if indices is None: indices = [None for _ in self.batch_generators] else: assert isinstance(indices, list) and len(indices) == len(self.batch_generators),\ "the length of 'indices' must match the number of batch generators." for idx,gen in zip(indices, self.batch_generators): gen.reset(indices=idx) self.batch_count = 0
[docs] def get_indices(self): ''' Get the indice sequence used for sampling the data tables Returns: : array Indices ''' return [g.get_indices() for g in self.batch_generators]
[docs] def set_return_batch_ids(self, v): ''' Change the behaviour of the generator between returning only X,Y or id,X,Y Args: v: bool Whether to return id in addition to X,Y ''' self.return_batch_ids = v for g in self.batch_generators: g.set_return_batch_ids(v)
[docs] def set_shuffle(self, v): ''' Change the behaviour of the generator between shuffling or not shuffling the indices. Args: v: bool Whether to return shuffle the indices ''' for g in self.batch_generators: g.set_shuffle(v)
class MultiModalBatchGen(): """ Join two or more batch generators. A multi-modal batch generator is composed of multiple BatchGenerator objects. It is intended for use with multi-modal models, i.e., models that integrate several data representations (e.g. waveform and spectrogram). In particular, the multi-modal batch generator provides handy way to load data stored in separate tables. OBS: While the data may be stored in separate tables, the sample sequence must be identical across tables and the tables must of course contain the same number of entries. Each batch is composed by collecting batches from the individual generators to form a a nested list, where the first dimension is the batch size and the second dimension is the number of modes. (This ordering is chosen to be consistent with the conventions used in the `BatchGenerator` and `JointBatchGen` classes.) An assertion is made at initialization to check that the batch generators are loading data from tables with consistent lengths. If the assertion fails, an error is thrown. TODO: Add example Args: batch_generators: list of BatchGenerator objects A list of 2 or more BatchGenerator instances. batch_size: int The number of instances in each batch. shuffle: bool If True, instances are selected randomly (without replacement). If False, instances are selected in the order the appear in the database refresh_on_epoch_end: bool If True, and shuffle is also True, resampling is performed at the end of each epoch resulting in different batches for every epoch. If False, the same batches are used in all epochs. Has no effect if shuffle is False. return_batch_ids: bool If False, each batch will consist of X and Y. If True, the generator index and the instance indices (as they are in the hdf5_table) will be included ((ids, X, Y)). Default is False. output_transform_func: function A function to be applied to the combined batch, transforming the instances. Must accept 'X' and 'Y' and, after processing, also return 'X' and 'Y' in a tuple. """ def __init__(self, batch_generators, batch_size=None, shuffle=False, refresh_on_epoch_end=False, return_batch_ids=False, output_transform_func=None): self.batch_generators = batch_generators self.return_batch_ids = return_batch_ids self.refresh_on_epoch_end = refresh_on_epoch_end self.output_transform_func = output_transform_func self.batch_count = 0 # overwrite attributes of individual generators for generator in self.batch_generators: if batch_size is not None: generator.batch_size = batch_size generator.set_return_batch_ids(True) generator.set_shuffle(shuffle) assert generator.n_batches == self.batch_generators[0].n_batches, "All batch generators must have the same number of batches" assert generator.batch_size == self.batch_generators[0].batch_size, "All batch generators must have the same batch size" # number of batches and batch size self.n_batches = self.batch_generators[0].n_batches self.batch_size = self.batch_generators[0].batch_size # update indices and create batches self.reset() def __iter__(self): return self def __next__(self): """ Get the next batch """ self.batch_count += 1 # collect the outputs of the individual batch generators X, Y = [], [] for generator in self.batch_generators: ids,x,y = next(generator) X.append(x) Y.append(y) # invert ordering: gen,batch -> batch,gen X = [[x[i] for x in X] for i in range(len(X[0]))] Y = [[y[i] for y in Y] for i in range(len(Y[0]))] if self.output_transform_func is not None: X,Y = self.output_transform_func(X,Y) if self.return_batch_ids: return (ids,X,Y) else: return (X,Y) def reset(self, indices=None): """ Reset the batch generator. Resets the batch index counter and reshuffles the sample indices if shuffle was set to True. Args: indices: array Manually specify the sequence of indices that should be used after reset. """ g0 = self.batch_generators[0] g0.reset(indices=indices) for generator in self.batch_generators: generator.reset(indices=g0.get_indices()) self.batch_count = 0 def get_indices(self): ''' Get the indice sequence used for sampling the data table Returns: : array Indices ''' return self.batch_generators[0].get_indices() def set_return_batch_ids(self, v): ''' Change the behaviour of the generator between returning only X,Y or id,X,Y Args: v: bool Whether to return id in addition to X,Y ''' self.return_batch_ids = v for g in self.batch_generators: g.set_return_batch_ids(v) def set_shuffle(self, v): ''' Change the behaviour of the generator between shuffling or not shuffling the indices. Args: v: bool Whether to return shuffle the indices ''' for g in self.batch_generators: g.set_shuffle(v)