plot_training_curves¶
- daart.eval.plot_training_curves(metrics_file: str, dtype: str = 'val', expt_ids: list | None = None, save_file: str | None = None, format: str = 'pdf') None[source]¶
Create training plots for each term in the objective function.
The dtype argument controls which type of trials are plotted (‘train’ or ‘val’). Additionally, multiple models can be plotted simultaneously by varying one (and only one) of the following parameters:
TODO
Each of these entries must be an array of length 1 except for one option, which can be an array of arbitrary length (corresponding to already trained models). This function generates a single plot with panels for each of the following terms:
total loss
weak label loss
strong label loss
prediction loss
- Parameters:
metrics_file (str) – csv file saved during training
dtype (str) – ‘train’ | ‘val’
expt_ids (list, optional) – dataset names for easier parsing
save_file (str, optional) – absolute path of save file; does not need file extension
format (str, optional) – format of saved image; ‘pdf’ | ‘png’ | ‘jpeg’ | …