Source code for rpasdt.algorithm.graph_drawing

"""Graph drawing utilities."""
from typing import List

import networkx as nx

from rpasdt.algorithm.taxonomies import (
    DiffusionGraphNodeRenderTypeEnum,
    GraphLayout,
)

GRAPH_LAYOUT_DRAW_ALGORITHM = {
    GraphLayout.CIRCULAR: lambda graph, *args, **kwargs: nx.circular_layout(graph),
    GraphLayout.KAMADA_KAWAI: lambda graph, *args, **kwargs: nx.kamada_kawai_layout(
        graph
    ),
    GraphLayout.PLANAR: lambda graph, *args, **kwargs: nx.planar_layout(graph),
    GraphLayout.RANDOM: lambda graph, seed=100, *args, **kwargs: nx.random_layout(
        graph, seed=seed
    ),
    GraphLayout.SPECTRAL: lambda graph, *args, **kwargs: nx.spectral_layout(graph),
    GraphLayout.SPRING: lambda graph, seed=100, *args, **kwargs: nx.spring_layout(
        graph, seed=seed
    ),
    GraphLayout.SHELL: lambda graph, *args, **kwargs: nx.shell_layout(graph),
}


[docs]def compute_graph_draw_position( graph: nx.Graph, layout: GraphLayout = None, *args, **kwargs ): layout = layout or GraphLayout.SPRING return GRAPH_LAYOUT_DRAW_ALGORITHM[layout](graph)
[docs]def get_diffusion_graph( source_graph: nx.Graph, infected_nodes: List[int], graph_node_rendering_type: DiffusionGraphNodeRenderTypeEnum = DiffusionGraphNodeRenderTypeEnum.ONLY_INFECTED, ) -> nx.Graph: if DiffusionGraphNodeRenderTypeEnum.FULL == graph_node_rendering_type: return source_graph.subgraph(source_graph.nodes()) else: return source_graph.subgraph(infected_nodes)