Source code for daart.utils

"""Utility functions for daart package."""

import logging
import os

from daart.data import compute_sequence_pad, DataGenerator
from daart.transforms import ZScore


# to ignore imports for sphix-autoapidoc
__all__ = ['build_data_generator', 'collect_callbacks']


[docs]def build_data_generator(hparams: dict) -> DataGenerator: """Helper function to build a data generator from hparam dict.""" signals = [] transforms = [] paths = [] for expt_id in hparams['expt_ids']: signals_curr = [] transforms_curr = [] paths_curr = [] # DLC markers or features (e.g. from simba) input_type = hparams.get('input_type', 'markers') base_dir = os.path.join(hparams['data_dir'], input_type) possible_markers_files = [ os.path.join(base_dir, expt_id + '_labeled.h5'), os.path.join(base_dir, expt_id + '_labeled.csv'), os.path.join(base_dir, expt_id + '_labeled.npy'), os.path.join(base_dir, expt_id + '.h5'), os.path.join(base_dir, expt_id + '.csv'), os.path.join(base_dir, expt_id + '.npy'), ] markers_file = None for marker_file_ in possible_markers_files: if os.path.exists(marker_file_): markers_file = marker_file_ break if markers_file is None: msg = f'did not find marker file for {expt_id} in {base_dir}' logging.info(msg) raise FileNotFoundError(msg) signals_curr.append('markers') transforms_curr.append(ZScore()) paths_curr.append(markers_file) # hand labels if hparams.get('lambda_strong', 0) > 0: if expt_id not in hparams.get('expt_ids_to_keep', hparams['expt_ids']): hand_labels_file = None else: base_dir = os.path.join(hparams['data_dir'], 'labels-hand') possible_hand_labels_files = [ os.path.join(base_dir, expt_id + '_labels.csv'), os.path.join(base_dir, expt_id + '.csv'), ] hand_labels_file = None for hand_labels_file_ in possible_hand_labels_files: if os.path.exists(hand_labels_file_): hand_labels_file = hand_labels_file_ break if hand_labels_file is None: logging.warning(f'did not find hand labels file for {expt_id} in {base_dir}') signals_curr.append('labels_strong') transforms_curr.append(None) paths_curr.append(hand_labels_file) # heuristic labels if hparams.get('lambda_weak', 0) > 0: base_dir = os.path.join(hparams['data_dir'], 'labels-heuristic') possible_heur_labels_files = [ os.path.join(base_dir, expt_id + '_labels.csv'), os.path.join(base_dir, expt_id + '.csv'), ] heur_labels_file = None for heur_labels_file_ in possible_heur_labels_files: if os.path.exists(heur_labels_file_): heur_labels_file = heur_labels_file_ break if heur_labels_file is None: logging.warning(f'did not find heuristic labels file for {expt_id} in {base_dir}') signals_curr.append('labels_weak') transforms_curr.append(None) paths_curr.append(heur_labels_file) # tasks if hparams.get('lambda_task', 0) > 0: tasks_labels_file = os.path.join(hparams['data_dir'], 'tasks', expt_id + '.csv') signals_curr.append('tasks') transforms_curr.append(ZScore()) paths_curr.append(tasks_labels_file) # define data generator signals signals.append(signals_curr) transforms.append(transforms_curr) paths.append(paths_curr) # compute padding needed to account for convolutions hparams['sequence_pad'] = compute_sequence_pad(hparams) # build data generator data_gen = DataGenerator( hparams['expt_ids'], signals, transforms, paths, device=hparams['device'], sequence_length=hparams['sequence_length'], sequence_pad=hparams['sequence_pad'], batch_size=hparams['batch_size'], trial_splits=hparams['trial_splits'], train_frac=hparams['train_frac'], input_type=hparams.get('input_type', 'markers'), ) # automatically compute input/output sizes from data hparams['input_size'] = data_gen.input_size hparams['output_size'] = len(data_gen.label_names) if hparams.get('lambda_task', 0) > 0: task_size = 0 for batch in data_gen.datasets[0].data['tasks']: if batch.shape[1] == 0: continue else: task_size = batch.shape[1] break hparams['task_size'] = task_size return data_gen
[docs]def collect_callbacks(hparams: dict) -> list: """Helper function to build a list of callbacks from hparam dict.""" callbacks = [] if hparams['enable_early_stop']: from daart.callbacks import EarlyStopping # Note that patience does not account for val check interval values greater than 1; # for example, if val_check_interval=5 and patience=20, then the model will train # for at least 5 * 20 = 100 epochs before training can terminate callbacks.append(EarlyStopping(patience=hparams['early_stop_history'])) if hparams.get('semi_supervised_algo', 'none') == 'pseudo_labels': from daart.callbacks import AnnealHparam, PseudoLabels if hparams['lambda_weak'] == 0: print('warning! use lambda_weak in model.yaml to weight pseudo label loss') else: callbacks.append(AnnealHparam( hparams=hparams, key='lambda_weak', epoch_start=hparams['anneal_start'], epoch_end=hparams['anneal_end'], )) callbacks.append(PseudoLabels( prob_threshold=hparams['prob_threshold'], epoch_start=hparams['anneal_start'], )) elif hparams.get('semi_supervised_algo', 'none') == 'ups': from daart.callbacks import AnnealHparam, UPS if hparams['lambda_weak'] == 0: print('warning! use lambda_weak in model.yaml to weight pseudo label loss') else: callbacks.append(AnnealHparam( hparams=hparams, key='lambda_weak', epoch_start=hparams['anneal_start'], epoch_end=hparams['anneal_end'], )) callbacks.append(UPS( prob_threshold=hparams['prob_threshold'], variance_threshold=hparams['variance_threshold'], epoch_start=hparams['anneal_start'], )) if hparams.get('variational', False): from daart.callbacks import AnnealHparam callbacks.append(AnnealHparam( hparams=hparams, key='kl_weight', epoch_start=0, epoch_end=100, )) return callbacks