Source code for daart.callbacks

"""Callback classes to control training."""

import logging
import numpy as np

# to ignore imports for sphix-autoapidoc
__all__ = ['BaseCallback', 'EarlyStopping', 'AnnealHparam', 'PseudoLabels', 'UPS']


[docs]class BaseCallback(object): """Abstract base class for callbacks."""
[docs] def on_epoch_end(self, data_generator, model, trainer, **kwargs): raise NotImplementedError
[docs]class EarlyStopping(BaseCallback): """Stop training when a monitored quantity has stopped improving. Adapted from: https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py """ def __init__(self, patience=10, delta=0.0): """ Note: It must be noted that the patience parameter counts the number of validation checks with no improvement, and not the number of training epochs. Therefore, with parameters `check_val_interval=10` and `patience=3`, the trainer will perform at least 40 training epochs before being stopped. Parameters ---------- patience : int, optional number of previous checks to average over when checking for increase in loss delta : float, optional minimum change in monitored quantity to qualify as an improvement """ self.patience = patience self.delta = delta self.counter = 0 self.best_epoch = 0 self.best_loss = np.inf
[docs] def on_epoch_end(self, data_generator, model, trainer, logger=None, **kwargs): # skip if this is not a validation epoch if ~np.any(trainer.curr_batch == trainer.val_check_batch): return # use overall validation loss for early stopping loss = logger.get_loss('val') # update best loss and epoch that it occurred if loss < self.best_loss - self.delta: self.best_loss = loss self.best_epoch = trainer.curr_epoch self.counter = 0 else: self.counter += 1 # check if smoothed loss is starting to increase; exit training if so if (trainer.curr_epoch > trainer.min_epochs) and (self.counter >= self.patience): trainer.should_halt = True print_str = '\n== early stopping criteria met; exiting train loop ==\n' print_str += 'training epochs: %d\n' % trainer.curr_epoch print_str += 'end cost: %04f\n' % loss print_str += 'best epoch: %i\n' % self.best_epoch print_str += 'best cost: %04f\n' % self.best_loss logging.info(print_str)
[docs]class AnnealHparam(BaseCallback): """Linearly increase value in an hparam dict.""" def __init__(self, hparams, key, epoch_start, epoch_end, val_start=0): """ Parameters ---------- hparams : dict hparam dict that is an attribute of a daart model key : str key to value to anneal epoch_start : int keep value at `val_start` until this epoch epoch_end : int linearly increase value from `epoch_start` to `epoch_end` val_start : int, optional """ # basic error checking assert epoch_start <= epoch_end assert key in hparams.keys() # store data self.hparams = hparams self.key = key self.epoch_start = epoch_start self.epoch_end = epoch_end self.val_start = val_start self.val_end = self.hparams[self.key]
[docs] def on_epoch_end(self, data_generator, model, trainer, **kwargs): if trainer.curr_epoch < self.epoch_start: self.hparams[self.key] = self.val_start elif trainer.curr_epoch > self.epoch_end: self.hparams[self.key] = self.val_end else: frac = (trainer.curr_epoch - self.epoch_start) / (self.epoch_end - self.epoch_start) self.hparams[self.key] = self.val_end * frac
[docs]class PseudoLabels(BaseCallback): """Implement PseudoLabels algorithm.""" def __init__(self, prob_threshold=0.95, epoch_start=10): self.prob_threshold = prob_threshold self.epoch_start = epoch_start
[docs] def on_epoch_end(self, data_generator, model, trainer, **kwargs): if trainer.curr_epoch < self.epoch_start: return # push training data through model; collect output probabilities outputs_dict = model.predict_labels(data_generator, remove_pad=False) # outputs_dict['labels'] # list of list of numpy arrays # threshold the probabilities to produce pseudo-labels pseudo_labels = [] for dataset in outputs_dict['labels']: # `dataset` is a list of numpy arrays pseudo_labels_data = [] for batch in dataset: if batch.shape[0] > 0: # batch is a numpy array # set all probabilities > threshold to 1 batch[batch >= self.prob_threshold] = 1 # set all other probabilities to 0 batch[batch < 1] = 0 # update background class batch[np.sum(batch, axis=1) == 0, 0] = 1 # turn into a one-hot vector batch = np.argmax(batch, axis=1) pseudo_labels_data.append(batch.astype(int)) pseudo_labels.append(pseudo_labels_data) # total_new_pseudos = \ # np.sum([np.sum([np.sum(b[:, 1:]) for b in data]) for data in pseudo_labels]) # print(total_new_pseudos) # update the data generator with the new psuedo-labels for dataset, labels in zip(data_generator.datasets, pseudo_labels): dataset.data['labels_weak'] = labels
[docs]class UPS(BaseCallback): """Implement uncertainty-aware pseudo-labels algorithm. See details in: https://arxiv.org/pdf/2101.06329.pdf """ def __init__(self, prob_threshold=0.95, variance_threshold=0.05, epoch_start=10): self.prob_threshold = prob_threshold self.variance_threshold = variance_threshold self.epoch_start = epoch_start
[docs] def on_epoch_end(self, data_generator, model, trainer, **kwargs): if trainer.curr_epoch < self.epoch_start: return # push training data through model 10 times to get a sense of variability in output # probabilities; collect output probabilities n_passes = 10 n_datasets = 0 outputs_list = [] # list (over passes) of lists (over datasets) of lists (over batches) for n in range(n_passes): # outputs_dict['labels'] # list of list of numpy arrays outputs_dict = model.predict_labels(data_generator, remove_pad=False, mode='train') outputs_list.append(outputs_dict['labels']) n_datasets = len(outputs_dict['labels']) # threshold the probabilities and the variances across passes to produce pseudo-labels pseudo_labels = [] for dataset in range(n_datasets): n_batches = len(outputs_list[0][dataset]) pseudo_labels_data = [] for batch in range(n_batches): # compute medians and variances of probabilities across passes # batch_data will be of shape (n_t, n_classes, n_passes) batch_data = np.concatenate( [o[dataset][batch][:, :, None] for o in outputs_list], axis=2) new_batch = np.zeros((batch_data.shape[0], batch_data.shape[1])) if batch_data.shape[0] > 0: batch_medians = np.median(batch_data, axis=2) # shape (n_t, n_classes) batch_vars = np.variance(batch_data, axis=2) # shape (n_t, n_classes) # set all probabilities > threshold to 1 new_batch[ (batch_medians >= self.prob_threshold) & (batch_vars <= self.variance_threshold) ] = 1 # update background class new_batch[np.sum(new_batch, axis=1) == 0, 0] = 1 # turn into a one-hot vector new_batch = np.argmax(new_batch, axis=1) pseudo_labels_data.append(new_batch.astype(int)) pseudo_labels.append(pseudo_labels_data) # update the data generator with the new psuedo-labels for dataset, labels in zip(data_generator.datasets, pseudo_labels): dataset.data['labels_weak'] = labels