Source code for daart.eval

"""Evaluation functions for the daart package."""

import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import recall_score, precision_score
from typeguard import typechecked
from typing import List, Optional, Union

from daart.io import make_dir_if_not_exists

# to ignore imports for sphix-autoapidoc
__all__ = [
    'get_precision_recall',
    'int_over_union',
    'run_lengths',
    'plot_training_curves',
    'load_metrics_csv_as_df',
]


[docs]@typechecked def get_precision_recall( true_classes: np.ndarray, pred_classes: np.ndarray, background: Union[int, None] = 0, n_classes: Optional[int] = None ) -> dict: """Compute precision and recall for classifier. Parameters ---------- true_classes : array-like entries should be in [0, K-1] where K is the number of classes pred_classes : array-like entries should be in [0, K-1] where K is the number of classes background : int or NoneType defines the background class that identifies points with no supervised label; these time points are omitted from the precision and recall calculations; if NoneType, no background class is utilized n_classes : int, optional total number of non-background classes; if NoneType, will be inferred from true classes Returns ------- dict: 'precision' (array-like): precision for each class (including background class) 'recall' (array-like): recall for each class (including background class) """ assert true_classes.shape[0] == pred_classes.shape[0] # find all data points that are not background if background is not None: assert background == 0 # need to generalize obs_idxs = np.where(true_classes != background)[0] else: obs_idxs = np.arange(true_classes.shape[0]) if n_classes is None: n_classes = len(np.unique(true_classes[obs_idxs])) # set of labels to include in metric computations if background is not None: labels = np.arange(1, n_classes + 1) else: labels = np.arange(n_classes) precision = precision_score( true_classes[obs_idxs], pred_classes[obs_idxs], labels=labels, average=None, zero_division=0) recall = recall_score( true_classes[obs_idxs], pred_classes[obs_idxs], labels=labels, average=None, zero_division=0) # replace 0s with NaNs for classes with no ground truth # for n in range(precision.shape[0]): # if precision[n] == 0 and recall[n] == 0: # precision[n] = np.nan # recall[n] = np.nan # compute f1 p = precision r = recall f1 = 2 * p * r / (p + r + 1e-10) return {'precision': p, 'recall': r, 'f1': f1}
[docs]@typechecked def int_over_union(array1: np.ndarray, array2: np.ndarray) -> dict: """Compute intersection over union for two 1D arrays. Parameters ---------- array1 : array-like integer array of shape (n,) array2 : array-like integer array of shape (n,) Returns ------- dict keys are integer values in arrays, values are corresponding IoU (float) """ vals = np.unique(np.concatenate([np.unique(array1), np.unique(array2)])) iou = {val: np.nan for val in vals} for val in vals: intersection = np.sum((array1 == val) & (array2 == val)) union = np.sum((array1 == val) | (array2 == val)) iou[val] = intersection / union return iou
[docs]@typechecked def run_lengths(array: np.ndarray) -> dict: """Compute distribution of run lengths for an array with integer entries. Parameters ---------- array : array-like single-dimensional array Returns ------- dict keys are integer values up to max value in array, values are lists of run lengths Example ------- >>> a = [1, 1, 1, 0, 0, 4, 4, 4, 4, 4, 4, 0, 1, 1, 1, 1] >>> run_lengths(a) {0: [2, 1], 1: [3, 4], 2: [], 3: [], 4: [6]} """ seqs = {k: [] for k in np.arange(np.max(array) + 1)} for key, iterable in itertools.groupby(array): seqs[key].append(len(list(iterable))) return seqs
[docs]@typechecked def plot_training_curves( metrics_file: str, dtype: str = 'val', expt_ids: Optional[list] = None, save_file: Optional[str] = None, format: str = 'pdf' ) -> None: """Create training plots for each term in the objective function. The `dtype` argument controls which type of trials are plotted ('train' or 'val'). Additionally, multiple models can be plotted simultaneously by varying one (and only one) of the following parameters: TODO Each of these entries must be an array of length 1 except for one option, which can be an array of arbitrary length (corresponding to already trained models). This function generates a single plot with panels for each of the following terms: - total loss - weak label loss - strong label loss - prediction loss Parameters ---------- metrics_file : str csv file saved during training dtype : str 'train' | 'val' expt_ids : list, optional dataset names for easier parsing save_file : str, optional absolute path of save file; does not need file extension format : str, optional format of saved image; 'pdf' | 'png' | 'jpeg' | ... """ metrics_list = [ 'loss', 'loss_weak', 'loss_strong', 'loss_pred', 'loss_task', 'loss_kl', 'fc' ] metrics_dfs = [load_metrics_csv_as_df(metrics_file, metrics_list, expt_ids=expt_ids)] metrics_df = pd.concat(metrics_dfs, sort=False) if isinstance(expt_ids, list) and len(expt_ids) > 1: hue = 'dataset' else: hue = None sns.set_style('white') sns.set_context('talk') data_queried = metrics_df[ (metrics_df.epoch > 10) & ~pd.isna(metrics_df.val) & (metrics_df.dtype == dtype)] g = sns.FacetGrid( data_queried, col='loss', col_wrap=2, hue=hue, sharey=False, height=4) g = g.map(plt.plot, 'epoch', 'val').add_legend() if save_file is not None: make_dir_if_not_exists(save_file) g.savefig(save_file + '.' + format, dpi=300, format=format) plt.close()
[docs]@typechecked def load_metrics_csv_as_df( metric_file: str, metrics_list: List[str], expt_ids: Optional[List[str]] = None, test: bool = False ) -> pd.DataFrame: """Load metrics csv file and return as a pandas dataframe for easy plotting. Parameters ---------- metric_file : str csv file saved during training metrics_list : list names of metrics to pull from csv; do not prepend with 'tr', 'val', or 'test' expt_ids : list, optional dataset names for easier parsing test : bool, optional True to only return test values (computed once at end of training) Returns ------- pandas.DataFrame object """ metrics = pd.read_csv(metric_file) # collect data from csv file metrics_df = [] for i, row in metrics.iterrows(): if row['dataset'] == -1: dataset = 'all' elif expt_ids is not None: dataset = expt_ids[int(row['dataset'])] else: dataset = row['dataset'] if test: test_dict = {'dataset': dataset, 'epoch': row['epoch'], 'dtype': 'test'} for metric in metrics_list: name = 'test_%s' % metric if name not in row.keys(): continue metrics_df.append(pd.DataFrame( {**test_dict, 'loss': metric, 'val': row[name]}, index=[0])) else: # make dict for val data val_dict = {'dataset': dataset, 'epoch': row['epoch'], 'dtype': 'val'} for metric in metrics_list: name = 'val_%s' % metric if name not in row.keys(): continue metrics_df.append(pd.DataFrame( {**val_dict, 'loss': metric, 'val': row[name]}, index=[0])) # make dict for train data tr_dict = {'dataset': dataset, 'epoch': row['epoch'], 'dtype': 'train'} for metric in metrics_list: name = 'tr_%s' % metric if name not in row.keys(): continue metrics_df.append(pd.DataFrame( {**tr_dict, 'loss': metric, 'val': row[name]}, index=[0])) return pd.concat(metrics_df, sort=True)