Source code for daart.io

"""File IO for daart package."""

import csv
import numpy as np
import os
import pickle
from typing import List, Optional, Union
from typeguard import typechecked
import yaml


__all__ = [
    'get_subdirs',
    'get_expt_dir',
    'get_model_dir',
    'get_model_params',
    'find_experiment',
    'read_expt_info_from_csv',
    'export_expt_info_to_csv',
    'export_hparams',
    'make_dir_if_not_exists',
]


[docs]@typechecked def get_subdirs(path: str) -> List[str]: """Get all first-level subdirectories in a given path (no recursion). Parameters ---------- path : str absolute path Returns ------- list first-level subdirectories in :obj:`path` """ if not os.path.exists(path): raise NotADirectoryError('%s is not a path' % path) try: s = next(os.walk(path))[1] except StopIteration: raise StopIteration('%s does not contain any subdirectories' % path) if len(s) == 0: raise StopIteration('%s does not contain any subdirectories' % path) return s
[docs]@typechecked def get_expt_dir(base_dir: str, expt_ids: Union[str, List[str]]) -> str: """Construct experiment directory given base directory and list of experiment ids. Parameters ---------- base_dir : str base results directory expt_ids : str or list single experiment id (str) of a list of experiment ids that will define a "multisession" directory Returns ------- str absolute path of experiment directory """ if isinstance(expt_ids, list) and len(expt_ids) > 1: # multisession; see if multisession already exists; if not, create a new one try: subdirs = get_subdirs(base_dir) except StopIteration: # fresh results directory subdirs = [] expt_dir = None max_val = -1 for subdir in subdirs: if subdir[:5] == 'multi': # load csv containing expt_ids multi_sess = read_expt_info_from_csv( os.path.join(base_dir, subdir, 'expt_info.csv')) # compare to current ids multi_sess = [row['expt'] for row in multi_sess] if sorted(multi_sess) == sorted(expt_ids): expt_dir = subdir break else: max_val = np.max([max_val, int(subdir.split('-')[-1])]) if expt_dir is None: expt_dir = 'multi-' + str(max_val + 1) # save csv with expt ids export_expt_info_to_csv(os.path.join(base_dir, expt_dir), expt_ids) else: if isinstance(expt_ids, list): expt_dir = expt_ids[0] else: expt_dir = expt_ids return os.path.join(base_dir, expt_dir)
[docs]@typechecked def get_model_dir(base_dir: str, model_params: dict) -> str: """Helper function to construct model directory from model param dict. Parameters ---------- base_dir : str base results directory model_params : dict should contain the keys `backbone` and optionally `experiment_name` Returns ------- str absolute path of model directory """ if model_params['model_class'] == 'segmenter': model_dir = model_params['backbone'] else: model_dir = model_params['model_class'] return os.path.join(base_dir, model_dir, model_params.get('experiment_name', ''))
[docs]@typechecked def get_model_params(hparams: dict) -> dict: """Returns dict containing all params considered essential for defining a model of that type. Parameters ---------- hparams : dict all relevant hparams for the given model type will be pulled from this dict Returns ------- dict hparams dict """ model_class = hparams['model_class'] backbone = hparams['backbone'] # start with general params hparams_less = { 'model_class': model_class, 'rng_seed_train': hparams['rng_seed_train'], 'rng_seed_model': hparams['rng_seed_model'], 'trial_splits': hparams['trial_splits'], 'train_frac': hparams['train_frac'], 'backbone': hparams['backbone'], 'sequence_length': hparams['sequence_length'], 'batch_size': hparams['batch_size'], 'input_type': hparams['input_type'], } if model_class == 'segmenter': hparams_less['lambda_weak'] = hparams['lambda_weak'] hparams_less['lambda_strong'] = hparams['lambda_strong'] hparams_less['lambda_pred'] = hparams['lambda_pred'] hparams_less['lambda_task'] = hparams.get('lambda_task', 0) hparams_less['variational'] = hparams.get('variational', False) hparams_less['semi_supervised_algo'] = hparams.get('semi_supervised_algo', None) if hparams_less['semi_supervised_algo'] == 'pseudo_labels': hparams_less['prob_threshold'] = hparams['prob_threshold'] hparams_less['anneal_start'] = hparams['anneal_start'] hparams_less['anneal_end'] = hparams['anneal_end'] elif hparams_less['semi_supervised_algo'] == 'ups': hparams_less['prob_threshold'] = hparams['prob_threshold'] hparams_less['variance_threshold'] = hparams['variance_threshold'] hparams_less['anneal_start'] = hparams['anneal_start'] hparams_less['anneal_end'] = hparams['anneal_end'] elif model_class == 'random-forest' or model_class == 'xgboost': hparams_less.pop('rng_seed_train') hparams_less.pop('backbone') hparams_less.pop('sequence_length') hparams_less.pop('batch_size') else: raise NotImplementedError('"%s" is not a valid model class' % model_class) # get backbone-specific params if model_class == 'segmenter': if backbone == 'temporal-mlp': hparams_less['learning_rate'] = hparams['learning_rate'] hparams_less['n_hid_layers'] = hparams['n_hid_layers'] if hparams['n_hid_layers'] != 0: hparams_less['n_hid_units'] = hparams['n_hid_units'] hparams_less['n_lags'] = hparams['n_lags'] hparams_less['activation'] = hparams['activation'] hparams_less['l2_reg'] = hparams['l2_reg'] elif backbone in ['lstm', 'gru']: hparams_less['learning_rate'] = hparams['learning_rate'] hparams_less['n_hid_layers'] = hparams['n_hid_layers'] if hparams['n_hid_layers'] != 0: hparams_less['n_hid_units'] = hparams['n_hid_units'] hparams_less['activation'] = hparams['activation'] hparams_less['l2_reg'] = hparams['l2_reg'] hparams_less['bidirectional'] = hparams['bidirectional'] elif backbone in ['tcn', 'dtcn']: hparams_less['learning_rate'] = hparams['learning_rate'] hparams_less['n_hid_layers'] = hparams['n_hid_layers'] if hparams['n_hid_layers'] != 0: hparams_less['n_hid_units'] = hparams['n_hid_units'] hparams_less['n_lags'] = hparams['n_lags'] hparams_less['activation'] = hparams['activation'] hparams_less['l2_reg'] = hparams['l2_reg'] if backbone == 'dtcn': hparams_less['dropout'] = hparams['dropout'] else: raise NotImplementedError('"%s" is not a valid backbone network' % backbone) return hparams_less
[docs]@typechecked def find_experiment( hparams: dict, verbose: bool = False, keys_to_sweep: List[str] = []) -> List[str]: """Search testtube versions to find if experiment with the same hyperparameters has been fit. Parameters ---------- hparams : dict needs to contain enough information to specify a test tube experiment (model + training parameters) verbose : bool True to print desired hparams keys_to_sweep : list of strs these can be any value Returns ------- list """ # fill out path info if not present if 'tt_expt_dir' in hparams: tt_expt_dir = hparams['tt_expt_dir'] else: if 'model_dir' not in hparams: if 'expt_dir' not in hparams: hparams['expt_dir'] = get_expt_dir(hparams['results_dir'], hparams['expt_ids']) hparams['model_dir'] = get_model_dir(hparams['expt_dir'], hparams) tt_expt_dir = os.path.join(hparams['model_dir'], hparams['tt_experiment_name']) try: tt_versions = get_subdirs(tt_expt_dir) except StopIteration: # no versions yet return [] # get model-specific params hparams_req = get_model_params(hparams) # remove params if we don't want a specific value for key in keys_to_sweep: del hparams_req[key] version_list = [] for version in tt_versions: # try to load hparams try: try: version_file = os.path.join(tt_expt_dir, version, 'hparams.pkl') with open(version_file, 'rb') as f: hparams_ = pickle.load(f) except FileNotFoundError: version_file = os.path.join(tt_expt_dir, version, 'hparams.yaml') with open(version_file, 'r') as f: hparams_ = yaml.safe_load(f) if all([hparams_[key] == hparams_req[key] for key in hparams_req.keys()]): # found match - did it finish training? if hparams_['training_completed']: version_list.append(os.path.join(tt_expt_dir, version)) if len(keys_to_sweep) == 0: # we found the only model we're looking for break else: if verbose: print('unmatched keys, %s:' % version) for key in hparams_req.keys(): if hparams_[key] != hparams_req[key]: print('{}: {} vs {}'.format(key, hparams_[key], hparams_req[key])) print() except IOError: # various reasons why this may fail; all mean that this version is not what we seek continue except KeyError: # usually occurs when checking older models against newer models with more hparams continue if len(version_list) == 0 and verbose: print('could not find match for requested hyperparameters: {}'.format(hparams_req)) return version_list
[docs]@typechecked def read_expt_info_from_csv(expt_file: str) -> List[dict]: """Read csv file that contains expt id info. Parameters ---------- expt_file : str /full/path/to/expt_info.csv Returns ------- list list of dicts with expt info """ expts_multi = [] # load and parse csv file that contains single session info with open(expt_file) as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: expts_multi.append(dict(row)) return expts_multi
[docs]@typechecked def export_expt_info_to_csv(expt_dir: str, ids_list: List[str]) -> None: """Export list of expt ids to csv file. Parameters ---------- expt_dir : str absolute path for where to save `expt_info.csv` file ids_list : list list which contains each expt name """ expt_file = os.path.join(expt_dir, 'expt_info.csv') if not os.path.isdir(expt_dir): os.makedirs(expt_dir) with open(expt_file, mode='w') as f: expt_writer = csv.DictWriter(f, fieldnames=['expt']) expt_writer.writeheader() for id in ids_list: expt_writer.writerow({'expt': id})
[docs]@typechecked def export_hparams(hparams: dict, filename: Optional[str] = None) -> None: """Export hyperparameter dictionary as a yaml file. Parameters ---------- hparams : dict hyperparameter dict to export filename : str, optional filename to save hparams as; if None, filename is constructed from hparams """ if filename is None: filename = os.path.join(hparams['tt_version_dir'], 'hparams.yaml') with open(filename, 'w') as f: yaml.dump(hparams, f)
[docs]@typechecked def make_dir_if_not_exists(save_file: str) -> None: """Utility function for creating necessary dictories for a specified filename. Parameters ---------- save_file : str absolute path of save file """ save_dir = os.path.dirname(save_file) os.makedirs(save_dir, exist_ok=True)