diff --git a/.github/workflows/ci_extra.yml b/.github/workflows/ci_extra.yml new file mode 100644 index 0000000..7309c12 --- /dev/null +++ b/.github/workflows/ci_extra.yml @@ -0,0 +1,39 @@ +# This workflow is for any branch. It runs additional tests for several python versions. + +name: Build extra + +on: [push, pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10"] + + steps: + + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Lint with pylint + run: | + pylint pyformlang || true + - name: Lint with pycodestyle + run: | + pycodestyle pyformlang || true + - name: Check with pyright + run: | + pyright --stats pyformlang + - name: Test with pytest + run: | + pytest --showlocals -v pyformlang diff --git a/.github/workflows/ci_feature.yml b/.github/workflows/ci_feature.yml new file mode 100644 index 0000000..edf0626 --- /dev/null +++ b/.github/workflows/ci_feature.yml @@ -0,0 +1,54 @@ +# This workflow is for feature branches. It sets up python, lints with several analyzers, +# runs tests, collects test coverage and makes a coverage comment. + +name: Build feature + +on: + push: + branches-ignore: "master" + pull_request: + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.8"] + + steps: + + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Lint with pylint + run: | + pylint pyformlang || true + - name: Lint with pycodestyle + run: | + pycodestyle pyformlang || true + - name: Check with pyright + run: | + pyright --stats pyformlang + - name: Test with pytest + run: | + pytest --showlocals -v pyformlang + + - name: Build coverage file + run: | + pytest pyformlang --junitxml=pytest.xml --cov=pyformlang | tee ./pytest-coverage.txt + - name: Make coverage comment + uses: MishaKav/pytest-coverage-comment@main + id: coverageComment + with: + pytest-coverage-path: ./pytest-coverage.txt + junitxml-path: ./pytest.xml + default-branch: master diff --git a/.github/workflows/python-package.yml b/.github/workflows/ci_master.yml similarity index 68% rename from .github/workflows/python-package.yml rename to .github/workflows/ci_master.yml index caf70a6..2ab30a8 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/ci_master.yml @@ -1,9 +1,11 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions +# This workflow is for master branch only. It sets up python, lints with several analyzers, +# runs tests, collects test coverage, makes a coverage comment and creates a coverage badge. -name: Python package +name: Build master -on: [push, pull_request] +on: + push: + branches: "master" jobs: build: @@ -12,9 +14,10 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8"] steps: + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 @@ -23,31 +26,33 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install -r requirements.txt + - name: Lint with pylint run: | pylint pyformlang || true - name: Lint with pycodestyle run: | pycodestyle pyformlang || true + - name: Check with pyright + run: | + pyright --stats pyformlang - name: Test with pytest run: | pytest --showlocals -v pyformlang + - name: Build coverage file - if: ${{ matrix.python-version == '3.8'}} run: | pytest pyformlang --junitxml=pytest.xml --cov=pyformlang | tee ./pytest-coverage.txt - - name: Pytest coverage comment - if: ${{ matrix.python-version == '3.8'}} + - name: Make coverage comment uses: MishaKav/pytest-coverage-comment@main id: coverageComment with: pytest-coverage-path: ./pytest-coverage.txt junitxml-path: ./pytest.xml default-branch: master + - name: Create coverage Badge - if: ${{ github.ref_name == 'master' && matrix.python-version == '3.8'}} uses: schneegans/dynamic-badges-action@v1.0.0 with: auth: ${{ secrets.GIST_SECRET }} diff --git a/pyformlang/cfg/cfg.py b/pyformlang/cfg/cfg.py index 1517440..1b5762e 100644 --- a/pyformlang/cfg/cfg.py +++ b/pyformlang/cfg/cfg.py @@ -47,16 +47,12 @@ def __init__(self, start_symbol: Hashable = None, productions: Iterable[Production] = None) -> None: super().__init__() - if variables is not None: - variables = {to_variable(x) for x in variables} - self._variables = variables or set() - if terminals is not None: - terminals = {to_terminal(x) for x in terminals} - self._terminals = terminals or set() + self._variables = {to_variable(x) for x in variables or set()} + self._terminals = {to_terminal(x) for x in terminals or set()} + self._start_symbol = None if start_symbol is not None: - start_symbol = to_variable(start_symbol) - self._variables.add(start_symbol) - self._start_symbol = start_symbol + self._start_symbol = to_variable(start_symbol) + self._variables.add(self._start_symbol) self._productions = set() for production in productions or set(): self.add_production(production) diff --git a/pyformlang/cfg/tests/test_terminal.py b/pyformlang/cfg/tests/test_terminal.py index 7cd9a0e..53fa5a0 100644 --- a/pyformlang/cfg/tests/test_terminal.py +++ b/pyformlang/cfg/tests/test_terminal.py @@ -24,6 +24,8 @@ def test_creation(self): assert epsilon.to_text() == "epsilon" assert Terminal("C").to_text() == '"TER:C"' assert repr(Epsilon()) == "epsilon" + assert str(terminal0) == "0" + assert repr(terminal0) == "Terminal(0)" def test_eq(self): assert "epsilon" == Epsilon() diff --git a/pyformlang/cfg/tests/test_variable.py b/pyformlang/cfg/tests/test_variable.py index 56c186e..f3ed904 100644 --- a/pyformlang/cfg/tests/test_variable.py +++ b/pyformlang/cfg/tests/test_variable.py @@ -20,3 +20,5 @@ def test_creation(self): assert str(variable0) == str(variable3) assert str(variable0) != str(variable1) assert "A" == Variable("A") + assert str(variable1) == "1" + assert repr(variable1) == "Variable(1)" diff --git a/pyformlang/finite_automaton/epsilon_nfa.py b/pyformlang/finite_automaton/epsilon_nfa.py index e80a125..09af9e7 100644 --- a/pyformlang/finite_automaton/epsilon_nfa.py +++ b/pyformlang/finite_automaton/epsilon_nfa.py @@ -65,24 +65,14 @@ def __init__( start_states: AbstractSet[Hashable] = None, final_states: AbstractSet[Hashable] = None) -> None: super().__init__() - if states is not None: - states = {to_state(x) for x in states} - self._states = states or set() - if input_symbols is not None: - input_symbols = {to_symbol(x) for x in input_symbols} - self._input_symbols = input_symbols or set() + self._states = {to_state(x) for x in states or set()} + self._input_symbols = {to_symbol(x) for x in input_symbols or set()} self._transition_function = transition_function \ or NondeterministicTransitionFunction() - if start_states is not None: - start_states = {to_state(x) for x in start_states} - self._start_states = start_states or set() - if final_states is not None: - final_states = {to_state(x) for x in final_states} - self._final_states = final_states or set() - for state in self._final_states: - self._states.add(state) - for state in self._start_states: - self._states.add(state) + self._start_states = {to_state(x) for x in start_states or set()} + self._states.update(self._start_states) + self._final_states = {to_state(x) for x in final_states or set()} + self._states.update(self._final_states) def _get_next_states_iterable( self, diff --git a/pyformlang/finite_automaton/finite_automaton.py b/pyformlang/finite_automaton/finite_automaton.py index 654a039..30ce648 100644 --- a/pyformlang/finite_automaton/finite_automaton.py +++ b/pyformlang/finite_automaton/finite_automaton.py @@ -436,14 +436,14 @@ def to_fst(self) -> FST: """ fst = FST() for start_state in self._start_states: - fst.add_start_state(start_state.value) + fst.add_start_state(start_state) for final_state in self._final_states: - fst.add_final_state(final_state.value) + fst.add_final_state(final_state) for s_from, symb_by, s_to in self._transition_function: - fst.add_transition(s_from.value, - symb_by.value, - s_to.value, - [symb_by.value]) + fst.add_transition(s_from, + symb_by, + s_to, + [symb_by]) return fst def is_acyclic(self) -> bool: @@ -700,10 +700,10 @@ def __try_add(set_to_add_to: Set[Any], element_to_add: Any) -> bool: @staticmethod def __add_start_state_to_graph(graph: MultiDiGraph, state: State) -> None: """ Adds a starting node to a given graph """ - graph.add_node("starting_" + str(state.value), + graph.add_node("starting_" + str(state), label="", shape=None, height=.0, width=.0) - graph.add_edge("starting_" + str(state.value), + graph.add_edge("starting_" + str(state), state.value) diff --git a/pyformlang/finite_automaton/tests/test_state.py b/pyformlang/finite_automaton/tests/test_state.py index 8046f88..6305e50 100644 --- a/pyformlang/finite_automaton/tests/test_state.py +++ b/pyformlang/finite_automaton/tests/test_state.py @@ -22,6 +22,7 @@ def test_repr(self): assert str(state1) == "ABC" state2 = State(1) assert str(state2) == "1" + assert repr(state1) == "State(ABC)" def test_eq(self): """ Tests the equality of states diff --git a/pyformlang/finite_automaton/tests/test_symbol.py b/pyformlang/finite_automaton/tests/test_symbol.py index fcb114c..51e666b 100644 --- a/pyformlang/finite_automaton/tests/test_symbol.py +++ b/pyformlang/finite_automaton/tests/test_symbol.py @@ -22,6 +22,7 @@ def test_repr(self): assert str(symbol1) == "ABC" symbol2 = Symbol(1) assert str(symbol2) == "1" + assert repr(symbol2) == "Symbol(1)" def test_eq(self): """ Tests equality of symbols diff --git a/pyformlang/fst/__init__.py b/pyformlang/fst/__init__.py index afd33d1..b7fdb5f 100644 --- a/pyformlang/fst/__init__.py +++ b/pyformlang/fst/__init__.py @@ -12,7 +12,11 @@ """ -from .fst import FST +from .fst import FST, TransitionFunction, State, Symbol, Epsilon -__all__ = ["FST"] +__all__ = ["FST", + "TransitionFunction", + "State", + "Symbol", + "Epsilon"] diff --git a/pyformlang/fst/fst.py b/pyformlang/fst/fst.py index dcecafa..9d5c14e 100644 --- a/pyformlang/fst/fst.py +++ b/pyformlang/fst/fst.py @@ -1,29 +1,41 @@ """ Finite State Transducer """ -import json -from typing import Any, Iterable -import networkx as nx +from typing import Dict, List, Set, AbstractSet, \ + Tuple, Iterator, Iterable, Hashable + +from networkx import MultiDiGraph from networkx.drawing.nx_pydot import write_dot -from pyformlang.indexed_grammar import DuplicationRule, ProductionRule, \ - EndRule, ConsumptionRule, IndexedGrammar, Rules +from .transition_function import TransitionFunction +from .transition_function import TransitionKey, TransitionValues, Transition +from .utils import StateRenaming +from ..objects.finite_automaton_objects import State, Symbol, Epsilon +from ..objects.finite_automaton_objects.utils import to_state, to_symbol + +InputTransition = Tuple[Hashable, Hashable, Hashable, Iterable[Hashable]] -class FST: +class FST(Iterable[Transition]): """ Representation of a Finite State Transducer""" - def __init__(self): - self._states = set() # Set of states - self._input_symbols = set() # Set of input symbols - self._output_symbols = set() # Set of output symbols - # Dict from _states x _input_symbols U {epsilon} into a subset of - # _states X _output_symbols* - self._delta = {} - self._start_states = set() - self._final_states = set() # _final_states is final states + def __init__(self, + states: AbstractSet[Hashable] = None, + input_symbols: AbstractSet[Hashable] = None, + output_symbols: AbstractSet[Hashable] = None, + transition_function: TransitionFunction = None, + start_states: AbstractSet[Hashable] = None, + final_states: AbstractSet[Hashable] = None) -> None: + self._states = {to_state(x) for x in states or set()} + self._input_symbols = {to_symbol(x) for x in input_symbols or set()} + self._output_symbols = {to_symbol(x) for x in output_symbols or set()} + self._transition_function = transition_function or TransitionFunction() + self._start_states = {to_state(x) for x in start_states or set()} + self._states.update(self._start_states) + self._final_states = {to_state(x) for x in final_states or set()} + self._states.update(self._final_states) @property - def states(self): + def states(self) -> Set[State]: """ Get the states of the FST Returns @@ -34,7 +46,7 @@ def states(self): return self._states @property - def input_symbols(self): + def input_symbols(self) -> Set[Symbol]: """ Get the input symbols of the FST Returns @@ -45,7 +57,7 @@ def input_symbols(self): return self._input_symbols @property - def output_symbols(self): + def output_symbols(self) -> Set[Symbol]: """ Get the output symbols of the FST Returns @@ -56,7 +68,7 @@ def output_symbols(self): return self._output_symbols @property - def start_states(self): + def start_states(self) -> Set[State]: """ Get the start states of the FST Returns @@ -67,7 +79,7 @@ def start_states(self): return self._start_states @property - def final_states(self): + def final_states(self) -> Set[State]: """ Get the final states of the FST Returns @@ -77,25 +89,11 @@ def final_states(self): """ return self._final_states - @property - def transitions(self): - """Gives the transitions as a dictionary""" - return self._delta - - def get_number_transitions(self) -> int: - """ Get the number of transitions in the FST - - Returns - ---------- - n_transitions : int - The number of transitions - """ - return sum(len(x) for x in self._delta.values()) - - def add_transition(self, s_from: Any, - input_symbol: Any, - s_to: Any, - output_symbols: Iterable[Any]): + def add_transition(self, + s_from: Hashable, + input_symbol: Hashable, + s_to: Hashable, + output_symbols: Iterable[Hashable]) -> None: """ Add a transition to the FST Parameters @@ -109,20 +107,22 @@ def add_transition(self, s_from: Any, output_symbols : iterable of Any The symbols to output """ + s_from = to_state(s_from) + input_symbol = to_symbol(input_symbol) + s_to = to_state(s_to) + output_symbols = tuple(to_symbol(x) for x in output_symbols + if x != Epsilon()) self._states.add(s_from) self._states.add(s_to) - if input_symbol != "epsilon": + if input_symbol != Epsilon(): self._input_symbols.add(input_symbol) - for output_symbol in output_symbols: - if output_symbol != "epsilon": - self._output_symbols.add(output_symbol) - head = (s_from, input_symbol) - if head in self._delta: - self._delta[head].append((s_to, output_symbols)) - else: - self._delta[head] = [(s_to, output_symbols)] - - def add_transitions(self, transitions_list): + self._output_symbols.update(output_symbols) + self._transition_function.add_transition(s_from, + input_symbol, + s_to, + output_symbols) + + def add_transitions(self, transitions: Iterable[InputTransition]) -> None: """ Adds several transitions to the FST @@ -131,15 +131,38 @@ def add_transitions(self, transitions_list): transitions_list : list of tuples The tuples have the form (s_from, in_symbol, s_to, out_symbols) """ - for s_from, input_symbol, s_to, output_symbols in transitions_list: - self.add_transition( - s_from, - input_symbol, - s_to, - output_symbols - ) + for s_from, input_symbol, s_to, output_symbols in transitions: + self.add_transition(s_from, + input_symbol, + s_to, + output_symbols) + + def remove_transition(self, + s_from: Hashable, + input_symbol: Hashable, + s_to: Hashable, + output_symbols: Iterable[Hashable]) -> None: + """ Removes the given transition from the FST """ + s_from = to_state(s_from) + input_symbol = to_symbol(input_symbol) + s_to = to_state(s_to) + output_symbols = tuple(to_symbol(x) for x in output_symbols) + self._transition_function.remove_transition(s_from, + input_symbol, + s_to, + output_symbols) + + def get_number_transitions(self) -> int: + """ Get the number of transitions in the FST + + Returns + ---------- + n_transitions : int + The number of transitions + """ + return self._transition_function.get_number_transitions() - def add_start_state(self, start_state: Any): + def add_start_state(self, start_state: Hashable) -> None: """ Add a start state Parameters @@ -147,10 +170,11 @@ def add_start_state(self, start_state: Any): start_state : any The start state """ + start_state = to_state(start_state) self._states.add(start_state) self._start_states.add(start_state) - def add_final_state(self, final_state: Any): + def add_final_state(self, final_state: Hashable) -> None: """ Add a final state Parameters @@ -158,11 +182,33 @@ def add_final_state(self, final_state: Any): final_state : any The final state to add """ + final_state = to_state(final_state) self._final_states.add(final_state) self._states.add(final_state) - def translate(self, input_word: Iterable[Any], max_length: int = -1) -> \ - Iterable[Any]: + def __call__(self, s_from: Hashable, input_symbol: Hashable) \ + -> TransitionValues: + """ Calls the transition function of the FST """ + s_from = to_state(s_from) + input_symbol = to_symbol(input_symbol) + return self._transition_function(s_from, input_symbol) + + def __contains__(self, transition: InputTransition) -> bool: + """ Whether the given transition is present in the FST """ + s_from, input_symbol, s_to, output_symbols = transition + s_from = to_state(s_from) + input_symbol = to_symbol(input_symbol) + s_to = to_state(s_to) + output_symbols = tuple(to_symbol(x) for x in output_symbols) + return (s_to, output_symbols) in self(s_from, input_symbol) + + def __iter__(self) -> Iterator[Transition]: + """ Gets an iterator of transitions of the FST """ + yield from self._transition_function + + def translate(self, + input_word: Iterable[Hashable], + max_length: int = -1) -> Iterable[List[Symbol]]: """ Translate a string into another using the FST Parameters @@ -179,7 +225,8 @@ def translate(self, input_word: Iterable[Any], max_length: int = -1) -> \ The translation of the input word """ # (remaining in the input, generated so far, current_state) - to_process = [] + input_word = [to_symbol(x) for x in input_word if x != Epsilon()] + to_process: List[Tuple[List[Symbol], List[Symbol], State]] = [] seen_by_state = {state: [] for state in self.states} for start_state in self._start_states: to_process.append((input_word, [], start_state)) @@ -192,126 +239,21 @@ def translate(self, input_word: Iterable[Any], max_length: int = -1) -> \ yield generated # We try to read an input if len(remaining) != 0: - for next_state, output_string in self._delta.get( - (current_state, remaining[0]), []): + for next_state, output_symbols in self(current_state, + remaining[0]): to_process.append( (remaining[1:], - generated + output_string, + generated + list(output_symbols), next_state)) # We try to read an epsilon transition if max_length == -1 or len(generated) < max_length: - for next_state, output_string in self._delta.get( - (current_state, "epsilon"), []): + for next_state, output_symbols in self(current_state, + Epsilon()): to_process.append((remaining, - generated + output_string, + generated + list(output_symbols), next_state)) - def intersection(self, indexed_grammar): - """ Compute the intersection with an other object - - Equivalent to: - >> fst and indexed_grammar - """ - rules = indexed_grammar.rules - new_rules = [EndRule("T", "epsilon")] - self._extract_consumption_rules_intersection(rules, new_rules) - self._extract_indexed_grammar_rules_intersection(rules, new_rules) - self._extract_terminals_intersection(rules, new_rules) - self._extract_epsilon_transitions_intersection(new_rules) - self._extract_fst_delta_intersection(new_rules) - self._extract_fst_epsilon_intersection(new_rules) - self._extract_fst_duplication_rules_intersection(new_rules) - rules = Rules(new_rules, rules.optim) - return IndexedGrammar(rules).remove_useless_rules() - - def _extract_fst_duplication_rules_intersection(self, new_rules): - for state_p in self._final_states: - for start_state in self._start_states: - new_rules.append(DuplicationRule( - "S", - str((start_state, "S", state_p)), - "T")) - - def _extract_fst_epsilon_intersection(self, new_rules): - for state_p in self._states: - new_rules.append(EndRule( - str((state_p, "epsilon", state_p)), - "epsilon")) - - def _extract_fst_delta_intersection(self, new_rules): - for key, pair in self._delta.items(): - state_p = key[0] - terminal = key[1] - for transition in pair: - state_q = transition[0] - symbol = transition[1] - new_rules.append(EndRule(str((state_p, terminal, state_q)), - symbol)) - - def _extract_epsilon_transitions_intersection(self, new_rules): - for state_p in self._states: - for state_q in self._states: - for state_r in self._states: - new_rules.append(DuplicationRule( - str((state_p, "epsilon", state_q)), - str((state_p, "epsilon", state_r)), - str((state_r, "epsilon", state_q)))) - - def _extract_indexed_grammar_rules_intersection(self, rules, new_rules): - for rule in rules.rules: - if rule.is_duplication(): - for state_p in self._states: - for state_q in self._states: - for state_r in self._states: - new_rules.append(DuplicationRule( - str((state_p, rule.left_term, state_q)), - str((state_p, rule.right_terms[0], state_r)), - str((state_r, rule.right_terms[1], state_q)))) - elif rule.is_production(): - for state_p in self._states: - for state_q in self._states: - new_rules.append(ProductionRule( - str((state_p, rule.left_term, state_q)), - str((state_p, rule.right_term, state_q)), - str(rule.production))) - elif rule.is_end_rule(): - for state_p in self._states: - for state_q in self._states: - new_rules.append(DuplicationRule( - str((state_p, rule.left_term, state_q)), - str((state_p, rule.right_term, state_q)), - "T")) - - def _extract_terminals_intersection(self, rules, new_rules): - terminals = rules.terminals - for terminal in terminals: - for state_p in self._states: - for state_q in self._states: - for state_r in self._states: - new_rules.append(DuplicationRule( - str((state_p, terminal, state_q)), - str((state_p, "epsilon", state_r)), - str((state_r, terminal, state_q)))) - new_rules.append(DuplicationRule( - str((state_p, terminal, state_q)), - str((state_p, terminal, state_r)), - str((state_r, "epsilon", state_q)))) - - def _extract_consumption_rules_intersection(self, rules, new_rules): - consumptions = rules.consumption_rules - for consumption_rule in consumptions: - for consumption in consumptions[consumption_rule]: - for state_r in self._states: - for state_s in self._states: - new_rules.append(ConsumptionRule( - consumption.f_parameter, - str((state_r, consumption.left_term, state_s)), - str((state_r, consumption.right, state_s)))) - - def __and__(self, other): - return self.intersection(other) - - def union(self, other_fst): + def union(self, other_fst: "FST") -> "FST": """ Makes the union of two fst Parameters @@ -332,7 +274,7 @@ def union(self, other_fst): other_fst._copy_into(union_fst, state_renaming, 1) return union_fst - def __or__(self, other_fst): + def __or__(self, other_fst: "FST") -> "FST": """ Makes the union of two fst Parameters @@ -348,33 +290,48 @@ def __or__(self, other_fst): """ return self.union(other_fst) - def _copy_into(self, union_fst, state_renaming, idx): + def _copy_into(self, + union_fst: "FST", + state_renaming: StateRenaming, + idx: int) -> None: self._add_extremity_states_to(union_fst, state_renaming, idx) self._add_transitions_to(union_fst, state_renaming, idx) - def _add_transitions_to(self, union_fst, state_renaming, idx): - for head, transition in self.transitions.items(): - s_from, input_symbol = head - for s_to, output_symbols in transition: - union_fst.add_transition( - state_renaming.get_name(s_from, idx), - input_symbol, - state_renaming.get_name(s_to, idx), - output_symbols) - - def _add_extremity_states_to(self, union_fst, state_renaming, idx): + def _add_transitions_to(self, + union_fst: "FST", + state_renaming: StateRenaming, + idx: int) -> None: + for (s_from, input_symbol), (s_to, output_symbols) in self: + union_fst.add_transition( + state_renaming.get_renamed_state(s_from, idx), + input_symbol, + state_renaming.get_renamed_state(s_to, idx), + output_symbols) + + def _add_extremity_states_to(self, + union_fst: "FST", + state_renaming: StateRenaming, + idx: int) -> None: self._add_start_states_to(union_fst, state_renaming, idx) self._add_final_states_to(union_fst, state_renaming, idx) - def _add_final_states_to(self, union_fst, state_renaming, idx): + def _add_final_states_to(self, + union_fst: "FST", + state_renaming: StateRenaming, + idx: int) -> None: for state in self.final_states: - union_fst.add_final_state(state_renaming.get_name(state, idx)) + union_fst.add_final_state( + state_renaming.get_renamed_state(state, idx)) - def _add_start_states_to(self, union_fst, state_renaming, idx): + def _add_start_states_to(self, + union_fst: "FST", + state_renaming: StateRenaming, + idx: int) -> None: for state in self.start_states: - union_fst.add_start_state(state_renaming.get_name(state, idx)) + union_fst.add_start_state( + state_renaming.get_renamed_state(state, idx)) - def concatenate(self, other_fst): + def concatenate(self, other_fst: "FST") -> "FST": """ Makes the concatenation of two fst Parameters @@ -398,14 +355,14 @@ def concatenate(self, other_fst): for final_state in self.final_states: for start_state in other_fst.start_states: fst_concatenate.add_transition( - state_renaming.get_name(final_state, 0), - "epsilon", - state_renaming.get_name(start_state, 1), + state_renaming.get_renamed_state(final_state, 0), + Epsilon(), + state_renaming.get_renamed_state(start_state, 1), [] ) return fst_concatenate - def __add__(self, other): + def __add__(self, other: "FST") -> "FST": """ Makes the concatenation of two fst Parameters @@ -421,13 +378,13 @@ def __add__(self, other): """ return self.concatenate(other) - def _get_state_renaming(self, other_fst): - state_renaming = FSTStateRemaining() - state_renaming.add_states(list(self.states), 0) + def _get_state_renaming(self, other_fst: "FST") -> StateRenaming: + state_renaming = StateRenaming() + state_renaming.add_states(self.states, 0) state_renaming.add_states(other_fst.states, 1) return state_renaming - def kleene_star(self): + def kleene_star(self) -> "FST": """ Computes the kleene star of the FST @@ -437,29 +394,29 @@ def kleene_star(self): A FST representing the kleene star of the FST """ fst_star = FST() - state_renaming = FSTStateRemaining() - state_renaming.add_states(list(self.states), 0) + state_renaming = StateRenaming() + state_renaming.add_states(self.states, 0) self._add_extremity_states_to(fst_star, state_renaming, 0) self._add_transitions_to(fst_star, state_renaming, 0) for final_state in self.final_states: for start_state in self.start_states: fst_star.add_transition( - state_renaming.get_name(final_state, 0), - "epsilon", - state_renaming.get_name(start_state, 0), + state_renaming.get_renamed_state(final_state, 0), + Epsilon(), + state_renaming.get_renamed_state(start_state, 0), [] ) for final_state in self.start_states: for start_state in self.final_states: fst_star.add_transition( - state_renaming.get_name(final_state, 0), - "epsilon", - state_renaming.get_name(start_state, 0), + state_renaming.get_renamed_state(final_state, 0), + Epsilon(), + state_renaming.get_renamed_state(start_state, 0), [] ) return fst_star - def to_networkx(self) -> nx.MultiDiGraph: + def to_networkx(self) -> MultiDiGraph: """ Transform the current fst into a networkx graph @@ -469,13 +426,13 @@ def to_networkx(self) -> nx.MultiDiGraph: A networkx MultiDiGraph representing the fst """ - graph = nx.MultiDiGraph() + graph = MultiDiGraph() for state in self._states: - graph.add_node(state, + graph.add_node(state.value, is_start=state in self.start_states, is_final=state in self.final_states, peripheries=2 if state in self.final_states else 1, - label=state) + label=state.value) if state in self.start_states: graph.add_node("starting_" + str(state), label="", @@ -483,18 +440,18 @@ def to_networkx(self) -> nx.MultiDiGraph: height=.0, width=.0) graph.add_edge("starting_" + str(state), - state) - for s_from, input_symbol in self._delta: - for s_to, output_symbols in self._delta[(s_from, input_symbol)]: - graph.add_edge( - s_from, - s_to, - label=(json.dumps(input_symbol) + " -> " + - json.dumps(output_symbols))) + state.value) + for (s_from, input_symbol), (s_to, output_symbols) in self: + input_symbol = input_symbol.value + output_symbols = tuple(map(lambda x: x.value, output_symbols)) + graph.add_edge( + s_from.value, + s_to.value, + label=(input_symbol, output_symbols)) return graph @classmethod - def from_networkx(cls, graph): + def from_networkx(cls, graph: MultiDiGraph) -> "FST": """ Import a networkx graph into an finite state transducer. \ The imported graph requires to have the good format, i.e. to come \ @@ -519,10 +476,8 @@ def from_networkx(cls, graph): for s_to in graph[s_from]: for transition in graph[s_from][s_to].values(): if "label" in transition: - in_symbol, out_symbols = transition["label"].split( - " -> ") - in_symbol = json.loads(in_symbol) - out_symbols = json.loads(out_symbols) + label = transition["label"] + in_symbol, out_symbols = label fst.add_transition(s_from, in_symbol, s_to, @@ -534,7 +489,7 @@ def from_networkx(cls, graph): fst.add_final_state(node) return fst - def write_as_dot(self, filename): + def write_as_dot(self, filename: str) -> None: """ Write the FST in dot format into a file @@ -546,63 +501,18 @@ def write_as_dot(self, filename): """ write_dot(self.to_networkx(), filename) + def copy(self) -> "FST": + """ Copies the FST """ + return FST(states=self.states, + input_symbols=self.input_symbols, + output_symbols=self.output_symbols, + transition_function=self._transition_function.copy(), + start_states=self.start_states, + final_states=self.final_states) -class FSTStateRemaining: - """Class for remaining the states in FST""" - - def __init__(self): - self._state_renaming = {} - self._seen_states = set() - - def add_state(self, state, idx): - """ - Add a state - Parameters - ---------- - state : str - The state to add - idx : int - The index of the FST - """ - if state in self._seen_states: - counter = 0 - new_state = state + str(counter) - while new_state in self._seen_states: - counter += 1 - new_state = state + str(counter) - self._state_renaming[(state, idx)] = new_state - self._seen_states.add(new_state) - else: - self._state_renaming[(state, idx)] = state - self._seen_states.add(state) - - def add_states(self, states, idx): - """ - Add states - Parameters - ---------- - states : list of str - The states to add - idx : int - The index of the FST - """ - for state in states: - self.add_state(state, idx) - - def get_name(self, state, idx): - """ - Get the renaming. - - Parameters - ---------- - state : str - The state to rename - idx : int - The index of the FST + def __copy__(self) -> "FST": + return self.copy() - Returns - ------- - new_name : str - The new name of the state - """ - return self._state_renaming[(state, idx)] + def to_dict(self) -> Dict[TransitionKey, TransitionValues]: + """Gives the transitions as a dictionary""" + return self._transition_function.to_dict() diff --git a/pyformlang/fst/tests/test_fst.py b/pyformlang/fst/tests/test_fst.py index ec4e4da..7914389 100644 --- a/pyformlang/fst/tests/test_fst.py +++ b/pyformlang/fst/tests/test_fst.py @@ -4,10 +4,7 @@ import pytest -from pyformlang.fst import FST -from pyformlang.indexed_grammar import ( - DuplicationRule, ProductionRule, EndRule, - ConsumptionRule, IndexedGrammar, Rules) +from pyformlang.fst import FST, TransitionFunction, State, Symbol @pytest.fixture @@ -94,34 +91,6 @@ def test_translate(self): assert ["b", "c"] in translation assert ["b"] + ["c"] * 9 in translation - def test_intersection_indexed_grammar(self): - """ Test the intersection with indexed grammar """ - l_rules = [] - rules = Rules(l_rules) - indexed_grammar = IndexedGrammar(rules) - fst = FST() - intersection = fst & indexed_grammar - assert intersection.is_empty() - - l_rules.append(ProductionRule("S", "D", "f")) - l_rules.append(DuplicationRule("D", "A", "B")) - l_rules.append(ConsumptionRule("f", "A", "Afinal")) - l_rules.append(ConsumptionRule("f", "B", "Bfinal")) - l_rules.append(EndRule("Afinal", "a")) - l_rules.append(EndRule("Bfinal", "b")) - - rules = Rules(l_rules) - indexed_grammar = IndexedGrammar(rules) - intersection = fst.intersection(indexed_grammar) - assert intersection.is_empty() - - fst.add_start_state("q0") - fst.add_final_state("final") - fst.add_transition("q0", "a", "q1", ["a"]) - fst.add_transition("q1", "b", "final", ["b"]) - intersection = fst.intersection(indexed_grammar) - assert not intersection.is_empty() - def test_union(self, fst0, fst1): """ Tests the union""" fst_union = fst0.union(fst1) @@ -210,12 +179,73 @@ def test_paper(self): (2, "alone", 3, ["seul"])]) fst.add_start_state(0) fst.add_final_state(3) - assert list(fst.translate(["I", "am", "alone"])) == \ - [['Je', 'suis', 'seul'], - ['Je', 'suis', 'tout', 'seul']] + translation = list(fst.translate(["I", "am", "alone"])) + assert ['Je', 'suis', 'seul'] in translation + assert ['Je', 'suis', 'tout', 'seul'] in translation + assert len(translation) == 2 fst = FST.from_networkx(fst.to_networkx()) - assert list(fst.translate(["I", "am", "alone"])) == \ - [['Je', 'suis', 'seul'], - ['Je', 'suis', 'tout', 'seul']] + translation = list(fst.translate(["I", "am", "alone"])) + assert ['Je', 'suis', 'seul'] in translation + assert ['Je', 'suis', 'tout', 'seul'] in translation + assert len(translation) == 2 fst.write_as_dot("fst.dot") assert path.exists("fst.dot") + + def test_contains(self, fst0: FST): + """ Tests the containment of transition in the FST """ + assert ("q0", "a", "q1", ["b"]) in fst0 + assert ("a", "b", "c", ["d"]) not in fst0 + fst0.add_transition("a", "b", "c", {"d"}) + assert ("a", "b", "c", ["d"]) in fst0 + + def test_iter(self, fst0: FST): + """ Tests the iteration of FST transitions """ + fst0.add_transition("q1", "A", "q2", ["B"]) + fst0.add_transition("q1", "A", "q2", ["C", "D"]) + transitions = list(iter(fst0)) + assert (("q0", "a"), ("q1", tuple("b"))) in transitions + assert (("q1", "A"), ("q2", tuple("B"))) in transitions + assert (("q1", "A"), ("q2", ("C", "D"))) in transitions + assert len(transitions) == 3 + + def test_remove_transition(self, fst0: FST): + """ Tests the removal of transition from the FST """ + assert ("q0", "a", "q1", ["b"]) in fst0 + fst0.remove_transition("q0", "a", "q1", ["b"]) + assert ("q0", "a", "q1", ["b"]) not in fst0 + fst0.remove_transition("q0", "a", "q1", ["b"]) + assert ("q0", "a", "q1", ["b"]) not in fst0 + assert fst0.get_number_transitions() == 0 + + def test_initialization(self): + """ Tests the initialization of the FST """ + fst = FST(states={0}, + input_symbols={"a", "b"}, + output_symbols={"c"}, + start_states={1}, + final_states={2}) + assert fst.states == {0, 1, 2} + assert fst.input_symbols == {"a", "b"} + assert fst.output_symbols == {"c"} + assert fst.get_number_transitions() == 0 + assert not list(iter(fst)) + + function = TransitionFunction() + function.add_transition(State(1), Symbol("a"), State(2), (Symbol("b"),)) + function.add_transition(State(1), Symbol("a"), State(2), (Symbol("c"),)) + fst = FST(transition_function=function) + assert fst.get_number_transitions() == 2 + assert (1, "a", 2, ["b"]) in fst + assert (1, "a", 2, ["c"]) in fst + assert fst(1, "a") == {(2, tuple("b")), (2, tuple("c"))} + + def test_copy(self, fst0: FST): + """ Tests the copying of the FST """ + fst_copy = fst0.copy() + assert fst_copy.states == fst0.states + assert fst_copy.input_symbols == fst0.input_symbols + assert fst_copy.output_symbols == fst0.output_symbols + assert fst_copy.start_states == fst0.start_states + assert fst_copy.final_states == fst0.final_states + assert fst_copy.to_dict() == fst0.to_dict() + assert fst_copy is not fst0 diff --git a/pyformlang/fst/transition_function.py b/pyformlang/fst/transition_function.py new file mode 100644 index 0000000..9f75805 --- /dev/null +++ b/pyformlang/fst/transition_function.py @@ -0,0 +1,78 @@ +""" The transition function of Finite State Transducer """ + +from typing import Dict, Set, Tuple, Iterator, Iterable +from copy import deepcopy + +from ..objects.finite_automaton_objects import State, Symbol + +TransitionKey = Tuple[State, Symbol] +TransitionValue = Tuple[State, Tuple[Symbol, ...]] +TransitionValues = Set[TransitionValue] +Transition = Tuple[TransitionKey, TransitionValue] + + +class TransitionFunction(Iterable[Transition]): + """ The transition function of Finite State Transducer """ + + def __init__(self) -> None: + self._transitions: Dict[TransitionKey, TransitionValues] = {} + + def add_transition(self, + s_from: State, + input_symbol: Symbol, + s_to: State, + output_symbols: Tuple[Symbol, ...]) -> None: + """ Adds given transition to the function """ + key = (s_from, input_symbol) + value = (s_to, output_symbols) + self._transitions.setdefault(key, set()).add(value) + + def remove_transition(self, + s_from: State, + input_symbol: Symbol, + s_to: State, + output_symbols: Tuple[Symbol, ...]) -> None: + """ Removes given transition from the function """ + key = (s_from, input_symbol) + value = (s_to, output_symbols) + self._transitions.get(key, set()).discard(value) + + def get_number_transitions(self) -> int: + """ Gets the number of transitions in the function + + Returns + ---------- + n_transitions : int + The number of transitions + """ + return sum(len(x) for x in self._transitions.values()) + + def __call__(self, s_from: State, input_symbol: Symbol) \ + -> TransitionValues: + """ Calls the transition function """ + return self._transitions.get((s_from, input_symbol), set()) + + def __contains__(self, transition: Transition) -> bool: + """ Whether the given transition is present in the function """ + key, value = transition + return value in self(*key) + + def __iter__(self) -> Iterator[Transition]: + """ Gets an iterator of transitions of the function """ + for key, values in self._transitions.items(): + for value in values: + yield key, value + + def copy(self) -> "TransitionFunction": + """ Copies the transition function """ + new_tf = TransitionFunction() + for key, value in self: + new_tf.add_transition(*key, *value) + return new_tf + + def __copy__(self) -> "TransitionFunction": + return self.copy() + + def to_dict(self) -> Dict[TransitionKey, TransitionValues]: + """ Gives the transition function as a dictionary """ + return deepcopy(self._transitions) diff --git a/pyformlang/fst/utils.py b/pyformlang/fst/utils.py new file mode 100644 index 0000000..0a6c243 --- /dev/null +++ b/pyformlang/fst/utils.py @@ -0,0 +1,69 @@ +""" Utility for FST """ + +from typing import Dict, Set, Iterable, Tuple + +from ..objects.finite_automaton_objects import State +from ..objects.finite_automaton_objects.utils import to_state + + +class StateRenaming: + """ Class for renaming the states in FST """ + + def __init__(self) -> None: + self._state_renaming: Dict[Tuple[str, int], str] = {} + self._seen_states: Set[str] = set() + + def add_state(self, state: State, idx: int) -> None: + """ + Add a state + Parameters + ---------- + state : State + The state to add + idx : int + The index of the FST + """ + current_name = str(state) + if current_name in self._seen_states: + counter = 0 + new_name = current_name + str(counter) + while new_name in self._seen_states: + counter += 1 + new_name = current_name + str(counter) + self._state_renaming[(current_name, idx)] = new_name + self._seen_states.add(new_name) + else: + self._state_renaming[(current_name, idx)] = current_name + self._seen_states.add(current_name) + + def add_states(self, states: Iterable[State], idx: int) -> None: + """ + Add states + Parameters + ---------- + states : Iterable of States + The states to add + idx : int + The index of the FST + """ + for state in states: + self.add_state(state, idx) + + def get_renamed_state(self, state: State, idx: int) -> State: + """ + Get the renaming. + + Parameters + ---------- + state : State + The state to rename + idx : int + The index of the FST + + Returns + ------- + new_name : State + Renamed state + """ + renaming = self._state_renaming[(str(state), idx)] + return to_state(renaming) diff --git a/pyformlang/indexed_grammar/__init__.py b/pyformlang/indexed_grammar/__init__.py index 14da624..00e6f18 100644 --- a/pyformlang/indexed_grammar/__init__.py +++ b/pyformlang/indexed_grammar/__init__.py @@ -23,16 +23,23 @@ """ from .rules import Rules +from .reduced_rule import ReducedRule from .consumption_rule import ConsumptionRule from .end_rule import EndRule from .production_rule import ProductionRule from .duplication_rule import DuplicationRule from .indexed_grammar import IndexedGrammar +from ..objects.cfg_objects import CFGObject, Variable, Terminal, Epsilon __all__ = ["Rules", + "ReducedRule", "ConsumptionRule", "EndRule", "ProductionRule", "DuplicationRule", - "IndexedGrammar"] + "IndexedGrammar", + "CFGObject", + "Variable", + "Terminal", + "Epsilon"] diff --git a/pyformlang/indexed_grammar/consumption_rule.py b/pyformlang/indexed_grammar/consumption_rule.py index 39b2e16..b59609f 100644 --- a/pyformlang/indexed_grammar/consumption_rule.py +++ b/pyformlang/indexed_grammar/consumption_rule.py @@ -3,9 +3,12 @@ the stack """ -from typing import Any, Iterable, AbstractSet +from typing import List, Set, Hashable, Any + +from pyformlang.cfg import CFGObject, Variable, Terminal from .reduced_rule import ReducedRule +from ..objects.cfg_objects.utils import to_variable, to_terminal class ConsumptionRule(ReducedRule): @@ -23,31 +26,16 @@ class ConsumptionRule(ReducedRule): The non terminal on the right (here B) """ - @property - def right_term(self): - raise NotImplementedError + def __init__(self, + f_param: Hashable, + left_term: Hashable, + right_term: Hashable) -> None: + self._f = to_terminal(f_param) + self._left_term = to_variable(left_term) + self._right_term = to_variable(right_term) @property - def right_terms(self): - raise NotImplementedError - - def __init__(self, f_param: Any, left: Any, right: Any): - self._f = f_param - self._right = right - self._left_term = left - - def is_consumption(self) -> bool: - """Whether the rule is a consumption rule or not - - Returns - ---------- - is_consumption : bool - Whether the rule is a consumption rule or not - """ - return True - - @property - def f_parameter(self) -> Any: + def f_parameter(self) -> Terminal: """Gets the symbol which is consumed Returns @@ -58,38 +46,49 @@ def f_parameter(self) -> Any: return self._f @property - def production(self): + def production(self) -> Terminal: raise NotImplementedError @property - def right(self) -> Any: - """Gets the symbole on the right of the rule + def left_term(self) -> Variable: + """Gets the symbol on the left of the rule + + left : any + The left symbol of the rule + """ + return self._left_term + + @property + def right_term(self) -> Variable: + """Gets the symbol on the right of the rule right : any The right symbol """ - return self._right + return self._right_term @property - def left_term(self) -> Any: - """Gets the symbol on the left of the rule + def right_terms(self) -> List[CFGObject]: + """Gives the non-terminals on the right of the rule - left : any - The left symbol of the rule + Returns + --------- + right_terms : iterable of any + The right terms of the rule """ - return self._left_term + return [self._right_term] @property - def non_terminals(self) -> Iterable[Any]: + def non_terminals(self) -> Set[Variable]: """Gets the non-terminals used in the rule non_terminals : iterable of any The non_terminals used in the rule """ - return [self._left_term, self._right] + return {self._left_term, self._right_term} @property - def terminals(self) -> AbstractSet[Any]: + def terminals(self) -> Set[Terminal]: """Gets the terminals used in the rule terminals : set of any @@ -97,10 +96,12 @@ def terminals(self) -> AbstractSet[Any]: """ return {self._f} - def __repr__(self): - return self._left_term + " [ " + self._f + " ] -> " + self._right + def __eq__(self, other: Any) -> bool: + if not isinstance(other, ConsumptionRule): + return False + return other.left_term == self.left_term \ + and other.right_term == self.right_term \ + and other.f_parameter == self.f_parameter - def __eq__(self, other): - return other.is_consumption() and other.left_term == \ - self.left_term and other.right == self.right and \ - other.f_parameter() == self.f_parameter + def __repr__(self) -> str: + return f"{self._left_term} [ {self._f} ] -> {self._right_term}" diff --git a/pyformlang/indexed_grammar/duplication_rule.py b/pyformlang/indexed_grammar/duplication_rule.py index cc719b5..de9238f 100644 --- a/pyformlang/indexed_grammar/duplication_rule.py +++ b/pyformlang/indexed_grammar/duplication_rule.py @@ -2,9 +2,12 @@ A representation of a duplication rule, i.e. a rule that duplicates the stack """ -from typing import Any, Iterable, AbstractSet, Tuple +from typing import List, Set, Hashable, Any + +from pyformlang.cfg import CFGObject, Variable, Terminal from .reduced_rule import ReducedRule +from ..objects.cfg_objects.utils import to_variable class DuplicationRule(ReducedRule): @@ -21,34 +24,39 @@ class DuplicationRule(ReducedRule): The second non-terminal on the right of the rule (C here) """ - @property - def production(self): - raise NotImplementedError + def __init__(self, + left_term: Hashable, + right_term0: Hashable, + right_term1: Hashable) -> None: + self._left_term = to_variable(left_term) + self._right_terms = (to_variable(right_term0), + to_variable(right_term1)) @property - def right_term(self): + def f_parameter(self) -> Terminal: raise NotImplementedError @property - def f_parameter(self): + def production(self) -> Terminal: raise NotImplementedError - def __init__(self, left_term, right_term0, right_term1): - self._left_term = left_term - self._right_terms = (right_term0, right_term1) - - def is_duplication(self) -> bool: - """Whether the rule is a duplication rule or not + @property + def left_term(self) -> Variable: + """Gives the non-terminal on the left of the rule Returns - ---------- - is_duplication : bool - Whether the rule is a duplication rule or not + --------- + left_term : any + The left term of the rule """ - return True + return self._left_term + + @property + def right_term(self) -> CFGObject: + raise NotImplementedError @property - def right_terms(self) -> Tuple[Any, Any]: + def right_terms(self) -> List[CFGObject]: """Gives the non-terminals on the right of the rule Returns @@ -56,21 +64,10 @@ def right_terms(self) -> Tuple[Any, Any]: right_terms : iterable of any The right terms of the rule """ - return self._right_terms + return list(self._right_terms) @property - def left_term(self) -> Any: - """Gives the non-terminal on the left of the rule - - Returns - --------- - left_term : any - The left term of the rule - """ - return self._left_term - - @property - def non_terminals(self) -> Iterable[Any]: + def non_terminals(self) -> Set[Variable]: """Gives the set of non-terminals used in this rule Returns @@ -78,10 +75,10 @@ def non_terminals(self) -> Iterable[Any]: non_terminals : iterable of any The non terminals used in this rule """ - return [self._left_term, self._right_terms[0], self._right_terms[1]] + return {self._left_term, *self._right_terms} @property - def terminals(self) -> AbstractSet[Any]: + def terminals(self) -> Set[Terminal]: """Gets the terminals used in the rule Returns @@ -91,11 +88,13 @@ def terminals(self) -> AbstractSet[Any]: """ return set() - def __repr__(self): - """Gives a string representation of the rule, ignoring the sigmas""" - return self._left_term + " -> " + self._right_terms[0] + \ - " " + self._right_terms[1] + def __eq__(self, other: Any) -> bool: + if not isinstance(other, DuplicationRule): + return False + return other.left_term == self._left_term \ + and other.right_terms == self.right_terms - def __eq__(self, other): - return other.is_duplication() and other.left_term == \ - self._left_term and other.right_terms == self.right_terms + def __repr__(self) -> str: + """Gives a string representation of the rule, ignoring the sigmas""" + return f"{self._left_term} -> " \ + + f"{self._right_terms[0]} {self._right_terms[1]}" diff --git a/pyformlang/indexed_grammar/end_rule.py b/pyformlang/indexed_grammar/end_rule.py index 7979b84..1433ca9 100644 --- a/pyformlang/indexed_grammar/end_rule.py +++ b/pyformlang/indexed_grammar/end_rule.py @@ -2,9 +2,12 @@ Represents a end rule, i.e. a rule which give only a terminal """ -from typing import Any, Iterable, AbstractSet +from typing import List, Set, Hashable, Any + +from pyformlang.cfg import CFGObject, Variable, Terminal from .reduced_rule import ReducedRule +from ..objects.cfg_objects.utils import to_variable, to_terminal class EndRule(ReducedRule): @@ -19,30 +22,31 @@ class EndRule(ReducedRule): The terminal on the right, "a" here """ + def __init__(self, left_term: Hashable, right_term: Hashable) -> None: + self._left_term = to_variable(left_term) + self._right_term = to_terminal(right_term) + @property - def production(self): + def f_parameter(self) -> Terminal: raise NotImplementedError @property - def right_terms(self): + def production(self) -> Terminal: raise NotImplementedError - def __init__(self, left, right): - self._left_term = left - self._right_term = right - - def is_end_rule(self) -> bool: - """Whether the rule is an end rule or not + @property + def left_term(self) -> Variable: + """Gets the non-terminal on the left of the rule Returns - ---------- - is_end : bool - Whether the rule is an end rule or not + --------- + left_term : any + The left non-terminal of the rule """ - return True + return self._left_term @property - def right_term(self) -> Any: + def right_term(self) -> Terminal: """Gets the terminal on the right of the rule Returns @@ -53,18 +57,18 @@ def right_term(self) -> Any: return self._right_term @property - def left_term(self) -> Any: - """Gets the non-terminal on the left of the rule + def right_terms(self) -> List[CFGObject]: + """Gives the terminals on the right of the rule Returns --------- - left_term : any - The left non-terminal of the rule + right_terms : iterable of any + The right terms of the rule """ - return self._left_term + return [self._right_term] @property - def non_terminals(self) -> Iterable[Any]: + def non_terminals(self) -> Set[Variable]: """Gets the non-terminals used Returns @@ -72,10 +76,10 @@ def non_terminals(self) -> Iterable[Any]: non_terminals : iterable of any The non terminals used in this rule """ - return [self._left_term] + return {self._left_term} @property - def terminals(self) -> AbstractSet[Any]: + def terminals(self) -> Set[Terminal]: """Gets the terminals used Returns @@ -85,14 +89,12 @@ def terminals(self) -> AbstractSet[Any]: """ return {self._right_term} - def __repr__(self): - """Gets the string representation of the rule""" - return self._left_term + " -> " + self._right_term - - def __eq__(self, other): - return other.is_end_rule() and other.left_term == self.left_term\ + def __eq__(self, other: Any) -> bool: + if not isinstance(other, EndRule): + return False + return other.left_term == self.left_term \ and other.right_term == self.right_term - @property - def f_parameter(self): - raise NotImplementedError + def __repr__(self) -> str: + """Gets the string representation of the rule""" + return f"{self._left_term} -> {self._right_term}" diff --git a/pyformlang/indexed_grammar/indexed_grammar.py b/pyformlang/indexed_grammar/indexed_grammar.py index 5e66917..0037222 100644 --- a/pyformlang/indexed_grammar/indexed_grammar.py +++ b/pyformlang/indexed_grammar/indexed_grammar.py @@ -2,13 +2,21 @@ Representation of an indexed grammar """ -from typing import Any, Iterable, AbstractSet +# pylint: disable=cell-var-from-loop -import pyformlang +from typing import Dict, List, Set, FrozenSet, Tuple, Hashable +from pyformlang.cfg import CFGObject, Variable, Terminal +from pyformlang.fst import FST + +from .rules import Rules +from .reduced_rule import ReducedRule from .duplication_rule import DuplicationRule from .production_rule import ProductionRule -from .rules import Rules +from .consumption_rule import ConsumptionRule +from .end_rule import EndRule +from .utils import addrec_bis +from ..objects.cfg_objects.utils import to_variable class IndexedGrammar: @@ -24,40 +32,60 @@ class IndexedGrammar: def __init__(self, rules: Rules, - start_variable: Any = "S"): - self.rules = rules - self.start_variable = start_variable + start_variable: Hashable = "S") -> None: + self._rules = rules + self._start_variable = to_variable(start_variable) # Precompute all non-terminals - self.non_terminals = rules.non_terminals - self.non_terminals.append(self.start_variable) - self.non_terminals = set(self.non_terminals) + non_terminals = self.non_terminals # We cache the marked items in case of future update of the query - self.marked = {} + self._marked: Dict[CFGObject, Set[FrozenSet[Variable]]] = {} # Initialize the marked symbols # Mark the identity - for non_terminal_a in self.non_terminals: - self.marked[non_terminal_a] = set() + for non_terminal_a in non_terminals: + self._marked[non_terminal_a] = set() temp = frozenset({non_terminal_a}) - self.marked[non_terminal_a].add(temp) + self._marked[non_terminal_a].add(temp) # Mark all end symbols - for non_terminal_a in self.non_terminals: - if exists(self.rules.rules, - lambda x: x.is_end_rule() - and x.left_term == non_terminal_a): - self.marked[non_terminal_a].add(frozenset()) + for non_terminal_a in non_terminals: + if any(map(lambda x: isinstance(x, EndRule) + and x.left_term == non_terminal_a, + self._rules.rules)): + self._marked[non_terminal_a].add(frozenset()) + + @property + def rules(self) -> Rules: + """ Get the rules of the grammar """ + return self._rules + + @property + def start_variable(self) -> Variable: + """ Get the start variable of the grammar """ + return self._start_variable @property - def terminals(self) -> Iterable[Any]: + def non_terminals(self) -> Set[Variable]: + """Get all the non-terminals in the grammar + + Returns + ---------- + terminals : iterable of any + The non-terminals used in the grammar + """ + return {self.start_variable} | self._rules.non_terminals + + @property + def terminals(self) -> Set[Terminal]: """Get all the terminals in the grammar Returns ---------- terminals : iterable of any - The terminals used in the rules + The terminals used in the grammar """ - return self.rules.terminals + return self._rules.terminals - def _duplication_processing(self, rule: DuplicationRule): + def _duplication_processing(self, rule: DuplicationRule) \ + -> Tuple[bool, bool]: """Processes a duplication rule Parameters @@ -68,9 +96,9 @@ def _duplication_processing(self, rule: DuplicationRule): was_modified = False need_stop = False right_term_marked0 = [] - for marked_term0 in self.marked[rule.right_terms[0]]: + for marked_term0 in self._marked[rule.right_terms[0]]: right_term_marked1 = [] - for marked_term1 in self.marked[rule.right_terms[1]]: + for marked_term1 in self._marked[rule.right_terms[1]]: if marked_term0 <= marked_term1: temp = marked_term1 elif marked_term1 <= marked_term0: @@ -78,26 +106,27 @@ def _duplication_processing(self, rule: DuplicationRule): else: temp = marked_term0.union(marked_term1) # Check if it was marked before - if temp not in self.marked[rule.left_term]: + if temp not in self._marked[rule.left_term]: was_modified = True if rule.left_term == rule.right_terms[0]: right_term_marked0.append(temp) elif rule.left_term == rule.right_terms[1]: right_term_marked1.append(temp) else: - self.marked[rule.left_term].add(temp) + self._marked[rule.left_term].add(temp) # Stop condition, no need to continue - if rule.left_term == self.start_variable and len( + if rule.left_term == self._start_variable and len( temp) == 0: need_stop = True for temp in right_term_marked1: - self.marked[rule.right_terms[1]].add(temp) + self._marked[rule.right_terms[1]].add(temp) for temp in right_term_marked0: - self.marked[rule.right_terms[0]].add(temp) + self._marked[rule.right_terms[0]].add(temp) return was_modified, need_stop - def _production_process(self, rule: ProductionRule): + def _production_process(self, rule: ProductionRule) \ + -> Tuple[bool, bool]: """Processes a production rule Parameters @@ -108,19 +137,19 @@ def _production_process(self, rule: ProductionRule): was_modified = False # f_rules contains the consumption rules associated with # the current production symbol - f_rules = self.rules.consumption_rules.setdefault( + f_rules = self._rules.consumption_rules.setdefault( rule.production, []) # l_rules contains the left symbol plus what is marked on # the right side l_temp = [(x.left_term, - self.marked[x.right]) for x in f_rules] + self._marked[x.right_term]) for x in f_rules] marked_symbols = [x.left_term for x in f_rules] # Process all combinations of consumption rule was_modified |= addrec_bis(l_temp, - self.marked[rule.left_term], - self.marked[rule.right_term]) + self._marked[rule.left_term], + self._marked[rule.right_term]) # End condition - if frozenset() in self.marked[self.start_variable]: + if frozenset() in self._marked[self._start_variable]: return was_modified, True # Is it useful? if rule.right_term in marked_symbols: @@ -129,17 +158,17 @@ def _production_process(self, rule: ProductionRule): for sub_term in [sub_term for sub_term in term[1] if sub_term not in - self.marked[rule.left_term]]: + self._marked[rule.left_term]]: was_modified = True - self.marked[rule.left_term].add(sub_term) - if (rule.left_term == self.start_variable and + self._marked[rule.left_term].add(sub_term) + if (rule.left_term == self._start_variable and len(sub_term) == 0): return was_modified, True # Edge case - if frozenset() in self.marked[rule.right_term]: - if frozenset() not in self.marked[rule.left_term]: + if frozenset() in self._marked[rule.right_term]: + if frozenset() not in self._marked[rule.left_term]: was_modified = True - self.marked[rule.left_term].add(frozenset()) + self._marked[rule.left_term].add(frozenset()) return was_modified, False def is_empty(self) -> bool: @@ -154,28 +183,28 @@ def is_empty(self) -> bool: was_modified = True while was_modified: was_modified = False - for rule in self.rules.rules: + for rule in self._rules.rules: # If we have a duplication rule, we mark all combinations of # the sets marked on the right side for the symbol on the left # side - if rule.is_duplication(): + if isinstance(rule, DuplicationRule): dup_res = self._duplication_processing(rule) was_modified |= dup_res[0] if dup_res[1]: return False - elif rule.is_production(): + elif isinstance(rule, ProductionRule): prod_res = self._production_process(rule) if prod_res[1]: return False was_modified |= prod_res[0] - if frozenset() in self.marked[self.start_variable]: + if frozenset() in self._marked[self._start_variable]: return False return True - def __bool__(self): + def __bool__(self) -> bool: return not self.is_empty() - def get_reachable_non_terminals(self) -> AbstractSet[Any]: + def get_reachable_non_terminals(self) -> Set[Variable]: """ Get the reachable symbols Returns @@ -184,10 +213,10 @@ def get_reachable_non_terminals(self) -> AbstractSet[Any]: The reachable symbols from the start state """ # Preprocess - reachable_from = {} - consumption_rules = self.rules.consumption_rules - for rule in self.rules.rules: - if rule.is_duplication(): + reachable_from: Dict[Variable, Set[CFGObject]] = {} + consumption_rules = self._rules.consumption_rules + for rule in self._rules.rules: + if isinstance(rule, DuplicationRule): left = rule.left_term right0 = rule.right_terms[0] right1 = rule.right_terms[1] @@ -195,7 +224,7 @@ def get_reachable_non_terminals(self) -> AbstractSet[Any]: reachable_from[left] = set() reachable_from[left].add(right0) reachable_from[left].add(right1) - if rule.is_production(): + if isinstance(rule, ProductionRule): left = rule.left_term right = rule.right_term if left not in reachable_from: @@ -204,22 +233,23 @@ def get_reachable_non_terminals(self) -> AbstractSet[Any]: for key in consumption_rules: for rule in consumption_rules[key]: left = rule.left_term - right = rule.right + right = rule.right_term if left not in reachable_from: reachable_from[left] = set() reachable_from[left].add(right) # Processing - to_process = [self.start_variable] - reachables = {self.start_variable} + to_process = [self._start_variable] + reachables = {self._start_variable} while to_process: current = to_process.pop() - for symbol in reachable_from.get(current, []): + for symbol in reachable_from.get(current, set()): if symbol not in reachables: - reachables.add(symbol) - to_process.append(symbol) + variable = to_variable(symbol) + reachables.add(variable) + to_process.append(variable) return reachables - def get_generating_non_terminals(self) -> AbstractSet[Any]: + def get_generating_non_terminals(self) -> Set[Variable]: """ Get the generating symbols Returns @@ -228,8 +258,8 @@ def get_generating_non_terminals(self) -> AbstractSet[Any]: The generating symbols from the start state """ # Preprocess - generating_from = {} - duplication_pointer = {} + generating_from: Dict[Variable, Set[Variable]] = {} + duplication_pointer: Dict[CFGObject, List[List]] = {} generating = set() to_process = [] self._preprocess_rules_generating(duplication_pointer, generating, @@ -250,40 +280,42 @@ def get_generating_non_terminals(self) -> AbstractSet[Any]: to_process.append(duplication[0]) return generating - def _preprocess_consumption_rules_generating(self, generating_from): - for key in self.rules.consumption_rules: - for rule in self.rules.consumption_rules[key]: + def _preprocess_consumption_rules_generating( + self, + generating_from: Dict[Variable, Set[Variable]]) \ + -> None: + for key in self._rules.consumption_rules: + for rule in self._rules.consumption_rules[key]: left = rule.left_term - right = rule.right + right = rule.right_term if right in generating_from: generating_from[right].add(left) else: generating_from[right] = {left} - def _preprocess_rules_generating(self, duplication_pointer, generating, - generating_from, to_process): - for rule in self.rules.rules: - if rule.is_duplication(): + def _preprocess_rules_generating( + self, + duplication_pointer: Dict[CFGObject, List[List]], + generating: Set[Variable], + generating_from: Dict[Variable, Set[Variable]], + to_process: List[Variable]) \ + -> None: + for rule in self._rules.rules: + if isinstance(rule, DuplicationRule): left = rule.left_term right0 = rule.right_terms[0] right1 = rule.right_terms[1] temp = [left, 2] - if right0 in duplication_pointer: - duplication_pointer[right0].append(temp) - else: - duplication_pointer[right0] = [temp] - if right1 in duplication_pointer: - duplication_pointer[right1].append(temp) - else: - duplication_pointer[right1] = [temp] - if rule.is_production(): + duplication_pointer.setdefault(right0, []).append(temp) + duplication_pointer.setdefault(right1, []).append(temp) + if isinstance(rule, ProductionRule): left = rule.left_term right = rule.right_term if right in generating_from: generating_from[right].add(left) else: generating_from[right] = {left} - if rule.is_end_rule(): + if isinstance(rule, EndRule): left = rule.left_term if left not in generating: generating.add(left) @@ -303,36 +335,36 @@ def remove_useless_rules(self) -> "IndexedGrammar": l_rules = [] generating = self.get_generating_non_terminals() reachables = self.get_reachable_non_terminals() - consumption_rules = self.rules.consumption_rules - for rule in self.rules.rules: - if rule.is_duplication(): + consumption_rules = self._rules.consumption_rules + for rule in self._rules.rules: + if isinstance(rule, DuplicationRule): left = rule.left_term right0 = rule.right_terms[0] right1 = rule.right_terms[1] if all(x in generating and x in reachables for x in [left, right0, right1]): l_rules.append(rule) - if rule.is_production(): + if isinstance(rule, ProductionRule): left = rule.left_term right = rule.right_term if all(x in generating and x in reachables for x in [left, right]): l_rules.append(rule) - if rule.is_end_rule(): + if isinstance(rule, EndRule): left = rule.left_term if left in generating and left in reachables: l_rules.append(rule) for key in consumption_rules: for rule in consumption_rules[key]: left = rule.left_term - right = rule.right + right = rule.right_term if all(x in generating and x in reachables for x in [left, right]): l_rules.append(rule) - rules = Rules(l_rules, self.rules.optim) + rules = Rules(l_rules, self._rules.optim) return IndexedGrammar(rules) - def intersection(self, other: Any) -> "IndexedGrammar": + def intersection(self, other: FST) -> "IndexedGrammar": """ Computes the intersection of the current indexed grammar with the \ other object @@ -356,14 +388,18 @@ def intersection(self, other: Any) -> "IndexedGrammar": When trying to intersection with something else than a regular expression or a finite automaton """ - if isinstance(other, pyformlang.regular_expression.Regex): - other = other.to_epsilon_nfa() - if isinstance(other, pyformlang.finite_automaton.FiniteAutomaton): - fst = other.to_fst() - return fst.intersection(self) - raise NotImplementedError - - def __and__(self, other): + new_rules: List[ReducedRule] = [EndRule("T", "epsilon")] + self._extract_consumption_rules_intersection(other, new_rules) + self._extract_indexed_grammar_rules_intersection(other, new_rules) + self._extract_terminals_intersection(other, new_rules) + self._extract_epsilon_transitions_intersection(other, new_rules) + self._extract_fst_delta_intersection(other, new_rules) + self._extract_fst_epsilon_intersection(other, new_rules) + self._extract_fst_duplication_rules_intersection(other, new_rules) + rules = Rules(new_rules, self.rules.optim) + return IndexedGrammar(rules).remove_useless_rules() + + def __and__(self, other: FST) -> "IndexedGrammar": """ Computes the intersection of the current indexed grammar with the other object @@ -379,96 +415,118 @@ def __and__(self, other): """ return self.intersection(other) - -def exists(list_elements, check_function): - """exists - Check whether at least an element x of l is True for f(x) - :param list_elements: A list of elements to test - :param check_function: The checking function (takes one parameter and \ - return a boolean) - """ - for element in list_elements: - if check_function(element): - return True - return False - - -def addrec_bis(l_sets, marked_left, marked_right): - """addrec_bis - Optimized version of addrec - :param l_sets: a list containing tuples (C, M) where: - * C is a non-terminal on the left of a consumption rule - * M is the set of the marked set for the right non-terminal in the - production rule - :param marked_left: Sets which are marked for the non-terminal on the - left of the production rule - :param marked_right: Sets which are marked for the non-terminal on the - right of the production rule - """ - was_modified = False - for marked in list(marked_right): - l_temp = [x for x in l_sets if x[0] in marked] - s_temp = [x[0] for x in l_temp] - # At least one symbol to consider - if frozenset(s_temp) == marked and len(marked) > 0: - was_modified |= addrec_ter(l_temp, marked_left) - return was_modified - - -def addrec_ter(l_sets, marked_left): - """addrec - Explores all possible combination of consumption rules to mark a - production rule. - :param l_sets: a list containing tuples (C, M) where: - * C is a non-terminal on the left of a consumption rule - * M is the set of the marked set for the right non-terminal in the - production rule - :param marked_left: Sets which are marked for the non-terminal on the - left of the production rule - :return Whether an element was actually marked - """ - # End condition, nothing left to process - temp_in = [x[0] for x in l_sets] - exists_after = [ - exists(l_sets[index + 1:], lambda x: x[0] == l_sets[index][0]) - for index in range(len(l_sets))] - exists_before = [l_sets[index][0] in temp_in[:index] - for index in range(len(l_sets))] - marked_sets = [l_sets[index][1] for index in range(len(l_sets))] - marked_sets = [sorted(x, key=lambda x: -len(x)) for x in marked_sets] - # Try to optimize by having an order of the sets - sorted_zip = sorted(zip(exists_after, exists_before, marked_sets), - key=lambda x: -len(x[2])) - exists_after, exists_before, marked_sets = \ - zip(*sorted_zip) - res = False - # contains tuples of index, temp_set - to_process = [(0, frozenset())] - done = set() - while to_process: - index, new_temp = to_process.pop() - if index >= len(l_sets): - # Check if at least one non-terminal was considered, then if the - # set of non-terminals considered is marked of the right - # non-terminal in the production rule, then if a new set is - # marked or not - if new_temp not in marked_left: - marked_left.add(new_temp) - res = True - continue - if exists_before[index] or exists_after[index]: - to_append = (index + 1, new_temp) - to_process.append(to_append) - if not exists_before[index]: - # For all sets which were marked for the current consumption rule - for marked_set in marked_sets[index]: - if marked_set <= new_temp: - to_append = (index + 1, new_temp) - elif new_temp <= marked_set: - to_append = (index + 1, marked_set) - else: - to_append = (index + 1, new_temp.union(marked_set)) - if to_append not in done: - done.add(to_append) - to_process.append(to_append) - return res + def _extract_fst_duplication_rules_intersection( + self, + other: FST, + new_rules: List[ReducedRule]) \ + -> None: + for final_state in other.final_states: + for start_state in other.start_states: + new_rules.append(DuplicationRule( + "S", + (start_state.value, "S", final_state.value), + "T")) + + def _extract_fst_epsilon_intersection( + self, + other: FST, + new_rules: List[ReducedRule]) \ + -> None: + for state in other.states: + new_rules.append(EndRule( + (state.value, "epsilon", state.value), + "epsilon")) + + def _extract_fst_delta_intersection( + self, + other: FST, + new_rules: List[ReducedRule]) \ + -> None: + for (s_from, symb_from), (s_to, symb_to) in other: + new_rules.append(EndRule( + (s_from.value, symb_from.value, s_to.value), + tuple(map(lambda x: x.value, symb_to)))) + + def _extract_epsilon_transitions_intersection( + self, + other: FST, + new_rules: List[ReducedRule]) \ + -> None: + for state_p in other.states: + for state_q in other.states: + for state_r in other.states: + new_rules.append(DuplicationRule( + (state_p.value, "epsilon", state_q.value), + (state_p.value, "epsilon", state_r.value), + (state_r.value, "epsilon", state_q.value))) + + def _extract_indexed_grammar_rules_intersection( + self, + other: FST, + new_rules: List[ReducedRule]) \ + -> None: + for rule in self.rules.rules: + if isinstance(rule, DuplicationRule): + for state_p in other.states: + for state_q in other.states: + for state_r in other.states: + new_rules.append(DuplicationRule( + (state_p.value, rule.left_term.value, + state_q.value), + (state_p.value, rule.right_terms[0].value, + state_r.value), + (state_r.value, rule.right_terms[1].value, + state_q.value))) + elif isinstance(rule, ProductionRule): + for state_p in other.states: + for state_q in other.states: + new_rules.append(ProductionRule( + (state_p.value, rule.left_term.value, + state_q.value), + (state_p.value, rule.right_term.value, + state_q.value), + rule.production.value)) + elif isinstance(rule, EndRule): + for state_p in other.states: + for state_q in other.states: + new_rules.append(DuplicationRule( + (state_p.value, rule.left_term.value, + state_q.value), + (state_p.value, rule.right_term.value, + state_q.value), + "T")) + + def _extract_terminals_intersection( + self, + other: FST, + new_rules: List[ReducedRule]) \ + -> None: + for terminal in self.rules.terminals: + for state_p in other.states: + for state_q in other.states: + for state_r in other.states: + new_rules.append(DuplicationRule( + (state_p.value, terminal.value, state_q.value), + (state_p.value, "epsilon", state_r.value), + (state_r.value, terminal.value, state_q.value))) + new_rules.append(DuplicationRule( + (state_p.value, terminal.value, state_q.value), + (state_p.value, terminal.value, state_r.value), + (state_r.value, "epsilon", state_q.value))) + + def _extract_consumption_rules_intersection( + self, + other: FST, + new_rules: List[ReducedRule]) \ + -> None: + consumptions = self.rules.consumption_rules + for terminal in consumptions: + for consumption in consumptions[terminal]: + for state_r in other.states: + for state_s in other.states: + new_rules.append(ConsumptionRule( + consumption.f_parameter.value, + (state_r.value, consumption.left_term.value, + state_s.value), + (state_r.value, consumption.right_term.value, + state_s.value))) diff --git a/pyformlang/indexed_grammar/production_rule.py b/pyformlang/indexed_grammar/production_rule.py index bccb68f..a8bfc47 100644 --- a/pyformlang/indexed_grammar/production_rule.py +++ b/pyformlang/indexed_grammar/production_rule.py @@ -2,9 +2,12 @@ Represents a production rule, i.e. a rule that pushed on the stack """ -from typing import Any, Iterable, AbstractSet +from typing import List, Set, Hashable, Any + +from pyformlang.cfg import CFGObject, Variable, Terminal from .reduced_rule import ReducedRule +from ..objects.cfg_objects.utils import to_variable, to_terminal class ProductionRule(ReducedRule): @@ -21,31 +24,20 @@ class ProductionRule(ReducedRule): The terminal used in the rule, "r" here """ - @property - def right_terms(self): - raise NotImplementedError + def __init__(self, + left_term: Hashable, + right_term: Hashable, + production: Hashable) -> None: + self._left_term = to_variable(left_term) + self._right_term = to_variable(right_term) + self._production = to_terminal(production) @property - def f_parameter(self): + def f_parameter(self) -> Terminal: raise NotImplementedError - def __init__(self, left, right, prod): - self._production = prod - self._left_term = left - self._right_term = right - - def is_production(self) -> bool: - """Whether the rule is a production rule or not - - Returns - ---------- - is_production : bool - Whether the rule is a production rule or not - """ - return True - @property - def production(self) -> Any: + def production(self) -> Terminal: """Gets the terminal used in the production Returns @@ -56,7 +48,7 @@ def production(self) -> Any: return self._production @property - def left_term(self) -> Any: + def left_term(self) -> Variable: """Gets the non-terminal on the left side of the rule Returns @@ -67,7 +59,7 @@ def left_term(self) -> Any: return self._left_term @property - def right_term(self) -> Any: + def right_term(self) -> Variable: """Gets the non-terminal on the right side of the rule Returns @@ -78,7 +70,18 @@ def right_term(self) -> Any: return self._right_term @property - def non_terminals(self) -> Iterable[Any]: + def right_terms(self) -> List[CFGObject]: + """Gives the non-terminals on the right of the rule + + Returns + --------- + right_terms : iterable of any + The right terms of the rule + """ + return [self._right_term] + + @property + def non_terminals(self) -> Set[Variable]: """Gets the non-terminals used in the rule Returns @@ -86,10 +89,10 @@ def non_terminals(self) -> Iterable[Any]: non_terminals : any The non terminals used in this rules """ - return [self._left_term, self._right_term] + return {self._left_term, self._right_term} @property - def terminals(self) -> AbstractSet[Any]: + def terminals(self) -> Set[Terminal]: """Gets the terminals used in the rule Returns @@ -99,12 +102,13 @@ def terminals(self) -> AbstractSet[Any]: """ return {self._production} - def __repr__(self): - """Gets the string representation of the rule""" - return self._left_term + " -> " + \ - self._right_term + "[ " + self._production + " ]" + def __eq__(self, other: Any) -> bool: + if not isinstance(other, ProductionRule): + return False + return other.left_term == self.left_term \ + and other.right_term == self.right_term \ + and other.production == self.production - def __eq__(self, other): - return other.is_production() and other.left_term == \ - self.left_term and other.right_term == self.right_term \ - and other.production == self.production + def __repr__(self) -> str: + """Gets the string representation of the rule""" + return f"{self._left_term} -> {self._right_term} [ {self._production} ]" diff --git a/pyformlang/indexed_grammar/reduced_rule.py b/pyformlang/indexed_grammar/reduced_rule.py index 5d463ad..39c6c50 100644 --- a/pyformlang/indexed_grammar/reduced_rule.py +++ b/pyformlang/indexed_grammar/reduced_rule.py @@ -1,8 +1,12 @@ """ Representation of a reduced rule """ + +from typing import List, Set, Any from abc import abstractmethod +from pyformlang.cfg import CFGObject, Variable, Terminal + class ReducedRule: """Representation of all possible reduced forms. @@ -13,102 +17,90 @@ class ReducedRule: * Duplication """ - def is_consumption(self) -> bool: - """Whether the rule is a consumption rule or not - - Returns - ---------- - is_consumption : bool - Whether the rule is a consumption rule or not - """ - return False - - def is_duplication(self) -> bool: - """Whether the rule is a duplication rule or not - - Returns - ---------- - is_duplication : bool - Whether the rule is a duplication rule or not - """ - return False - - def is_production(self) -> bool: - """Whether the rule is a production rule or not + @property + @abstractmethod + def f_parameter(self) -> Terminal: + """The f parameter Returns ---------- - is_production : bool - Whether the rule is a production rule or not + f : cfg.Terminal + The f parameter """ - return False + raise NotImplementedError - def is_end_rule(self) -> bool: - """Whether the rule is an end rule or not + @property + @abstractmethod + def production(self) -> Terminal: + """The production Returns ---------- - is_end : bool - Whether the rule is an end rule or not + right_terms : any + The production """ - return False + raise NotImplementedError @property @abstractmethod - def f_parameter(self): - """The f parameter + def left_term(self) -> Variable: + """The left term Returns ---------- - f : any - The f parameter + left_term : cfg.Variable + The left term of the rule """ raise NotImplementedError @property @abstractmethod - def left_term(self): - """The left term + def right_term(self) -> CFGObject: + """The unique right term Returns ---------- - left_term : any - The left term of the rule + right_term : cfg.cfg_object.CFGObject + The unique right term of the rule """ raise NotImplementedError @property @abstractmethod - def right_terms(self): + def right_terms(self) -> List[CFGObject]: """The right terms Returns ---------- - right_terms : iterable of any + right_terms : list of cfg.cfg_object.CFGObject The right terms of the rule """ raise NotImplementedError @property @abstractmethod - def right_term(self): - """The unique right term + def non_terminals(self) -> Set[Variable]: + """Gets the non-terminals used in the rule - Returns - ---------- - right_term : iterable of any - The unique right term of the rule + terminals : set of cfg.Variable + The non-terminals used in the rule """ raise NotImplementedError @property @abstractmethod - def production(self): - """The production + def terminals(self) -> Set[Terminal]: + """Gets the terminals used in the rule - Returns - ---------- - right_terms : any - The production + terminals : set of cfg.Terminal + The terminals used in the rule """ raise NotImplementedError + + @abstractmethod + def __eq__(self, other: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def __repr__(self) -> str: + raise NotImplementedError diff --git a/pyformlang/indexed_grammar/rule_ordering.py b/pyformlang/indexed_grammar/rule_ordering.py index af9236d..677c9f5 100644 --- a/pyformlang/indexed_grammar/rule_ordering.py +++ b/pyformlang/indexed_grammar/rule_ordering.py @@ -2,15 +2,18 @@ Representation of a way to order rules """ -from typing import Iterable, Dict, Any +from typing import List, Dict from queue import Queue -import random +from random import shuffle +from networkx import DiGraph, core_number, minimum_spanning_tree -import networkx as nx +from pyformlang.cfg import Terminal from .reduced_rule import ReducedRule from .consumption_rule import ConsumptionRule +from .duplication_rule import DuplicationRule +from .production_rule import ProductionRule class RuleOrdering: @@ -25,12 +28,13 @@ class RuleOrdering: The consumption rules of the indexed grammar """ - def __init__(self, rules: Iterable[ReducedRule], - conso_rules: Dict[Any, ConsumptionRule]): + def __init__(self, + rules: List[ReducedRule], + conso_rules: Dict[Terminal, List[ConsumptionRule]]) -> None: self.rules = rules self.conso_rules = conso_rules - def reverse(self) -> Iterable[ReducedRule]: + def reverse(self) -> List[ReducedRule]: """The reverser ordering, simply reverse the order. Returns @@ -41,26 +45,26 @@ def reverse(self) -> Iterable[ReducedRule]: """ return self.rules[::1] - def _get_graph(self): + def _get_graph(self) -> DiGraph: """ Get the graph of the non-terminals in the rules. If there there is a link between A and B (oriented), it means that modifying A may modify B""" - di_graph = nx.DiGraph() + di_graph = DiGraph() for rule in self.rules: - if rule.is_duplication(): + if isinstance(rule, DuplicationRule): if rule.right_terms[0] != rule.left_term: di_graph.add_edge(rule.right_terms[0], rule.left_term) if rule.right_terms[1] != rule.left_term: di_graph.add_edge(rule.right_terms[1], rule.left_term) - if rule.is_production(): + if isinstance(rule, ProductionRule): f_rules = self.conso_rules.setdefault( rule.production, []) for f_rule in f_rules: - if f_rule.right != rule.left_term: - di_graph.add_edge(f_rule.right, rule.left_term) + if f_rule.right_term != rule.left_term: + di_graph.add_edge(f_rule.right_term, rule.left_term) return di_graph - def order_by_core(self, reverse: bool = False) -> Iterable[ReducedRule]: + def order_by_core(self, reverse: bool = False) -> List[ReducedRule]: """Order the rules using the core numbers Parameters @@ -77,7 +81,7 @@ def order_by_core(self, reverse: bool = False) -> Iterable[ReducedRule]: # Graph construction di_graph = self._get_graph() # Get core number, careful the degree is in + out - core_numbers = nx.core_number(di_graph) + core_numbers = dict(core_number(di_graph)) new_order = sorted(self.rules, key=lambda x: core_numbers.setdefault( x.left_term, 0)) @@ -86,7 +90,7 @@ def order_by_core(self, reverse: bool = False) -> Iterable[ReducedRule]: return new_order def order_by_arborescence(self, reverse: bool = True) \ - -> Iterable[ReducedRule]: + -> List[ReducedRule]: """Order the rules using the arborescence method. Parameters @@ -101,7 +105,7 @@ def order_by_arborescence(self, reverse: bool = True) \ The rules ordered using core number """ di_graph = self._get_graph() - arborescence = nx.minimum_spanning_tree(di_graph.to_undirected()) + arborescence = minimum_spanning_tree(di_graph.to_undirected()) to_process = Queue() processed = set() res = {} @@ -126,7 +130,7 @@ def order_by_arborescence(self, reverse: bool = True) \ return new_order @staticmethod - def _get_len_out(di_graph, rule): + def _get_len_out(di_graph: DiGraph, rule: ReducedRule) -> int: """Get the number of out edges of a rule (more exactly, the non \ terminal at its left. @@ -141,7 +145,7 @@ def _get_len_out(di_graph, rule): return len(di_graph[rule.left_term]) return 0 - def order_by_edges(self, reverse=False): + def order_by_edges(self, reverse: bool = False) -> List[ReducedRule]: """Order using the number of edges. Parameters @@ -162,7 +166,7 @@ def order_by_edges(self, reverse=False): new_order.reverse() return new_order - def order_random(self): + def order_random(self) -> List[ReducedRule]: """The random ordering Returns @@ -171,5 +175,5 @@ def order_random(self): :class:`~pyformlang.indexed_grammar.ReducedRule` The rules ordered at random """ - random.shuffle(self.rules) + shuffle(self.rules) return self.rules diff --git a/pyformlang/indexed_grammar/rules.py b/pyformlang/indexed_grammar/rules.py index 017845a..f94aa85 100644 --- a/pyformlang/indexed_grammar/rules.py +++ b/pyformlang/indexed_grammar/rules.py @@ -2,12 +2,15 @@ Representations of rules in a indexed grammar """ -from typing import Iterable, Dict, Any, List +from typing import Dict, List, Set, Tuple, Iterable, Hashable +from pyformlang.cfg import Variable, Terminal + +from .reduced_rule import ReducedRule from .production_rule import ProductionRule from .consumption_rule import ConsumptionRule from .rule_ordering import RuleOrdering -from .reduced_rule import ReducedRule +from ..objects.cfg_objects.utils import to_variable, to_terminal class Rules: @@ -30,13 +33,13 @@ class Rules: 8 -> random order """ - def __init__(self, rules: Iterable[ReducedRule], optim: int = 7): - self._rules = [] - self._consumption_rules = {} + def __init__(self, rules: Iterable[ReducedRule], optim: int = 7) -> None: + self._rules: List[ReducedRule] = [] + self._consumption_rules: Dict[Terminal, List[ConsumptionRule]] = {} self._optim = optim for rule in rules: # We separate consumption rule from other - if rule.is_consumption(): + if isinstance(rule, ConsumptionRule): temp = self._consumption_rules.setdefault(rule.f_parameter, []) if rule not in temp: temp.append(rule) @@ -63,7 +66,7 @@ def __init__(self, rules: Iterable[ReducedRule], optim: int = 7): self._rules = rule_ordering.order_random() @property - def optim(self): + def optim(self) -> int: """Gets the optimization number Returns @@ -74,7 +77,7 @@ def optim(self): return self._optim @property - def rules(self) -> Iterable[ReducedRule]: + def rules(self) -> List[ReducedRule]: """Gets the non consumption rules Returns @@ -86,7 +89,7 @@ def rules(self) -> Iterable[ReducedRule]: return self._rules @property - def length(self) -> (int, int): + def length(self) -> Tuple[int, int]: """Get the total number of rules Returns @@ -98,7 +101,7 @@ def length(self) -> (int, int): return len(self._rules), len(self._consumption_rules.values()) @property - def consumption_rules(self) -> Dict[Any, Iterable[ConsumptionRule]]: + def consumption_rules(self) -> Dict[Terminal, List[ConsumptionRule]]: """Gets the consumption rules Returns @@ -111,41 +114,44 @@ def consumption_rules(self) -> Dict[Any, Iterable[ConsumptionRule]]: return self._consumption_rules @property - def terminals(self) -> Iterable[Any]: - """Gets all the terminals used by all the rules + def non_terminals(self) -> Set[Variable]: + """Gets all the non-terminals used by all the rules Returns ---------- - terminals : iterable of any - The terminals used in the rules + non_terminals : iterable of any + The non terminals used in the rule """ - terminals = set() - for temp_rule in self._consumption_rules.values(): - for rule in temp_rule: - terminals = terminals.union(rule.terminals) + non_terminals = set() + for rules in self._consumption_rules.values(): + for rule in rules: + non_terminals.update(rule.non_terminals) for rule in self._rules: - terminals = terminals.union(rule.terminals) - return terminals + non_terminals.update(rule.non_terminals) + return non_terminals @property - def non_terminals(self) -> List[Any]: - """Gets all the non-terminals used by all the rules + def terminals(self) -> Set[Terminal]: + """Gets all the terminals used by all the rules Returns ---------- - non_terminals : iterable of any - The non terminals used in the rule + terminals : iterable of any + The terminals used in the rules """ terminals = set() - for temp_rule in self._consumption_rules.values(): - for rule in temp_rule: - terminals = terminals.union(set(rule.non_terminals)) + for rules in self._consumption_rules.values(): + for rule in rules: + terminals.update(rule.terminals) for rule in self._rules: - terminals = terminals.union(set(rule.non_terminals)) - return list(terminals) + terminals.update(rule.terminals) + return terminals - def remove_production(self, left: Any, right: Any, prod: Any): - """Remove the production rule: + def add_production(self, + left: Hashable, + right: Hashable, + prod: Hashable) -> None: + """Add the production rule: left[sigma] -> right[prod sigma] Parameters @@ -157,14 +163,16 @@ def remove_production(self, left: Any, right: Any, prod: Any): prod : any The production used in the rule """ - self._rules = list(filter(lambda x: not (x.is_production() and - x.left_term == left and - x.right_term == right and - x.production == prod), - self._rules)) + left = to_variable(left) + right = to_variable(right) + prod = to_terminal(prod) + self._rules.append(ProductionRule(left, right, prod)) - def add_production(self, left: Any, right: Any, prod: Any): - """Add the production rule: + def remove_production(self, + left: Hashable, + right: Hashable, + prod: Hashable) -> None: + """Remove the production rule: left[sigma] -> right[prod sigma] Parameters @@ -176,4 +184,11 @@ def add_production(self, left: Any, right: Any, prod: Any): prod : any The production used in the rule """ - self._rules.append(ProductionRule(left, right, prod)) + left = to_variable(left) + right = to_variable(right) + prod = to_terminal(prod) + self._rules = list(filter(lambda x: not (isinstance(x, ProductionRule) + and x.left_term == left + and x.right_term == right + and x.production == prod), + self._rules)) diff --git a/pyformlang/indexed_grammar/tests/test_indexed_grammar.py b/pyformlang/indexed_grammar/tests/test_indexed_grammar.py index 9b184be..6f4bb01 100644 --- a/pyformlang/indexed_grammar/tests/test_indexed_grammar.py +++ b/pyformlang/indexed_grammar/tests/test_indexed_grammar.py @@ -8,6 +8,7 @@ from pyformlang.indexed_grammar import DuplicationRule from pyformlang.indexed_grammar import IndexedGrammar from pyformlang.regular_expression import Regex +from pyformlang.fst import FST class TestIndexedGrammar: @@ -338,7 +339,7 @@ def test_removal_useless(self): assert i_grammar2.non_terminals == \ i_grammar2.get_reachable_non_terminals() - def test_intersection(self): + def test_intersection0(self): """ Tests the intersection of indexed grammar with regex Long to run! """ @@ -349,9 +350,38 @@ def test_intersection(self): EndRule("Bfinal", "b")] rules = Rules(l_rules, 6) indexed_grammar = IndexedGrammar(rules) - i_inter = indexed_grammar.intersection(Regex("a.b")) + fst = Regex("a.b").to_epsilon_nfa().to_fst() + i_inter = indexed_grammar.intersection(fst) assert i_inter + def test_intersection1(self): + """ Test the intersection with fst """ + l_rules = [] + rules = Rules(l_rules) + indexed_grammar = IndexedGrammar(rules) + fst = FST() + intersection = indexed_grammar & fst + assert intersection.is_empty() + + l_rules.append(ProductionRule("S", "D", "f")) + l_rules.append(DuplicationRule("D", "A", "B")) + l_rules.append(ConsumptionRule("f", "A", "Afinal")) + l_rules.append(ConsumptionRule("f", "B", "Bfinal")) + l_rules.append(EndRule("Afinal", "a")) + l_rules.append(EndRule("Bfinal", "b")) + + rules = Rules(l_rules) + indexed_grammar = IndexedGrammar(rules) + intersection = indexed_grammar.intersection(fst) + assert intersection.is_empty() + + fst.add_start_state("q0") + fst.add_final_state("final") + fst.add_transition("q0", "a", "q1", ["a"]) + fst.add_transition("q1", "b", "final", ["b"]) + intersection = indexed_grammar.intersection(fst) + assert not intersection.is_empty() + def get_example_rules(): """ Duplicate example of rules """ diff --git a/pyformlang/indexed_grammar/tests/test_rules.py b/pyformlang/indexed_grammar/tests/test_rules.py index 8a4daaf..e0b3ce7 100644 --- a/pyformlang/indexed_grammar/tests/test_rules.py +++ b/pyformlang/indexed_grammar/tests/test_rules.py @@ -17,18 +17,17 @@ class TestIndexedGrammar: def test_consumption_rules(self): """ Tests the consumption rules """ - conso = ConsumptionRule("end", "C", "T") - terminals = conso.terminals + consumption = ConsumptionRule("end", "C", "T") + terminals = consumption.terminals assert terminals == {"end"} - representation = str(conso) + representation = str(consumption) assert representation == "C [ end ] -> T" def test_duplication_rules(self): """ Tests the duplication rules """ - dupli = DuplicationRule("B0", "A0", "C") - assert dupli.terminals == set() - assert str(dupli) == \ - "B0 -> A0 C" + duplication = DuplicationRule("B0", "A0", "C") + assert duplication.terminals == set() + assert str(duplication) == "B0 -> A0 C" def test_end_rule(self): """ Tests the end rules """ @@ -39,9 +38,9 @@ def test_end_rule(self): def test_production_rules(self): """ Tests the production rules """ - produ = ProductionRule("S", "C", "end") - assert produ.terminals == {"end"} - assert str(produ) == "S -> C[ end ]" + production = ProductionRule("S", "C", "end") + assert production.terminals == {"end"} + assert str(production) == "S -> C [ end ]" def test_rules(self): """ Tests the rules """ diff --git a/pyformlang/indexed_grammar/utils.py b/pyformlang/indexed_grammar/utils.py new file mode 100644 index 0000000..2fce4df --- /dev/null +++ b/pyformlang/indexed_grammar/utils.py @@ -0,0 +1,88 @@ +""" Utility for indexed grammars """ + +# pylint: disable=cell-var-from-loop + +from typing import List, Set, Iterable, Any + + +def addrec_bis(l_sets: Iterable[Any], + marked_left: Set[Any], + marked_right: Set[Any]) -> bool: + """addrec_bis + Optimized version of addrec + :param l_sets: a list containing tuples (C, M) where: + * C is a non-terminal on the left of a consumption rule + * M is the set of the marked set for the right non-terminal in the + production rule + :param marked_left: Sets which are marked for the non-terminal on the + left of the production rule + :param marked_right: Sets which are marked for the non-terminal on the + right of the production rule + """ + was_modified = False + for marked in list(marked_right): + l_temp = [x for x in l_sets if x[0] in marked] + s_temp = [x[0] for x in l_temp] + # At least one symbol to consider + if frozenset(s_temp) == marked and len(marked) > 0: + was_modified |= addrec_ter(l_temp, marked_left) + return was_modified + + +def addrec_ter(l_sets: List[Any], marked_left: Set[Any]) -> bool: + """addrec + Explores all possible combination of consumption rules to mark a + production rule. + :param l_sets: a list containing tuples (C, M) where: + * C is a non-terminal on the left of a consumption rule + * M is the set of the marked set for the right non-terminal in the + production rule + :param marked_left: Sets which are marked for the non-terminal on the + left of the production rule + :return Whether an element was actually marked + """ + # End condition, nothing left to process + temp_in = [x[0] for x in l_sets] + exists_after = [ + any(map(lambda x: x[0] == l_sets[index][0], l_sets[index + 1:])) + for index in range(len(l_sets))] + exists_before = [l_sets[index][0] in temp_in[:index] + for index in range(len(l_sets))] + marked_sets = [l_sets[index][1] for index in range(len(l_sets))] + marked_sets = [sorted(x, key=lambda x: -len(x)) for x in marked_sets] + # Try to optimize by having an order of the sets + sorted_zip = sorted(zip(exists_after, exists_before, marked_sets), + key=lambda x: -len(x[2])) + exists_after, exists_before, marked_sets = \ + zip(*sorted_zip) + res = False + # contains tuples of index, temp_set + to_process = [(0, frozenset())] + done = set() + while to_process: + index, new_temp = to_process.pop() + if index >= len(l_sets): + # Check if at least one non-terminal was considered, then if the + # set of non-terminals considered is marked of the right + # non-terminal in the production rule, then if a new set is + # marked or not + if new_temp not in marked_left: + marked_left.add(new_temp) + res = True + continue + if exists_before[index] or exists_after[index]: + to_append = (index + 1, new_temp) + to_process.append(to_append) + if not exists_before[index]: + # For all sets which were marked for the current consumption rule + for marked_set in marked_sets[index]: + if marked_set <= new_temp: + to_append = (index + 1, new_temp) + elif new_temp <= marked_set: + to_append = (index + 1, marked_set) + else: + to_append = (index + 1, new_temp.union(marked_set)) + if to_append not in done: + done.add(to_append) + to_process.append(to_append) + return res diff --git a/pyformlang/pda/pda.py b/pyformlang/pda/pda.py index f170d02..8a492ed 100644 --- a/pyformlang/pda/pda.py +++ b/pyformlang/pda/pda.py @@ -69,33 +69,23 @@ def __init__(self, transition_function: TransitionFunction = None, start_state: Hashable = None, start_stack_symbol: Hashable = None, - final_states: AbstractSet[Hashable] = None): + final_states: AbstractSet[Hashable] = None) -> None: # pylint: disable=too-many-arguments - if states is not None: - states = {to_state(x) for x in states} - if input_symbols is not None: - input_symbols = {to_symbol(x) for x in input_symbols} - if stack_alphabet is not None: - stack_alphabet = {to_stack_symbol(x) for x in stack_alphabet} - if start_state is not None: - start_state = to_state(start_state) - if start_stack_symbol is not None: - start_stack_symbol = to_stack_symbol(start_stack_symbol) - if final_states is not None: - final_states = {to_state(x) for x in final_states} - self._states: Set[State] = states or set() - self._input_symbols: Set[PDASymbol] = input_symbols or set() - self._stack_alphabet: Set[StackSymbol] = stack_alphabet or set() + self._states = {to_state(x) for x in states or set()} + self._input_symbols = {to_symbol(x) for x in input_symbols or set()} + self._stack_alphabet = {to_stack_symbol(x) + for x in stack_alphabet or set()} self._transition_function = transition_function or TransitionFunction() - self._start_state: Optional[State] = start_state + self._start_state = None if start_state is not None: - self._states.add(start_state) - self._start_stack_symbol: Optional[StackSymbol] = start_stack_symbol + self._start_state = to_state(start_state) + self._states.add(self._start_state) + self._start_stack_symbol = None if start_stack_symbol is not None: - self._stack_alphabet.add(start_stack_symbol) - self._final_states: Set[State] = final_states or set() - for state in self._final_states: - self._states.add(state) + self._start_stack_symbol = to_stack_symbol(start_stack_symbol) + self._stack_alphabet.add(self._start_stack_symbol) + self._final_states = {to_state(x) for x in final_states or set()} + self._states.update(self._final_states) @property def states(self) -> Set[State]: @@ -189,16 +179,6 @@ def add_final_state(self, state: Hashable) -> None: state = to_state(state) self._final_states.add(state) - def get_number_transitions(self) -> int: - """ Gets the number of transitions in the PDA - - Returns - ---------- - n_transitions : int - The number of transitions - """ - return self._transition_function.get_number_transitions() - def add_transition(self, s_from: Hashable, input_symbol: Hashable, @@ -271,6 +251,16 @@ def remove_transition(self, s_to, stack_to) + def get_number_transitions(self) -> int: + """ Gets the number of transitions in the PDA + + Returns + ---------- + n_transitions : int + The number of transitions + """ + return self._transition_function.get_number_transitions() + def __call__(self, s_from: Hashable, input_symbol: Hashable, @@ -665,16 +655,6 @@ def __and__(self, other: DeterministicFiniteAutomaton) -> "PDA": """ return self.intersection(other) - def to_dict(self) -> Dict[TransitionKey, TransitionValues]: - """ - Get the transitions of the PDA as a dictionary - Returns - ------- - transitions : dict - The transitions - """ - return self._transition_function.to_dict() - def to_networkx(self) -> MultiDiGraph: """ Transform the current pda into a networkx graph @@ -785,16 +765,26 @@ def copy(self) -> "PDA": def __copy__(self) -> "PDA": return self.copy() + def to_dict(self) -> Dict[TransitionKey, TransitionValues]: + """ + Get the transitions of the PDA as a dictionary + Returns + ------- + transitions : dict + The transitions + """ + return self._transition_function.to_dict() + @staticmethod def __add_start_state_to_graph(graph: MultiDiGraph, state: State) -> None: """ Adds a starting node to a given graph """ - graph.add_node("starting_" + str(state.value), + graph.add_node("starting_" + str(state), label="", shape=None, height=.0, width=.0) - graph.add_edge("starting_" + str(state.value), + graph.add_edge("starting_" + str(state), state.value) @staticmethod diff --git a/pyformlang/pda/tests/test_pda.py b/pyformlang/pda/tests/test_pda.py index 03173cd..5399041 100644 --- a/pyformlang/pda/tests/test_pda.py +++ b/pyformlang/pda/tests/test_pda.py @@ -87,11 +87,14 @@ def test_represent(self): symb = Symbol("S") assert repr(symb) == "Symbol(S)" state = State("T") + assert str(state) == "T" assert repr(state) == "State(T)" stack_symb = StackSymbol("U") assert repr(stack_symb) == "StackSymbol(U)" assert repr(Epsilon()) == "epsilon" assert str(Epsilon()) == "epsilon" + assert str(StackSymbol(12)) == "12" + assert repr(StackSymbol(12)) == "StackSymbol(12)" def test_transition(self): """ Tests the creation of transition """ diff --git a/pyformlang/pda/transition_function.py b/pyformlang/pda/transition_function.py index fa52bec..4801376 100644 --- a/pyformlang/pda/transition_function.py +++ b/pyformlang/pda/transition_function.py @@ -1,7 +1,7 @@ """ A transition function in a pushdown automaton """ -from copy import deepcopy from typing import Dict, Set, Iterator, Iterable, Tuple +from copy import deepcopy from ..objects.pda_objects import State, Symbol, StackSymbol @@ -17,16 +17,6 @@ class TransitionFunction(Iterable[Transition]): def __init__(self) -> None: self._transitions: Dict[TransitionKey, TransitionValues] = {} - def get_number_transitions(self) -> int: - """ Gets the number of transitions - - Returns - ---------- - n_transitions : int - The number of transitions - """ - return sum(len(x) for x in self._transitions.values()) - # pylint: disable=too-many-arguments def add_transition(self, s_from: State, @@ -64,25 +54,17 @@ def remove_transition(self, stack_to: Tuple[StackSymbol, ...]) -> None: """ Remove the given transition from the function """ key = (s_from, input_symbol, stack_from) - if key in self._transitions: - self._transitions[key].discard((s_to, stack_to)) + self._transitions.get(key, set()).discard((s_to, stack_to)) - def copy(self) -> "TransitionFunction": - """ Copy the current transition function + def get_number_transitions(self) -> int: + """ Gets the number of transitions Returns ---------- - new_tf : :class:`~pyformlang.pda.TransitionFunction` - The copy of the transition function + n_transitions : int + The number of transitions """ - new_tf = TransitionFunction() - for temp_in, transition in self._transitions.items(): - for temp_out in transition: - new_tf.add_transition(*temp_in, *temp_out) - return new_tf - - def __copy__(self) -> "TransitionFunction": - return self.copy() + return sum(len(x) for x in self._transitions.values()) def __call__(self, s_from: State, @@ -99,6 +81,22 @@ def __iter__(self) -> Iterator[Transition]: for value in values: yield key, value + def copy(self) -> "TransitionFunction": + """ Copy the current transition function + + Returns + ---------- + new_tf : :class:`~pyformlang.pda.TransitionFunction` + The copy of the transition function + """ + new_tf = TransitionFunction() + for temp_in, temp_out in self: + new_tf.add_transition(*temp_in, *temp_out) + return new_tf + + def __copy__(self) -> "TransitionFunction": + return self.copy() + def to_dict(self) -> Dict[TransitionKey, TransitionValues]: """Get the dictionary representation of the transitions""" return deepcopy(self._transitions) diff --git a/pyformlang/rsa/recursive_automaton.py b/pyformlang/rsa/recursive_automaton.py index 703bbfa..4c9edbe 100644 --- a/pyformlang/rsa/recursive_automaton.py +++ b/pyformlang/rsa/recursive_automaton.py @@ -100,7 +100,7 @@ def from_regex(cls, regex: Regex, start_nonterminal: Hashable) \ return RecursiveAutomaton(box, {box}) @classmethod - def from_ebnf(cls, text: str, start_nonterminal: Hashable = Symbol("S")) \ + def from_ebnf(cls, text: str, start_nonterminal: Hashable = "S") \ -> "RecursiveAutomaton": """ Create a recursive automaton from ebnf \ (ebnf = Extended Backus-Naur Form) diff --git a/requirements.txt b/requirements.txt index 3179fd5..c65ce0c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ sphinx_rtd_theme numpy pylint pycodestyle +pyright pydot pygments>=2.7.4 # not directly required, pinned by Snyk to avoid a vulnerability pylint>=2.7.0 # not directly required, pinned by Snyk to avoid a vulnerability