Browse Source

Fixes to handle password protected archives

test_fixup
Vadim Lebedev 2 years ago
parent
commit
0c5afe44e7
4 changed files with 28 additions and 15 deletions
  1. +12
    -2
      libarchive/__init__.py
  2. +3
    -2
      libarchive/_libarchive.i
  3. +3
    -2
      libarchive/_libarchive_wrap.c
  4. +10
    -9
      libarchive/zip.py

+ 12
- 2
libarchive/__init__.py View File

@@ -426,11 +426,13 @@ class Entry(object):
class Archive(object): class Archive(object):
'''A low-level archive reader which provides forward-only iteration. Consider '''A low-level archive reader which provides forward-only iteration. Consider
this a light-weight pythonic libarchive wrapper.''' this a light-weight pythonic libarchive wrapper.'''
def __init__(self, f, mode='r', format=None, filter=None, entry_class=Entry, encoding=ENCODING, blocksize=BLOCK_SIZE):
def __init__(self, f, mode='r', format=None, filter=None, entry_class=Entry,
encoding=ENCODING, blocksize=BLOCK_SIZE, password=None):
assert mode in ('r', 'w', 'wb', 'a'), 'Mode should be "r", "w", "wb", or "a".' assert mode in ('r', 'w', 'wb', 'a'), 'Mode should be "r", "w", "wb", or "a".'
self._stream = None self._stream = None
self.encoding = encoding self.encoding = encoding
self.blocksize = blocksize self.blocksize = blocksize
self.password = password
if isinstance(f, str): if isinstance(f, str):
self.filename = f self.filename = f
f = open(f, mode) f = open(f, mode)
@@ -499,8 +501,12 @@ class Archive(object):
self.format_func(self._a) self.format_func(self._a)
self.filter_func(self._a) self.filter_func(self._a)
if self.mode == 'r': if self.mode == 'r':
if self.password:
self.add_passphrase(self.password)
call_and_check(_libarchive.archive_read_open_fd, self._a, self._a, self.f.fileno(), self.blocksize) call_and_check(_libarchive.archive_read_open_fd, self._a, self._a, self.f.fileno(), self.blocksize)
else: else:
if self.password:
self.set_passphrase(self.password)
call_and_check(_libarchive.archive_write_open_fd, self._a, self._a, self.f.fileno()) call_and_check(_libarchive.archive_write_open_fd, self._a, self._a, self.f.fileno())


def denit(self): def denit(self):
@@ -562,7 +568,7 @@ class Archive(object):
'''Write current archive entry contents to file. f can be a file-like object or '''Write current archive entry contents to file. f can be a file-like object or
a path.''' a path.'''
if isinstance(f, str): if isinstance(f, str):
basedir = os.path.basename(f)
basedir = os.path.dirname(f)
if not os.path.exists(basedir): if not os.path.exists(basedir):
os.makedirs(basedir) os.makedirs(basedir)
f = open(f, 'w') f = open(f, 'w')
@@ -626,6 +632,10 @@ class Archive(object):
def add_passphrase(self, password): def add_passphrase(self, password):
'''Adds a password to the archive.''' '''Adds a password to the archive.'''
_libarchive.archive_read_add_passphrase(self._a, password) _libarchive.archive_read_add_passphrase(self._a, password)
def set_passphrase(self, password):
'''Sets a password for the archive.'''
_libarchive.archive_write_set_passphrase(self._a, password)




class SeekableArchive(Archive): class SeekableArchive(Archive):


+ 3
- 2
libarchive/_libarchive.i View File

@@ -535,8 +535,9 @@ PyObject *archive_read_data_into_str(struct archive *archive, int len) {
} }


PyObject *archive_write_data_from_str(struct archive *archive, PyObject *str) { PyObject *archive_write_data_from_str(struct archive *archive, PyObject *str) {
int len = PyString_Size(str);
if (!archive_write_data(archive, PyString_AS_STRING(str), len)) {
Py_ssize_t len = PyBytes_Size(str);
if (!archive_write_data(archive, PyBytes_AS_STRING(str), len)) {
PyErr_SetString(PyExc_RuntimeError, "could not write requested data."); PyErr_SetString(PyExc_RuntimeError, "could not write requested data.");
return NULL; return NULL;
} }


+ 3
- 2
libarchive/_libarchive_wrap.c View File

@@ -3207,8 +3207,9 @@ PyObject *archive_read_data_into_str(struct archive *archive, int len) {
} }


PyObject *archive_write_data_from_str(struct archive *archive, PyObject *str) { PyObject *archive_write_data_from_str(struct archive *archive, PyObject *str) {
int len = PyString_Size(str);
if (!archive_write_data(archive, PyString_AS_STRING(str), len)) {
Py_ssize_t len = PyBytes_Size(str);
if (!archive_write_data(archive, PyBytes_AS_STRING(str), len)) {
PyErr_SetString(PyExc_RuntimeError, "could not write requested data."); PyErr_SetString(PyExc_RuntimeError, "could not write requested data.");
return NULL; return NULL;
} }


+ 10
- 9
libarchive/zip.py View File

@@ -62,8 +62,8 @@ class ZipEntry(Entry):




class ZipFile(SeekableArchive): class ZipFile(SeekableArchive):
def __init__(self, f, mode='r', compression=ZIP_DEFLATED, allowZip64=False):
super(ZipFile, self).__init__(f, mode=mode, format='zip', entry_class=ZipEntry, encoding='CP437')
def __init__(self, f, mode='r', compression=ZIP_DEFLATED, allowZip64=False, password=None):
super(ZipFile, self).__init__(f, mode=mode, format='zip', entry_class=ZipEntry, encoding='CP437', password=password)
if mode == 'w' and compression == ZIP_STORED: if mode == 'w' and compression == ZIP_STORED:
# Disable compression for writing. # Disable compression for writing.
_libarchive.archive_write_set_format_option(self.archive._a, "zip", "compression", "store") _libarchive.archive_write_set_format_option(self.archive._a, "zip", "compression", "store")
@@ -72,38 +72,39 @@ class ZipFile(SeekableArchive):
getinfo = SeekableArchive.getentry getinfo = SeekableArchive.getentry


def namelist(self): def namelist(self):
return list(self.iterpaths)
return list(self.iterpaths())


def infolist(self): def infolist(self):
return list(self) return list(self)


def open(self, name, mode, pwd=None): def open(self, name, mode, pwd=None):
if pwd:
raise NotImplemented('Encryption not supported.')
if mode == 'r': if mode == 'r':
if pwd:
self.add_passphrase(pwd)
return self.readstream(name) return self.readstream(name)
else: else:
return self.writestream(name) return self.writestream(name)


def extract(self, name, path=None, pwd=None): def extract(self, name, path=None, pwd=None):
if pwd: if pwd:
raise NotImplemented('Encryption not supported.')
self.add_passphrase(pwd)
if not path: if not path:
path = os.getcwd() path = os.getcwd()
return self.readpath(name, os.path.join(path, name)) return self.readpath(name, os.path.join(path, name))


def extractall(self, path, names=None, pwd=None): def extractall(self, path, names=None, pwd=None):
if pwd: if pwd:
raise NotImplemented('Encryption not supported.')
self.add_passphrase(pwd)
if not names: if not names:
names = self.namelist() names = self.namelist()
if names: if names:
print(f"Extracting {names} files.")
for name in names: for name in names:
self.extract(name, path) self.extract(name, path)


def read(self, name, pwd=None): def read(self, name, pwd=None):
if pwd: if pwd:
raise NotImplemented('Encryption not supported.')
self.add_passphrase(pwd)
return self.read(name) return self.read(name)


def writestr(self, member, data, compress_type=None): def writestr(self, member, data, compress_type=None):
@@ -112,7 +113,7 @@ class ZipFile(SeekableArchive):
return self.write(member, data) return self.write(member, data)


def setpassword(self, pwd): def setpassword(self, pwd):
raise NotImplemented('Encryption not supported.')
return self.set_passphrase(pwd)


def testzip(self): def testzip(self):
raise NotImplemented() raise NotImplemented()


Loading…
Cancel
Save