@@ -426,11 +426,13 @@ class Entry(object): | |||||
class Archive(object): | class Archive(object): | ||||
'''A low-level archive reader which provides forward-only iteration. Consider | '''A low-level archive reader which provides forward-only iteration. Consider | ||||
this a light-weight pythonic libarchive wrapper.''' | this a light-weight pythonic libarchive wrapper.''' | ||||
def __init__(self, f, mode='r', format=None, filter=None, entry_class=Entry, encoding=ENCODING, blocksize=BLOCK_SIZE): | |||||
def __init__(self, f, mode='r', format=None, filter=None, entry_class=Entry, | |||||
encoding=ENCODING, blocksize=BLOCK_SIZE, password=None): | |||||
assert mode in ('r', 'w', 'wb', 'a'), 'Mode should be "r", "w", "wb", or "a".' | assert mode in ('r', 'w', 'wb', 'a'), 'Mode should be "r", "w", "wb", or "a".' | ||||
self._stream = None | self._stream = None | ||||
self.encoding = encoding | self.encoding = encoding | ||||
self.blocksize = blocksize | self.blocksize = blocksize | ||||
self.password = password | |||||
if isinstance(f, str): | if isinstance(f, str): | ||||
self.filename = f | self.filename = f | ||||
f = open(f, mode) | f = open(f, mode) | ||||
@@ -499,8 +501,12 @@ class Archive(object): | |||||
self.format_func(self._a) | self.format_func(self._a) | ||||
self.filter_func(self._a) | self.filter_func(self._a) | ||||
if self.mode == 'r': | if self.mode == 'r': | ||||
if self.password: | |||||
self.add_passphrase(self.password) | |||||
call_and_check(_libarchive.archive_read_open_fd, self._a, self._a, self.f.fileno(), self.blocksize) | call_and_check(_libarchive.archive_read_open_fd, self._a, self._a, self.f.fileno(), self.blocksize) | ||||
else: | else: | ||||
if self.password: | |||||
self.set_passphrase(self.password) | |||||
call_and_check(_libarchive.archive_write_open_fd, self._a, self._a, self.f.fileno()) | call_and_check(_libarchive.archive_write_open_fd, self._a, self._a, self.f.fileno()) | ||||
def denit(self): | def denit(self): | ||||
@@ -562,7 +568,7 @@ class Archive(object): | |||||
'''Write current archive entry contents to file. f can be a file-like object or | '''Write current archive entry contents to file. f can be a file-like object or | ||||
a path.''' | a path.''' | ||||
if isinstance(f, str): | if isinstance(f, str): | ||||
basedir = os.path.basename(f) | |||||
basedir = os.path.dirname(f) | |||||
if not os.path.exists(basedir): | if not os.path.exists(basedir): | ||||
os.makedirs(basedir) | os.makedirs(basedir) | ||||
f = open(f, 'w') | f = open(f, 'w') | ||||
@@ -626,6 +632,10 @@ class Archive(object): | |||||
def add_passphrase(self, password): | def add_passphrase(self, password): | ||||
'''Adds a password to the archive.''' | '''Adds a password to the archive.''' | ||||
_libarchive.archive_read_add_passphrase(self._a, password) | _libarchive.archive_read_add_passphrase(self._a, password) | ||||
def set_passphrase(self, password): | |||||
'''Sets a password for the archive.''' | |||||
_libarchive.archive_write_set_passphrase(self._a, password) | |||||
class SeekableArchive(Archive): | class SeekableArchive(Archive): | ||||
@@ -535,8 +535,9 @@ PyObject *archive_read_data_into_str(struct archive *archive, int len) { | |||||
} | } | ||||
PyObject *archive_write_data_from_str(struct archive *archive, PyObject *str) { | PyObject *archive_write_data_from_str(struct archive *archive, PyObject *str) { | ||||
int len = PyString_Size(str); | |||||
if (!archive_write_data(archive, PyString_AS_STRING(str), len)) { | |||||
Py_ssize_t len = PyBytes_Size(str); | |||||
if (!archive_write_data(archive, PyBytes_AS_STRING(str), len)) { | |||||
PyErr_SetString(PyExc_RuntimeError, "could not write requested data."); | PyErr_SetString(PyExc_RuntimeError, "could not write requested data."); | ||||
return NULL; | return NULL; | ||||
} | } | ||||
@@ -3207,8 +3207,9 @@ PyObject *archive_read_data_into_str(struct archive *archive, int len) { | |||||
} | } | ||||
PyObject *archive_write_data_from_str(struct archive *archive, PyObject *str) { | PyObject *archive_write_data_from_str(struct archive *archive, PyObject *str) { | ||||
int len = PyString_Size(str); | |||||
if (!archive_write_data(archive, PyString_AS_STRING(str), len)) { | |||||
Py_ssize_t len = PyBytes_Size(str); | |||||
if (!archive_write_data(archive, PyBytes_AS_STRING(str), len)) { | |||||
PyErr_SetString(PyExc_RuntimeError, "could not write requested data."); | PyErr_SetString(PyExc_RuntimeError, "could not write requested data."); | ||||
return NULL; | return NULL; | ||||
} | } | ||||
@@ -62,8 +62,8 @@ class ZipEntry(Entry): | |||||
class ZipFile(SeekableArchive): | class ZipFile(SeekableArchive): | ||||
def __init__(self, f, mode='r', compression=ZIP_DEFLATED, allowZip64=False): | |||||
super(ZipFile, self).__init__(f, mode=mode, format='zip', entry_class=ZipEntry, encoding='CP437') | |||||
def __init__(self, f, mode='r', compression=ZIP_DEFLATED, allowZip64=False, password=None): | |||||
super(ZipFile, self).__init__(f, mode=mode, format='zip', entry_class=ZipEntry, encoding='CP437', password=password) | |||||
if mode == 'w' and compression == ZIP_STORED: | if mode == 'w' and compression == ZIP_STORED: | ||||
# Disable compression for writing. | # Disable compression for writing. | ||||
_libarchive.archive_write_set_format_option(self.archive._a, "zip", "compression", "store") | _libarchive.archive_write_set_format_option(self.archive._a, "zip", "compression", "store") | ||||
@@ -72,38 +72,39 @@ class ZipFile(SeekableArchive): | |||||
getinfo = SeekableArchive.getentry | getinfo = SeekableArchive.getentry | ||||
def namelist(self): | def namelist(self): | ||||
return list(self.iterpaths) | |||||
return list(self.iterpaths()) | |||||
def infolist(self): | def infolist(self): | ||||
return list(self) | return list(self) | ||||
def open(self, name, mode, pwd=None): | def open(self, name, mode, pwd=None): | ||||
if pwd: | |||||
raise NotImplemented('Encryption not supported.') | |||||
if mode == 'r': | if mode == 'r': | ||||
if pwd: | |||||
self.add_passphrase(pwd) | |||||
return self.readstream(name) | return self.readstream(name) | ||||
else: | else: | ||||
return self.writestream(name) | return self.writestream(name) | ||||
def extract(self, name, path=None, pwd=None): | def extract(self, name, path=None, pwd=None): | ||||
if pwd: | if pwd: | ||||
raise NotImplemented('Encryption not supported.') | |||||
self.add_passphrase(pwd) | |||||
if not path: | if not path: | ||||
path = os.getcwd() | path = os.getcwd() | ||||
return self.readpath(name, os.path.join(path, name)) | return self.readpath(name, os.path.join(path, name)) | ||||
def extractall(self, path, names=None, pwd=None): | def extractall(self, path, names=None, pwd=None): | ||||
if pwd: | if pwd: | ||||
raise NotImplemented('Encryption not supported.') | |||||
self.add_passphrase(pwd) | |||||
if not names: | if not names: | ||||
names = self.namelist() | names = self.namelist() | ||||
if names: | if names: | ||||
print(f"Extracting {names} files.") | |||||
for name in names: | for name in names: | ||||
self.extract(name, path) | self.extract(name, path) | ||||
def read(self, name, pwd=None): | def read(self, name, pwd=None): | ||||
if pwd: | if pwd: | ||||
raise NotImplemented('Encryption not supported.') | |||||
self.add_passphrase(pwd) | |||||
return self.read(name) | return self.read(name) | ||||
def writestr(self, member, data, compress_type=None): | def writestr(self, member, data, compress_type=None): | ||||
@@ -112,7 +113,7 @@ class ZipFile(SeekableArchive): | |||||
return self.write(member, data) | return self.write(member, data) | ||||
def setpassword(self, pwd): | def setpassword(self, pwd): | ||||
raise NotImplemented('Encryption not supported.') | |||||
return self.set_passphrase(pwd) | |||||
def testzip(self): | def testzip(self): | ||||
raise NotImplemented() | raise NotImplemented() | ||||