Source code for scglue.models.dx

r"""
Model diagnostics
"""

from typing import Mapping

import h5py
import networkx as nx
import pandas as pd
from anndata import AnnData

try:
    from anndata._core.sparse_dataset import SparseDataset
except ImportError:  # Newer version of anndata
    from anndata._core.sparse_dataset import (
        BaseCompressedSparseDataset as SparseDataset,
    )

from ..data import count_prep, metacell_corr
from ..utils import config, logged
from .scglue import SCGLUEModel


[docs]@logged def integration_consistency( model: SCGLUEModel, adatas: Mapping[str, AnnData], graph: nx.Graph, **kwargs ) -> pd.DataFrame: r""" Integration consistency score, defined as the consistency between aligned-space meta-cell correlation and the guidance graph Parameters ---------- model Integration model to be evaluated adatas Datasets (indexed by modality name) graph Guidance graph **kwargs Additional keyword arguments are passed to :func:`scglue.data.metacell_corr` Returns ------- consistency_df Consistency score at different numbers of meta cells """ for adata in adatas.values(): if isinstance(adata.X, (h5py.Dataset, SparseDataset)): raise RuntimeError("Backed data is not currently supported!") logger = integration_consistency.logger adatas = { k: AnnData( X=adata.X, obs=adata.obs, var=adata.var, obsm=adata.obsm.copy(), layers=adata.layers, uns=adata.uns, dtype=adata.X.dtype, ) for k, adata in adatas.items() } # Avoid unwanted updates to the input objects for k, adata in adatas.items(): adata.obsm["X_glue"] = model.encode_data(k, adata) for k, adata in adatas.items(): use_layer = adata.uns[config.ANNDATA_KEY]["use_layer"] if use_layer: logger.info('Using layer "%s" for modality "%s"', use_layer, k) adata.X = adata.layers[use_layer] if "agg_fns" not in kwargs: agg_fns = [] for k, adata in adatas.items(): if adata.uns[config.ANNDATA_KEY]["prob_model"] in ("NB", "ZINB"): logger.info('Selecting aggregation "sum" for modality "%s"', k) agg_fns.append("sum") else: logger.info('Selecting aggregation "mean" for modality "%s"', k) agg_fns.append("mean") kwargs["agg_fns"] = agg_fns if "prep_fns" not in kwargs: prep_fns = [] for k, adata in adatas.items(): if adata.uns[config.ANNDATA_KEY]["prob_model"] in ("NB", "ZINB"): logger.info('Selecting log-norm preprocessing for modality "%s"', k) prep_fns.append(count_prep) else: logger.info('Selecting no preprocessing for modality "%s"', k) prep_fns.append(None) kwargs["prep_fns"] = prep_fns n_metas, consistencies = [], [] for n_meta in (10, 20, 50, 100, 200): if n_meta > min(adata.shape[0] for adata in adatas.values()): continue corr = metacell_corr( *adatas.values(), skeleton=graph, use_rep="X_glue", n_meta=n_meta, **kwargs ) corr = corr.edge_subgraph( e for e in corr.edges if e[0] != e[1] ) # Exclude self-loops edgelist = nx.to_pandas_edgelist(corr).dropna(subset="corr") n_metas.append(n_meta) consistencies.append( (edgelist["sign"] * edgelist["weight"] * edgelist["corr"]).sum() / edgelist["weight"].sum() ) return pd.DataFrame({"n_meta": n_metas, "consistency": consistencies})