VLAN Manager tool
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

525 lines
14 KiB

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from pysnmp.hlapi import *
  4. from pysnmp.smi.builder import MibBuilder
  5. from pysnmp.smi.view import MibViewController
  6. import importlib
  7. import itertools
  8. import mock
  9. import random
  10. import unittest
  11. _mbuilder = MibBuilder()
  12. _mvc = MibViewController(_mbuilder)
  13. #import data
  14. # received packages
  15. # pvid: dot1qPvid
  16. #
  17. # tx packets:
  18. # dot1qVlanStaticEgressPorts
  19. # dot1qVlanStaticUntaggedPorts
  20. #
  21. # vlans:
  22. # dot1qVlanCurrentTable
  23. # lists ALL vlans, including baked in ones
  24. #
  25. # note that even though an snmpwalk of dot1qVlanStaticEgressPorts
  26. # skips over other vlans (only shows statics), the other vlans (1,2,3)
  27. # are still accessible via that oid
  28. #
  29. # LLDP:
  30. # 1.0.8802.1.1.2.1.4.1.1 aka LLDP-MIB, lldpRemTable
  31. class SwitchConfig(object):
  32. def __init__(self, host, community, vlanconf):
  33. self._host = host
  34. self._community = community
  35. self._vlanconf = vlanconf
  36. @property
  37. def host(self):
  38. return self._host
  39. @property
  40. def community(self):
  41. return self._community
  42. @property
  43. def vlanconf(self):
  44. return self._vlanconf
  45. def _octstrtobits(os):
  46. num = 1
  47. for i in str(os):
  48. num = (num << 8) | ord(i)
  49. return bin(num)[3:]
  50. def _intstobits(*ints):
  51. v = 0
  52. for i in ints:
  53. v |= 1 << i
  54. r = list(bin(v)[2:-1])
  55. r.reverse()
  56. return ''.join(r)
  57. import vlanmang
  58. def checkchanges(module):
  59. mod = importlib.import_module(module)
  60. mods = [ i for i in mod.__dict__.itervalues() if isinstance(i, vlanmang.SwitchConfig) ]
  61. res = []
  62. for i in mods:
  63. vlans = i.vlanconf.keys()
  64. switch = SNMPSwitch(i.host, i.community)
  65. portmapping = switch.getportmapping()
  66. invportmap = { y: x for x, y in portmapping.iteritems() }
  67. lufun = invportmap.__getitem__
  68. # get complete set of ports
  69. portlist = getportlist(i.vlanconf, lufun)
  70. ports = set(portmapping.iterkeys())
  71. # make sure switch agrees w/ them all
  72. if ports != portlist:
  73. raise ValueError('missing or extra ports found: %s' %
  74. `ports.symmetric_difference(portlist)`)
  75. # compare pvid
  76. pvidmap = getpvidmapping(i.vlanconf, lufun)
  77. switchpvid = switch.getpvid()
  78. res.extend(('setpvid', idx, vlan, switchpvid[idx]) for idx, vlan in
  79. pvidmap.iteritems() if switchpvid[idx] != vlan)
  80. # compare egress & untagged
  81. switchegress = switch.getegress(*vlans)
  82. egress = getegress(i.vlanconf, lufun)
  83. switchuntagged = switch.getuntagged(*vlans)
  84. untagged = getuntagged(i.vlanconf, lufun)
  85. for i in vlans:
  86. if switchegress[i] != egress[i]:
  87. res.append(('setegress', i, egress[i], switchegress[i]))
  88. if switchuntagged[i] != untagged[i]:
  89. res.append(('setuntagged', i, untagged[i], switchuntagged[i]))
  90. return res
  91. def getidxs(lst, lookupfun):
  92. return [ lookupfun(i) if isinstance(i, str) else i for i in lst ]
  93. def getpvidmapping(data, lookupfun):
  94. '''Return a mapping from vlan based table to a port: vlan
  95. dictionary.'''
  96. res = []
  97. for id in data:
  98. for i in data[id].get('u', []):
  99. if isinstance(i, str):
  100. i = lookupfun(i)
  101. res.append((i, id))
  102. return dict(res)
  103. def getegress(data, lookupfun):
  104. r = {}
  105. for id in data:
  106. r[id] = _intstobits(*(getidxs(data[id].get('u', []),
  107. lookupfun) + getidxs(data[id].get('t', []), lookupfun)))
  108. return r
  109. def getuntagged(data, lookupfun):
  110. r = {}
  111. for id in data:
  112. r[id] = _intstobits(*getidxs(data[id].get('u', []), lookupfun))
  113. return r
  114. def getportlist(data, lookupfun):
  115. '''Return a set of all the ports indexes in data.'''
  116. res = set()
  117. for id in data:
  118. res.update(data[id].get('u', []))
  119. res.update(data[id].get('t', []))
  120. # filter out the strings
  121. strports = set(x for x in res if isinstance(x, str))
  122. res.update(lookupfun(x) for x in strports)
  123. res.difference_update(strports)
  124. return res
  125. class SNMPSwitch(object):
  126. '''A class for manipulating switches via standard SNMP MIBs.'''
  127. def __init__(self, host, community):
  128. self._eng = SnmpEngine()
  129. self._cd = CommunityData(community, mpModel=0)
  130. self._targ = UdpTransportTarget((host, 161))
  131. def _getmany(self, *oids):
  132. oids = [ ObjectIdentity(*oid) for oid in oids ]
  133. [ oid.resolveWithMib(_mvc) for oid in oids ]
  134. errorInd, errorStatus, errorIndex, varBinds = \
  135. next(getCmd(self._eng, self._cd, self._targ, ContextData(), *(ObjectType(oid) for oid in oids)))
  136. if errorInd: # pragma: no cover
  137. raise ValueError(errorIndication)
  138. elif errorStatus: # pragma: no cover
  139. raise ValueError('%s at %s' %
  140. (errorStatus.prettyPrint(), errorIndex and
  141. varBinds[int(errorIndex)-1][0] or '?'))
  142. else:
  143. if len(varBinds) != len(oids): # pragma: no cover
  144. raise ValueError('too many return values')
  145. return varBinds
  146. def _get(self, oid):
  147. varBinds = self._getmany(oid)
  148. varBind = varBinds[0]
  149. return varBind[1]
  150. def _set(self, oid, value):
  151. oid = ObjectIdentity(*oid)
  152. oid.resolveWithMib(_mvc)
  153. if isinstance(value, (int, long)):
  154. value = Integer(value)
  155. elif isinstance(value, str):
  156. value = OctetString(value)
  157. errorInd, errorStatus, errorIndex, varBinds = \
  158. next(setCmd(self._eng, self._cd, self._targ, ContextData(), ObjectType(oid, value)))
  159. if errorInd: # pragma: no cover
  160. raise ValueError(errorIndication)
  161. elif errorStatus: # pragma: no cover
  162. raise ValueError('%s at %s' %
  163. (errorStatus.prettyPrint(), errorIndex and
  164. varBinds[int(errorIndex)-1][0] or '?'))
  165. else:
  166. for varBind in varBinds:
  167. if varBind[1] != value: # pragma: no cover
  168. raise RuntimeError('failed to set: %s' % ' = '.join([x.prettyPrint() for x in varBind]))
  169. def _walk(self, *oid):
  170. oid = ObjectIdentity(*oid)
  171. # XXX - keep these, this might stop working, no clue what managed to magically make things work
  172. # ref: http://snmplabs.com/pysnmp/examples/smi/manager/browsing-mib-tree.html#mib-objects-to-pdu-var-binds
  173. # mibdump.py --mib-source '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs' --mib-source /usr/share/snmp/mibs --rebuild rfc1212 pbridge vlan
  174. #oid.addAsn1MibSource('/usr/share/snmp/mibs', '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs')
  175. oid.resolveWithMib(_mvc)
  176. for (errorInd, errorStatus, errorIndex, varBinds) in nextCmd(
  177. self._eng, self._cd, self._targ, ContextData(),
  178. ObjectType(oid),
  179. lexicographicMode=False):
  180. if errorInd: # pragma: no cover
  181. raise ValueError(errorInd)
  182. elif errorStatus: # pragma: no cover
  183. raise ValueError('%s at %s' % (errorStatus.prettyPrint(), errorIndex and varBinds[int(errorIndex)-1][0] or '?'))
  184. else:
  185. for varBind in varBinds:
  186. yield varBind
  187. def getportmapping(self):
  188. '''Return a port name mapping. Keys are the port index
  189. and the value is the name from the ifName entry.'''
  190. return { x[0][-1]: str(x[1]) for x in self._walk('IF-MIB', 'ifName') }
  191. def findport(self, name):
  192. '''Look up a port name and return it's port index. This
  193. looks up via the ifName table in IF-MIB.'''
  194. return [ x[0][-1] for x in self._walk('IF-MIB', 'ifName') if str(x[1]) == name ][0]
  195. def getvlanname(self, vlan):
  196. '''Return the name for the vlan.'''
  197. v = self._get(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', vlan))
  198. return str(v).decode('utf-8')
  199. def createvlan(self, vlan, name):
  200. # createAndGo(4)
  201. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
  202. int(vlan)), 4)
  203. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', int(vlan)),
  204. name)
  205. def deletevlan(self, vlan):
  206. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
  207. int(vlan)), 6) # destroy(6)
  208. def getvlans(self):
  209. '''Return an iterator with all the vlan ids.'''
  210. return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStatus'))
  211. def staticvlans(self):
  212. '''Return an iterator of the staticly defined/configured
  213. vlans. This sometimes excludes special built in vlans,
  214. like vlan 1.'''
  215. return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB', 'dot1qVlanStaticName'))
  216. def getpvid(self):
  217. '''Returns a dictionary w/ the interface index as the key,
  218. and the pvid of the interface.'''
  219. return { x[0][-1]: int(x[1]) for x in self._walk('Q-BRIDGE-MIB', 'dot1qPvid') }
  220. def getegress(self, *vlans):
  221. r = { x[-1]: _octstrtobits(y) for x, y in
  222. self._getmany(*(('Q-BRIDGE-MIB',
  223. 'dot1qVlanStaticEgressPorts', x) for x in vlans)) }
  224. return r
  225. def getuntagged(self, *vlans):
  226. r = { x[-1]: _octstrtobits(y) for x, y in
  227. self._getmany(*(('Q-BRIDGE-MIB',
  228. 'dot1qVlanStaticUntaggedPorts', x) for x in vlans)) }
  229. return r
  230. if __name__ == '__main__': # pragma: no cover
  231. print `checkchanges('data')`
  232. class _TestMisc(unittest.TestCase):
  233. def setUp(self):
  234. import test_data
  235. self._test_data = test_data
  236. def test_intstobits(self):
  237. self.assertEqual(_intstobits(1, 5, 10), '1000100001')
  238. self.assertEqual(_intstobits(3, 4, 9), '001100001')
  239. def test_octstrtobits(self):
  240. self.assertEqual(_octstrtobits('\x00'), '0' * 8)
  241. self.assertEqual(_octstrtobits('\xff'), '1' * 8)
  242. self.assertEqual(_octstrtobits('\xf0'), '1' * 4 + '0' * 4)
  243. self.assertEqual(_octstrtobits('\x0f'), '0' * 4 + '1' * 4)
  244. def test_pvidegressuntagged(self):
  245. data = {
  246. 1: {
  247. 'u': [ 1, 5, 10 ] + range(13, 20),
  248. 't': [ 'lag2', 6, 7 ],
  249. },
  250. 10: {
  251. 'u': [ 2, 3, 6, 7, 8, 'lag2' ],
  252. },
  253. 13: {
  254. 'u': [ 4, 9 ],
  255. 't': [ 'lag2', 6, 7 ],
  256. },
  257. 14: {
  258. 't': [ 'lag2' ],
  259. },
  260. }
  261. lookup = {
  262. 'lag2': 30
  263. }
  264. lufun = lookup.__getitem__
  265. check = dict(itertools.chain(enumerate([ 1, 10, 10, 13, 1, 10,
  266. 10, 10, 13, 1 ], 1), enumerate([ 1 ] * 7, 13),
  267. [ (30, 10) ]))
  268. # That a pvid mapping
  269. res = getpvidmapping(data, lufun)
  270. # is correct
  271. self.assertEqual(res, check)
  272. self.assertEqual(getportlist(data, lufun),
  273. set(xrange(1, 11)) | set(xrange(13, 20)) | set([30]))
  274. checkegress = {
  275. 1: '1000111001001111111' + '0' * (30 - 20) + '1',
  276. 10: '01100111' + '0' * (30 - 9) + '1',
  277. 13: '000101101' + '0' * (30 - 10) + '1',
  278. 14: '0' * (30 - 1) + '1',
  279. }
  280. self.assertEqual(getegress(data, lufun), checkegress)
  281. checkuntagged = {
  282. 1: '1000100001001111111',
  283. 10: '01100111' + '0' * (30 - 9) + '1',
  284. 13: '000100001',
  285. 14: '',
  286. }
  287. self.assertEqual(getuntagged(data, lufun), checkuntagged)
  288. #@unittest.skip('foo')
  289. @mock.patch('vlanmang.SNMPSwitch.getuntagged')
  290. @mock.patch('vlanmang.SNMPSwitch.getegress')
  291. @mock.patch('vlanmang.SNMPSwitch.getpvid')
  292. @mock.patch('vlanmang.SNMPSwitch.getportmapping')
  293. @mock.patch('importlib.import_module')
  294. def test_checkchanges(self, imprt, portmapping, gpvid, gegress, guntagged):
  295. # that import returns the test data
  296. imprt.side_effect = itertools.repeat(self._test_data)
  297. # that getportmapping returns the following dict
  298. ports = { x: 'g%d' % x for x in xrange(1, 24) }
  299. ports[30] = 'lag1'
  300. ports[31] = 'lag2'
  301. portmapping.side_effect = itertools.repeat(ports)
  302. # that the switch's pvid returns
  303. spvid = { x: 283 for x in xrange(1, 24) }
  304. spvid[30] = 5
  305. gpvid.side_effect = itertools.repeat(spvid)
  306. # the the extra port is caught
  307. self.assertRaises(ValueError, checkchanges, 'data')
  308. # that the functions were called
  309. imprt.assert_called_with('data')
  310. portmapping.assert_called()
  311. # XXX - check that an ignore statement is honored
  312. # delete the extra port
  313. del ports[31]
  314. # that the egress data provided
  315. gegress.side_effect = [ {
  316. 1: '1' * 10,
  317. 5: '1' * 10,
  318. 283: '000000001111111111100110000001',
  319. } ]
  320. # that the untagged data provided
  321. guntagged.side_effect = [ {
  322. 1: '1' * 10,
  323. 5: '1' * 8,
  324. 283: '00000000111111111110011',
  325. } ]
  326. res = checkchanges('data')
  327. validres = [ ('setpvid', x, 5, 283) for x in xrange(1, 9) ] + \
  328. [ ('setpvid', 20, 1, 283),
  329. ('setpvid', 21, 1, 283),
  330. ('setpvid', 30, 1, 5),
  331. ('setegress', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10),
  332. ('setuntagged', 1, '0' * 19 + '11' + '0' * 8 + '1', '1' * 10),
  333. ('setegress', 5, '1' * 8 + '0' * 11 + '11' + '0' * 8 + '1', '1' * 10),
  334. ]
  335. self.assertEqual(set(res), set(validres))
  336. _skipSwitchTests = False
  337. class _TestSwitch(unittest.TestCase):
  338. def setUp(self):
  339. # If we don't have it, pretend it's true for now and
  340. # we'll recheck it later
  341. model = 'GS108T smartSwitch'
  342. if getattr(self, 'switchmodel', model) != model or \
  343. _skipSwitchTests: # pragma: no cover
  344. self.skipTest('Need a GS108T switch to run these tests')
  345. args = open('test.creds').read().split()
  346. self.switch = SNMPSwitch(*args)
  347. self.switchmodel = self.switch._get(('ENTITY-MIB',
  348. 'entPhysicalModelName', 1))
  349. if self.switchmodel != model: # pragma: no cover
  350. self.skipTest('Need a GS108T switch to run these tests')
  351. def test_misc(self):
  352. switch = self.switch
  353. self.assertEqual(switch.findport('g1'), 1)
  354. self.assertEqual(switch.findport('l1'), 14)
  355. def test_portnames(self):
  356. switch = self.switch
  357. resp = dict((x, 'g%d' % x) for x in xrange(1, 9))
  358. resp.update({ 13: 'cpu' })
  359. resp.update((x, 'l%d' % (x - 13)) for x in xrange(14, 18))
  360. self.assertEqual(switch.getportmapping(), resp)
  361. def test_egress(self):
  362. switch = self.switch
  363. egress = switch.getegress(1, 2, 3)
  364. checkegress = {
  365. 1: '1' * 8 + '0' * 5 + '1' * 4 + '0' * 23,
  366. 2: '0' * 8 * 5,
  367. 3: '0' * 8 * 5,
  368. }
  369. self.assertEqual(egress, checkegress)
  370. def test_untagged(self):
  371. switch = self.switch
  372. untagged = switch.getuntagged(1, 2, 3)
  373. checkuntagged = {
  374. 1: '1' * 8 * 5,
  375. 2: '1' * 8 * 5,
  376. 3: '1' * 8 * 5,
  377. }
  378. self.assertEqual(untagged, checkuntagged)
  379. def test_vlan(self):
  380. switch = self.switch
  381. existingvlans = set(switch.getvlans())
  382. while True:
  383. testvlan = random.randint(1,4095)
  384. if testvlan not in existingvlans:
  385. break
  386. # Test that getting a non-existant vlans raises an exception
  387. self.assertRaises(ValueError, switch.getvlanname, testvlan)
  388. self.assertTrue(set(switch.staticvlans()).issubset(existingvlans))
  389. pvidres = { x: 1 for x in xrange(1, 9) }
  390. pvidres.update({ x: 1 for x in xrange(14, 18) })
  391. self.assertEqual(switch.getpvid(), pvidres)
  392. testname = 'Sometestname'
  393. # Create test vlan
  394. switch.createvlan(testvlan, testname)
  395. try:
  396. # make sure the test vlan was created
  397. self.assertIn(testvlan, set(switch.staticvlans()))
  398. self.assertEqual(testname, switch.getvlanname(testvlan))
  399. finally:
  400. switch.deletevlan(testvlan)