diff --git a/setup.py b/setup.py index 42a46c3..e416b08 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ setup(name='vlanmang', 'pysnmp-mibs', 'mock', 'pyasn1==0.4.8', + 'asyncssh @ git+https://github.com/ronf/asyncssh', ], extras_require = { 'dev': [ 'coverage' ], diff --git a/vlanmang/__init__.py b/vlanmang/__init__.py index ded4a15..07d235d 100644 --- a/vlanmang/__init__.py +++ b/vlanmang/__init__.py @@ -35,10 +35,13 @@ if False: from pysnmp import debug debug.setLogger(debug.Debug('mibbuild')) +import asyncio +import asyncssh import importlib import itertools import mock import random +import sys import unittest __author__ = 'John-Mark Gurney' @@ -51,6 +54,12 @@ __all__ = [ 'SNMPSwitch', ] +if False: + from pysnmp import debug + + # use specific flags or 'all' for full debugging + debug.setLogger(debug.Debug('dsp', 'msgproc', 'secmod')) + _mbuilder = MibBuilder() _mvc = MibViewController(_mbuilder) @@ -83,6 +92,11 @@ class SwitchConfig(object): host -- The host of the switch you are maintaining configuration of. + switchfactory -- Default is SNMPSwitch. This is the factory function + that will be called to create the switch object. It will be + passed the host and the authargs like: + switchfactory(host, **authargs). + authargs -- This is a dictionary of kwargs to pass to SNMPSwitch. If SNMPv1 (insecure) is used, pass dict(community='communitystr'). @@ -235,8 +249,9 @@ def checkchanges(module): res = [] + print(repr(mods)) for name, i in mods: - #print('probing %s' % repr(name)) + print('probing %s' % repr(name)) vlans = i.vlanconf.keys() switch = SNMPSwitch(i.host, **i.authargs) switchpvid = switch.getpvid() @@ -280,6 +295,7 @@ def checkchanges(module): # compare pvid pvidmap = getpvidmapping(i.vlanconf, lufun) switchpvid = switch.getpvid() + print(repr(switchpvid)) res.extend((switch, name, 'setpvid', idx, vlan, switchpvid[idx]) for idx, vlan in pvidmap.items() if switchpvid[idx] != vlan) @@ -290,9 +306,8 @@ def checkchanges(module): switchuntagged = switch.getuntagged(*vlans) untagged = getuntagged(i.vlanconf, lufun) for i in vlans: - if not _cmpbits(switchegress[i], egress[i]): + if not _cmpbits(switchegress[i], egress[i]) or not _cmpbits(switchuntagged[i], untagged[i]): res.append((switch, name, 'setegress', i, egress[i], switchegress[i])) - if not _cmpbits(switchuntagged[i], untagged[i]): res.append((switch, name, 'setuntagged', i, untagged[i], switchuntagged[i])) return res @@ -344,6 +359,8 @@ def getuntagged(data, lookupfun): return r +_lenovo_ce0128t = (1, 3, 6, 1, 4, 1, 19046, 1, 7, 43) # LENOVO-REF-MIB::ce0128t + class SNMPSwitch(object): '''A class for manipulating switches via standard SNMP MIBs.''' @@ -385,6 +402,8 @@ class SNMPSwitch(object): self._eng = SnmpEngine() + self._getmanybroken = False + if community is not None: self._auth = CommunityData(community, mpModel=0) else: @@ -399,10 +418,29 @@ class SNMPSwitch(object): self._targ = UdpTransportTarget((host, 161), timeout=10) + r = self._get(('SNMPv2-MIB', 'sysObjectID', 0)) + if tuple(r) == _lenovo_ce0128t: + self._getmanybroken = True + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + def __repr__(self): # pragma: no cover return '' % (repr(self._auth), repr(self._targ)) def _getmany(self, *oids): + if not oids: + return [] + + if self._getmanybroken: + return [ self._getmany_real(x)[0] for x in oids ] + else: + return self._getmany_real(*oids) + + def _getmany_real(self, *oids): woids = [ ObjectIdentity(*oid) for oid in oids ] [ oid.resolveWithMib(_mvc) for oid in woids ] @@ -619,9 +657,19 @@ class SNMPSwitch(object): value is a bit string that preresents what ports that particular VLAN will be transmitted on.''' + vlans = list(vlans) + print(repr(vlans)) + for x, y in self._getmany(*(('Q-BRIDGE-MIB', + 'dot1qVlanStaticEgressPorts', x) for x in vlans)): + print(repr(x), repr(y)) + + r = { x[-1]: _octstrtobits(y) for x, y in + self._getmany(*(('Q-BRIDGE-MIB', + 'dot1qVlanCurrentEgressPorts', 0, x) for x in vlans)) } + r = { x[-1]: _octstrtobits(y) for x, y in self._getmany(*(('Q-BRIDGE-MIB', - 'dot1qVlanStaticEgressPorts', x) for x in vlans)) } + 'dot1qVlanCurrentEgressPorts', 0, x) for x in vlans)) } return r @@ -655,6 +703,76 @@ class SNMPSwitch(object): self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticUntaggedPorts', int(vlan)), value) +class LenovoCampusNOS(SNMPSwitch): + ''' + Interface w/ the Lenovo Campus NOS switches. Specifically + the: + Lenovo CE0128TB Switch, Campus NOS 8.4.3.14 + + It requires sshuser and sshpass to be provided to change settings. + ''' + + def __init__(self, host, community=None, sshuser=None, sshpass=None, + username=None, authKey=None, authProtocol=usmHMACSHAAuthProtocol, + privKey=None, privProtocol=None): + + if sshuser is None: + raise ValueError('sshuser is required') + + if sshpass is None: + raise ValueError('sshpass is required') + + super().__init__(host, community, username, authKey, + authProtocol, privKey, privProtocol) + + self._conn = None + self._proc = None + + self._host = host + self._sshuser = sshuser + self._sshpass = sshpass + + self._loop = asyncio.get_event_loop() + + async def _setupproc(self): + self._conn = await asyncssh.connect(self._host, + username=self._sshuser, password=self._sshpass) + self._proc = await self._conn.create_process() + + def __enter__(self): + self._loop.run_until_complete(self._setupproc()) + + def __exit__(self, *args): + self._conn.close() + + self._loop.run_until_complete(self._conn.wait_closed()) + + def _set(self): + raise RuntimeError('should not be called') + + # Note: setpvid works via SNMP + #def setpvid(self, port, vlan): + + # Note: createvlan works via SNMP + #def createvlan(self, vlan, name): + + def setegress(self, vlan, ports): + ''' + interface 1/0/x + vlan participation include/exclude x + exit + ''' + raise NotImplementedError('todo') + + def setuntagged(self, vlan, ports): + ''' + interface 1/0/x + vlan participation include x + [no] vlan tagging x + exit + ''' + raise NotImplementedError('todo') + def main(): import pprint import sys @@ -675,19 +793,21 @@ def main(): print('applying...') failed = [] prevname = None - for switch, name, verb, arg1, arg2, oldarg in changes: - if prevname != name: - print('Configuring switch %s...' % repr(name)) - prevname = name - - print('%s: %s %s' % (verb, arg1, repr(arg2))) - try: - fun = getattr(switch, verb) - fun(arg1, arg2) - pass - except Exception as e: - print('failed') - failed.append((verb, arg1, arg2, e)) + for switch, group in itertools.groupby(changes, lambda x: x[0]): + with switch: + for _, name, verb, arg1, arg2, oldarg in group: + if prevname != name: + print('Configuring switch %s...' % repr(name)) + prevname = name + + print('%s: %s %s' % (verb, arg1, repr(arg2))) + try: + fun = getattr(switch, verb) + fun(arg1, arg2) + pass + except Exception as e: + print('failed') + failed.append((verb, arg1, arg2, e)) if failed: print('%d failed to apply, they are:' % len(failed)) @@ -907,6 +1027,7 @@ class _TestMisc(unittest.TestCase): gvlans.return_value = iter([ 1, 5 ]) res = checkchanges('data') + validres = [ ('createvlan', 283, '', '') ] # make sure it needs to get created @@ -982,6 +1103,7 @@ class _TestMisc(unittest.TestCase): '1' * 10), ('setegress', 5, '1' * 8 + '0' * 11 + '11' + '0' * 8 + '1', '1' * 10), + ('setuntagged', 5, '1' * 8, '1' * 8 + '0' * 10), ] self.assertEqual(set(res), set(validres)) @@ -1062,6 +1184,54 @@ class _TestSNMPSwitch(unittest.TestCase): self.assertEqual(res, { x: _octstrtobits(lookup[x]) for x in range(1, 10) }) +class TestLenovoCampus(unittest.IsolatedAsyncioTestCase): + def test_reqs(self): + with self.assertRaises(ValueError): + LenovoCampusNOS('somehost') + + with self.assertRaises(ValueError): + LenovoCampusNOS('somehost', sshuser='admin') + + with self.assertRaises(ValueError): + LenovoCampusNOS('somehost', sshpass='foo') + + @mock.patch('asyncssh.connect') + @mock.patch('pysnmp.hlapi.ContextData') + def test_basic(self, cd, asconn): + sshuser = 'bogus' + sshpass = 'foo' + a = LenovoCampusNOS(None, sshuser=sshuser, sshpass=sshpass) + + self.assertRaises(RuntimeError, a.setegress, 1, 1) + + connobj = mock.AsyncMock() + + connfuncalled = [ False ] + async def connfun(host, username, password, connfuncalled=connfuncalled): + connfuncalled[0] = True + + # called with the correct arguments + self.assertEqual(host, None) + + self.assertEqual(username, sshuser) + self.assertEqual(password, sshpass) + + return connobj + + asconn.side_effect = connfun + + with a: + # that connect was called + self.assertTrue(connfuncalled[0]) + + # that create_process was called + connobj.create_process.assert_called() + + a.setegress(1, 1) + +class TestMikroTikSwitch(unittest.TestCase): + pass + _skipSwitchTests = True class _TestSwitch(unittest.TestCase):