scglue.models.glue.GLUETrainer.fit
- GLUETrainer.fit(data, graph, val_split=None, data_batch_size=None, graph_batch_size=None, align_burnin=None, safe_burnin=True, max_epochs=None, patience=None, reduce_lr_patience=None, wait_n_lrs=None, random_seed=None, directory=None, plugins=None)[source]
Fit network
- Parameters:
data (
scglue.models.data.ArrayDataset) – Data datasetgraph (
scglue.models.data.GraphDataset) – Graph datasetval_split (
typing.Optional[float]) – Validation splitdata_batch_size (
typing.Optional[int]) – Number of samples in each data minibatchgraph_batch_size (
typing.Optional[int]) – Number of edges in each graph minibatchalign_burnin (
typing.Optional[int]) – Number of epochs to wait before starting alignmentsafe_burnin (
bool) – Whether to postpone learning rate scheduling and earlystopping until after the burnin stagemax_epochs (
typing.Optional[int]) – Maximal number of epochspatience (
typing.Optional[int]) – Patience of early stoppingreduce_lr_patience (
typing.Optional[int]) – Patience to reduce learning ratewait_n_lrs (
typing.Optional[int]) – Wait n learning rate scheduling events before starting early stoppingrandom_seed (
typing.Optional[int]) – Random seeddirectory (
typing.Optional[os.PathLike]) – Directory to store checkpoints and tensorboard logsplugins (
typing.Optional[typing.List[scglue.models.base.TrainingPlugin]]) – Optional list of training plugins
- Return type: