Inference

Once you have trained a model you’ll likely want to run inference on new videos.

Similar to training, there are a set of high-level functions used to perform inference and evaluate performance; this page details some of the main steps.

Load model

Using a provided model directory, construct a model and load the weights.

import os
import torch
import yaml

from daart.models import Segmenter

model_dir = /path/to/model_dir
model_file = os.path.join(model_dir, 'best_val_model.pt')

hparams_file = os.path.join(model_dir, 'hparams.yaml')
hparams = yaml.safe_load(open(hparams_file, 'rb'))

model = Segmenter(hparams)
model.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage))
model.to(hparams['device'])
model.eval()

Build data generator

To run inference on a new session, you must provide a csv file that contains markers or features from a new session (you must use the same type of inputs the model was trained on).

from daart.data import DataGenerator
from daart.transforms import ZScore

sess_id = <name_of_session>
input_file = /path/to/markers_or_features_csv

# define data generator signals
signals = ['markers']  # same for markers or features
transforms = [ZScore()]
paths = [input_file]

# build data generator
data_gen_test = DataGenerator(
    [sess_id], [signals], [transforms], [paths], device=hparams['device'],
    sequence_length=hparams['sequence_length'], batch_size=hparams['batch_size'],
    trial_splits=hparams['trial_splits'],
    sequence_pad=hparams['sequence_pad'], input_type=hparams['input_type'],
)

Run inference

Inference can be performed by passing the newly constructed data generator to the model’s predict_labels method:

import numpy as np

# predict probabilities from model
print('computing states for %s...' % sess_id, end='')
tmp = model.predict_labels(data_gen_test, return_scores=True)
probs = np.vstack(tmp['labels'][0])
print('done')

# get discrete state by taking argmax over probabilities at each time point
states = np.argmax(probs, axis=1)