Source code for sphinx_graph.vertex.state

"""Shared state for the sphinx-graph extension."""

from __future__ import annotations

from collections.abc import Iterable, Iterator, Mapping
from contextlib import contextmanager
from typing import TYPE_CHECKING

import rustworkx as rx
from sphinx.errors import DocumentError, SphinxError
from sphinx.util import logging

from sphinx_graph.vertex.info import Info

if TYPE_CHECKING:
    from sphinx.application import Sphinx
    from sphinx.environment import BuildEnvironment

logger = logging.getLogger(__name__)

__all__ = [
    "State",
]


class DuplicateIdError(DocumentError):
    """Raised when a vertex with the same ID is added to the graph twice."""

    category = "Document Error"


@contextmanager
def _vertices_tmp(env: BuildEnvironment) -> Iterator[dict[str, Info]]:
    vertices: dict[str, Info] = getattr(env, "graph_vertices_tmp", {})
    yield vertices
    env.graph_vertices_tmp = vertices  # type: ignore[attr-defined]


def purge(_app: Sphinx, env: BuildEnvironment, docname: str) -> None:
    """Clear out all stale vertices.

    All vertices whose docname matches the given one from the graph_all_vertices list
    will be removed.

    If there are vertices left in the document, they will be added again during parsing.
    """
    with _vertices_tmp(env) as vertices:
        vertices = {  # noqa: PLW2901
            uid: vert for uid, vert in vertices.items() if vert.docname != docname
        }


def merge(
    _app: Sphinx,
    env: BuildEnvironment,
    _docnames: list[str],
    other: BuildEnvironment,
) -> None:
    """Merge the vertices from multiple environments during parallel builds."""
    with _vertices_tmp(env) as vertices, _vertices_tmp(other) as other_vertices:
        vertices.update(other_vertices)


def insert_vertex(env: BuildEnvironment, uid: str, info: Info) -> None:
    """Insert a vertex into the build environment."""
    with _vertices_tmp(env) as vertices:
        if uid in vertices:
            err_msg = f"Vertex {uid} already exists."
            raise DuplicateIdError(err_msg)
        vertices[uid] = info


def build_and_check_graph(env: BuildEnvironment) -> State:
    """Build the graph from the collected vertices.

    Also checks the graph for consistency.
    """
    vertices_tmp: dict[str, Info] = env.graph_vertices_tmp  # type: ignore[attr-defined]
    vertices: dict[str, tuple[int, Info]] = {}
    graph: rx.PyDiGraph[str, str | None] = rx.PyDiGraph()

    for uid, info in vertices_tmp.items():
        node_id = graph.add_node(uid)
        vertices[uid] = node_id, info

    build_graph_edges(vertices, graph)

    env.graph_vertices = vertices  # type: ignore[attr-defined]
    env.graph_graph = graph  # type: ignore[attr-defined]

    return State(vertices, graph)


[docs] class State: """State object for Sphinx Graph vertices.""" def __init__( self, vertices: dict[str, tuple[int, Info]], graph: rx.PyDiGraph[str, str | None], ) -> None: """Create a new state object.""" self._vertices = vertices self._graph = graph @classmethod def read(cls, env: BuildEnvironment) -> State: """Read the State object for the given environment. This is a read-only view of the state. Changes will not be saved. """ vertices = getattr(env, "graph_vertices", {}) graph: rx.PyDiGraph[str, str | None] = getattr( env, "graph_graph", rx.PyDiGraph(multigraph=False) ) return State(vertices, graph) @property def graph(self) -> rx.PyDiGraph[str, str | None]: """A graph representing the relationships between vertices. Vertices in the graph are stored using 'node ids'. These can be retrieved using the `State.node_ids` mapping. """ return self._graph @property def vertices(self) -> Mapping[str, Info]: """A mapping from vertex uid to vertex Info.""" return Vertices(self._vertices) @property def node_ids(self) -> Mapping[str, int]: """A mapping from vertex uid to graph node ID.""" return NodeIds(self._vertices) def children(self, uid: str) -> Iterable[str]: """Iterate over the children of the given node.""" node_id, _info = self._vertices[uid] yield from self._graph.successors(node_id) def ancestors(self, uid: str) -> Iterable[str]: """Recursively find all direct parents and ancestors of the given node.""" node_id = self.node_ids[uid] yield from ( self.graph[anc_node_id] for anc_node_id in rx.ancestors(self.graph, node_id) ) def descendants(self, uid: str) -> Iterable[str]: """Recursively find all direct children and descendants of the given node.""" node_id = self.node_ids[uid] yield from ( self.graph[desc_node_id] for desc_node_id in rx.descendants(self.graph, node_id) )
class Vertices(Mapping[str, Info]): """A dict-like view of vertex Info keyed by vertex ID.""" def __init__(self, vertices: dict[str, tuple[int, Info]]) -> None: self._vertices = vertices def __getitem__(self, key: str) -> Info: _node_id, info = self._vertices[key] return info def __iter__(self) -> Iterator[str]: return iter(self._vertices) def __len__(self) -> int: return len(self._vertices) class NodeIds(Mapping[str, int]): """A dict-like view of graph node IDs keyed by vertex ID.""" def __init__(self, vertices: dict[str, tuple[int, Info]]) -> None: self._vertices = vertices def __getitem__(self, key: str) -> int: node_id, _info = self._vertices[key] return node_id def __iter__(self) -> Iterator[str]: return iter(self._vertices) def __len__(self) -> int: return len(self._vertices) def build_graph_edges( vertices: Mapping[str, tuple[int, Info]], graph: rx.PyDiGraph[str, str | None] ) -> None: """Build the graph from the list of vertices. This is called during setup, and doesn't need to be called again. """ # add all 'parent' edges for uid, (node_id, info) in vertices.items(): fingerprints_required = info.config.require_fingerprints for parent_uid, fingerprint in info.parents.items(): try: parent_node_id, parent = vertices[parent_uid] graph.add_edge(parent_node_id, node_id, fingerprint) except KeyError as e: msg = ( f"vertex '{uid}' has a parent link to '{parent_uid}'," f" but '{parent_uid}' doesn't exist" ) logger.exception(msg) raise SphinxError(msg) from e if fingerprints_required and fingerprint is None: logger.warning( f"link fingerprints are required, but {uid} doesn't have a" f" fingerprint for its link to its parent {parent_uid}.\nthe" f" fingerprint can be added by changing the parent reference on" f" {uid} to '{parent_uid}:{parent.fingerprint}'.", ) if fingerprint and fingerprint != parent.fingerprint: logger.warning( f"suspect link found. vertex {uid} is linked to vertex" f" {parent_uid} with a fingerprint of '{fingerprint}', but" f" {parent_uid}'s fingerprint is '{parent.fingerprint}'.\n{uid}" " should be reviewed, and the link fingerprint manually updated.", ) cycles = [ [graph[node_id] for node_id in node_ids] for node_ids in rx.simple_cycles(graph) ] if cycles: suffix = ", ".join( f"[{uids[0]} -> {' -> '.join(uids[1:])} -> {uids[0]}]" for uids in cycles ) logger.warning( f"vertices must not have cyclic dependencies. cycles detected: {suffix}" )