scglue.models.glue.GLUETrainer.fit¶
-
GLUETrainer.
fit
(data, graph, val_split=None, data_batch_size=None, graph_batch_size=None, align_burnin=None, safe_burnin=True, max_epochs=None, patience=None, reduce_lr_patience=None, wait_n_lrs=None, random_seed=None, directory=None, plugins=None)[source]¶ Fit network
- Parameters
data (
ArrayDataset
) – Data datasetgraph (
GraphDataset
) – Graph datasetdata_batch_size (
Optional
[int
]) – Number of samples in each data minibatchgraph_batch_size (
Optional
[int
]) – Number of edges in each graph minibatchalign_burnin (
Optional
[int
]) – Number of epochs to wait before starting alignmentsafe_burnin (
bool
) – Whether to postpone learning rate scheduling and earlystopping until after the burnin stagereduce_lr_patience (
Optional
[int
]) – Patience to reduce learning ratewait_n_lrs (
Optional
[int
]) – Wait n learning rate scheduling events before starting early stoppingdirectory (
Optional
[PathLike
]) – Directory to store checkpoints and tensorboard logsplugins (
Optional
[List
[TrainingPlugin
]]) – Optional list of training plugins
- Return type