scglue.models.glue.GLUETrainer

class scglue.models.glue.GLUETrainer(net, lam_data=None, lam_kl=None, lam_graph=None, lam_align=None, modality_weight=None, optim=None, lr=None, **kwargs)[source]

Bases: Trainer

Trainer for GLUE

Parameters:

Methods

compute_losses

Compute loss functions

fit

Fit network

format_data

Format data tensors :rtype: typing.Tuple[typing.Mapping[str, torch.Tensor], typing.Mapping[str, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]

get_losses

Get loss values for given data

load_state_dict

Load state from a state dict

state_dict

State dict

train_step

A single training step

val_step

A single validation step

Attributes

logger