| @@ -1,5 +1,6 @@ | |||
| from . import bencode | |||
| from functools import reduce | |||
| from hashlib import sha1 | |||
| import importlib.resources | |||
| import itertools | |||
| @@ -19,9 +20,22 @@ class Storage: | |||
| self._buildindex() | |||
| def _filepaths(self): | |||
| for curfile in self._files: | |||
| fname = pathlib.Path( | |||
| *(x.decode(self._encoding) for x in | |||
| curfile['path'])) | |||
| curfilepath = self._rootpath / fname | |||
| yield curfile, fname, curfilepath | |||
| def allfiles(self): | |||
| for x, y, curfilepath in self._filepaths(): | |||
| yield curfilepath | |||
| def _buildindex(self): | |||
| self._index = [] | |||
| files = iter(self._files) | |||
| files = self._filepaths() | |||
| left = 0 | |||
| curfile = None | |||
| @@ -29,11 +43,7 @@ class Storage: | |||
| if curfile is None or curfileoff == curfile['length']: | |||
| # next file | |||
| try: | |||
| curfile = next(files) | |||
| fname = pathlib.Path( | |||
| *(x.decode(self._encoding) for x in | |||
| curfile['path'])) | |||
| curfilepath = self._rootpath / fname | |||
| curfile, fname, curfilepath = next(files) | |||
| except StopIteration: | |||
| break | |||
| curfileoff = 0 | |||
| @@ -51,6 +61,10 @@ class Storage: | |||
| curfileoff += sz | |||
| left -= sz | |||
| def filesforpiece(self, idx): | |||
| for x in self._index[idx]: | |||
| yield x['file'] | |||
| def apply_piece(self, idx, fun): | |||
| for i in self._index[idx]: | |||
| with open(i['file'], 'rb') as fp: | |||
| @@ -72,14 +86,28 @@ def validate(torrent, basedir): | |||
| stor = Storage(torrentdir, info['files'], info['piece length'], encoding) | |||
| pieces = info['pieces'] | |||
| piecescnt = len(pieces) // 20 | |||
| valid = [ None ] * piecescnt | |||
| for num, i in enumerate(pieces[x:x+20] for x in range(0, len(pieces), | |||
| 20)): | |||
| hash = sha1() | |||
| stor.apply_piece(num, hash.update) | |||
| if hash.digest() != i: | |||
| raise ValueError | |||
| if hash.digest() == i: | |||
| valid[num] = True | |||
| else: | |||
| valid[num] = False | |||
| # if any piece of a file is bad, it's bad | |||
| allfiles = set(stor.allfiles()) | |||
| badpieces = [ x for x, v in enumerate(valid) if not v ] | |||
| badfiles = reduce(set.__or__, (set(stor.filesforpiece(x)) for x in | |||
| badpieces), set()) | |||
| return allfiles - badfiles, badfiles | |||
| class _TestCases(unittest.TestCase): | |||
| dirname = 'somedir' | |||
| @@ -155,13 +183,20 @@ class _TestCases(unittest.TestCase): | |||
| missingfiles = self.origfiledata.copy() | |||
| missingfiles['filea.txt'] = b'' | |||
| missingfiles['filec.txt'] = b'\x00\x00\x00\x00a\n' | |||
| missingfiles['filee.txt'] = b'no' | |||
| badfiles = { | |||
| 'filea.txt': b'', | |||
| 'filec.txt': b'\x00\x00\x00\x00a\n', | |||
| 'filee.txt': b'no', | |||
| } | |||
| missingfiles.update(badfiles) | |||
| sd = self.basetempdir / self.dirname | |||
| sd.mkdir() | |||
| self.make_files(sd, missingfiles) | |||
| self.assertRaises(ValueError, validate, self.torrent, self.basetempdir) | |||
| val, inval = validate(self.torrent, self.basetempdir) | |||
| self.assertEqual(set(val), { sd / x for x in missingfiles.keys() if x not in badfiles }) | |||
| self.assertEqual(set(inval), { sd / x for x in badfiles.keys() }) | |||