diff --git a/vlanmang.py b/vlanmang.py index a1c0005..9d10a5e 100644 --- a/vlanmang.py +++ b/vlanmang.py @@ -52,7 +52,14 @@ class SwitchConfig(object): def vlanconf(self): return self._vlanconf -def intstobits(*ints): +def _octstrtobits(os): + num = 1 + for i in str(os): + num = (num << 8) | ord(i) + + return bin(num)[3:] + +def _intstobits(*ints): v = 0 for i in ints: v |= 1 << i @@ -69,25 +76,40 @@ def checkchanges(module): res = [] for i in mods: + vlans = i.vlanconf.keys() switch = SNMPSwitch(i.host, i.community) portmapping = switch.getportmapping() invportmap = { y: x for x, y in portmapping.iteritems() } lufun = invportmap.__getitem__ - portlist = getportlist(i._vlanconf, lufun) + # get complete set of ports + portlist = getportlist(i.vlanconf, lufun) ports = set(portmapping.iterkeys()) + # make sure switch agrees w/ them all if ports != portlist: raise ValueError('missing or extra ports found: %s' % `ports.symmetric_difference(portlist)`) - pvidmap = getpvidmapping(i._vlanconf, lufun) + # compare pvid + pvidmap = getpvidmapping(i.vlanconf, lufun) switchpvid = switch.getpvid() res.extend(('setpvid', idx, vlan, switchpvid[idx]) for idx, vlan in pvidmap.iteritems() if switchpvid[idx] != vlan) + # compare egress & untagged + switchegress = switch.getegress(*vlans) + egress = getegress(i.vlanconf, lufun) + switchuntagged = switch.getuntagged(*vlans) + untagged = getuntagged(i.vlanconf, lufun) + for i in vlans: + if switchegress[i] != egress[i]: + res.append(('setegress', i, egress[i], switchegress[i])) + if switchuntagged[i] != untagged[i]: + res.append(('setuntagged', i, untagged[i], switchuntagged[i])) + return res def getidxs(lst, lookupfun): @@ -109,7 +131,7 @@ def getpvidmapping(data, lookupfun): def getegress(data, lookupfun): r = {} for id in data: - r[id] = intstobits(*(getidxs(data[id]['u'], lookupfun) + + r[id] = _intstobits(*(getidxs(data[id]['u'], lookupfun) + getidxs(data[id].get('t', []), lookupfun))) return r @@ -117,7 +139,7 @@ def getegress(data, lookupfun): def getuntagged(data, lookupfun): r = {} for id in data: - r[id] = intstobits(*getidxs(data[id]['u'], lookupfun)) + r[id] = _intstobits(*getidxs(data[id]['u'], lookupfun)) return r @@ -146,6 +168,25 @@ class SNMPSwitch(object): self._cd = CommunityData(community, mpModel=0) self._targ = UdpTransportTarget((host, 161)) + def _getmany(self, *oids): + oids = [ ObjectIdentity(*oid) for oid in oids ] + [ oid.resolveWithMib(_mvc) for oid in oids ] + + errorInd, errorStatus, errorIndex, varBinds = \ + next(getCmd(self._eng, self._cd, self._targ, ContextData(), *(ObjectType(oid) for oid in oids))) + + if errorInd: # pragma: no cover + raise ValueError(errorIndication) + elif errorStatus: # pragma: no cover + raise ValueError('%s at %s' % + (errorStatus.prettyPrint(), errorIndex and + varBinds[int(errorIndex)-1][0] or '?')) + else: + if len(varBinds) != len(oids): # pragma: no cover + raise ValueError('too many return values') + + return varBinds + def _get(self, oid): oid = ObjectIdentity(*oid) oid.resolveWithMib(_mvc) @@ -258,6 +299,20 @@ class SNMPSwitch(object): return { x[0][-1]: int(x[1]) for x in self._walk('Q-BRIDGE-MIB', 'dot1qPvid') } + def getegress(self, *vlans): + r = { x[-1]: _octstrtobits(y) for x, y in + self._getmany(*(('Q-BRIDGE-MIB', + 'dot1qVlanStaticEgressPorts', x) for x in vlans)) } + + return r + + def getuntagged(self, *vlans): + r = { x[-1]: _octstrtobits(y) for x, y in + self._getmany(*(('Q-BRIDGE-MIB', + 'dot1qVlanStaticUntaggedPorts', x) for x in vlans)) } + + return r + class _TestMisc(unittest.TestCase): def setUp(self): import test_data @@ -265,8 +320,14 @@ class _TestMisc(unittest.TestCase): self._test_data = test_data def test_intstobits(self): - self.assertEqual(intstobits(1, 5, 10), '1000100001') - self.assertEqual(intstobits(3, 4, 9), '001100001') + self.assertEqual(_intstobits(1, 5, 10), '1000100001') + self.assertEqual(_intstobits(3, 4, 9), '001100001') + + def test_octstrtobits(self): + self.assertEqual(_octstrtobits('\x00'), '0' * 8) + self.assertEqual(_octstrtobits('\xff'), '1' * 8) + self.assertEqual(_octstrtobits('\xf0'), '1' * 4 + '0' * 4) + self.assertEqual(_octstrtobits('\x0f'), '0' * 4 + '1' * 4) def test_pvidegressuntagged(self): data = { @@ -315,11 +376,13 @@ class _TestMisc(unittest.TestCase): } self.assertEqual(getuntagged(data, lufun), checkuntagged) - @unittest.skip('foo') + #@unittest.skip('foo') + @mock.patch('vlanmang.SNMPSwitch.getuntagged') + @mock.patch('vlanmang.SNMPSwitch.getegress') @mock.patch('vlanmang.SNMPSwitch.getpvid') @mock.patch('vlanmang.SNMPSwitch.getportmapping') @mock.patch('importlib.import_module') - def test_checkchanges(self, imprt, portmapping, gpvid): + def test_checkchanges(self, imprt, portmapping, gpvid, gegress, guntagged): # that import returns the test data imprt.side_effect = itertools.repeat(self._test_data) @@ -346,30 +409,50 @@ class _TestMisc(unittest.TestCase): # delete the extra port del ports[31] + # that the egress data provided + gegress.side_effect = [ { + 1: '1' * 10, + 5: '1' * 10, + 283: '000000001111111111100110000001', + } ] + + # that the untagged data provided + guntagged.side_effect = [ { + 1: '1' * 10, + 5: '1' * 8, + 283: '00000000111111111110011', + } ] + res = checkchanges('data') validres = [ ('setpvid', x, 5, 283) for x in xrange(1, 9) ] + \ [ ('setpvid', 20, 1, 283), ('setpvid', 21, 1, 283), ('setpvid', 30, 1, 5), - ('setegress', 1, '0' * 19 + '11' + '0' * 8 + '1', ''), - ('setuntagged', 1, '0' * 19 + '11' + '0' * 8 + '1', ''), - ('setegress', 5, '1' * 8 + '0' * 11 + '11' + '0' * 8 + '1', ''), + ('setegress', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10), + ('setuntagged', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10), + ('setegress', 5, '1' * 8 + '0' * 11 + '11' + '0' * 8 + '1', '1' * 10), ] self.assertEqual(set(res), set(validres)) -_skipSwitchTests = True +_skipSwitchTests = False class _TestSwitch(unittest.TestCase): def setUp(self): + # If we don't have it, pretend it's true for now and + # we'll recheck it later + model = 'GS108T smartSwitch' + if getattr(self, 'switchmodel', model) != model or \ + _skipSwitchTests: # pragma: no cover + self.skipTest('Need a GS108T switch to run these tests') + args = open('test.creds').read().split() self.switch = SNMPSwitch(*args) - switchmodel = self.switch._get(('ENTITY-MIB', + self.switchmodel = self.switch._get(('ENTITY-MIB', 'entPhysicalModelName', 1)) - if switchmodel != 'GS108T smartSwitch' or \ - _skipSwitchTests: # pragma: no cover + if self.switchmodel != model: # pragma: no cover self.skipTest('Need a GS108T switch to run these tests') def test_misc(self): @@ -387,6 +470,32 @@ class _TestSwitch(unittest.TestCase): self.assertEqual(switch.getportmapping(), resp) + def test_egress(self): + switch = self.switch + + egress = switch.getegress(1, 2, 3) + + checkegress = { + 1: '1' * 8 + '0' * 5 + '1' * 4 + '0' * 23, + 2: '0' * 8 * 5, + 3: '0' * 8 * 5, + } + + self.assertEqual(egress, checkegress) + + def test_untagged(self): + switch = self.switch + + untagged = switch.getuntagged(1, 2, 3) + + checkuntagged = { + 1: '1' * 8 * 5, + 2: '1' * 8 * 5, + 3: '1' * 8 * 5, + } + + self.assertEqual(untagged, checkuntagged) + def test_vlan(self): switch = self.switch