diff --git a/casimport/__init__.py b/casimport/__init__.py index 03a10b2..1c2f32c 100644 --- a/casimport/__init__.py +++ b/casimport/__init__.py @@ -310,6 +310,7 @@ class CASFinder(MetaPathFinder, Loader): # MetaPathFinder methods def find_spec(self, fullname, path, target=None): + aliases_to_add = () if path is None: ms = ModuleSpec(fullname, self, is_package=True) else: @@ -335,10 +336,11 @@ class CASFinder(MetaPathFinder, Loader): raise ValueError('unable to find base hash url for alias %s' % repr(arg)) # fix up the full name: + aliases_to_add = (fullname,) fullname = 'cas.v1_f_%s' % hashurl.path[1:] ms = ModuleSpec(fullname, self, is_package=False, - loader_state=(hashurl,)) + loader_state=(hashurl,) + aliases_to_add) return ms @@ -348,34 +350,38 @@ class CASFinder(MetaPathFinder, Loader): # Loader methods def exec_module(self, module): if module.__name__ == 'cas': - pass + return + + (url, *aliases) = module.__spec__.loader_state + for load in self._loaders: + try: + data = load.fetch_data(url) + break + except Exception: + pass + else: - (url,) = module.__spec__.loader_state - for load in self._loaders: - try: - data = load.fetch_data(url) - break - except Exception: - pass + for url in self._aliases[ + self._makebasichashurl(url)]: + url = urllib.parse.urlparse(url) + for load in self._loaders: + try: + data = load.fetch_data(url) + break + except Exception: + pass + else: + continue + break else: - for url in self._aliases[ - self._makebasichashurl(url)]: - url = urllib.parse.urlparse(url) - for load in self._loaders: - try: - data = load.fetch_data(url) - break - except Exception: - pass - else: - continue + raise ValueError('unable to find loader for url %s' % repr(urllib.parse.urlunparse(url))) - break - else: - raise ValueError('unable to find loader for url %s' % repr(urllib.parse.urlunparse(url))) + exec(data, module.__dict__) - exec(data, module.__dict__) + # we were successful, add the aliases + for i in aliases: + sys.modules[i] = module _supportedmodules = { 'https': HTTPSCAS.fromconfig, @@ -641,6 +647,18 @@ class Test(unittest.TestCase): # they are the same self.assertIs(hello_alias, hello_hash) + # that when we reimport the alias + from cas.v1_a_hello import hello as hello_alias + + # it is still the same + self.assertIs(hello_alias, hello_hash) + + # and when we reimport the hash + from cas.v1_f_330884aa2febb5e19fb7194ec6a69ed11dd3d77122f1a5175ee93e73cf0161c3 import hello as hello_hash + + # it is still the same + self.assertIs(hello_alias, hello_hash) + def test_aliasimports(self): # setup the cache cachedir = self.tempdir / 'cache'