@@ -2,7 +2,7 @@ | |||||
from typing import ( | from typing import ( | ||||
TypeVar, Type, List, Dict, IO, Iterator, Callable, Union, Optional, | TypeVar, Type, List, Dict, IO, Iterator, Callable, Union, Optional, | ||||
Literal, Protocol, | |||||
Literal, Protocol, Tuple, | |||||
) | ) | ||||
from .visitors import Transformer | from .visitors import Transformer | ||||
from .lexer import Token, Lexer, TerminalDef | from .lexer import Token, Lexer, TerminalDef | ||||
@@ -32,6 +32,7 @@ class LarkOptions: | |||||
cache: Union[bool, str] | cache: Union[bool, str] | ||||
g_regex_flags: int | g_regex_flags: int | ||||
use_bytes: bool | use_bytes: bool | ||||
import_sources: List[Union[str, Callable[[str, str], str]]] | |||||
class Lark: | class Lark: | ||||
@@ -60,6 +61,7 @@ class Lark: | |||||
cache: Union[bool, str] = False, | cache: Union[bool, str] = False, | ||||
g_regex_flags: int = ..., | g_regex_flags: int = ..., | ||||
use_bytes: bool = False, | use_bytes: bool = False, | ||||
import_sources: List[Union[str, Callable[[List[str], str], Tuple[str, str]]]] = ..., | |||||
): | ): | ||||
... | ... | ||||
@@ -90,6 +90,8 @@ class LarkOptions(Serialize): | |||||
Accept an input of type ``bytes`` instead of ``str`` (Python 3 only). | Accept an input of type ``bytes`` instead of ``str`` (Python 3 only). | ||||
edit_terminals | edit_terminals | ||||
A callback for editing the terminals before parse. | A callback for editing the terminals before parse. | ||||
import_sources | |||||
A List of either paths or loader functions to specify from where grammars are imported | |||||
**=== End Options ===** | **=== End Options ===** | ||||
""" | """ | ||||
@@ -115,6 +117,7 @@ class LarkOptions(Serialize): | |||||
'edit_terminals': None, | 'edit_terminals': None, | ||||
'g_regex_flags': 0, | 'g_regex_flags': 0, | ||||
'use_bytes': False, | 'use_bytes': False, | ||||
'import_sources': [], | |||||
} | } | ||||
def __init__(self, options_dict): | def __init__(self, options_dict): | ||||
@@ -267,7 +270,7 @@ class Lark(Serialize): | |||||
assert self.options.ambiguity in ('resolve', 'explicit', 'forest', 'auto', ) | assert self.options.ambiguity in ('resolve', 'explicit', 'forest', 'auto', ) | ||||
# Parse the grammar file and compose the grammars (TODO) | # Parse the grammar file and compose the grammars (TODO) | ||||
self.grammar = load_grammar(grammar, self.source, re_module) | |||||
self.grammar = load_grammar(grammar, self.source, re_module, self.options.import_sources) | |||||
# Compile the EBNF grammar into BNF | # Compile the EBNF grammar into BNF | ||||
self.terminals, self.rules, self.ignore_tokens = self.grammar.compile(self.options.start) | self.terminals, self.rules, self.ignore_tokens = self.grammar.compile(self.options.start) | ||||
@@ -20,7 +20,7 @@ from .visitors import Transformer, Visitor, v_args, Transformer_InPlace, Transfo | |||||
inline_args = v_args(inline=True) | inline_args = v_args(inline=True) | ||||
__path__ = os.path.dirname(__file__) | __path__ = os.path.dirname(__file__) | ||||
IMPORT_PATHS = [os.path.join(__path__, 'grammars')] | |||||
IMPORT_PATHS = ['grammars'] | |||||
EXT = '.lark' | EXT = '.lark' | ||||
@@ -648,19 +648,35 @@ class Grammar: | |||||
return terminals, compiled_rules, self.ignore | return terminals, compiled_rules, self.ignore | ||||
def stdlib_loader(base_paths, grammar_path): | |||||
import pkgutil | |||||
for path in IMPORT_PATHS: | |||||
text = pkgutil.get_data('lark', path + '/' + grammar_path) | |||||
if text is None: | |||||
continue | |||||
return '<stdlib:' + grammar_path + '>', text.decode() | |||||
raise FileNotFoundError() | |||||
_imported_grammars = {} | _imported_grammars = {} | ||||
def import_grammar(grammar_path, re_, base_paths=[]): | |||||
def import_grammar(grammar_path, re_, base_paths=(), import_sources=()): | |||||
if grammar_path not in _imported_grammars: | if grammar_path not in _imported_grammars: | ||||
import_paths = base_paths + IMPORT_PATHS | |||||
for import_path in import_paths: | |||||
with suppress(IOError): | |||||
joined_path = os.path.join(import_path, grammar_path) | |||||
with open(joined_path, encoding='utf8') as f: | |||||
text = f.read() | |||||
grammar = load_grammar(text, joined_path, re_) | |||||
_imported_grammars[grammar_path] = grammar | |||||
break | |||||
import_paths = import_sources + base_paths + [stdlib_loader] | |||||
for source in import_paths: | |||||
if isinstance(source, str): | |||||
with suppress(IOError): | |||||
joined_path = os.path.join(source, grammar_path) | |||||
with open(joined_path, encoding='utf8') as f: | |||||
text = f.read() | |||||
grammar = load_grammar(text, joined_path, re_, import_sources) | |||||
_imported_grammars[grammar_path] = grammar | |||||
break | |||||
else: | |||||
with suppress(IOError): | |||||
joined_path, text = source(base_paths, grammar_path) | |||||
grammar = load_grammar(text, joined_path, re_, import_sources) | |||||
_imported_grammars[grammar_path] = grammar | |||||
break | |||||
else: | else: | ||||
open(grammar_path, encoding='utf8') | open(grammar_path, encoding='utf8') | ||||
assert False | assert False | ||||
@@ -817,7 +833,7 @@ class GrammarLoader: | |||||
self.canonize_tree = CanonizeTree() | self.canonize_tree = CanonizeTree() | ||||
self.re_module = re_module | self.re_module = re_module | ||||
def load_grammar(self, grammar_text, grammar_name='<?>'): | |||||
def load_grammar(self, grammar_text, grammar_name='<?>', import_sources=[]): | |||||
"Parse grammar_text, verify, and create Grammar object. Display nice messages on error." | "Parse grammar_text, verify, and create Grammar object. Display nice messages on error." | ||||
try: | try: | ||||
@@ -901,7 +917,7 @@ class GrammarLoader: | |||||
# import grammars | # import grammars | ||||
for dotted_path, (base_paths, aliases) in imports.items(): | for dotted_path, (base_paths, aliases) in imports.items(): | ||||
grammar_path = os.path.join(*dotted_path) + EXT | grammar_path = os.path.join(*dotted_path) + EXT | ||||
g = import_grammar(grammar_path, self.re_module, base_paths=base_paths) | |||||
g = import_grammar(grammar_path, self.re_module, base_paths=base_paths, import_sources=import_sources) | |||||
new_td, new_rd = import_from_grammar_into_namespace(g, '__'.join(dotted_path), aliases) | new_td, new_rd = import_from_grammar_into_namespace(g, '__'.join(dotted_path), aliases) | ||||
term_defs += new_td | term_defs += new_td | ||||
@@ -981,5 +997,5 @@ class GrammarLoader: | |||||
def load_grammar(grammar, source, re_): | |||||
return GrammarLoader(re_).load_grammar(grammar, source) | |||||
def load_grammar(grammar, source, re_, import_sources): | |||||
return GrammarLoader(re_).load_grammar(grammar, source, import_sources) |
@@ -1782,6 +1782,24 @@ def _make_parser_test(LEXER, PARSER): | |||||
""" | """ | ||||
self.assertRaises(IOError, _Lark, grammar) | self.assertRaises(IOError, _Lark, grammar) | ||||
def test_import_custom_sources(self): | |||||
def custom_loader(base_paths, grammar_path): | |||||
import pkgutil | |||||
text = pkgutil.get_data('tests', 'grammars/' + grammar_path) | |||||
if text is None: | |||||
raise FileNotFoundError() | |||||
return '<tests.grammars:' + grammar_path + '>', text.decode() | |||||
grammar = """ | |||||
start: startab | |||||
%import ab.startab | |||||
""" | |||||
p = _Lark(grammar, import_sources=[custom_loader]) | |||||
self.assertEqual(p.parse('ab'), | |||||
Tree('start', [Tree('startab', [Tree('ab__expr', [Token('ab__A', 'a'), Token('ab__B', 'b')])])])) | |||||
@unittest.skipIf(PARSER != 'earley', "Currently only Earley supports priority in rules") | @unittest.skipIf(PARSER != 'earley', "Currently only Earley supports priority in rules") | ||||
def test_earley_prioritization(self): | def test_earley_prioritization(self): | ||||
"Tests effect of priority on result" | "Tests effect of priority on result" | ||||