scglue.models 源代码

r"""
Integration models
"""

import os
from pathlib import Path
from typing import Mapping, Optional

import dill
import networkx as nx
import numpy as np
import pandas as pd
from anndata import AnnData

from ..data import estimate_balancing_weight
from ..typehint import Kws
from ..utils import config, logged
from .base import Model
from .dx import integration_consistency
from .nn import autodevice
from .scclue import SCCLUEModel
from .scglue import PairedSCGLUEModel, SCGLUEModel


[文档]@logged def configure_dataset( adata: AnnData, prob_model: str, use_highly_variable: bool = True, use_layer: Optional[str] = None, use_rep: Optional[str] = None, use_batch: Optional[str] = None, use_cell_type: Optional[str] = None, use_dsc_weight: Optional[str] = None, use_obs_names: bool = False ) -> None: r""" Configure dataset for model training Parameters ---------- adata Dataset to be configured prob_model Probabilistic generative model used by the decoder, must be one of ``{"Normal", "ZIN", "ZILN", "NB", "ZINB"}``. use_highly_variable Whether to use highly variable features use_layer Data layer to use (key in ``adata.layers``) use_rep Data representation to use as the first encoder transformation (key in ``adata.obsm``) use_batch Data batch to use (key in ``adata.obs``) use_cell_type Data cell type to use (key in ``adata.obs``) use_dsc_weight Discriminator sample weight to use (key in ``adata.obs``) use_obs_names Whether to use ``obs_names`` to mark paired cells across different datasets Note ----- The ``use_rep`` option applies to encoder inputs, but not the decoders, which are always fitted on data in the original space. """ if config.ANNDATA_KEY in adata.uns: configure_dataset.logger.warning( "`configure_dataset` has already been called. " "Previous configuration will be overwritten!" ) data_config = {} data_config["prob_model"] = prob_model if use_highly_variable: if "highly_variable" not in adata.var: raise ValueError("Please mark highly variable features first!") data_config["use_highly_variable"] = True data_config["features"] = adata.var.query("highly_variable").index.to_numpy().tolist() else: data_config["use_highly_variable"] = False data_config["features"] = adata.var_names.to_numpy().tolist() if use_layer: if use_layer not in adata.layers: raise ValueError("Invalid `use_layer`!") data_config["use_layer"] = use_layer else: data_config["use_layer"] = None if use_rep: if use_rep not in adata.obsm: raise ValueError("Invalid `use_rep`!") data_config["use_rep"] = use_rep data_config["rep_dim"] = adata.obsm[use_rep].shape[1] else: data_config["use_rep"] = None data_config["rep_dim"] = None if use_batch: if use_batch not in adata.obs: raise ValueError("Invalid `use_batch`!") data_config["use_batch"] = use_batch data_config["batches"] = pd.Index( adata.obs[use_batch] ).dropna().drop_duplicates().sort_values().to_numpy() # AnnData does not support saving pd.Index in uns else: data_config["use_batch"] = None data_config["batches"] = None if use_cell_type: if use_cell_type not in adata.obs: raise ValueError("Invalid `use_cell_type`!") data_config["use_cell_type"] = use_cell_type data_config["cell_types"] = pd.Index( adata.obs[use_cell_type] ).dropna().drop_duplicates().sort_values().to_numpy() # AnnData does not support saving pd.Index in uns else: data_config["use_cell_type"] = None data_config["cell_types"] = None if use_dsc_weight: if use_dsc_weight not in adata.obs: raise ValueError("Invalid `use_dsc_weight`!") data_config["use_dsc_weight"] = use_dsc_weight else: data_config["use_dsc_weight"] = None data_config["use_obs_names"] = use_obs_names adata.uns[config.ANNDATA_KEY] = data_config
[文档]def load_model(fname: os.PathLike) -> Model: r""" Load model from file Parameters ---------- fname Specifies path to the file Returns ------- model Loaded model """ fname = Path(fname) with fname.open("rb") as f: model = dill.load(f) model.upgrade() # pylint: disable=no-member model.net.device = autodevice() # pylint: disable=no-member return model
[文档]@logged def fit_SCGLUE( adatas: Mapping[str, AnnData], graph: nx.Graph, model: type = SCGLUEModel, init_kws: Kws = None, compile_kws: Kws = None, fit_kws: Kws = None, balance_kws: Kws = None ) -> SCGLUEModel: r""" Fit GLUE model to integrate single-cell multi-omics data Parameters ---------- adatas Single-cell datasets (indexed by modality name) graph Guidance graph model Model class, must be one of {:class:`scglue.models.scglue.SCGLUEModel`, :class:`scglue.models.scglue.PairedSCGLUEModel`} init_kws Model initialization keyword arguments (see the constructor of the ``model`` class, either :class:`scglue.models.scglue.SCGLUEModel` or :class:`scglue.models.scglue.PairedSCGLUEModel`) compile_kws Model compile keyword arguments (see the ``compile`` method of the ``model`` class, either :meth:`scglue.models.scglue.SCGLUEModel.compile` or :meth:`scglue.models.scglue.PairedSCGLUEModel.compile`) fit_kws Model fitting keyword arguments (see :meth:`scglue.models.scglue.SCGLUEModel.fit`) balance_kws Balancing weight estimation keyword arguments (see :func:`scglue.data.estimate_balancing_weight`) Returns ------- model Fitted model object """ init_kws = init_kws or {} compile_kws = compile_kws or {} fit_kws = fit_kws or {} balance_kws = balance_kws or {} fit_SCGLUE.logger.info("Pretraining SCGLUE model...") pretrain_init_kws = init_kws.copy() pretrain_init_kws.update({"shared_batches": False}) pretrain_fit_kws = fit_kws.copy() pretrain_fit_kws.update({"align_burnin": np.inf, "safe_burnin": False}) if "directory" in pretrain_fit_kws: pretrain_fit_kws["directory"] = \ os.path.join(pretrain_fit_kws["directory"], "pretrain") pretrain = model(adatas, sorted(graph.nodes), **pretrain_init_kws) pretrain.compile(**compile_kws) pretrain.fit(adatas, graph, **pretrain_fit_kws) if "directory" in pretrain_fit_kws: pretrain.save(os.path.join(pretrain_fit_kws["directory"], "pretrain.dill")) fit_SCGLUE.logger.info("Estimating balancing weight...") for k, adata in adatas.items(): adata.obsm[f"X_{config.TMP_PREFIX}"] = pretrain.encode_data(k, adata) if init_kws.get("shared_batches"): use_batch = set( adata.uns[config.ANNDATA_KEY]["use_batch"] for adata in adatas.values() ) use_batch = use_batch.pop() if len(use_batch) == 1 else None else: use_batch = None estimate_balancing_weight( *adatas.values(), use_rep=f"X_{config.TMP_PREFIX}", use_batch=use_batch, key_added="balancing_weight", **balance_kws ) for adata in adatas.values(): adata.uns[config.ANNDATA_KEY]["use_dsc_weight"] = "balancing_weight" del adata.obsm[f"X_{config.TMP_PREFIX}"] fit_SCGLUE.logger.info("Fine-tuning SCGLUE model...") finetune_fit_kws = fit_kws.copy() if "directory" in finetune_fit_kws: finetune_fit_kws["directory"] = \ os.path.join(finetune_fit_kws["directory"], "fine-tune") finetune = model(adatas, sorted(graph.nodes), **init_kws) finetune.adopt_pretrained_model(pretrain) finetune.compile(**compile_kws) fit_SCGLUE.logger.debug("Increasing random seed by 1 to prevent idential data order...") finetune.random_seed += 1 finetune.fit(adatas, graph, **finetune_fit_kws) if "directory" in finetune_fit_kws: finetune.save(os.path.join(finetune_fit_kws["directory"], "fine-tune.dill")) return finetune