diff --git a/__init__.py b/__init__.py index f2582e8..2170be4 100644 --- a/__init__.py +++ b/__init__.py @@ -1,5 +1,6 @@ from . import bencode +from functools import reduce from hashlib import sha1 import importlib.resources import itertools @@ -19,9 +20,22 @@ class Storage: self._buildindex() + def _filepaths(self): + for curfile in self._files: + fname = pathlib.Path( + *(x.decode(self._encoding) for x in + curfile['path'])) + curfilepath = self._rootpath / fname + + yield curfile, fname, curfilepath + + def allfiles(self): + for x, y, curfilepath in self._filepaths(): + yield curfilepath + def _buildindex(self): self._index = [] - files = iter(self._files) + files = self._filepaths() left = 0 curfile = None @@ -29,11 +43,7 @@ class Storage: if curfile is None or curfileoff == curfile['length']: # next file try: - curfile = next(files) - fname = pathlib.Path( - *(x.decode(self._encoding) for x in - curfile['path'])) - curfilepath = self._rootpath / fname + curfile, fname, curfilepath = next(files) except StopIteration: break curfileoff = 0 @@ -51,6 +61,10 @@ class Storage: curfileoff += sz left -= sz + def filesforpiece(self, idx): + for x in self._index[idx]: + yield x['file'] + def apply_piece(self, idx, fun): for i in self._index[idx]: with open(i['file'], 'rb') as fp: @@ -72,14 +86,28 @@ def validate(torrent, basedir): stor = Storage(torrentdir, info['files'], info['piece length'], encoding) pieces = info['pieces'] + piecescnt = len(pieces) // 20 + valid = [ None ] * piecescnt for num, i in enumerate(pieces[x:x+20] for x in range(0, len(pieces), 20)): hash = sha1() stor.apply_piece(num, hash.update) - if hash.digest() != i: - raise ValueError + if hash.digest() == i: + valid[num] = True + else: + valid[num] = False + + # if any piece of a file is bad, it's bad + allfiles = set(stor.allfiles()) + + badpieces = [ x for x, v in enumerate(valid) if not v ] + + badfiles = reduce(set.__or__, (set(stor.filesforpiece(x)) for x in + badpieces), set()) + + return allfiles - badfiles, badfiles class _TestCases(unittest.TestCase): dirname = 'somedir' @@ -155,13 +183,20 @@ class _TestCases(unittest.TestCase): missingfiles = self.origfiledata.copy() - missingfiles['filea.txt'] = b'' - missingfiles['filec.txt'] = b'\x00\x00\x00\x00a\n' - missingfiles['filee.txt'] = b'no' + badfiles = { + 'filea.txt': b'', + 'filec.txt': b'\x00\x00\x00\x00a\n', + 'filee.txt': b'no', + } + + missingfiles.update(badfiles) sd = self.basetempdir / self.dirname sd.mkdir() self.make_files(sd, missingfiles) - self.assertRaises(ValueError, validate, self.torrent, self.basetempdir) + val, inval = validate(self.torrent, self.basetempdir) + + self.assertEqual(set(val), { sd / x for x in missingfiles.keys() if x not in badfiles }) + self.assertEqual(set(inval), { sd / x for x in badfiles.keys() })