scglue.models.glue.GLUETrainer.fit
- GLUETrainer.fit(data, graph, val_split=None, data_batch_size=None, graph_batch_size=None, align_burnin=None, safe_burnin=True, max_epochs=None, patience=None, reduce_lr_patience=None, wait_n_lrs=None, random_seed=None, directory=None, plugins=None)[source]
Fit network
- Parameters
data (
scglue.models.data.ArrayDataset
) – Data datasetgraph (
scglue.models.data.GraphDataset
) – Graph datasetval_split (
typing.Optional
[float
]) – Validation splitdata_batch_size (
typing.Optional
[int
]) – Number of samples in each data minibatchgraph_batch_size (
typing.Optional
[int
]) – Number of edges in each graph minibatchalign_burnin (
typing.Optional
[int
]) – Number of epochs to wait before starting alignmentsafe_burnin (
bool
) – Whether to postpone learning rate scheduling and earlystopping until after the burnin stagemax_epochs (
typing.Optional
[int
]) – Maximal number of epochspatience (
typing.Optional
[int
]) – Patience of early stoppingreduce_lr_patience (
typing.Optional
[int
]) – Patience to reduce learning ratewait_n_lrs (
typing.Optional
[int
]) – Wait n learning rate scheduling events before starting early stoppingrandom_seed (
typing.Optional
[int
]) – Random seeddirectory (
typing.Optional
[os.PathLike
]) – Directory to store checkpoints and tensorboard logsplugins (
typing.Optional
[typing.List
[scglue.models.base.TrainingPlugin
]]) – Optional list of training plugins
- Return type