diff --git a/libarchive/__init__.py b/libarchive/__init__.py index 7fd44b6..985bac4 100644 --- a/libarchive/__init__.py +++ b/libarchive/__init__.py @@ -117,7 +117,10 @@ def get_func(name, items, index): def guess_format(filename): - filename, ext = os.path.splitext(filename) + if isinstance(filename, int): + filename = ext = '' + else: + filename, ext = os.path.splitext(filename) filter = FILTER_EXTENSIONS.get(ext) if filter: filename, ext = os.path.splitext(filename) @@ -312,10 +315,11 @@ class Entry(object): call_and_check(_libarchive.archive_read_next_header2, archive._a, archive._a, e) mode = _libarchive.archive_entry_filetype(e) mode |= _libarchive.archive_entry_perm(e) + if PY3: - pathname=_libarchive.archive_entry_pathname(e) + pathname = _libarchive.archive_entry_pathname(e) else: - pathname=_libarchive.archive_entry_pathname(e).decode(encoding), + pathname = _libarchive.archive_entry_pathname(e).decode(encoding) entry = cls( pathname=pathname, @@ -628,11 +632,7 @@ class SeekableArchive(Archive): def getentry(self, pathname): '''Take a name or entry object and returns an entry object.''' for entry in self: - if PY3: - entry_pathname = entry.pathname - if not PY3: - entry_pathname = entry.pathname[0] - if entry_pathname == pathname: + if entry.pathname == pathname: return entry raise KeyError(pathname) diff --git a/tests.py b/tests.py index fa8cba8..fcba1d7 100644 --- a/tests.py +++ b/tests.py @@ -26,15 +26,16 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import os, unittest, tempfile, random, string, subprocess, sys +import os, unittest, tempfile, random, string, sys +import zipfile +import io -from libarchive import is_archive_name, is_archive +from libarchive import Archive, is_archive_name, is_archive from libarchive.zip import is_zipfile, ZipFile, ZipEntry PY3 = sys.version_info[0] == 3 -TMPDIR = tempfile.mkdtemp() -ZIPCMD = '/usr/bin/zip' +TMPDIR = tempfile.mkdtemp(suffix='.python-libarchive') ZIPFILE = 'test.zip' ZIPPATH = os.path.join(TMPDIR, ZIPFILE) @@ -54,13 +55,10 @@ def make_temp_files(): def make_temp_archive(): - if not os.access(ZIPCMD, os.X_OK): - raise AssertionError('Cannot execute %s.' % ZIPCMD) - cmd = [ZIPCMD, ZIPFILE] make_temp_files() - cmd.extend(FILENAMES) - os.chdir(TMPDIR) - subprocess.call(cmd, stdout=subprocess.PIPE) + with zipfile.ZipFile(ZIPPATH, mode="w") as z: + for name in FILENAMES: + z.write(os.path.join(TMPDIR, name), arcname=name) class TestIsArchiveName(unittest.TestCase): @@ -152,10 +150,7 @@ class TestZipRead(unittest.TestCase): z = ZipFile(self.f, 'r') names = [] for e in z: - if PY3: - names.append(e.filename) - else: - names.append(e.filename[0]) + names.append(e.filename) self.assertEqual(names, FILENAMES, 'File names differ in archive.') #~ def test_non_ascii(self): @@ -250,5 +245,35 @@ class TestZipWrite(unittest.TestCase): self.assertIsNone(z._stream) z.close() + +class TestHighLevelAPI(unittest.TestCase): + def setUp(self): + make_temp_archive() + + def _test_listing_content(self, f): + """ Test helper capturing file paths while iterating the archive. """ + found = [] + with Archive(f) as a: + for entry in a: + found.append(entry.pathname) + + self.assertEqual(set(found), set(FILENAMES)) + + def test_open_by_name(self): + """ Test an archive opened directly by name. """ + self._test_listing_content(ZIPPATH) + + def test_open_by_named_fobj(self): + """ Test an archive using a file-like object opened by name. """ + with open(ZIPPATH, 'rb') as f: + self._test_listing_content(f) + + def test_open_by_unnamed_fobj(self): + """ Test an archive using file-like object opened by fileno(). """ + with open(ZIPPATH, 'rb') as zf: + with io.FileIO(zf.fileno(), mode='r', closefd=False) as f: + self._test_listing_content(f) + + if __name__ == '__main__': unittest.main()