Source code for pytoolbox.states

import collections, inspect

from . import module

_all = module.All(globals())


[docs]class StateEnumMetaclass(type):
[docs] def __init__(cls, name, bases, cls_dict): super().__init__(name, bases, cls_dict) if hasattr(cls, 'TRANSITIONS'): cls.ALL_STATES = frozenset(cls.TRANSITIONS.keys()) cls.FINAL_STATES = frozenset(s for s, t in cls.TRANSITIONS.items() if not t) inverse_transitions = collections.defaultdict(set) for state, transitions in cls.TRANSITIONS.items(): for transition in transitions: inverse_transitions[transition].add(state) cls.INVERSE_TRANSITIONS = {s: frozenset(t) for s, t in inverse_transitions.items()}
[docs]class StateEnumMergeMetaclass(StateEnumMetaclass):
[docs] def __init__(cls, name, bases, cls_dict): # Retrieve base "state" classes attributes m_states, transitions = collections.defaultdict(set), collections.defaultdict(set) for base in bases: for key, value in (i for i in inspect.getmembers(base) if i[0].endswith('_STATES')): m_states[key].update(value) for key, values in base.TRANSITIONS.items(): transitions[key].update(values) # Update "state" class attributes for key, value in m_states.items(): setattr(cls, key, frozenset(value)) for state in transitions.keys(): setattr(cls, state, state) cls.TRANSITIONS = transitions super().__init__(name, bases, cls_dict)
[docs]class StateEnum(object, metaclass=StateEnumMetaclass):
[docs] @classmethod def get(cls, name): if name.lower() == name: if (name := name.upper()) in cls.ALL_STATES: return name return getattr(cls, name + '_STATES', None) return None
[docs] @classmethod def get_transit_from(cls, state, auto_inverse=False): """ Return a set with the states having a transition to given `state`. If `auto_inverse` is set to True then a tuple is returned containing the smallest set from: * (States allowed to transit to given `state`, True) * (States not allowed to transit to given `state`, False) """ valid = cls.INVERSE_TRANSITIONS[state] if not auto_inverse: return valid not_valid = cls.ALL_STATES - valid return (not_valid, False) if len(valid) > len(not_valid) else (valid, True)
__all__ = _all.diff(globals())