"""This module implements a CYK parser.""" # Author: https://github.com/ehudt (2018) # # Adapted by Erez from collections import defaultdict import itertools from ..common import ParseError, is_terminal from ..lexer import Token from ..tree import Tree try: xrange except NameError: xrange = range class Symbol(object): """Any grammar symbol.""" def __init__(self, s): self.s = s def __repr__(self): return '%s(%s)' % (type(self).__name__, str(self)) def __str__(self): return str(self.s) def __eq__(self, other): return self.s == str(other) def __ne__(self, other): return not self.__eq__(other) def __hash__(self): return hash((type(self), str(self.s))) class T(Symbol): """Terminal.""" def match(self, s): return self.s == s.type class NT(Symbol): """Non-terminal.""" pass class Rule(object): """Context-free grammar rule.""" def __init__(self, lhs, rhs, weight, alias): super(Rule, self).__init__() assert isinstance(lhs, NT), lhs assert all(isinstance(x, NT) or isinstance(x, T) for x in rhs), rhs self.lhs = lhs self.rhs = rhs self.weight = weight self.alias = alias def __str__(self): return '%s -> %s' % (str(self.lhs), ' '.join(str(x) for x in self.rhs)) def __repr__(self): return str(self) def __hash__(self): return hash((self.lhs, tuple(self.rhs))) def __eq__(self, other): return self.lhs == other.lhs and self.rhs == other.rhs def __ne__(self, other): return not (self == other) class Grammar(object): """Context-free grammar.""" def __init__(self, rules): self.rules = frozenset(rules) def __eq__(self, other): return self.rules == other.rules def __str__(self): return '\n' + '\n'.join(sorted(repr(x) for x in self.rules)) + '\n' def __repr__(self): return str(self) # Parse tree data structures class RuleNode(object): """A node in the parse tree, which also contains the full rhs rule.""" def __init__(self, rule, children, weight=0): self.rule = rule self.children = children self.weight = weight def __repr__(self): return 'RuleNode(%s, [%s])' % (repr(self.rule.lhs), ', '.join(str(x) for x in self.children)) class Parser(object): """Parser wrapper.""" def __init__(self, rules, start): super(Parser, self).__init__() self.orig_rules = {rule.alias: rule for rule in rules} rules = [self._to_rule(rule) for rule in rules] self.grammar = to_cnf(Grammar(rules)) self.start = NT(start) def _to_rule(self, lark_rule): """Converts a lark rule, (lhs, rhs, callback, options), to a Rule.""" return Rule( NT(lark_rule.origin), [ T(x) if is_terminal(x) else NT(x) for x in lark_rule.expansion ], weight=lark_rule.options.priority if lark_rule.options and lark_rule.options.priority else 0, alias=lark_rule.alias) def parse(self, tokenized): # pylint: disable=invalid-name """Parses input, which is a list of tokens.""" table, trees = _parse(tokenized, self.grammar) # Check if the parse succeeded. if all(r.lhs != self.start for r in table[(0, len(tokenized) - 1)]): raise ParseError('Parsing failed.') parse = trees[(0, len(tokenized) - 1)][NT(self.start)] return self._to_tree(revert_cnf(parse)) def _to_tree(self, rule_node): """Converts a RuleNode parse tree to a lark Tree.""" orig_rule = self.orig_rules[rule_node.rule.alias] children = [] for child in rule_node.children: if isinstance(child, RuleNode): children.append(self._to_tree(child)) else: assert isinstance(child.s, Token) children.append(child.s) t = Tree(orig_rule.origin, children) t.rule=orig_rule return t def print_parse(node, indent=0): if isinstance(node, RuleNode): print(' ' * (indent * 2) + str(node.rule.lhs)) for child in node.children: print_parse(child, indent + 1) else: print(' ' * (indent * 2) + str(node.s)) def _parse(s, g): """Parses sentence 's' using CNF grammar 'g'.""" # The CYK table. Indexed with a 2-tuple: (start pos, end pos) table = defaultdict(set) # Top-level structure is similar to the CYK table. Each cell is a dict from # rule name to the best (lightest) tree for that rule. trees = defaultdict(dict) # Populate base case with existing terminal production rules for i, w in enumerate(s): for terminal, rules in g.terminal_rules.items(): if terminal.match(w): for rule in rules: table[(i, i)].add(rule) if (rule.lhs not in trees[(i, i)] or rule.weight < trees[(i, i)][rule.lhs].weight): trees[(i, i)][rule.lhs] = RuleNode(rule, [T(w)], weight=rule.weight) # Iterate over lengths of sub-sentences for l in xrange(2, len(s) + 1): # Iterate over sub-sentences with the given length for i in xrange(len(s) - l + 1): # Choose partition of the sub-sentence in [1, l) for p in xrange(i + 1, i + l): span1 = (i, p - 1) span2 = (p, i + l - 1) for r1, r2 in itertools.product(table[span1], table[span2]): for rule in g.nonterminal_rules.get((r1.lhs, r2.lhs), []): table[(i, i + l - 1)].add(rule) r1_tree = trees[span1][r1.lhs] r2_tree = trees[span2][r2.lhs] rule_total_weight = rule.weight + r1_tree.weight + r2_tree.weight if (rule.lhs not in trees[(i, i + l - 1)] or rule_total_weight < trees[(i, i + l - 1)][rule.lhs].weight): trees[(i, i + l - 1)][rule.lhs] = RuleNode(rule, [r1_tree, r2_tree], weight=rule_total_weight) return table, trees # This section implements context-free grammar converter to Chomsky normal form. # It also implements a conversion of parse trees from its CNF to the original # grammar. # Overview: # Applies the following operations in this order: # * TERM: Eliminates non-solitary terminals from all rules # * BIN: Eliminates rules with more than 2 symbols on their right-hand-side. # * UNIT: Eliminates non-terminal unit rules # # The following grammar characteristics aren't featured: # * Start symbol appears on RHS # * Empty rules (epsilon rules) class CnfWrapper(object): """CNF wrapper for grammar. Validates that the input grammar is CNF and provides helper data structures. """ def __init__(self, grammar): super(CnfWrapper, self).__init__() self.grammar = grammar self.rules = grammar.rules self.terminal_rules = defaultdict(list) self.nonterminal_rules = defaultdict(list) for r in self.rules: # Validate that the grammar is CNF and populate auxiliary data structures. assert isinstance(r.lhs, NT), r assert len(r.rhs) in [1, 2], r if len(r.rhs) == 1 and isinstance(r.rhs[0], T): self.terminal_rules[r.rhs[0]].append(r) elif len(r.rhs) == 2 and all(isinstance(x, NT) for x in r.rhs): self.nonterminal_rules[tuple(r.rhs)].append(r) else: assert False, r def __eq__(self, other): return self.grammar == other.grammar def __repr__(self): return repr(self.grammar) class UnitSkipRule(Rule): """A rule that records NTs that were skipped during transformation.""" def __init__(self, lhs, rhs, skipped_rules, weight, alias): super(UnitSkipRule, self).__init__(lhs, rhs, weight, alias) self.skipped_rules = skipped_rules def __eq__(self, other): return isinstance(other, type(self)) and self.skipped_rules == other.skipped_rules __hash__ = Rule.__hash__ def build_unit_skiprule(unit_rule, target_rule): skipped_rules = [] if isinstance(unit_rule, UnitSkipRule): skipped_rules += unit_rule.skipped_rules skipped_rules.append(target_rule) if isinstance(target_rule, UnitSkipRule): skipped_rules += target_rule.skipped_rules return UnitSkipRule(unit_rule.lhs, target_rule.rhs, skipped_rules, weight=unit_rule.weight + target_rule.weight, alias=unit_rule.alias) def get_any_nt_unit_rule(g): """Returns a non-terminal unit rule from 'g', or None if there is none.""" for rule in g.rules: if len(rule.rhs) == 1 and isinstance(rule.rhs[0], NT): return rule return None def _remove_unit_rule(g, rule): """Removes 'rule' from 'g' without changing the langugage produced by 'g'.""" new_rules = [x for x in g.rules if x != rule] refs = [x for x in g.rules if x.lhs == rule.rhs[0]] new_rules += [build_unit_skiprule(rule, ref) for ref in refs] return Grammar(new_rules) def _split(rule): """Splits a rule whose len(rhs) > 2 into shorter rules.""" rule_str = str(rule.lhs) + '__' + '_'.join(str(x) for x in rule.rhs) rule_name = '__SP_%s' % (rule_str) + '_%d' yield Rule(rule.lhs, [rule.rhs[0], NT(rule_name % 1)], weight=rule.weight, alias=rule.alias) for i in xrange(1, len(rule.rhs) - 2): yield Rule(NT(rule_name % i), [rule.rhs[i], NT(rule_name % (i + 1))], weight=0, alias='Split') yield Rule(NT(rule_name % (len(rule.rhs) - 2)), rule.rhs[-2:], weight=0, alias='Split') def _term(g): """Applies the TERM rule on 'g' (see top comment).""" all_t = {x for rule in g.rules for x in rule.rhs if isinstance(x, T)} t_rules = {t: Rule(NT('__T_%s' % str(t)), [t], weight=0, alias='Term') for t in all_t} new_rules = [] for rule in g.rules: if len(rule.rhs) > 1 and any(isinstance(x, T) for x in rule.rhs): new_rhs = [t_rules[x].lhs if isinstance(x, T) else x for x in rule.rhs] new_rules.append(Rule(rule.lhs, new_rhs, weight=rule.weight, alias=rule.alias)) new_rules.extend(v for k, v in t_rules.items() if k in rule.rhs) else: new_rules.append(rule) return Grammar(new_rules) def _bin(g): """Applies the BIN rule to 'g' (see top comment).""" new_rules = [] for rule in g.rules: if len(rule.rhs) > 2: new_rules += _split(rule) else: new_rules.append(rule) return Grammar(new_rules) def _unit(g): """Applies the UNIT rule to 'g' (see top comment).""" nt_unit_rule = get_any_nt_unit_rule(g) while nt_unit_rule: g = _remove_unit_rule(g, nt_unit_rule) nt_unit_rule = get_any_nt_unit_rule(g) return g def to_cnf(g): """Creates a CNF grammar from a general context-free grammar 'g'.""" g = _unit(_bin(_term(g))) return CnfWrapper(g) def unroll_unit_skiprule(lhs, orig_rhs, skipped_rules, children, weight, alias): if not skipped_rules: return RuleNode(Rule(lhs, orig_rhs, weight=weight, alias=alias), children, weight=weight) else: weight = weight - skipped_rules[0].weight return RuleNode( Rule(lhs, [skipped_rules[0].lhs], weight=weight, alias=alias), [ unroll_unit_skiprule(skipped_rules[0].lhs, orig_rhs, skipped_rules[1:], children, skipped_rules[0].weight, skipped_rules[0].alias) ], weight=weight) def revert_cnf(node): """Reverts a parse tree (RuleNode) to its original non-CNF form (Node).""" if isinstance(node, T): return node # Reverts TERM rule. if node.rule.lhs.s.startswith('__T_'): return node.children[0] else: children = [] for child in map(revert_cnf, node.children): # Reverts BIN rule. if isinstance(child, RuleNode) and child.rule.lhs.s.startswith('__SP_'): children += child.children else: children.append(child) # Reverts UNIT rule. if isinstance(node.rule, UnitSkipRule): return unroll_unit_skiprule(node.rule.lhs, node.rule.rhs, node.rule.skipped_rules, children, node.rule.weight, node.rule.alias) else: return RuleNode(node.rule, children)