Source code for expertsystem.ui

"""Main interface of the `expertsystem`.

This module contains the functions that you need for the most common use cases
of the `expertsystem`.
"""

__all__ = [
    "StateTransitionManager",
    "load_default_particle_list",
    "InteractionTypes",
]

import logging
from copy import deepcopy
from multiprocessing import Pool
from os import path

from progress.bar import IncrementalBar

from expertsystem.amplitude.canonical_decay import CanonicalAmplitudeGenerator
from expertsystem.amplitude.helicity_decay import HelicityAmplitudeGenerator
from expertsystem.state import particle
from expertsystem.state.propagation import (
    FullPropagator,
    InteractionTypes,
)
from expertsystem.topology import graph
from expertsystem.topology.topology_builder import (
    SimpleStateTransitionTopologyBuilder,
)

from ._default_settings import (
    DEFAULT_PARTICLE_LIST_FILE,
    DEFAULT_PARTICLE_LIST_PATH,
    create_default_interaction_settings,
)
from ._system_control import (
    GammaCheck,
    LeptonCheck,
    analyse_solution_failure,
    create_interaction_setting_groups,
    filter_interaction_types,
    match_external_edges,
    perform_external_edge_identical_particle_combinatorics,
    remove_duplicate_solutions,
)


[docs]class StateTransitionManager: # pylint: disable=too-many-instance-attributes """Main handler for decay topologies.""" def __init__( # pylint: disable=too-many-arguments self, initial_state, final_state, allowed_intermediate_particles=None, interaction_type_settings=None, formalism_type="helicity", topology_building="isobar", number_of_threads=4, propagation_mode="fast", ): if allowed_intermediate_particles is None: allowed_intermediate_particles = [] if interaction_type_settings is None: interaction_type_settings = {} allowed_formalism_types = [ "helicity", "canonical-helicity", "canonical", ] if formalism_type not in allowed_formalism_types: raise NotImplementedError( f"Formalism type {formalism_type} not implemented." f" Use {allowed_formalism_types} instead." ) self.__formalism_type = formalism_type self.number_of_threads = number_of_threads self.propagation_mode = propagation_mode self.initial_state = initial_state self.final_state = final_state self.interaction_type_settings = interaction_type_settings if not self.interaction_type_settings: self.interaction_type_settings = create_default_interaction_settings( formalism_type ) self.interaction_determinators = [LeptonCheck(), GammaCheck()] self.allowed_intermediate_particles = allowed_intermediate_particles self.final_state_groupings = [] self.allowed_interaction_types = [ InteractionTypes.Strong, InteractionTypes.EM, InteractionTypes.Weak, ] self.filter_remove_qns = [] self.filter_ignore_qns = [] if formalism_type == "helicity": self.filter_remove_qns = [ particle.InteractionQuantumNumberNames.S, particle.InteractionQuantumNumberNames.L, ] if "helicity" in formalism_type: self.filter_ignore_qns = [ particle.InteractionQuantumNumberNames.ParityPrefactor ] int_nodes = [] if topology_building == "isobar": if len(initial_state) == 1: int_nodes.append(graph.InteractionNode("TwoBodyDecay", 1, 2)) else: int_nodes.append( graph.InteractionNode( "NBodyScattering", len(initial_state), len(final_state) ) ) # turn of mass conservation, in case more than one initial state # particle is present if len(initial_state) > 1: self.interaction_type_settings = create_default_interaction_settings( formalism_type, False ) self.topology_builder = SimpleStateTransitionTopologyBuilder(int_nodes) load_default_particle_list() @property def formalism_type(self) -> str: return self.__formalism_type
[docs] def set_topology_builder(self, topology_builder): self.topology_builder = topology_builder
[docs] def add_final_state_grouping(self, fs_group): if not isinstance(fs_group, list): raise ValueError( "The final state grouping has to be of type list." ) if len(fs_group) > 0: if not isinstance(fs_group[0], list): fs_group = [fs_group] self.final_state_groupings.append(fs_group)
[docs] def set_allowed_interaction_types(self, allowed_interaction_types): # verify order for allowed_types in allowed_interaction_types: if not isinstance(allowed_types, InteractionTypes): raise TypeError( "allowed interaction types must be of type" "[InteractionTypes]" ) if allowed_types not in self.interaction_type_settings: logging.info(self.interaction_type_settings.keys()) raise ValueError( f"interaction {allowed_types} not found in settings" ) self.allowed_interaction_types = allowed_interaction_types
[docs] def prepare_graphs(self): topology_graphs = self._build_topologies() init_graphs = self._create_seed_graphs(topology_graphs) graph_node_setting_pairs = self._determine_node_settings(init_graphs) # create groups of settings ordered by "probability" graph_settings_groups = create_interaction_setting_groups( graph_node_setting_pairs ) return graph_settings_groups
def _build_topologies(self): all_graphs = self.topology_builder.build_graphs( len(self.initial_state), len(self.final_state) ) logging.info(f"number of topology graphs: {len(all_graphs)}") return all_graphs def _create_seed_graphs(self, topology_graphs): # initialize the graph edges (initial and final state) init_graphs = [] for topology_graph in topology_graphs: topology_graph.set_graph_element_properties_comparator( particle.CompareGraphElementPropertiesFunctor() ) init_graphs.extend( particle.initialize_graph( topology_graph, self.initial_state, self.final_state, self.final_state_groupings, ) ) logging.info(f"initialized {len(init_graphs)} graphs!") return init_graphs def _determine_node_settings(self, graphs): # pylint: disable=too-many-locals graph_node_setting_pairs = [] for instance in graphs: final_state_edges = graph.get_final_state_edges(instance) initial_state_edges = graph.get_initial_state_edges(instance) node_settings = {} for node_id in instance.nodes: node_int_types = [] out_edge_ids = graph.get_edges_outgoing_to_node( instance, node_id ) in_edge_ids = graph.get_edges_outgoing_to_node( instance, node_id ) in_edge_props = [ instance.edge_props[edge_id] for edge_id in [ x for x in in_edge_ids if x in initial_state_edges ] ] out_edge_props = [ instance.edge_props[edge_id] for edge_id in [ x for x in out_edge_ids if x in final_state_edges ] ] node_props = {} if node_id in instance.node_props: node_props = instance.node_props[node_id] for int_det in self.interaction_determinators: determined_interactions = int_det.check( in_edge_props, out_edge_props, node_props ) if node_int_types: node_int_types = list( set(determined_interactions) & set(node_int_types) ) else: node_int_types = determined_interactions node_int_types = filter_interaction_types( node_int_types, self.allowed_interaction_types ) logging.debug( "using %s interaction order for node: %s", str(node_int_types), str(node_id), ) node_settings[node_id] = [ deepcopy(self.interaction_type_settings[x]) for x in node_int_types ] graph_node_setting_pairs.append((instance, node_settings)) return graph_node_setting_pairs
[docs] def find_solutions( self, graph_setting_groups ): # pylint: disable=too-many-locals """Check for solutions for a specific set of interaction settings.""" results = {} logging.info( "Number of interaction settings groups being processed: %d", len(graph_setting_groups), ) for strength, graph_setting_group in sorted( graph_setting_groups.items(), reverse=True ): logging.info( "processing interaction settings group with " f"strength {strength}", ) logging.info(f"{graph_setting_group} entries in this group") logging.info(f"running with {self.number_of_threads} threads...") temp_results = [] progress_bar = IncrementalBar( "Propagating quantum numbers...", max=len(graph_setting_group) ) progress_bar.update() if self.number_of_threads > 1: with Pool(self.number_of_threads) as pool: for result in pool.imap_unordered( self._propagate_quantum_numbers, graph_setting_group, 1 ): temp_results.append(result) progress_bar.next() else: for graph_setting_pair in graph_setting_group: temp_results.append( self._propagate_quantum_numbers(graph_setting_pair) ) progress_bar.next() progress_bar.finish() logging.info("Finished!") if strength not in results: results[strength] = [] results[strength].extend(temp_results) for key, value in results.items(): logging.info( f"number of solutions for strength ({key}) " f"after qn propagation: {sum([len(x[0]) for x in value])}", ) # remove duplicate solutions, which only differ in the interaction qn S results = remove_duplicate_solutions( results, self.filter_remove_qns, self.filter_ignore_qns ) node_non_satisfied_rules = [] solutions = [] for result in results.values(): for (tempsolutions, non_satisfied_laws) in result: solutions.extend(tempsolutions) node_non_satisfied_rules.append(non_satisfied_laws) logging.info(f"total number of found solutions: {len(solutions)}") violated_laws = [] if len(solutions) == 0: violated_laws = analyse_solution_failure(node_non_satisfied_rules) logging.info(f"violated rules: {violated_laws}") # finally perform combinatorics of identical external edges # (initial or final state edges) and prepare graphs for # amplitude generation match_external_edges(solutions) final_solutions = [] for sol in solutions: final_solutions.extend( perform_external_edge_identical_particle_combinatorics(sol) ) return (final_solutions, violated_laws)
def _propagate_quantum_numbers(self, state_graph_node_settings_pair): propagator = self._initialize_qn_propagator( state_graph_node_settings_pair[0], state_graph_node_settings_pair[1], ) solutions = propagator.find_solutions() return (solutions, propagator.get_non_satisfied_conservation_laws()) def _initialize_qn_propagator(self, state_graph, node_settings): propagator = FullPropagator(state_graph, self.propagation_mode) for node_id, interaction_settings in node_settings.items(): propagator.assign_settings_to_node(node_id, interaction_settings) # specify set of particles which are allowed to be intermediate # particles. If list is empty, then all particles in the default # particle list are used propagator.set_allowed_intermediate_particles( self.allowed_intermediate_particles ) return propagator
[docs] def write_amplitude_model(self, solutions: list, output_file: str) -> None: """Generate an amplitude model from the solutions. The type of amplitude model (`.HelicityAmplitudeGenerator` or `.CanonicalAmplitudeGenerator`) is determined from the :code:`formalism_type` that was chosen when constructing the `.StateTransitionManager`. """ if self.formalism_type == "helicity": amplitude_generator = HelicityAmplitudeGenerator() elif self.formalism_type in ["canonical-helicity", "canonical"]: amplitude_generator = CanonicalAmplitudeGenerator() amplitude_generator.generate(solutions) amplitude_generator.write_to_file(output_file)
[docs]def load_default_particle_list() -> None: """Load the default particle list that comes with the expertsystem. .. warning:: This resets all particle definitions and the removes particles that don't exist in the particle list that ships with the `expertsystem`! """ if not path.exists(DEFAULT_PARTICLE_LIST_PATH): raise FileNotFoundError( f"\n Failed to load {DEFAULT_PARTICLE_LIST_FILE}!" "\n Please contact the developers: https://github.com/ComPWA" ) particle.DATABASE = dict() particle.load_particles(DEFAULT_PARTICLE_LIST_PATH) logging.info( f"Loaded {len(particle.DATABASE)} particles from {DEFAULT_PARTICLE_LIST_FILE}!" )