BaseModel

class daart.models.base.BaseModel(*args, **kwargs)[source]

Bases: Module

Template for PyTorch models.

Methods Summary

build_model()

Build model from hparams.

forward(*args, **kwargs)

Push data through model.

get_parameters()

Get all model parameters that have gradient updates turned on.

load_parameters_from_file(filepath)

Load parameters from .pt file.

save(filepath)

Save model parameters.

training_step(*args, **kwargs)

Compute loss.

Methods Documentation

build_model()[source]

Build model from hparams.

forward(*args, **kwargs)[source]

Push data through model.

get_parameters()[source]

Get all model parameters that have gradient updates turned on.

load_parameters_from_file(filepath)[source]

Load parameters from .pt file.

save(filepath)[source]

Save model parameters.

training_step(*args, **kwargs)[source]

Compute loss.