This repo contains code to mirror other repos. It also contains the code that is getting mirrored.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

305 lines
10 KiB

  1. """This module builds a LALR(1) transition-table for lalr_parser.py
  2. For now, shift/reduce conflicts are automatically resolved as shifts.
  3. """
  4. # Author: Erez Shinan (2017)
  5. # Email : erezshin@gmail.com
  6. from collections import defaultdict
  7. from ..utils import classify, classify_bool, bfs, fzset, Enumerator, logger
  8. from ..exceptions import GrammarError
  9. from .grammar_analysis import GrammarAnalyzer, Terminal, LR0ItemSet
  10. from ..grammar import Rule
  11. ###{standalone
  12. class Action:
  13. def __init__(self, name):
  14. self.name = name
  15. def __str__(self):
  16. return self.name
  17. def __repr__(self):
  18. return str(self)
  19. Shift = Action('Shift')
  20. Reduce = Action('Reduce')
  21. class ParseTable:
  22. def __init__(self, states, start_states, end_states):
  23. self.states = states
  24. self.start_states = start_states
  25. self.end_states = end_states
  26. def serialize(self, memo):
  27. tokens = Enumerator()
  28. rules = Enumerator()
  29. states = {
  30. state: {tokens.get(token): ((1, arg.serialize(memo)) if action is Reduce else (0, arg))
  31. for token, (action, arg) in actions.items()}
  32. for state, actions in self.states.items()
  33. }
  34. return {
  35. 'tokens': tokens.reversed(),
  36. 'states': states,
  37. 'start_states': self.start_states,
  38. 'end_states': self.end_states,
  39. }
  40. @classmethod
  41. def deserialize(cls, data, memo):
  42. tokens = data['tokens']
  43. states = {
  44. state: {tokens[token]: ((Reduce, Rule.deserialize(arg, memo)) if action==1 else (Shift, arg))
  45. for token, (action, arg) in actions.items()}
  46. for state, actions in data['states'].items()
  47. }
  48. return cls(states, data['start_states'], data['end_states'])
  49. class IntParseTable(ParseTable):
  50. @classmethod
  51. def from_ParseTable(cls, parse_table):
  52. enum = list(parse_table.states)
  53. state_to_idx = {s:i for i,s in enumerate(enum)}
  54. int_states = {}
  55. for s, la in parse_table.states.items():
  56. la = {k:(v[0], state_to_idx[v[1]]) if v[0] is Shift else v
  57. for k,v in la.items()}
  58. int_states[ state_to_idx[s] ] = la
  59. start_states = {start:state_to_idx[s] for start, s in parse_table.start_states.items()}
  60. end_states = {start:state_to_idx[s] for start, s in parse_table.end_states.items()}
  61. return cls(int_states, start_states, end_states)
  62. ###}
  63. # digraph and traverse, see The Theory and Practice of Compiler Writing
  64. # computes F(x) = G(x) union (union { G(y) | x R y })
  65. # X: nodes
  66. # R: relation (function mapping node -> list of nodes that satisfy the relation)
  67. # G: set valued function
  68. def digraph(X, R, G):
  69. F = {}
  70. S = []
  71. N = {}
  72. for x in X:
  73. N[x] = 0
  74. for x in X:
  75. # this is always true for the first iteration, but N[x] may be updated in traverse below
  76. if N[x] == 0:
  77. traverse(x, S, N, X, R, G, F)
  78. return F
  79. # x: single node
  80. # S: stack
  81. # N: weights
  82. # X: nodes
  83. # R: relation (see above)
  84. # G: set valued function
  85. # F: set valued function we are computing (map of input -> output)
  86. def traverse(x, S, N, X, R, G, F):
  87. S.append(x)
  88. d = len(S)
  89. N[x] = d
  90. F[x] = G[x]
  91. for y in R[x]:
  92. if N[y] == 0:
  93. traverse(y, S, N, X, R, G, F)
  94. n_x = N[x]
  95. assert(n_x > 0)
  96. n_y = N[y]
  97. assert(n_y != 0)
  98. if (n_y > 0) and (n_y < n_x):
  99. N[x] = n_y
  100. F[x].update(F[y])
  101. if N[x] == d:
  102. f_x = F[x]
  103. while True:
  104. z = S.pop()
  105. N[z] = -1
  106. F[z] = f_x
  107. if z == x:
  108. break
  109. class LALR_Analyzer(GrammarAnalyzer):
  110. def __init__(self, parser_conf, debug=False):
  111. GrammarAnalyzer.__init__(self, parser_conf, debug)
  112. self.nonterminal_transitions = []
  113. self.directly_reads = defaultdict(set)
  114. self.reads = defaultdict(set)
  115. self.includes = defaultdict(set)
  116. self.lookback = defaultdict(set)
  117. def compute_lr0_states(self):
  118. self.lr0_states = set()
  119. # map of kernels to LR0ItemSets
  120. cache = {}
  121. def step(state):
  122. _, unsat = classify_bool(state.closure, lambda rp: rp.is_satisfied)
  123. d = classify(unsat, lambda rp: rp.next)
  124. for sym, rps in d.items():
  125. kernel = fzset({rp.advance(sym) for rp in rps})
  126. new_state = cache.get(kernel, None)
  127. if new_state is None:
  128. closure = set(kernel)
  129. for rp in kernel:
  130. if not rp.is_satisfied and not rp.next.is_term:
  131. closure |= self.expand_rule(rp.next, self.lr0_rules_by_origin)
  132. new_state = LR0ItemSet(kernel, closure)
  133. cache[kernel] = new_state
  134. state.transitions[sym] = new_state
  135. yield new_state
  136. self.lr0_states.add(state)
  137. for _ in bfs(self.lr0_start_states.values(), step):
  138. pass
  139. def compute_reads_relations(self):
  140. # handle start state
  141. for root in self.lr0_start_states.values():
  142. assert(len(root.kernel) == 1)
  143. for rp in root.kernel:
  144. assert(rp.index == 0)
  145. self.directly_reads[(root, rp.next)] = set([ Terminal('$END') ])
  146. for state in self.lr0_states:
  147. seen = set()
  148. for rp in state.closure:
  149. if rp.is_satisfied:
  150. continue
  151. s = rp.next
  152. # if s is a not a nonterminal
  153. if s not in self.lr0_rules_by_origin:
  154. continue
  155. if s in seen:
  156. continue
  157. seen.add(s)
  158. nt = (state, s)
  159. self.nonterminal_transitions.append(nt)
  160. dr = self.directly_reads[nt]
  161. r = self.reads[nt]
  162. next_state = state.transitions[s]
  163. for rp2 in next_state.closure:
  164. if rp2.is_satisfied:
  165. continue
  166. s2 = rp2.next
  167. # if s2 is a terminal
  168. if s2 not in self.lr0_rules_by_origin:
  169. dr.add(s2)
  170. if s2 in self.NULLABLE:
  171. r.add((next_state, s2))
  172. def compute_includes_lookback(self):
  173. for nt in self.nonterminal_transitions:
  174. state, nonterminal = nt
  175. includes = []
  176. lookback = self.lookback[nt]
  177. for rp in state.closure:
  178. if rp.rule.origin != nonterminal:
  179. continue
  180. # traverse the states for rp(.rule)
  181. state2 = state
  182. for i in range(rp.index, len(rp.rule.expansion)):
  183. s = rp.rule.expansion[i]
  184. nt2 = (state2, s)
  185. state2 = state2.transitions[s]
  186. if nt2 not in self.reads:
  187. continue
  188. for j in range(i + 1, len(rp.rule.expansion)):
  189. if not rp.rule.expansion[j] in self.NULLABLE:
  190. break
  191. else:
  192. includes.append(nt2)
  193. # state2 is at the final state for rp.rule
  194. if rp.index == 0:
  195. for rp2 in state2.closure:
  196. if (rp2.rule == rp.rule) and rp2.is_satisfied:
  197. lookback.add((state2, rp2.rule))
  198. for nt2 in includes:
  199. self.includes[nt2].add(nt)
  200. def compute_lookaheads(self):
  201. read_sets = digraph(self.nonterminal_transitions, self.reads, self.directly_reads)
  202. follow_sets = digraph(self.nonterminal_transitions, self.includes, read_sets)
  203. for nt, lookbacks in self.lookback.items():
  204. for state, rule in lookbacks:
  205. for s in follow_sets[nt]:
  206. state.lookaheads[s].add(rule)
  207. def compute_lalr1_states(self):
  208. m = {}
  209. reduce_reduce = []
  210. for state in self.lr0_states:
  211. actions = {}
  212. for la, next_state in state.transitions.items():
  213. actions[la] = (Shift, next_state.closure)
  214. for la, rules in state.lookaheads.items():
  215. if len(rules) > 1:
  216. # Try to resolve conflict based on priority
  217. p = [(r.options.priority or 0, r) for r in rules]
  218. p.sort(key=lambda r: r[0], reverse=True)
  219. best, second_best = p[:2]
  220. if best[0] > second_best[0]:
  221. rules = [best[1]]
  222. else:
  223. reduce_reduce.append((state, la, rules))
  224. if la in actions:
  225. if self.debug:
  226. logger.warning('Shift/Reduce conflict for terminal %s: (resolving as shift)', la.name)
  227. logger.warning(' * %s', list(rules)[0])
  228. else:
  229. actions[la] = (Reduce, list(rules)[0])
  230. m[state] = { k.name: v for k, v in actions.items() }
  231. if reduce_reduce:
  232. msgs = []
  233. for state, la, rules in reduce_reduce:
  234. msg = 'Reduce/Reduce collision in %s between the following rules: %s' % (la, ''.join([ '\n\t- ' + str(r) for r in rules ]))
  235. if self.debug:
  236. msg += '\n collision occurred in state: {%s\n }' % ''.join(['\n\t' + str(x) for x in state.closure])
  237. msgs.append(msg)
  238. raise GrammarError('\n\n'.join(msgs))
  239. states = { k.closure: v for k, v in m.items() }
  240. # compute end states
  241. end_states = {}
  242. for state in states:
  243. for rp in state:
  244. for start in self.lr0_start_states:
  245. if rp.rule.origin.name == ('$root_' + start) and rp.is_satisfied:
  246. assert(not start in end_states)
  247. end_states[start] = state
  248. _parse_table = ParseTable(states, { start: state.closure for start, state in self.lr0_start_states.items() }, end_states)
  249. if self.debug:
  250. self.parse_table = _parse_table
  251. else:
  252. self.parse_table = IntParseTable.from_ParseTable(_parse_table)
  253. def compute_lalr(self):
  254. self.compute_lr0_states()
  255. self.compute_reads_relations()
  256. self.compute_includes_lookback()
  257. self.compute_lookaheads()
  258. self.compute_lalr1_states()