#!/usr/bin/env python # -*- coding: utf-8 -*- from pysnmp.hlapi import * from pysnmp.smi.builder import MibBuilder from pysnmp.smi.view import MibViewController import importlib import itertools import mock import random import unittest _mbuilder = MibBuilder() _mvc = MibViewController(_mbuilder) #import data # received packages # pvid: dot1qPvid # # tx packets: # dot1qVlanStaticEgressPorts # dot1qVlanStaticUntaggedPorts # # vlans: # dot1qVlanCurrentTable # lists ALL vlans, including baked in ones # # note that even though an snmpwalk of dot1qVlanStaticEgressPorts # skips over other vlans (only shows statics), the other vlans (1,2,3) # are still accessible via that oid # # LLDP: # 1.0.8802.1.1.2.1.4.1.1 aka LLDP-MIB, lldpRemTable class SwitchConfig(object): def __init__(self, host, community, vlanconf): self._host = host self._community = community self._vlanconf = vlanconf @property def host(self): return self._host @property def community(self): return self._community @property def vlanconf(self): return self._vlanconf def _octstrtobits(os): num = 1 for i in str(os): num = (num << 8) | ord(i) return bin(num)[3:] def _intstobits(*ints): v = 0 for i in ints: v |= 1 << i r = list(bin(v)[2:-1]) r.reverse() return ''.join(r) import vlanmang def checkchanges(module): mod = importlib.import_module(module) mods = [ i for i in mod.__dict__.itervalues() if isinstance(i, vlanmang.SwitchConfig) ] res = [] for i in mods: vlans = i.vlanconf.keys() switch = SNMPSwitch(i.host, i.community) portmapping = switch.getportmapping() invportmap = { y: x for x, y in portmapping.iteritems() } lufun = invportmap.__getitem__ # get complete set of ports portlist = getportlist(i.vlanconf, lufun) ports = set(portmapping.iterkeys()) # make sure switch agrees w/ them all if ports != portlist: raise ValueError('missing or extra ports found: %s' % `ports.symmetric_difference(portlist)`) # compare pvid pvidmap = getpvidmapping(i.vlanconf, lufun) switchpvid = switch.getpvid() res.extend(('setpvid', idx, vlan, switchpvid[idx]) for idx, vlan in pvidmap.iteritems() if switchpvid[idx] != vlan) # compare egress & untagged switchegress = switch.getegress(*vlans) egress = getegress(i.vlanconf, lufun) switchuntagged = switch.getuntagged(*vlans) untagged = getuntagged(i.vlanconf, lufun) for i in vlans: if switchegress[i] != egress[i]: res.append(('setegress', i, egress[i], switchegress[i])) if switchuntagged[i] != untagged[i]: res.append(('setuntagged', i, untagged[i], switchuntagged[i])) return res def getidxs(lst, lookupfun): return [ lookupfun(i) if isinstance(i, str) else i for i in lst ] def getpvidmapping(data, lookupfun): '''Return a mapping from vlan based table to a port: vlan dictionary.''' res = [] for id in data: for i in data[id].get('u', []): if isinstance(i, str): i = lookupfun(i) res.append((i, id)) return dict(res) def getegress(data, lookupfun): r = {} for id in data: r[id] = _intstobits(*(getidxs(data[id].get('u', []), lookupfun) + getidxs(data[id].get('t', []), lookupfun))) return r def getuntagged(data, lookupfun): r = {} for id in data: r[id] = _intstobits(*getidxs(data[id].get('u', []), lookupfun)) return r def getportlist(data, lookupfun): '''Return a set of all the ports indexes in data.''' res = set() for id in data: res.update(data[id].get('u', [])) res.update(data[id].get('t', [])) # filter out the strings strports = set(x for x in res if isinstance(x, str)) res.update(lookupfun(x) for x in strports) res.difference_update(strports) return res class SNMPSwitch(object): '''A class for manipulating switches via standard SNMP MIBs.''' def __init__(self, host, community): self._eng = SnmpEngine() self._cd = CommunityData(community, mpModel=0) self._targ = UdpTransportTarget((host, 161)) def _getmany(self, *oids): oids = [ ObjectIdentity(*oid) for oid in oids ] [ oid.resolveWithMib(_mvc) for oid in oids ] errorInd, errorStatus, errorIndex, varBinds = \ next(getCmd(self._eng, self._cd, self._targ, ContextData(), *(ObjectType(oid) for oid in oids))) if errorInd: # pragma: no cover raise ValueError(errorIndication) elif errorStatus: # pragma: no cover raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?')) else: if len(varBinds) != len(oids): # pragma: no cover raise ValueError('too many return values') return varBinds def _get(self, oid): varBinds = self._getmany(oid) varBind = varBinds[0] return varBind[1] def _set(self, oid, value): oid = ObjectIdentity(*oid) oid.resolveWithMib(_mvc) if isinstance(value, (int, long)): value = Integer(value) elif isinstance(value, str): value = OctetString(value) errorInd, errorStatus, errorIndex, varBinds = \ next(setCmd(self._eng, self._cd, self._targ, ContextData(), ObjectType(oid, value))) if errorInd: # pragma: no cover raise ValueError(errorIndication) elif errorStatus: # pragma: no cover raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?')) else: for varBind in varBinds: if varBind[1] != value: # pragma: no cover raise RuntimeError('failed to set: %s' % ' = '.join([x.prettyPrint() for x in varBind])) def _walk(self, *oid): oid = ObjectIdentity(*oid) # XXX - keep these, this might stop working, no clue what managed to magically make things work # ref: http://snmplabs.com/pysnmp/examples/smi/manager/browsing-mib-tree.html#mib-objects-to-pdu-var-binds # mibdump.py --mib-source '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs' --mib-source /usr/share/snmp/mibs --rebuild rfc1212 pbridge vlan #oid.addAsn1MibSource('/usr/share/snmp/mibs', '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs') oid.resolveWithMib(_mvc) for (errorInd, errorStatus, errorIndex, varBinds) in nextCmd( self._eng, self._cd, self._targ, ContextData(), ObjectType(oid), lexicographicMode=False): if errorInd: # pragma: no cover raise ValueError(errorInd) elif errorStatus: # pragma: no cover raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?')) else: for varBind in varBinds: yield varBind def getportmapping(self): '''Return a port name mapping. Keys are the port index and the value is the name from the ifName entry.''' return { x[0][-1]: str(x[1]) for x in self._walk('IF-MIB', 'ifName') } def findport(self, name): '''Look up a port name and return it's port index. This looks up via the ifName table in IF-MIB.''' return [ x[0][-1] for x in self._walk('IF-MIB', 'ifName') if str(x[1]) == name ][0] def getvlanname(self, vlan): '''Return the name for the vlan.''' v = self._get(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', vlan)) return str(v).decode('utf-8') def createvlan(self, vlan, name): # createAndGo(4) self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus', int(vlan)), 4) self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', int(vlan)), name) def deletevlan(self, vlan): self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus', int(vlan)), 6) # destroy(6) def getvlans(self): '''Return an iterator with all the vlan ids.''' return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStatus')) def staticvlans(self): '''Return an iterator of the staticly defined/configured vlans. This sometimes excludes special built in vlans, like vlan 1.''' return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStaticName')) def getpvid(self): '''Returns a dictionary w/ the interface index as the key, and the pvid of the interface.''' return { x[0][-1]: int(x[1]) for x in self._walk('Q-BRIDGE-MIB', 'dot1qPvid') } def getegress(self, *vlans): r = { x[-1]: _octstrtobits(y) for x, y in self._getmany(*(('Q-BRIDGE-MIB', 'dot1qVlanStaticEgressPorts', x) for x in vlans)) } return r def getuntagged(self, *vlans): r = { x[-1]: _octstrtobits(y) for x, y in self._getmany(*(('Q-BRIDGE-MIB', 'dot1qVlanStaticUntaggedPorts', x) for x in vlans)) } return r if __name__ == '__main__': # pragma: no cover print `checkchanges('data')` class _TestMisc(unittest.TestCase): def setUp(self): import test_data self._test_data = test_data def test_intstobits(self): self.assertEqual(_intstobits(1, 5, 10), '1000100001') self.assertEqual(_intstobits(3, 4, 9), '001100001') def test_octstrtobits(self): self.assertEqual(_octstrtobits('\x00'), '0' * 8) self.assertEqual(_octstrtobits('\xff'), '1' * 8) self.assertEqual(_octstrtobits('\xf0'), '1' * 4 + '0' * 4) self.assertEqual(_octstrtobits('\x0f'), '0' * 4 + '1' * 4) def test_pvidegressuntagged(self): data = { 1: { 'u': [ 1, 5, 10 ] + range(13, 20), 't': [ 'lag2', 6, 7 ], }, 10: { 'u': [ 2, 3, 6, 7, 8, 'lag2' ], }, 13: { 'u': [ 4, 9 ], 't': [ 'lag2', 6, 7 ], }, 14: { 't': [ 'lag2' ], }, } lookup = { 'lag2': 30 } lufun = lookup.__getitem__ check = dict(itertools.chain(enumerate([ 1, 10, 10, 13, 1, 10, 10, 10, 13, 1 ], 1), enumerate([ 1 ] * 7, 13), [ (30, 10) ])) # That a pvid mapping res = getpvidmapping(data, lufun) # is correct self.assertEqual(res, check) self.assertEqual(getportlist(data, lufun), set(xrange(1, 11)) | set(xrange(13, 20)) | set([30])) checkegress = { 1: '1000111001001111111' + '0' * (30 - 20) + '1', 10: '01100111' + '0' * (30 - 9) + '1', 13: '000101101' + '0' * (30 - 10) + '1', 14: '0' * (30 - 1) + '1', } self.assertEqual(getegress(data, lufun), checkegress) checkuntagged = { 1: '1000100001001111111', 10: '01100111' + '0' * (30 - 9) + '1', 13: '000100001', 14: '', } self.assertEqual(getuntagged(data, lufun), checkuntagged) #@unittest.skip('foo') @mock.patch('vlanmang.SNMPSwitch.getuntagged') @mock.patch('vlanmang.SNMPSwitch.getegress') @mock.patch('vlanmang.SNMPSwitch.getpvid') @mock.patch('vlanmang.SNMPSwitch.getportmapping') @mock.patch('importlib.import_module') def test_checkchanges(self, imprt, portmapping, gpvid, gegress, guntagged): # that import returns the test data imprt.side_effect = itertools.repeat(self._test_data) # that getportmapping returns the following dict ports = { x: 'g%d' % x for x in xrange(1, 24) } ports[30] = 'lag1' ports[31] = 'lag2' portmapping.side_effect = itertools.repeat(ports) # that the switch's pvid returns spvid = { x: 283 for x in xrange(1, 24) } spvid[30] = 5 gpvid.side_effect = itertools.repeat(spvid) # the the extra port is caught self.assertRaises(ValueError, checkchanges, 'data') # that the functions were called imprt.assert_called_with('data') portmapping.assert_called() # XXX - check that an ignore statement is honored # delete the extra port del ports[31] # that the egress data provided gegress.side_effect = [ { 1: '1' * 10, 5: '1' * 10, 283: '000000001111111111100110000001', } ] # that the untagged data provided guntagged.side_effect = [ { 1: '1' * 10, 5: '1' * 8, 283: '00000000111111111110011', } ] res = checkchanges('data') validres = [ ('setpvid', x, 5, 283) for x in xrange(1, 9) ] + \ [ ('setpvid', 20, 1, 283), ('setpvid', 21, 1, 283), ('setpvid', 30, 1, 5), ('setegress', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10), ('setuntagged', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10), ('setegress', 5, '1' * 8 + '0' * 11 + '11' + '0' * 8 + '1', '1' * 10), ] self.assertEqual(set(res), set(validres)) _skipSwitchTests = False class _TestSwitch(unittest.TestCase): def setUp(self): # If we don't have it, pretend it's true for now and # we'll recheck it later model = 'GS108T smartSwitch' if getattr(self, 'switchmodel', model) != model or \ _skipSwitchTests: # pragma: no cover self.skipTest('Need a GS108T switch to run these tests') args = open('test.creds').read().split() self.switch = SNMPSwitch(*args) self.switchmodel = self.switch._get(('ENTITY-MIB', 'entPhysicalModelName', 1)) if self.switchmodel != model: # pragma: no cover self.skipTest('Need a GS108T switch to run these tests') def test_misc(self): switch = self.switch self.assertEqual(switch.findport('g1'), 1) self.assertEqual(switch.findport('l1'), 14) def test_portnames(self): switch = self.switch resp = dict((x, 'g%d' % x) for x in xrange(1, 9)) resp.update({ 13: 'cpu' }) resp.update((x, 'l%d' % (x - 13)) for x in xrange(14, 18)) self.assertEqual(switch.getportmapping(), resp) def test_egress(self): switch = self.switch egress = switch.getegress(1, 2, 3) checkegress = { 1: '1' * 8 + '0' * 5 + '1' * 4 + '0' * 23, 2: '0' * 8 * 5, 3: '0' * 8 * 5, } self.assertEqual(egress, checkegress) def test_untagged(self): switch = self.switch untagged = switch.getuntagged(1, 2, 3) checkuntagged = { 1: '1' * 8 * 5, 2: '1' * 8 * 5, 3: '1' * 8 * 5, } self.assertEqual(untagged, checkuntagged) def test_vlan(self): switch = self.switch existingvlans = set(switch.getvlans()) while True: testvlan = random.randint(1,4095) if testvlan not in existingvlans: break # Test that getting a non-existant vlans raises an exception self.assertRaises(ValueError, switch.getvlanname, testvlan) self.assertTrue(set(switch.staticvlans()).issubset(existingvlans)) pvidres = { x: 1 for x in xrange(1, 9) } pvidres.update({ x: 1 for x in xrange(14, 18) }) self.assertEqual(switch.getpvid(), pvidres) testname = 'Sometestname' # Create test vlan switch.createvlan(testvlan, testname) try: # make sure the test vlan was created self.assertIn(testvlan, set(switch.staticvlans())) self.assertEqual(testname, switch.getvlanname(testvlan)) finally: switch.deletevlan(testvlan)