plot_training_curves

daart.eval.plot_training_curves(metrics_file: str, dtype: str = 'val', expt_ids: list | None = None, save_file: str | None = None, format: str = 'pdf') None[source]

Create training plots for each term in the objective function.

The dtype argument controls which type of trials are plotted (‘train’ or ‘val’). Additionally, multiple models can be plotted simultaneously by varying one (and only one) of the following parameters:

TODO

Each of these entries must be an array of length 1 except for one option, which can be an array of arbitrary length (corresponding to already trained models). This function generates a single plot with panels for each of the following terms:

  • total loss

  • weak label loss

  • strong label loss

  • prediction loss

Parameters:
  • metrics_file (str) – csv file saved during training

  • dtype (str) – ‘train’ | ‘val’

  • expt_ids (list, optional) – dataset names for easier parsing

  • save_file (str, optional) – absolute path of save file; does not need file extension

  • format (str, optional) – format of saved image; ‘pdf’ | ‘png’ | ‘jpeg’ | …