From e42054d2233aef01c05e1d94a28bccc686c8c21f Mon Sep 17 00:00:00 2001 From: John-Mark Gurney Date: Fri, 20 Sep 2019 20:31:23 -0700 Subject: [PATCH] allow config to ignore some ports.. trim trailing zeros when comparing so that we don't have to make the bits the same length as the switch.. --- vlanmang.py | 85 +++++++++++++++++++++++++++++++++++------------------ 1 file changed, 57 insertions(+), 28 deletions(-) diff --git a/vlanmang.py b/vlanmang.py index 2c18d91..92f3b36 100644 --- a/vlanmang.py +++ b/vlanmang.py @@ -35,10 +35,11 @@ _mvc = MibViewController(_mbuilder) # 1.0.8802.1.1.2.1.4.1.1 aka LLDP-MIB, lldpRemTable class SwitchConfig(object): - def __init__(self, host, community, vlanconf): + def __init__(self, host, community, vlanconf, ignports): self._host = host self._community = community self._vlanconf = vlanconf + self._ignports = ignports @property def host(self): @@ -52,6 +53,30 @@ class SwitchConfig(object): def vlanconf(self): return self._vlanconf + @property + def ignports(self): + return self._ignports + + def getportlist(self, lookupfun): + '''Return a set of all the ports indexes in data.''' + + res = set() + + for id in self._vlanconf: + res.update(self._vlanconf[id].get('u', [])) + res.update(self._vlanconf[id].get('t', [])) + + # add in the ignore ports + res.update(self._ignports) + + # filter out the strings + strports = set(x for x in res if isinstance(x, str)) + + res.update(lookupfun(x) for x in strports) + res.difference_update(strports) + + return res + def _octstrtobits(os): num = 1 for i in str(os): @@ -69,6 +94,16 @@ def _intstobits(*ints): return ''.join(r) +def _cmpbits(a, b): + last1a = a.rindex('1') + last1b = b.rindex('1') + if last1a != -1: + a = a[:last1a] + if last1b != -1: + b = b[:last1b] + + return a == b + import vlanmang def checkchanges(module): @@ -85,7 +120,7 @@ def checkchanges(module): lufun = invportmap.__getitem__ # get complete set of ports - portlist = getportlist(i.vlanconf, lufun) + portlist = i.getportlist(lufun) ports = set(portmapping.iterkeys()) @@ -107,9 +142,9 @@ def checkchanges(module): switchuntagged = switch.getuntagged(*vlans) untagged = getuntagged(i.vlanconf, lufun) for i in vlans: - if switchegress[i] != egress[i]: + if not _cmpbits(switchegress[i], egress[i]): res.append(('setegress', i, egress[i], switchegress[i])) - if switchuntagged[i] != untagged[i]: + if not _cmpbits(switchuntagged[i], untagged[i]): res.append(('setuntagged', i, untagged[i], switchuntagged[i])) return res @@ -145,23 +180,6 @@ def getuntagged(data, lookupfun): return r -def getportlist(data, lookupfun): - '''Return a set of all the ports indexes in data.''' - - res = set() - - for id in data: - res.update(data[id].get('u', [])) - res.update(data[id].get('t', [])) - - # filter out the strings - strports = set(x for x in res if isinstance(x, str)) - - res.update(lookupfun(x) for x in strports) - res.difference_update(strports) - - return res - class SNMPSwitch(object): '''A class for manipulating switches via standard SNMP MIBs.''' @@ -320,6 +338,13 @@ class _TestMisc(unittest.TestCase): self.assertEqual(_octstrtobits('\xf0'), '1' * 4 + '0' * 4) self.assertEqual(_octstrtobits('\x0f'), '0' * 4 + '1' * 4) + def test_cmpbits(self): + self.assertTrue(_cmpbits('111000', '111')) + self.assertTrue(_cmpbits('000111000', '000111')) + self.assertTrue(_cmpbits('11', '11')) + self.assertFalse(_cmpbits('0011', '11')) + self.assertFalse(_cmpbits('11', '0011')) + def test_pvidegressuntagged(self): data = { 1: { @@ -337,8 +362,11 @@ class _TestMisc(unittest.TestCase): 't': [ 'lag2' ], }, } + swconf = SwitchConfig('', '', data, [ 'lag3' ]) + lookup = { - 'lag2': 30 + 'lag2': 30, + 'lag3': 31, } lufun = lookup.__getitem__ @@ -352,8 +380,8 @@ class _TestMisc(unittest.TestCase): # is correct self.assertEqual(res, check) - self.assertEqual(getportlist(data, lufun), - set(xrange(1, 11)) | set(xrange(13, 20)) | set([30])) + self.assertEqual(swconf.getportlist(lufun), + set(xrange(1, 11)) | set(xrange(13, 20)) | set(lookup.values())) checkegress = { 1: '1000111001001111111' + '0' * (30 - 20) + '1', @@ -386,6 +414,7 @@ class _TestMisc(unittest.TestCase): ports = { x: 'g%d' % x for x in xrange(1, 24) } ports[30] = 'lag1' ports[31] = 'lag2' + ports[32] = 'lag3' portmapping.side_effect = itertools.repeat(ports) # that the switch's pvid returns @@ -403,19 +432,19 @@ class _TestMisc(unittest.TestCase): # XXX - check that an ignore statement is honored # delete the extra port - del ports[31] + del ports[32] # that the egress data provided gegress.side_effect = [ { 1: '1' * 10, 5: '1' * 10, - 283: '000000001111111111100110000001', + 283: '00000000111111111110011000000100000', } ] # that the untagged data provided guntagged.side_effect = [ { 1: '1' * 10, - 5: '1' * 8, + 5: '1' * 8 + '0' * 10, 283: '00000000111111111110011', } ] @@ -432,7 +461,7 @@ class _TestMisc(unittest.TestCase): self.assertEqual(set(res), set(validres)) -_skipSwitchTests = False +_skipSwitchTests = True class _TestSwitch(unittest.TestCase): def setUp(self):