This repo contains code to mirror other repos. It also contains the code that is getting mirrored.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

187 lines
5.8 KiB

  1. """Tree matcher based on Lark grammar"""
  2. import re
  3. from collections import defaultdict
  4. from . import Tree, Token
  5. from .common import ParserConf
  6. from .parsers import earley
  7. from .grammar import Rule, Terminal, NonTerminal
  8. def is_discarded_terminal(t):
  9. return t.is_term and t.filter_out
  10. class _MakeTreeMatch:
  11. def __init__(self, name, expansion):
  12. self.name = name
  13. self.expansion = expansion
  14. def __call__(self, args):
  15. t = Tree(self.name, args)
  16. t.meta.match_tree = True
  17. t.meta.orig_expansion = self.expansion
  18. return t
  19. def _best_from_group(seq, group_key, cmp_key):
  20. d = {}
  21. for item in seq:
  22. key = group_key(item)
  23. if key in d:
  24. v1 = cmp_key(item)
  25. v2 = cmp_key(d[key])
  26. if v2 > v1:
  27. d[key] = item
  28. else:
  29. d[key] = item
  30. return list(d.values())
  31. def _best_rules_from_group(rules):
  32. rules = _best_from_group(rules, lambda r: r, lambda r: -len(r.expansion))
  33. rules.sort(key=lambda r: len(r.expansion))
  34. return rules
  35. def _match(term, token):
  36. if isinstance(token, Tree):
  37. name, _args = parse_rulename(term.name)
  38. return token.data == name
  39. elif isinstance(token, Token):
  40. return term == Terminal(token.type)
  41. assert False, (term, token)
  42. def make_recons_rule(origin, expansion, old_expansion):
  43. return Rule(origin, expansion, alias=_MakeTreeMatch(origin.name, old_expansion))
  44. def make_recons_rule_to_term(origin, term):
  45. return make_recons_rule(origin, [Terminal(term.name)], [term])
  46. def parse_rulename(s):
  47. "Parse rule names that may contain a template syntax (like rule{a, b, ...})"
  48. name, args_str = re.match(r'(\w+)(?:{(.+)})?', s).groups()
  49. args = args_str and [a.strip() for a in args_str.split(',')]
  50. return name, args
  51. class ChildrenLexer:
  52. def __init__(self, children):
  53. self.children = children
  54. def lex(self, parser_state):
  55. return self.children
  56. class TreeMatcher:
  57. """Match the elements of a tree node, based on an ontology
  58. provided by a Lark grammar.
  59. Supports templates and inlined rules (`rule{a, b,..}` and `_rule`)
  60. Initiialize with an instance of Lark.
  61. """
  62. def __init__(self, parser):
  63. # XXX TODO calling compile twice returns different results!
  64. assert parser.options.maybe_placeholders == False
  65. # XXX TODO: we just ignore the potential existence of a postlexer
  66. self.tokens, rules, _extra = parser.grammar.compile(parser.options.start, set())
  67. self.rules_for_root = defaultdict(list)
  68. self.rules = list(self._build_recons_rules(rules))
  69. self.rules.reverse()
  70. # Choose the best rule from each group of {rule => [rule.alias]}, since we only really need one derivation.
  71. self.rules = _best_rules_from_group(self.rules)
  72. self.parser = parser
  73. self._parser_cache = {}
  74. def _build_recons_rules(self, rules):
  75. "Convert tree-parsing/construction rules to tree-matching rules"
  76. expand1s = {r.origin for r in rules if r.options.expand1}
  77. aliases = defaultdict(list)
  78. for r in rules:
  79. if r.alias:
  80. aliases[r.origin].append(r.alias)
  81. rule_names = {r.origin for r in rules}
  82. nonterminals = {sym for sym in rule_names
  83. if sym.name.startswith('_') or sym in expand1s or sym in aliases}
  84. seen = set()
  85. for r in rules:
  86. recons_exp = [sym if sym in nonterminals else Terminal(sym.name)
  87. for sym in r.expansion if not is_discarded_terminal(sym)]
  88. # Skip self-recursive constructs
  89. if recons_exp == [r.origin] and r.alias is None:
  90. continue
  91. sym = NonTerminal(r.alias) if r.alias else r.origin
  92. rule = make_recons_rule(sym, recons_exp, r.expansion)
  93. if sym in expand1s and len(recons_exp) != 1:
  94. self.rules_for_root[sym.name].append(rule)
  95. if sym.name not in seen:
  96. yield make_recons_rule_to_term(sym, sym)
  97. seen.add(sym.name)
  98. else:
  99. if sym.name.startswith('_') or sym in expand1s:
  100. yield rule
  101. else:
  102. self.rules_for_root[sym.name].append(rule)
  103. for origin, rule_aliases in aliases.items():
  104. for alias in rule_aliases:
  105. yield make_recons_rule_to_term(origin, NonTerminal(alias))
  106. yield make_recons_rule_to_term(origin, origin)
  107. def match_tree(self, tree, rulename):
  108. """Match the elements of `tree` to the symbols of rule `rulename`.
  109. Parameters:
  110. tree (Tree): the tree node to match
  111. rulename (str): The expected full rule name (including template args)
  112. Returns:
  113. Tree: an unreduced tree that matches `rulename`
  114. Raises:
  115. UnexpectedToken: If no match was found.
  116. Note:
  117. It's the callers' responsibility match the tree recursively.
  118. """
  119. if rulename:
  120. # validate
  121. name, _args = parse_rulename(rulename)
  122. assert tree.data == name
  123. else:
  124. rulename = tree.data
  125. # TODO: ambiguity?
  126. try:
  127. parser = self._parser_cache[rulename]
  128. except KeyError:
  129. rules = self.rules + _best_rules_from_group(self.rules_for_root[rulename])
  130. # TODO pass callbacks through dict, instead of alias?
  131. callbacks = {rule: rule.alias for rule in rules}
  132. conf = ParserConf(rules, callbacks, [rulename])
  133. parser = earley.Parser(conf, _match, resolve_ambiguity=True)
  134. self._parser_cache[rulename] = parser
  135. # find a full derivation
  136. unreduced_tree = parser.parse(ChildrenLexer(tree.children), rulename)
  137. assert unreduced_tree.data == rulename
  138. return unreduced_tree