VLAN Manager tool
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

554 lines
14 KiB

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from pysnmp.hlapi import *
  4. from pysnmp.smi.builder import MibBuilder
  5. from pysnmp.smi.view import MibViewController
  6. import importlib
  7. import itertools
  8. import mock
  9. import random
  10. import unittest
  11. _mbuilder = MibBuilder()
  12. _mvc = MibViewController(_mbuilder)
  13. #import data
  14. # received packages
  15. # pvid: dot1qPvid
  16. #
  17. # tx packets:
  18. # dot1qVlanStaticEgressPorts
  19. # dot1qVlanStaticUntaggedPorts
  20. #
  21. # vlans:
  22. # dot1qVlanCurrentTable
  23. # lists ALL vlans, including baked in ones
  24. #
  25. # note that even though an snmpwalk of dot1qVlanStaticEgressPorts
  26. # skips over other vlans (only shows statics), the other vlans (1,2,3)
  27. # are still accessible via that oid
  28. #
  29. # LLDP:
  30. # 1.0.8802.1.1.2.1.4.1.1 aka LLDP-MIB, lldpRemTable
  31. class SwitchConfig(object):
  32. def __init__(self, host, community, vlanconf, ignports):
  33. self._host = host
  34. self._community = community
  35. self._vlanconf = vlanconf
  36. self._ignports = ignports
  37. @property
  38. def host(self):
  39. return self._host
  40. @property
  41. def community(self):
  42. return self._community
  43. @property
  44. def vlanconf(self):
  45. return self._vlanconf
  46. @property
  47. def ignports(self):
  48. return self._ignports
  49. def getportlist(self, lookupfun):
  50. '''Return a set of all the ports indexes in data.'''
  51. res = set()
  52. for id in self._vlanconf:
  53. res.update(self._vlanconf[id].get('u', []))
  54. res.update(self._vlanconf[id].get('t', []))
  55. # add in the ignore ports
  56. res.update(self._ignports)
  57. # filter out the strings
  58. strports = set(x for x in res if isinstance(x, str))
  59. res.update(lookupfun(x) for x in strports)
  60. res.difference_update(strports)
  61. return res
  62. def _octstrtobits(os):
  63. num = 1
  64. for i in str(os):
  65. num = (num << 8) | ord(i)
  66. return bin(num)[3:]
  67. def _intstobits(*ints):
  68. v = 0
  69. for i in ints:
  70. v |= 1 << i
  71. r = list(bin(v)[2:-1])
  72. r.reverse()
  73. return ''.join(r)
  74. def _cmpbits(a, b):
  75. last1a = a.rindex('1')
  76. last1b = b.rindex('1')
  77. if last1a != -1:
  78. a = a[:last1a]
  79. if last1b != -1:
  80. b = b[:last1b]
  81. return a == b
  82. import vlanmang
  83. def checkchanges(module):
  84. mod = importlib.import_module(module)
  85. mods = [ i for i in mod.__dict__.itervalues() if isinstance(i, vlanmang.SwitchConfig) ]
  86. res = []
  87. for i in mods:
  88. vlans = i.vlanconf.keys()
  89. switch = SNMPSwitch(i.host, i.community)
  90. portmapping = switch.getportmapping()
  91. invportmap = { y: x for x, y in portmapping.iteritems() }
  92. lufun = invportmap.__getitem__
  93. # get complete set of ports
  94. portlist = i.getportlist(lufun)
  95. ports = set(portmapping.iterkeys())
  96. # make sure switch agrees w/ them all
  97. if ports != portlist:
  98. raise ValueError('missing or extra ports found: %s' %
  99. `ports.symmetric_difference(portlist)`)
  100. # compare pvid
  101. pvidmap = getpvidmapping(i.vlanconf, lufun)
  102. switchpvid = switch.getpvid()
  103. res.extend(('setpvid', idx, vlan, switchpvid[idx]) for idx, vlan in
  104. pvidmap.iteritems() if switchpvid[idx] != vlan)
  105. # compare egress & untagged
  106. switchegress = switch.getegress(*vlans)
  107. egress = getegress(i.vlanconf, lufun)
  108. switchuntagged = switch.getuntagged(*vlans)
  109. untagged = getuntagged(i.vlanconf, lufun)
  110. for i in vlans:
  111. if not _cmpbits(switchegress[i], egress[i]):
  112. res.append(('setegress', i, egress[i], switchegress[i]))
  113. if not _cmpbits(switchuntagged[i], untagged[i]):
  114. res.append(('setuntagged', i, untagged[i], switchuntagged[i]))
  115. return res
  116. def getidxs(lst, lookupfun):
  117. return [ lookupfun(i) if isinstance(i, str) else i for i in lst ]
  118. def getpvidmapping(data, lookupfun):
  119. '''Return a mapping from vlan based table to a port: vlan
  120. dictionary.'''
  121. res = []
  122. for id in data:
  123. for i in data[id].get('u', []):
  124. if isinstance(i, str):
  125. i = lookupfun(i)
  126. res.append((i, id))
  127. return dict(res)
  128. def getegress(data, lookupfun):
  129. r = {}
  130. for id in data:
  131. r[id] = _intstobits(*(getidxs(data[id].get('u', []),
  132. lookupfun) + getidxs(data[id].get('t', []), lookupfun)))
  133. return r
  134. def getuntagged(data, lookupfun):
  135. r = {}
  136. for id in data:
  137. r[id] = _intstobits(*getidxs(data[id].get('u', []), lookupfun))
  138. return r
  139. class SNMPSwitch(object):
  140. '''A class for manipulating switches via standard SNMP MIBs.'''
  141. def __init__(self, host, community):
  142. self._eng = SnmpEngine()
  143. self._cd = CommunityData(community, mpModel=0)
  144. self._targ = UdpTransportTarget((host, 161))
  145. def _getmany(self, *oids):
  146. oids = [ ObjectIdentity(*oid) for oid in oids ]
  147. [ oid.resolveWithMib(_mvc) for oid in oids ]
  148. errorInd, errorStatus, errorIndex, varBinds = \
  149. next(getCmd(self._eng, self._cd, self._targ, ContextData(), *(ObjectType(oid) for oid in oids)))
  150. if errorInd: # pragma: no cover
  151. raise ValueError(errorIndication)
  152. elif errorStatus: # pragma: no cover
  153. raise ValueError('%s at %s' %
  154. (errorStatus.prettyPrint(), errorIndex and
  155. varBinds[int(errorIndex)-1][0] or '?'))
  156. else:
  157. if len(varBinds) != len(oids): # pragma: no cover
  158. raise ValueError('too many return values')
  159. return varBinds
  160. def _get(self, oid):
  161. varBinds = self._getmany(oid)
  162. varBind = varBinds[0]
  163. return varBind[1]
  164. def _set(self, oid, value):
  165. oid = ObjectIdentity(*oid)
  166. oid.resolveWithMib(_mvc)
  167. if isinstance(value, (int, long)):
  168. value = Integer(value)
  169. elif isinstance(value, str):
  170. value = OctetString(value)
  171. errorInd, errorStatus, errorIndex, varBinds = \
  172. next(setCmd(self._eng, self._cd, self._targ, ContextData(), ObjectType(oid, value)))
  173. if errorInd: # pragma: no cover
  174. raise ValueError(errorIndication)
  175. elif errorStatus: # pragma: no cover
  176. raise ValueError('%s at %s' %
  177. (errorStatus.prettyPrint(), errorIndex and
  178. varBinds[int(errorIndex)-1][0] or '?'))
  179. else:
  180. for varBind in varBinds:
  181. if varBind[1] != value: # pragma: no cover
  182. raise RuntimeError('failed to set: %s' % ' = '.join([x.prettyPrint() for x in varBind]))
  183. def _walk(self, *oid):
  184. oid = ObjectIdentity(*oid)
  185. # XXX - keep these, this might stop working, no clue what managed to magically make things work
  186. # ref: http://snmplabs.com/pysnmp/examples/smi/manager/browsing-mib-tree.html#mib-objects-to-pdu-var-binds
  187. # mibdump.py --mib-source '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs' --mib-source /usr/share/snmp/mibs --rebuild rfc1212 pbridge vlan
  188. #oid.addAsn1MibSource('/usr/share/snmp/mibs', '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs')
  189. oid.resolveWithMib(_mvc)
  190. for (errorInd, errorStatus, errorIndex, varBinds) in nextCmd(
  191. self._eng, self._cd, self._targ, ContextData(),
  192. ObjectType(oid),
  193. lexicographicMode=False):
  194. if errorInd: # pragma: no cover
  195. raise ValueError(errorInd)
  196. elif errorStatus: # pragma: no cover
  197. raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?'))
  198. else:
  199. for varBind in varBinds:
  200. yield varBind
  201. def getportmapping(self):
  202. '''Return a port name mapping. Keys are the port index
  203. and the value is the name from the ifName entry.'''
  204. return { x[0][-1]: str(x[1]) for x in self._walk('IF-MIB', 'ifName') }
  205. def findport(self, name):
  206. '''Look up a port name and return it's port index. This
  207. looks up via the ifName table in IF-MIB.'''
  208. return [ x[0][-1] for x in self._walk('IF-MIB', 'ifName') if str(x[1]) == name ][0]
  209. def getvlanname(self, vlan):
  210. '''Return the name for the vlan.'''
  211. v = self._get(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', vlan))
  212. return str(v).decode('utf-8')
  213. def createvlan(self, vlan, name):
  214. # createAndGo(4)
  215. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
  216. int(vlan)), 4)
  217. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', int(vlan)),
  218. name)
  219. def deletevlan(self, vlan):
  220. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
  221. int(vlan)), 6) # destroy(6)
  222. def getvlans(self):
  223. '''Return an iterator with all the vlan ids.'''
  224. return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStatus'))
  225. def staticvlans(self):
  226. '''Return an iterator of the staticly defined/configured
  227. vlans. This sometimes excludes special built in vlans,
  228. like vlan 1.'''
  229. return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStaticName'))
  230. def getpvid(self):
  231. '''Returns a dictionary w/ the interface index as the key,
  232. and the pvid of the interface.'''
  233. return { x[0][-1]: int(x[1]) for x in self._walk('Q-BRIDGE-MIB', 'dot1qPvid') }
  234. def getegress(self, *vlans):
  235. r = { x[-1]: _octstrtobits(y) for x, y in
  236. self._getmany(*(('Q-BRIDGE-MIB',
  237. 'dot1qVlanStaticEgressPorts', x) for x in vlans)) }
  238. return r
  239. def getuntagged(self, *vlans):
  240. r = { x[-1]: _octstrtobits(y) for x, y in
  241. self._getmany(*(('Q-BRIDGE-MIB',
  242. 'dot1qVlanStaticUntaggedPorts', x) for x in vlans)) }
  243. return r
  244. if __name__ == '__main__': # pragma: no cover
  245. print `checkchanges('data')`
  246. class _TestMisc(unittest.TestCase):
  247. def setUp(self):
  248. import test_data
  249. self._test_data = test_data
  250. def test_intstobits(self):
  251. self.assertEqual(_intstobits(1, 5, 10), '1000100001')
  252. self.assertEqual(_intstobits(3, 4, 9), '001100001')
  253. def test_octstrtobits(self):
  254. self.assertEqual(_octstrtobits('\x00'), '0' * 8)
  255. self.assertEqual(_octstrtobits('\xff'), '1' * 8)
  256. self.assertEqual(_octstrtobits('\xf0'), '1' * 4 + '0' * 4)
  257. self.assertEqual(_octstrtobits('\x0f'), '0' * 4 + '1' * 4)
  258. def test_cmpbits(self):
  259. self.assertTrue(_cmpbits('111000', '111'))
  260. self.assertTrue(_cmpbits('000111000', '000111'))
  261. self.assertTrue(_cmpbits('11', '11'))
  262. self.assertFalse(_cmpbits('0011', '11'))
  263. self.assertFalse(_cmpbits('11', '0011'))
  264. def test_pvidegressuntagged(self):
  265. data = {
  266. 1: {
  267. 'u': [ 1, 5, 10 ] + range(13, 20),
  268. 't': [ 'lag2', 6, 7 ],
  269. },
  270. 10: {
  271. 'u': [ 2, 3, 6, 7, 8, 'lag2' ],
  272. },
  273. 13: {
  274. 'u': [ 4, 9 ],
  275. 't': [ 'lag2', 6, 7 ],
  276. },
  277. 14: {
  278. 't': [ 'lag2' ],
  279. },
  280. }
  281. swconf = SwitchConfig('', '', data, [ 'lag3' ])
  282. lookup = {
  283. 'lag2': 30,
  284. 'lag3': 31,
  285. }
  286. lufun = lookup.__getitem__
  287. check = dict(itertools.chain(enumerate([ 1, 10, 10, 13, 1, 10,
  288. 10, 10, 13, 1 ], 1), enumerate([ 1 ] * 7, 13),
  289. [ (30, 10) ]))
  290. # That a pvid mapping
  291. res = getpvidmapping(data, lufun)
  292. # is correct
  293. self.assertEqual(res, check)
  294. self.assertEqual(swconf.getportlist(lufun),
  295. set(xrange(1, 11)) | set(xrange(13, 20)) | set(lookup.values()))
  296. checkegress = {
  297. 1: '1000111001001111111' + '0' * (30 - 20) + '1',
  298. 10: '01100111' + '0' * (30 - 9) + '1',
  299. 13: '000101101' + '0' * (30 - 10) + '1',
  300. 14: '0' * (30 - 1) + '1',
  301. }
  302. self.assertEqual(getegress(data, lufun), checkegress)
  303. checkuntagged = {
  304. 1: '1000100001001111111',
  305. 10: '01100111' + '0' * (30 - 9) + '1',
  306. 13: '000100001',
  307. 14: '',
  308. }
  309. self.assertEqual(getuntagged(data, lufun), checkuntagged)
  310. #@unittest.skip('foo')
  311. @mock.patch('vlanmang.SNMPSwitch.getuntagged')
  312. @mock.patch('vlanmang.SNMPSwitch.getegress')
  313. @mock.patch('vlanmang.SNMPSwitch.getpvid')
  314. @mock.patch('vlanmang.SNMPSwitch.getportmapping')
  315. @mock.patch('importlib.import_module')
  316. def test_checkchanges(self, imprt, portmapping, gpvid, gegress, guntagged):
  317. # that import returns the test data
  318. imprt.side_effect = itertools.repeat(self._test_data)
  319. # that getportmapping returns the following dict
  320. ports = { x: 'g%d' % x for x in xrange(1, 24) }
  321. ports[30] = 'lag1'
  322. ports[31] = 'lag2'
  323. ports[32] = 'lag3'
  324. portmapping.side_effect = itertools.repeat(ports)
  325. # that the switch's pvid returns
  326. spvid = { x: 283 for x in xrange(1, 24) }
  327. spvid[30] = 5
  328. gpvid.side_effect = itertools.repeat(spvid)
  329. # the the extra port is caught
  330. self.assertRaises(ValueError, checkchanges, 'data')
  331. # that the functions were called
  332. imprt.assert_called_with('data')
  333. portmapping.assert_called()
  334. # XXX - check that an ignore statement is honored
  335. # delete the extra port
  336. del ports[32]
  337. # that the egress data provided
  338. gegress.side_effect = [ {
  339. 1: '1' * 10,
  340. 5: '1' * 10,
  341. 283: '00000000111111111110011000000100000',
  342. } ]
  343. # that the untagged data provided
  344. guntagged.side_effect = [ {
  345. 1: '1' * 10,
  346. 5: '1' * 8 + '0' * 10,
  347. 283: '00000000111111111110011',
  348. } ]
  349. res = checkchanges('data')
  350. validres = [ ('setpvid', x, 5, 283) for x in xrange(1, 9) ] + \
  351. [ ('setpvid', 20, 1, 283),
  352. ('setpvid', 21, 1, 283),
  353. ('setpvid', 30, 1, 5),
  354. ('setegress', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10),
  355. ('setuntagged', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10),
  356. ('setegress', 5, '1' * 8 + '0' * 11 + '11' + '0' * 8 + '1', '1' * 10),
  357. ]
  358. self.assertEqual(set(res), set(validres))
  359. _skipSwitchTests = True
  360. class _TestSwitch(unittest.TestCase):
  361. def setUp(self):
  362. # If we don't have it, pretend it's true for now and
  363. # we'll recheck it later
  364. model = 'GS108T smartSwitch'
  365. if getattr(self, 'switchmodel', model) != model or \
  366. _skipSwitchTests: # pragma: no cover
  367. self.skipTest('Need a GS108T switch to run these tests')
  368. args = open('test.creds').read().split()
  369. self.switch = SNMPSwitch(*args)
  370. self.switchmodel = self.switch._get(('ENTITY-MIB',
  371. 'entPhysicalModelName', 1))
  372. if self.switchmodel != model: # pragma: no cover
  373. self.skipTest('Need a GS108T switch to run these tests')
  374. def test_misc(self):
  375. switch = self.switch
  376. self.assertEqual(switch.findport('g1'), 1)
  377. self.assertEqual(switch.findport('l1'), 14)
  378. def test_portnames(self):
  379. switch = self.switch
  380. resp = dict((x, 'g%d' % x) for x in xrange(1, 9))
  381. resp.update({ 13: 'cpu' })
  382. resp.update((x, 'l%d' % (x - 13)) for x in xrange(14, 18))
  383. self.assertEqual(switch.getportmapping(), resp)
  384. def test_egress(self):
  385. switch = self.switch
  386. egress = switch.getegress(1, 2, 3)
  387. checkegress = {
  388. 1: '1' * 8 + '0' * 5 + '1' * 4 + '0' * 23,
  389. 2: '0' * 8 * 5,
  390. 3: '0' * 8 * 5,
  391. }
  392. self.assertEqual(egress, checkegress)
  393. def test_untagged(self):
  394. switch = self.switch
  395. untagged = switch.getuntagged(1, 2, 3)
  396. checkuntagged = {
  397. 1: '1' * 8 * 5,
  398. 2: '1' * 8 * 5,
  399. 3: '1' * 8 * 5,
  400. }
  401. self.assertEqual(untagged, checkuntagged)
  402. def test_vlan(self):
  403. switch = self.switch
  404. existingvlans = set(switch.getvlans())
  405. while True:
  406. testvlan = random.randint(1,4095)
  407. if testvlan not in existingvlans:
  408. break
  409. # Test that getting a non-existant vlans raises an exception
  410. self.assertRaises(ValueError, switch.getvlanname, testvlan)
  411. self.assertTrue(set(switch.staticvlans()).issubset(existingvlans))
  412. pvidres = { x: 1 for x in xrange(1, 9) }
  413. pvidres.update({ x: 1 for x in xrange(14, 18) })
  414. self.assertEqual(switch.getpvid(), pvidres)
  415. testname = 'Sometestname'
  416. # Create test vlan
  417. switch.createvlan(testvlan, testname)
  418. try:
  419. # make sure the test vlan was created
  420. self.assertIn(testvlan, set(switch.staticvlans()))
  421. self.assertEqual(testname, switch.getvlanname(testvlan))
  422. finally:
  423. switch.deletevlan(testvlan)