This repo contains code to mirror other repos. It also contains the code that is getting mirrored.
Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

369 рядки
12 KiB

  1. """This module implements a CYK parser."""
  2. # Author: https://github.com/ehudt (2018)
  3. #
  4. # Adapted by Erez
  5. from collections import defaultdict
  6. import itertools
  7. from ..common import ParseError, is_terminal
  8. from ..lexer import Token
  9. from ..tree import Tree
  10. try:
  11. xrange
  12. except NameError:
  13. xrange = range
  14. class Symbol(object):
  15. """Any grammar symbol."""
  16. def __init__(self, s):
  17. self.s = s
  18. def __repr__(self):
  19. return '%s(%s)' % (type(self).__name__, str(self))
  20. def __str__(self):
  21. return str(self.s)
  22. def __eq__(self, other):
  23. return self.s == str(other)
  24. def __ne__(self, other):
  25. return not self.__eq__(other)
  26. def __hash__(self):
  27. return hash((type(self), str(self.s)))
  28. class T(Symbol):
  29. """Terminal."""
  30. def match(self, s):
  31. return self.s == s.type
  32. class NT(Symbol):
  33. """Non-terminal."""
  34. pass
  35. class Rule(object):
  36. """Context-free grammar rule."""
  37. def __init__(self, lhs, rhs, weight, alias):
  38. super(Rule, self).__init__()
  39. assert isinstance(lhs, NT), lhs
  40. assert all(isinstance(x, NT) or isinstance(x, T) for x in rhs), rhs
  41. self.lhs = lhs
  42. self.rhs = rhs
  43. self.weight = weight
  44. self.alias = alias
  45. def __str__(self):
  46. return '%s -> %s' % (str(self.lhs), ' '.join(str(x) for x in self.rhs))
  47. def __repr__(self):
  48. return str(self)
  49. def __hash__(self):
  50. return hash((self.lhs, tuple(self.rhs)))
  51. def __eq__(self, other):
  52. return self.lhs == other.lhs and self.rhs == other.rhs
  53. def __ne__(self, other):
  54. return not (self == other)
  55. class Grammar(object):
  56. """Context-free grammar."""
  57. def __init__(self, rules):
  58. self.rules = frozenset(rules)
  59. def __eq__(self, other):
  60. return self.rules == other.rules
  61. def __str__(self):
  62. return '\n' + '\n'.join(sorted(repr(x) for x in self.rules)) + '\n'
  63. def __repr__(self):
  64. return str(self)
  65. # Parse tree data structures
  66. class RuleNode(object):
  67. """A node in the parse tree, which also contains the full rhs rule."""
  68. def __init__(self, rule, children, weight=0):
  69. self.rule = rule
  70. self.children = children
  71. self.weight = weight
  72. def __repr__(self):
  73. return 'RuleNode(%s, [%s])' % (repr(self.rule.lhs), ', '.join(str(x) for x in self.children))
  74. class Parser(object):
  75. """Parser wrapper."""
  76. def __init__(self, rules, start):
  77. super(Parser, self).__init__()
  78. self.orig_rules = {rule.alias: rule for rule in rules}
  79. rules = [self._to_rule(rule) for rule in rules]
  80. self.grammar = to_cnf(Grammar(rules))
  81. self.start = NT(start)
  82. def _to_rule(self, lark_rule):
  83. """Converts a lark rule, (lhs, rhs, callback, options), to a Rule."""
  84. return Rule(
  85. NT(lark_rule.origin), [
  86. T(x) if is_terminal(x) else NT(x) for x in lark_rule.expansion
  87. ], weight=lark_rule.options.priority if lark_rule.options and lark_rule.options.priority else 0, alias=lark_rule.alias)
  88. def parse(self, tokenized): # pylint: disable=invalid-name
  89. """Parses input, which is a list of tokens."""
  90. table, trees = _parse(tokenized, self.grammar)
  91. # Check if the parse succeeded.
  92. if all(r.lhs != self.start for r in table[(0, len(tokenized) - 1)]):
  93. raise ParseError('Parsing failed.')
  94. parse = trees[(0, len(tokenized) - 1)][NT(self.start)]
  95. return self._to_tree(revert_cnf(parse))
  96. def _to_tree(self, rule_node):
  97. """Converts a RuleNode parse tree to a lark Tree."""
  98. orig_rule = self.orig_rules[rule_node.rule.alias]
  99. children = []
  100. for child in rule_node.children:
  101. if isinstance(child, RuleNode):
  102. children.append(self._to_tree(child))
  103. else:
  104. assert isinstance(child.s, Token)
  105. children.append(child.s)
  106. t = Tree(orig_rule.origin, children)
  107. t.rule=orig_rule
  108. return t
  109. def print_parse(node, indent=0):
  110. if isinstance(node, RuleNode):
  111. print(' ' * (indent * 2) + str(node.rule.lhs))
  112. for child in node.children:
  113. print_parse(child, indent + 1)
  114. else:
  115. print(' ' * (indent * 2) + str(node.s))
  116. def _parse(s, g):
  117. """Parses sentence 's' using CNF grammar 'g'."""
  118. # The CYK table. Indexed with a 2-tuple: (start pos, end pos)
  119. table = defaultdict(set)
  120. # Top-level structure is similar to the CYK table. Each cell is a dict from
  121. # rule name to the best (lightest) tree for that rule.
  122. trees = defaultdict(dict)
  123. # Populate base case with existing terminal production rules
  124. for i, w in enumerate(s):
  125. for terminal, rules in g.terminal_rules.items():
  126. if terminal.match(w):
  127. for rule in rules:
  128. table[(i, i)].add(rule)
  129. if (rule.lhs not in trees[(i, i)] or
  130. rule.weight < trees[(i, i)][rule.lhs].weight):
  131. trees[(i, i)][rule.lhs] = RuleNode(rule, [T(w)], weight=rule.weight)
  132. # Iterate over lengths of sub-sentences
  133. for l in xrange(2, len(s) + 1):
  134. # Iterate over sub-sentences with the given length
  135. for i in xrange(len(s) - l + 1):
  136. # Choose partition of the sub-sentence in [1, l)
  137. for p in xrange(i + 1, i + l):
  138. span1 = (i, p - 1)
  139. span2 = (p, i + l - 1)
  140. for r1, r2 in itertools.product(table[span1], table[span2]):
  141. for rule in g.nonterminal_rules.get((r1.lhs, r2.lhs), []):
  142. table[(i, i + l - 1)].add(rule)
  143. r1_tree = trees[span1][r1.lhs]
  144. r2_tree = trees[span2][r2.lhs]
  145. rule_total_weight = rule.weight + r1_tree.weight + r2_tree.weight
  146. if (rule.lhs not in trees[(i, i + l - 1)]
  147. or rule_total_weight < trees[(i, i + l - 1)][rule.lhs].weight):
  148. trees[(i, i + l - 1)][rule.lhs] = RuleNode(rule, [r1_tree, r2_tree], weight=rule_total_weight)
  149. return table, trees
  150. # This section implements context-free grammar converter to Chomsky normal form.
  151. # It also implements a conversion of parse trees from its CNF to the original
  152. # grammar.
  153. # Overview:
  154. # Applies the following operations in this order:
  155. # * TERM: Eliminates non-solitary terminals from all rules
  156. # * BIN: Eliminates rules with more than 2 symbols on their right-hand-side.
  157. # * UNIT: Eliminates non-terminal unit rules
  158. #
  159. # The following grammar characteristics aren't featured:
  160. # * Start symbol appears on RHS
  161. # * Empty rules (epsilon rules)
  162. class CnfWrapper(object):
  163. """CNF wrapper for grammar.
  164. Validates that the input grammar is CNF and provides helper data structures.
  165. """
  166. def __init__(self, grammar):
  167. super(CnfWrapper, self).__init__()
  168. self.grammar = grammar
  169. self.rules = grammar.rules
  170. self.terminal_rules = defaultdict(list)
  171. self.nonterminal_rules = defaultdict(list)
  172. for r in self.rules:
  173. # Validate that the grammar is CNF and populate auxiliary data structures.
  174. assert isinstance(r.lhs, NT), r
  175. assert len(r.rhs) in [1, 2], r
  176. if len(r.rhs) == 1 and isinstance(r.rhs[0], T):
  177. self.terminal_rules[r.rhs[0]].append(r)
  178. elif len(r.rhs) == 2 and all(isinstance(x, NT) for x in r.rhs):
  179. self.nonterminal_rules[tuple(r.rhs)].append(r)
  180. else:
  181. assert False, r
  182. def __eq__(self, other):
  183. return self.grammar == other.grammar
  184. def __repr__(self):
  185. return repr(self.grammar)
  186. class UnitSkipRule(Rule):
  187. """A rule that records NTs that were skipped during transformation."""
  188. def __init__(self, lhs, rhs, skipped_rules, weight, alias):
  189. super(UnitSkipRule, self).__init__(lhs, rhs, weight, alias)
  190. self.skipped_rules = skipped_rules
  191. def __eq__(self, other):
  192. return isinstance(other, type(self)) and self.skipped_rules == other.skipped_rules
  193. __hash__ = Rule.__hash__
  194. def build_unit_skiprule(unit_rule, target_rule):
  195. skipped_rules = []
  196. if isinstance(unit_rule, UnitSkipRule):
  197. skipped_rules += unit_rule.skipped_rules
  198. skipped_rules.append(target_rule)
  199. if isinstance(target_rule, UnitSkipRule):
  200. skipped_rules += target_rule.skipped_rules
  201. return UnitSkipRule(unit_rule.lhs, target_rule.rhs, skipped_rules,
  202. weight=unit_rule.weight + target_rule.weight, alias=unit_rule.alias)
  203. def get_any_nt_unit_rule(g):
  204. """Returns a non-terminal unit rule from 'g', or None if there is none."""
  205. for rule in g.rules:
  206. if len(rule.rhs) == 1 and isinstance(rule.rhs[0], NT):
  207. return rule
  208. return None
  209. def _remove_unit_rule(g, rule):
  210. """Removes 'rule' from 'g' without changing the langugage produced by 'g'."""
  211. new_rules = [x for x in g.rules if x != rule]
  212. refs = [x for x in g.rules if x.lhs == rule.rhs[0]]
  213. new_rules += [build_unit_skiprule(rule, ref) for ref in refs]
  214. return Grammar(new_rules)
  215. def _split(rule):
  216. """Splits a rule whose len(rhs) > 2 into shorter rules."""
  217. rule_str = str(rule.lhs) + '__' + '_'.join(str(x) for x in rule.rhs)
  218. rule_name = '__SP_%s' % (rule_str) + '_%d'
  219. yield Rule(rule.lhs, [rule.rhs[0], NT(rule_name % 1)], weight=rule.weight, alias=rule.alias)
  220. for i in xrange(1, len(rule.rhs) - 2):
  221. yield Rule(NT(rule_name % i), [rule.rhs[i], NT(rule_name % (i + 1))], weight=0, alias='Split')
  222. yield Rule(NT(rule_name % (len(rule.rhs) - 2)), rule.rhs[-2:], weight=0, alias='Split')
  223. def _term(g):
  224. """Applies the TERM rule on 'g' (see top comment)."""
  225. all_t = {x for rule in g.rules for x in rule.rhs if isinstance(x, T)}
  226. t_rules = {t: Rule(NT('__T_%s' % str(t)), [t], weight=0, alias='Term') for t in all_t}
  227. new_rules = []
  228. for rule in g.rules:
  229. if len(rule.rhs) > 1 and any(isinstance(x, T) for x in rule.rhs):
  230. new_rhs = [t_rules[x].lhs if isinstance(x, T) else x for x in rule.rhs]
  231. new_rules.append(Rule(rule.lhs, new_rhs, weight=rule.weight, alias=rule.alias))
  232. new_rules.extend(v for k, v in t_rules.items() if k in rule.rhs)
  233. else:
  234. new_rules.append(rule)
  235. return Grammar(new_rules)
  236. def _bin(g):
  237. """Applies the BIN rule to 'g' (see top comment)."""
  238. new_rules = []
  239. for rule in g.rules:
  240. if len(rule.rhs) > 2:
  241. new_rules += _split(rule)
  242. else:
  243. new_rules.append(rule)
  244. return Grammar(new_rules)
  245. def _unit(g):
  246. """Applies the UNIT rule to 'g' (see top comment)."""
  247. nt_unit_rule = get_any_nt_unit_rule(g)
  248. while nt_unit_rule:
  249. g = _remove_unit_rule(g, nt_unit_rule)
  250. nt_unit_rule = get_any_nt_unit_rule(g)
  251. return g
  252. def to_cnf(g):
  253. """Creates a CNF grammar from a general context-free grammar 'g'."""
  254. g = _unit(_bin(_term(g)))
  255. return CnfWrapper(g)
  256. def unroll_unit_skiprule(lhs, orig_rhs, skipped_rules, children, weight, alias):
  257. if not skipped_rules:
  258. return RuleNode(Rule(lhs, orig_rhs, weight=weight, alias=alias), children, weight=weight)
  259. else:
  260. weight = weight - skipped_rules[0].weight
  261. return RuleNode(
  262. Rule(lhs, [skipped_rules[0].lhs], weight=weight, alias=alias), [
  263. unroll_unit_skiprule(skipped_rules[0].lhs, orig_rhs,
  264. skipped_rules[1:], children,
  265. skipped_rules[0].weight, skipped_rules[0].alias)
  266. ], weight=weight)
  267. def revert_cnf(node):
  268. """Reverts a parse tree (RuleNode) to its original non-CNF form (Node)."""
  269. if isinstance(node, T):
  270. return node
  271. # Reverts TERM rule.
  272. if node.rule.lhs.s.startswith('__T_'):
  273. return node.children[0]
  274. else:
  275. children = []
  276. for child in map(revert_cnf, node.children):
  277. # Reverts BIN rule.
  278. if isinstance(child, RuleNode) and child.rule.lhs.s.startswith('__SP_'):
  279. children += child.children
  280. else:
  281. children.append(child)
  282. # Reverts UNIT rule.
  283. if isinstance(node.rule, UnitSkipRule):
  284. return unroll_unit_skiprule(node.rule.lhs, node.rule.rhs,
  285. node.rule.skipped_rules, children,
  286. node.rule.weight, node.rule.alias)
  287. else:
  288. return RuleNode(node.rule, children)