Source code for pybel.io.nodelink

# -*- coding: utf-8 -*-

"""Conversion functions for BEL graphs with node-link JSON."""

import gzip
import json
from io import BytesIO
from itertools import chain, count
from operator import methodcaller
from typing import Any, Mapping, TextIO, Union

from networkx.utils import open_file

from .utils import ensure_version
from ..constants import (
    ANNOTATIONS,
    CITATION,
    FUSION,
    GRAPH_ANNOTATION_CURIE,
    GRAPH_ANNOTATION_LIST,
    GRAPH_ANNOTATION_MIRIAM,
    MEMBERS,
    PARTNER_3P,
    PARTNER_5P,
    PRODUCTS,
    REACTANTS,
    SOURCE_MODIFIER,
    TARGET_MODIFIER,
)
from ..dsl import BaseEntity
from ..language import citation_dict
from ..struct import BELGraph
from ..struct.graph import _handle_modifier
from ..tokens import parse_result_to_dsl
from ..utils import hash_edge, tokenize_version

__all__ = [
    "to_nodelink",
    "to_nodelink_file",
    "to_nodelink_gz",
    "to_nodelink_jsons",
    "from_nodelink",
    "from_nodelink_file",
    "from_nodelink_gz",
    "from_nodelink_jsons",
    "to_nodelink_gz_io",
    "from_nodelink_gz_io",
]





def _prepare_graph_dict(g):
    # Convert annotation list definitions (which are sets) to canonicalized/sorted lists
    g[GRAPH_ANNOTATION_LIST] = {
        keyword: list(sorted(values)) for keyword, values in g.get(GRAPH_ANNOTATION_LIST, {}).items()
    }

    g[GRAPH_ANNOTATION_CURIE] = list(sorted(g[GRAPH_ANNOTATION_CURIE]))
    g[GRAPH_ANNOTATION_MIRIAM] = list(sorted(g[GRAPH_ANNOTATION_MIRIAM]))























def _to_nodelink_json_helper(graph: BELGraph) -> Mapping[str, Any]:
    """Convert a BEL graph to a node-link format.

    :param graph: BEL Graph

    Adapted from :func:`networkx.readwrite.json_graph.node_link_data`
    """
    nodes = sorted(graph, key=methodcaller("as_bel"))

    mapping = dict(zip(nodes, count()))

    return {
        "directed": True,
        "multigraph": True,
        "graph": graph.graph.copy(),
        "nodes": [_augment_node(node) for node in nodes],
        "links": [
            dict(
                chain(
                    data.copy().items(),
                    [("source", mapping[u]), ("target", mapping[v]), ("key", key)],
                ),
            )
            for u, v, key, data in graph.edges(keys=True, data=True)
        ],
    }


def _augment_node(node: BaseEntity) -> BaseEntity:
    """Add the SHA-512 identifier to a node's dictionary."""
    rv = node.copy()
    rv["id"] = node.md5
    rv["bel"] = node.as_bel()
    for m in chain(node.get(MEMBERS, []), node.get(REACTANTS, []), node.get(PRODUCTS, [])):
        m.update(_augment_node(m))
    if FUSION in node:
        node[FUSION][PARTNER_3P].update(_augment_node(node[FUSION][PARTNER_3P]))
        node[FUSION][PARTNER_5P].update(_augment_node(node[FUSION][PARTNER_5P]))
    return rv


def _recover_graph_dict(graph: BELGraph):
    graph.graph[GRAPH_ANNOTATION_LIST] = {
        keyword: set(values) for keyword, values in graph.graph.get(GRAPH_ANNOTATION_LIST, {}).items()
    }
    graph.graph[GRAPH_ANNOTATION_CURIE] = set(graph.graph.get(GRAPH_ANNOTATION_CURIE, []))
    graph.graph[GRAPH_ANNOTATION_MIRIAM] = set(graph.graph.get(GRAPH_ANNOTATION_MIRIAM, []))


def _from_nodelink_json_helper(data: Mapping[str, Any]) -> BELGraph:
    """Return graph from node-link data format.

    Adapted from :func:`networkx.readwrite.json_graph.node_link_graph`
    """
    graph = BELGraph()
    graph.graph = data.get("graph", {})
    _recover_graph_dict(graph)

    mapping = []

    for node_data in data["nodes"]:
        node = parse_result_to_dsl(node_data)
        graph.add_node_from_data(node)
        mapping.append(node)

    for data in data["links"]:
        u = mapping[data["source"]]
        v = mapping[data["target"]]

        edge_data = {k: v for k, v in data.items() if k not in {"source", "target", "key"}}

        for side in (SOURCE_MODIFIER, TARGET_MODIFIER):
            side_data = edge_data.get(side)
            if side_data:
                _handle_modifier(side_data)

        if CITATION in edge_data:
            edge_data[CITATION] = citation_dict(**edge_data[CITATION])

        if ANNOTATIONS in edge_data:
            edge_data[ANNOTATIONS] = graph._clean_annotations(edge_data[ANNOTATIONS])

        graph.add_edge(u, v, key=hash_edge(u, v, edge_data), **edge_data)

    return graph


def to_nodelink_gz_io(graph: BELGraph) -> BytesIO:
    """Get a BEL graph as a compressed BytesIO."""
    bytes_io = BytesIO()
    with gzip.GzipFile(fileobj=bytes_io, mode="w") as file:
        s = to_nodelink_jsons(graph)
        file.write(s.encode("utf-8"))
    bytes_io.seek(0)
    return bytes_io


def from_nodelink_gz_io(bytes_io: BytesIO) -> BELGraph:
    """Get BEL from gzipped nodelink JSON."""
    with gzip.GzipFile(fileobj=bytes_io, mode="r") as file:
        s = file.read()
    j = s.decode("utf-8")
    return from_nodelink_jsons(j)