"""Serialization module for the `expertsystem`.
The `.io` module provides tools to export or import objects from the
:mod:`.reaction.particle`, :mod:`.reaction` and :mod:`.amplitude` modules to and from
disk, so that they can be used by external packages, or just to store (cache)
the state of the system.
"""
import json
from collections import abc
from pathlib import Path
import attr
import yaml
from expertsystem.reaction import Result
from expertsystem.reaction.particle import Particle, ParticleCollection
from expertsystem.reaction.topology import StateTransitionGraph, Topology
from . import _dict, _dot
[docs]def asdict(instance: object) -> dict:
if isinstance(instance, Particle):
return _dict.from_particle(instance)
if isinstance(instance, ParticleCollection):
return _dict.from_particle_collection(instance)
if isinstance(instance, Result):
return _dict.from_result(instance)
if isinstance(instance, StateTransitionGraph):
return _dict.from_stg(instance)
if isinstance(instance, Topology):
return _dict.from_topology(instance)
raise NotImplementedError(
f"No conversion for dict available for class {instance.__class__.__name__}"
)
[docs]def fromdict(definition: dict) -> object:
keys = set(definition.keys())
if __REQUIRED_PARTICLE_FIELDS <= keys:
return _dict.build_particle(definition)
if keys == {"particles"}:
return _dict.build_particle_collection(definition)
if keys == {"transitions", "formalism_type"}:
return _dict.build_result(definition)
if keys == {"topology", "edge_props", "node_props"}:
return _dict.build_stg(definition)
if keys == __REQUIRED_TOPOLOGY_FIELDS:
return _dict.build_topology(definition)
raise NotImplementedError(f"Could not determine type from keys {keys}")
__REQUIRED_PARTICLE_FIELDS = {
field.name
for field in attr.fields(Particle)
if field.default == attr.NOTHING
}
__REQUIRED_TOPOLOGY_FIELDS = {
field.name for field in attr.fields(Topology) if field.init
}
[docs]def asdot(
instance: object,
*,
render_node: bool = True,
render_final_state_id: bool = True,
render_resonance_id: bool = False,
render_initial_state_id: bool = False,
strip_spin: bool = False,
collapse_graphs: bool = False,
) -> str:
"""Convert a `object` to a DOT language `str`.
Only works for objects that can be represented as a graph, particularly a
`.StateTransitionGraph` or a `list` of `.StateTransitionGraph` instances.
.. seealso:: :doc:`/usage/visualize`
"""
if isinstance(instance, (StateTransitionGraph, Topology)):
return _dot.graph_to_dot(
instance,
render_node=render_node,
render_final_state_id=render_final_state_id,
render_resonance_id=render_resonance_id,
render_initial_state_id=render_initial_state_id,
)
if isinstance(instance, (Result, abc.Sequence)):
if isinstance(instance, Result):
instance = instance.transitions
return _dot.graph_list_to_dot(
instance,
render_node=render_node,
render_final_state_id=render_final_state_id,
render_resonance_id=render_resonance_id,
render_initial_state_id=render_initial_state_id,
strip_spin=strip_spin,
collapse_graphs=collapse_graphs,
)
raise NotImplementedError(
f"Cannot convert a {instance.__class__.__name__} to DOT language"
)
[docs]def load(filename: str) -> object:
with open(filename) as stream:
file_extension = _get_file_extension(filename)
if file_extension == "json":
definition = json.load(stream)
return fromdict(definition)
if file_extension in ["yaml", "yml"]:
definition = yaml.load(stream, Loader=yaml.SafeLoader)
return fromdict(definition)
raise NotImplementedError(
f'No loader defined for file type "{file_extension}"'
)
class _IncreasedIndent(yaml.Dumper):
# pylint: disable=too-many-ancestors
def increase_indent(self, flow=False, indentless=False): # type: ignore
return super().increase_indent(flow, False)
def write_line_break(self, data=None): # type: ignore
"""See https://stackoverflow.com/a/44284819."""
super().write_line_break(data)
if len(self.indents) == 1:
super().write_line_break()
[docs]def write(instance: object, filename: str) -> None:
with open(filename, "w") as stream:
file_extension = _get_file_extension(filename)
if file_extension == "json":
json.dump(asdict(instance), stream, indent=2)
return
if file_extension in ["yaml", "yml"]:
yaml.dump(
asdict(instance),
stream,
sort_keys=False,
Dumper=_IncreasedIndent,
default_flow_style=False,
)
return
if file_extension == "gv":
if isinstance(instance, str): # direct output of asdot
output_str = instance
else:
output_str = asdot(instance)
with open(filename, "w") as stream:
stream.write(output_str)
return
raise NotImplementedError(
f'No writer defined for file type "{file_extension}"'
)
def _get_file_extension(filename: str) -> str:
path = Path(filename)
extension = path.suffix.lower()
if not extension:
raise Exception(f"No file extension in file {filename}")
extension = extension[1:]
return extension