scglue.graph 源代码

r"""
Graph-related functions
"""

from itertools import chain
from typing import Any, Callable, Iterable, Mapping, Optional, Set

import networkx as nx
from anndata import AnnData
from tqdm.auto import tqdm

from .utils import logged


[文档]def compose_multigraph(*graphs: nx.Graph) -> nx.MultiGraph: r""" Compose multi-graph from multiple graphs with no edge collision Parameters ---------- graphs An arbitrary number of graphs to be composed from Returns ------- composed Composed multi-graph Note ---- The resulting multi-graph would be directed if any of the input graphs is directed. """ if any(nx.is_directed(graph) for graph in graphs): graphs = [graph.to_directed() for graph in graphs] composed = nx.MultiDiGraph() else: composed = nx.MultiGraph() composed.add_edges_from( (e[0], e[1], graph.edges[e]) for graph in graphs for e in graph.edges ) return composed
[文档]def collapse_multigraph( graph: nx.MultiGraph, merge_fns: Optional[Mapping[str, Callable]] = None ) -> nx.Graph: r""" Collapse multi-edges into simple-edges Parameters ---------- graph Input multi-graph merge_fns Attribute-specific merge functions, indexed by attribute name. Each merge function should accept a list of values and return a single value. Returns ------- collapsed Collapsed graph Note ---- The collapsed graph would be directed if the input graph is directed. Edges causing ValueError in ``merge_fns`` will be discarded. """ if nx.is_directed(graph): # MultiDiGraph collapsed = nx.DiGraph(graph) else: # MultiGraph collapsed = nx.Graph(graph) if not merge_fns: return collapsed for e in tqdm(list(collapsed.edges), desc="collapse_multigraph"): attrs = graph.get_edge_data(*e).values() for k, fn in merge_fns.items(): try: collapsed.edges[e][k] = fn([attr[k] for attr in attrs]) except ValueError: collapsed.remove_edge(*e) return collapsed
[文档]def reachable_vertices(graph: nx.Graph, source: Iterable[Any]) -> Set[Any]: r""" Identify vertices reachable from source vertices (including source vertices themselves) Parameters ---------- graph Input graph source Source vertices Returns ------- reachable_vertices Reachable vertices """ source = set(source) return set(chain.from_iterable( nx.descendants(graph, item) for item in source if graph.has_node(item) )).union(source)
[文档]@logged def check_graph( graph: nx.Graph, adatas: Iterable[AnnData], cov: str = "error", attr: str = "error", loop: str = "error", sym: str = "error" ) -> None: r""" Check if a graph is a valid guidance graph Parameters ---------- graph Graph to be checked adatas AnnData objects where graph nodes are variables cov Action to take if graph nodes does not cover all variables, must be one of {"ignore", "warn", "error"} attr Action to take if graph edges does not contain required attributes, must be one of {"ignore", "warn", "error"} loop Action to take if graph does not contain self-loops, must be one of {"ignore", "warn", "error"} sym Action to take if graph is not symmetric, must be one of {"ignore", "warn", "error"} """ passed = True check_graph.logger.info("Checking variable coverage...") if not all( all(graph.has_node(var_name) for var_name in adata.var_names) for adata in adatas ): passed = False msg = "Some variables are not covered by the graph!" if cov == "error": raise ValueError(msg) elif cov == "warn": check_graph.logger.warning(msg) elif cov != "ignore": raise ValueError(f"Invalid `cov`: {cov}") check_graph.logger.info("Checking edge attributes...") if not all( "weight" in edge_attr and "sign" in edge_attr for edge_attr in dict(graph.edges).values() ): passed = False msg = "Missing weight or sign as edge attribute!" if attr == "error": raise ValueError(msg) elif attr == "warn": check_graph.logger.warning(msg) elif cov != "ignore": raise ValueError(f"Invalid `attr`: {attr}") check_graph.logger.info("Checking self-loops...") if not all( graph.has_edge(node, node) for node in graph.nodes ): passed = False msg = "Missing self-loop!" if loop == "error": raise ValueError(msg) elif loop == "warn": check_graph.logger.warning(msg) elif loop != "ignore": raise ValueError(f"Invalid `loop`: {loop}") check_graph.logger.info("Checking graph symmetry...") if not all( graph.has_edge(e[1], e[0]) for e in graph.edges ): passed = False msg = "Graph is not symmetric!" if sym == "error": raise ValueError(msg) elif sym == "warn": check_graph.logger.warning(msg) elif sym != "ignore": raise ValueError(f"Invalid `sym`: {sym}") if passed: check_graph.logger.info("All checks passed!")