This repo contains code to mirror other repos. It also contains the code that is getting mirrored.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

362 lines
12 KiB

  1. """This module implements a CYK parser."""
  2. from collections import defaultdict
  3. import itertools
  4. from ..common import ParseError, is_terminal
  5. from ..lexer import Token
  6. from ..tree import Tree
  7. try:
  8. xrange
  9. except NameError:
  10. xrange = range
  11. class Symbol(object):
  12. """Any grammar symbol."""
  13. def __init__(self, s):
  14. self.s = s
  15. def __repr__(self):
  16. return '%s(%s)' % (type(self).__name__, str(self))
  17. def __str__(self):
  18. return str(self.s)
  19. def __eq__(self, other):
  20. return self.s == str(other)
  21. def __ne__(self, other):
  22. return not self.__eq__(other)
  23. def __hash__(self):
  24. return hash((type(self), str(self.s)))
  25. class T(Symbol):
  26. """Terminal."""
  27. def match(self, s):
  28. return self.s == s.type
  29. class NT(Symbol):
  30. """Non-terminal."""
  31. pass
  32. class Rule(object):
  33. """Context-free grammar rule."""
  34. def __init__(self, lhs, rhs, weight, alias):
  35. super(Rule, self).__init__()
  36. assert isinstance(lhs, NT), lhs
  37. assert all(isinstance(x, NT) or isinstance(x, T) for x in rhs), rhs
  38. self.lhs = lhs
  39. self.rhs = rhs
  40. self.weight = weight
  41. self.alias = alias
  42. def __str__(self):
  43. return '%s -> %s' % (str(self.lhs), ' '.join(str(x) for x in self.rhs))
  44. def __repr__(self):
  45. return str(self)
  46. def __hash__(self):
  47. return hash((self.lhs, tuple(self.rhs)))
  48. def __eq__(self, other):
  49. return self.lhs == other.lhs and self.rhs == other.rhs
  50. def __ne__(self, other):
  51. return not (self == other)
  52. class Grammar(object):
  53. """Context-free grammar."""
  54. def __init__(self, rules):
  55. self.rules = frozenset(rules)
  56. def __eq__(self, other):
  57. return self.rules == other.rules
  58. def __str__(self):
  59. return '\n' + '\n'.join(sorted(repr(x) for x in self.rules)) + '\n'
  60. def __repr__(self):
  61. return str(self)
  62. # Parse tree data structures
  63. class RuleNode(object):
  64. """A node in the parse tree, which also contains the full rhs rule."""
  65. def __init__(self, rule, children, weight=0):
  66. self.rule = rule
  67. self.children = children
  68. self.weight = weight
  69. def __repr__(self):
  70. return 'RuleNode(%s, [%s])' % (repr(self.rule.lhs), ', '.join(str(x) for x in self.children))
  71. class Parser(object):
  72. """Parser wrapper."""
  73. def __init__(self, rules, start):
  74. super(Parser, self).__init__()
  75. self.orig_rules = {rule.alias: rule for rule in rules}
  76. rules = [self._to_rule(rule) for rule in rules]
  77. self.grammar = to_cnf(Grammar(rules))
  78. self.start = NT(start)
  79. def _to_rule(self, lark_rule):
  80. """Converts a lark rule, (lhs, rhs, callback, options), to a Rule."""
  81. return Rule(
  82. NT(lark_rule.origin), [
  83. T(x) if is_terminal(x) else NT(x) for x in lark_rule.expansion
  84. ], weight=lark_rule.options.priority if lark_rule.options and lark_rule.options.priority else 0, alias=lark_rule.alias)
  85. def parse(self, tokenized): # pylint: disable=invalid-name
  86. """Parses input, which is a list of tokens."""
  87. table, trees = _parse(tokenized, self.grammar)
  88. # Check if the parse succeeded.
  89. if all(r.lhs != self.start for r in table[(0, len(tokenized) - 1)]):
  90. raise ParseError('Parsing failed.')
  91. parse = trees[(0, len(tokenized) - 1)][NT(self.start)]
  92. return self._to_tree(revert_cnf(parse))
  93. def _to_tree(self, rule_node):
  94. """Converts a RuleNode parse tree to a lark Tree."""
  95. orig_rule = self.orig_rules[rule_node.rule.alias]
  96. children = []
  97. for i, child in enumerate(rule_node.children):
  98. if isinstance(child, RuleNode):
  99. children.append(self._to_tree(child))
  100. else:
  101. assert isinstance(child.s, Token)
  102. children.append(child.s)
  103. return Tree(orig_rule.origin, children, rule=orig_rule)
  104. def print_parse(node, indent=0):
  105. if isinstance(node, RuleNode):
  106. print(' ' * (indent * 2) + str(node.rule.lhs))
  107. for child in node.children:
  108. print_parse(child, indent + 1)
  109. else:
  110. print(' ' * (indent * 2) + str(node.s))
  111. def _parse(s, g):
  112. """Parses sentence 's' using CNF grammar 'g'."""
  113. # The CYK table. Indexed with a 2-tuple: (start pos, end pos)
  114. table = defaultdict(set)
  115. # Top-level structure is similar to the CYK table. Each cell is a dict from
  116. # rule name to the best (lightest) tree for that rule.
  117. trees = defaultdict(dict)
  118. # Populate base case with existing terminal production rules
  119. for i, w in enumerate(s):
  120. for terminal, rules in g.terminal_rules.items():
  121. if terminal.match(w):
  122. for rule in rules:
  123. table[(i, i)].add(rule)
  124. if (rule.lhs not in trees[(i, i)] or
  125. rule.weight < trees[(i, i)][rule.lhs].weight):
  126. trees[(i, i)][rule.lhs] = RuleNode(rule, [T(w)], weight=rule.weight)
  127. # Iterate over lengths of sub-sentences
  128. for l in xrange(2, len(s) + 1):
  129. # Iterate over sub-sentences with the given length
  130. for i in xrange(len(s) - l + 1):
  131. # Choose partition of the sub-sentence in [1, l)
  132. for p in xrange(i + 1, i + l):
  133. span1 = (i, p - 1)
  134. span2 = (p, i + l - 1)
  135. for r1, r2 in itertools.product(table[span1], table[span2]):
  136. for rule in g.nonterminal_rules.get((r1.lhs, r2.lhs), []):
  137. table[(i, i + l - 1)].add(rule)
  138. r1_tree = trees[span1][r1.lhs]
  139. r2_tree = trees[span2][r2.lhs]
  140. rule_total_weight = rule.weight + r1_tree.weight + r2_tree.weight
  141. if (rule.lhs not in trees[(i, i + l - 1)]
  142. or rule_total_weight < trees[(i, i + l - 1)][rule.lhs].weight):
  143. trees[(i, i + l - 1)][rule.lhs] = RuleNode(rule, [r1_tree, r2_tree], weight=rule_total_weight)
  144. return table, trees
  145. # This section implements context-free grammar converter to Chomsky normal form.
  146. # It also implements a conversion of parse trees from its CNF to the original
  147. # grammar.
  148. # Overview:
  149. # Applies the following operations in this order:
  150. # * TERM: Eliminates non-solitary terminals from all rules
  151. # * BIN: Eliminates rules with more than 2 symbols on their right-hand-side.
  152. # * UNIT: Eliminates non-terminal unit rules
  153. #
  154. # The following grammar characteristics aren't featured:
  155. # * Start symbol appears on RHS
  156. # * Empty rules (epsilon rules)
  157. class CnfWrapper(object):
  158. """CNF wrapper for grammar.
  159. Validates that the input grammar is CNF and provides helper data structures.
  160. """
  161. def __init__(self, grammar):
  162. super(CnfWrapper, self).__init__()
  163. self.grammar = grammar
  164. self.rules = grammar.rules
  165. self.terminal_rules = defaultdict(list)
  166. self.nonterminal_rules = defaultdict(list)
  167. for r in self.rules:
  168. # Validate that the grammar is CNF and populate auxiliary data structures.
  169. assert isinstance(r.lhs, NT), r
  170. assert len(r.rhs) in [1, 2], r
  171. if len(r.rhs) == 1 and isinstance(r.rhs[0], T):
  172. self.terminal_rules[r.rhs[0]].append(r)
  173. elif len(r.rhs) == 2 and all(isinstance(x, NT) for x in r.rhs):
  174. self.nonterminal_rules[tuple(r.rhs)].append(r)
  175. else:
  176. assert False, r
  177. def __eq__(self, other):
  178. return self.grammar == other.grammar
  179. def __repr__(self):
  180. return repr(self.grammar)
  181. class UnitSkipRule(Rule):
  182. """A rule that records NTs that were skipped during transformation."""
  183. def __init__(self, lhs, rhs, skipped_rules, weight, alias):
  184. super(UnitSkipRule, self).__init__(lhs, rhs, weight, alias)
  185. self.skipped_rules = skipped_rules
  186. def __eq__(self, other):
  187. return isinstance(other, type(self)) and self.skipped_rules == other.skipped_rules
  188. __hash__ = Rule.__hash__
  189. def build_unit_skiprule(unit_rule, target_rule):
  190. skipped_rules = []
  191. if isinstance(unit_rule, UnitSkipRule):
  192. skipped_rules += unit_rule.skipped_rules
  193. skipped_rules.append(target_rule)
  194. if isinstance(target_rule, UnitSkipRule):
  195. skipped_rules += target_rule.skipped_rules
  196. return UnitSkipRule(unit_rule.lhs, target_rule.rhs, skipped_rules,
  197. weight=unit_rule.weight + target_rule.weight, alias=unit_rule.alias)
  198. def get_any_nt_unit_rule(g):
  199. """Returns a non-terminal unit rule from 'g', or None if there is none."""
  200. for rule in g.rules:
  201. if len(rule.rhs) == 1 and isinstance(rule.rhs[0], NT):
  202. return rule
  203. return None
  204. def _remove_unit_rule(g, rule):
  205. """Removes 'rule' from 'g' without changing the langugage produced by 'g'."""
  206. new_rules = [x for x in g.rules if x != rule]
  207. refs = [x for x in g.rules if x.lhs == rule.rhs[0]]
  208. new_rules += [build_unit_skiprule(rule, ref) for ref in refs]
  209. return Grammar(new_rules)
  210. def _split(rule):
  211. """Splits a rule whose len(rhs) > 2 into shorter rules."""
  212. rule_str = str(rule.lhs) + '__' + '_'.join(str(x) for x in rule.rhs)
  213. rule_name = '__SP_%s' % (rule_str) + '_%d'
  214. yield Rule(rule.lhs, [rule.rhs[0], NT(rule_name % 1)], weight=rule.weight, alias=rule.alias)
  215. for i in xrange(1, len(rule.rhs) - 2):
  216. yield Rule(NT(rule_name % i), [rule.rhs[i], NT(rule_name % (i + 1))], weight=0, alias='Split')
  217. yield Rule(NT(rule_name % (len(rule.rhs) - 2)), rule.rhs[-2:], weight=0, alias='Split')
  218. def _term(g):
  219. """Applies the TERM rule on 'g' (see top comment)."""
  220. all_t = {x for rule in g.rules for x in rule.rhs if isinstance(x, T)}
  221. t_rules = {t: Rule(NT('__T_%s' % str(t)), [t], weight=0, alias='Term') for t in all_t}
  222. new_rules = []
  223. for rule in g.rules:
  224. if len(rule.rhs) > 1 and any(isinstance(x, T) for x in rule.rhs):
  225. new_rhs = [t_rules[x].lhs if isinstance(x, T) else x for x in rule.rhs]
  226. new_rules.append(Rule(rule.lhs, new_rhs, weight=rule.weight, alias=rule.alias))
  227. new_rules.extend(v for k, v in t_rules.items() if k in rule.rhs)
  228. else:
  229. new_rules.append(rule)
  230. return Grammar(new_rules)
  231. def _bin(g):
  232. """Applies the BIN rule to 'g' (see top comment)."""
  233. new_rules = []
  234. for rule in g.rules:
  235. if len(rule.rhs) > 2:
  236. new_rules += _split(rule)
  237. else:
  238. new_rules.append(rule)
  239. return Grammar(new_rules)
  240. def _unit(g):
  241. """Applies the UNIT rule to 'g' (see top comment)."""
  242. nt_unit_rule = get_any_nt_unit_rule(g)
  243. while nt_unit_rule:
  244. g = _remove_unit_rule(g, nt_unit_rule)
  245. nt_unit_rule = get_any_nt_unit_rule(g)
  246. return g
  247. def to_cnf(g):
  248. """Creates a CNF grammar from a general context-free grammar 'g'."""
  249. g = _unit(_bin(_term(g)))
  250. return CnfWrapper(g)
  251. def unroll_unit_skiprule(lhs, orig_rhs, skipped_rules, children, weight, alias):
  252. if not skipped_rules:
  253. return RuleNode(Rule(lhs, orig_rhs, weight=weight, alias=alias), children, weight=weight)
  254. else:
  255. weight = weight - skipped_rules[0].weight
  256. return RuleNode(
  257. Rule(lhs, [skipped_rules[0].lhs], weight=weight, alias=alias), [
  258. unroll_unit_skiprule(skipped_rules[0].lhs, orig_rhs,
  259. skipped_rules[1:], children,
  260. skipped_rules[0].weight, skipped_rules[0].alias)
  261. ], weight=weight)
  262. def revert_cnf(node):
  263. """Reverts a parse tree (RuleNode) to its original non-CNF form (Node)."""
  264. if isinstance(node, T):
  265. return node
  266. # Reverts TERM rule.
  267. if node.rule.lhs.s.startswith('__T_'):
  268. return node.children[0]
  269. else:
  270. children = []
  271. for child in map(revert_cnf, node.children):
  272. # Reverts BIN rule.
  273. if isinstance(child, RuleNode) and child.rule.lhs.s.startswith('__SP_'):
  274. children += child.children
  275. else:
  276. children.append(child)
  277. # Reverts UNIT rule.
  278. if isinstance(node.rule, UnitSkipRule):
  279. return unroll_unit_skiprule(node.rule.lhs, node.rule.rhs,
  280. node.rule.skipped_rules, children,
  281. node.rule.weight, node.rule.alias)
  282. else:
  283. return RuleNode(node.rule, children)