VLAN Manager tool
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

800 lines
22 KiB

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from pysnmp.hlapi import *
  4. from pysnmp.smi.builder import MibBuilder
  5. from pysnmp.smi.view import MibViewController
  6. import importlib
  7. import itertools
  8. import mock
  9. import random
  10. import unittest
  11. _mbuilder = MibBuilder()
  12. _mvc = MibViewController(_mbuilder)
  13. #import data
  14. # received packages
  15. # pvid: dot1qPvid
  16. #
  17. # tx packets:
  18. # dot1qVlanStaticEgressPorts
  19. # dot1qVlanStaticUntaggedPorts
  20. #
  21. # vlans:
  22. # dot1qVlanCurrentTable
  23. # lists ALL vlans, including baked in ones
  24. #
  25. # note that even though an snmpwalk of dot1qVlanStaticEgressPorts
  26. # skips over other vlans (only shows statics), the other vlans (1,2,3)
  27. # are still accessible via that oid
  28. #
  29. # LLDP:
  30. # 1.0.8802.1.1.2.1.4.1.1 aka LLDP-MIB, lldpRemTable
  31. class SwitchConfig(object):
  32. '''This is a simple object to store switch configuration for
  33. the checkchanges() function.
  34. host -- The host of the switch you are maintaining configuration of.
  35. community -- Either the SNMPv1 community name or a
  36. pysnmp.hlapi.UsmUserData object, either of which can write to the
  37. necessary MIBs to configure the VLANs of the switch.
  38. vlanconf -- This is a dictionary w/ vlans as the key. Each value has
  39. a dictionary that contains keys, 'u' or 't', each of which
  40. contains the port that traffic should be sent untagged ('u') or
  41. tagged ('t'). Note that the Pvid (vlan of traffic that is
  42. received when untagged), is set to match the 'u' definition. The
  43. port is either an integer, which maps directly to the switch's
  44. index number, or it can be a string, which will be looked up via
  45. the IF-MIB::ifName table.
  46. ignports -- Ports that will be ignored and not required to be
  47. configured. List any ports that will not be active here, such as
  48. any unused lag ports.
  49. '''
  50. def __init__(self, host, community, vlanconf, ignports):
  51. self._host = host
  52. self._community = community
  53. self._vlanconf = vlanconf
  54. self._ignports = ignports
  55. @property
  56. def host(self):
  57. return self._host
  58. @property
  59. def community(self):
  60. return self._community
  61. @property
  62. def vlanconf(self):
  63. return self._vlanconf
  64. @property
  65. def ignports(self):
  66. return self._ignports
  67. def getportlist(self, lookupfun):
  68. '''Return a set of all the ports indexes in data. This
  69. includes, both vlanconf and ignports. Any ports using names
  70. will be resolved by being passed to the provided lookupfun.'''
  71. res = []
  72. for id in self._vlanconf:
  73. res.extend(self._vlanconf[id].get('u', []))
  74. res.extend(self._vlanconf[id].get('t', []))
  75. # add in the ignore ports
  76. res.extend(self.ignports)
  77. # eliminate dups so that lookupfun isn't called as often
  78. res = set(res)
  79. return set(getidxs(res, lookupfun))
  80. def _octstrtobits(os):
  81. '''Convert a string into a list of bits. Easier to figure out what
  82. ports are set.'''
  83. num = 1 # leading 1 to make sure leading zeros are not stripped
  84. for i in str(os):
  85. num = (num << 8) | ord(i)
  86. return bin(num)[3:]
  87. def _intstobits(*ints):
  88. '''Convert the int args to a string of bits in the expected format
  89. that SNMP expects for them. The results will be a string of '1's
  90. and '0's where the first one represents 1, and second one
  91. representing 2 and so on.'''
  92. v = 0
  93. for i in ints:
  94. v |= 1 << i
  95. r = list(bin(v)[2:-1])
  96. r.reverse()
  97. return ''.join(r)
  98. def _cmpbits(a, b):
  99. '''Compare two strings of bits to make sure they are equal.
  100. Trailing 0's are ignored.'''
  101. try:
  102. last1a = a.rindex('1')
  103. except ValueError:
  104. last1a = -1
  105. a = ''
  106. try:
  107. last1b = b.rindex('1')
  108. except ValueError:
  109. last1b = -1
  110. b = ''
  111. if last1a != -1:
  112. a = a[:last1a + 1]
  113. if last1b != -1:
  114. b = b[:last1b + 1]
  115. return a == b
  116. import vlanmang
  117. def checkchanges(module):
  118. '''Function to check for any differences between the switch, and the
  119. configured state.
  120. The parameter module is a string to the name of a python module. It
  121. will be imported, and any names that reference a vlanmang.SwitchConfig
  122. class will be validate that the configuration matches. If it does not,
  123. the returned list will contain a set of tuples, each one containing
  124. (verb, arg1, arg2, switcharg2). verb is what needs to be changed.
  125. arg1 is either the port (for setting Pvid), or the VLAN that needs to
  126. be configured. arg2 is what it needs to be set to. switcharg2 is
  127. what the switch is currently configured to, so that you can easily
  128. see what the effect of the configuration change is.
  129. '''
  130. mod = importlib.import_module(module)
  131. mods = [ i for i in mod.__dict__.itervalues() if isinstance(i, vlanmang.SwitchConfig) ]
  132. res = []
  133. for i in mods:
  134. vlans = i.vlanconf.keys()
  135. switch = SNMPSwitch(i.host, i.community)
  136. portmapping = switch.getportmapping()
  137. invportmap = { y: x for x, y in portmapping.iteritems() }
  138. lufun = invportmap.__getitem__
  139. # get complete set of ports
  140. portlist = i.getportlist(lufun)
  141. ports = set(portmapping.iterkeys())
  142. # make sure switch agrees w/ them all
  143. if ports != portlist:
  144. raise ValueError('missing or extra ports found: %s' %
  145. `ports.symmetric_difference(portlist)`)
  146. # compare pvid
  147. pvidmap = getpvidmapping(i.vlanconf, lufun)
  148. switchpvid = switch.getpvid()
  149. res.extend(('setpvid', idx, vlan, switchpvid[idx]) for idx, vlan in
  150. pvidmap.iteritems() if switchpvid[idx] != vlan)
  151. # compare egress & untagged
  152. switchegress = switch.getegress(*vlans)
  153. egress = getegress(i.vlanconf, lufun)
  154. switchuntagged = switch.getuntagged(*vlans)
  155. untagged = getuntagged(i.vlanconf, lufun)
  156. for i in vlans:
  157. if not _cmpbits(switchegress[i], egress[i]):
  158. res.append(('setegress', i, egress[i], switchegress[i]))
  159. if not _cmpbits(switchuntagged[i], untagged[i]):
  160. res.append(('setuntagged', i, untagged[i], switchuntagged[i]))
  161. return res, switch
  162. def getidxs(lst, lookupfun):
  163. '''Take a list of ports, and if any are a string, replace them w/
  164. the value returned by lookupfun(s).
  165. Note that duplicates are not detected or removed, both in the
  166. original list, and the values returned by the lookup function
  167. may duplicate other values in the list.'''
  168. return [ lookupfun(i) if isinstance(i, str) else i for i in lst ]
  169. def getpvidmapping(data, lookupfun):
  170. '''Return a mapping from vlan based table to a port: vlan
  171. dictionary. This only looks at that untagged part of the vlan
  172. configuration, and is used for finding what a port's Pvid should
  173. be.'''
  174. res = []
  175. for id in data:
  176. for i in data[id].get('u', []):
  177. if isinstance(i, str):
  178. i = lookupfun(i)
  179. res.append((i, id))
  180. return dict(res)
  181. def getegress(data, lookupfun):
  182. '''Return a dictionary, keyed by VLAN id with a bit string of ports
  183. that need to be enabled for egress. This include both tagged and
  184. untagged traffic.'''
  185. r = {}
  186. for id in data:
  187. r[id] = _intstobits(*(getidxs(data[id].get('u', []),
  188. lookupfun) + getidxs(data[id].get('t', []), lookupfun)))
  189. return r
  190. def getuntagged(data, lookupfun):
  191. '''Return a dictionary, keyed by VLAN id with a bit string of ports
  192. that need to be enabled for untagged egress.'''
  193. r = {}
  194. for id in data:
  195. r[id] = _intstobits(*getidxs(data[id].get('u', []), lookupfun))
  196. return r
  197. class SNMPSwitch(object):
  198. '''A class for manipulating switches via standard SNMP MIBs.'''
  199. def __init__(self, host, auth):
  200. self._eng = SnmpEngine()
  201. if isinstance(auth, str):
  202. self._cd = CommunityData(auth, mpModel=0)
  203. else:
  204. self._cd = auth
  205. self._targ = UdpTransportTarget((host, 161))
  206. def _getmany(self, *oids):
  207. woids = [ ObjectIdentity(*oid) for oid in oids ]
  208. [ oid.resolveWithMib(_mvc) for oid in woids ]
  209. errorInd, errorStatus, errorIndex, varBinds = \
  210. next(getCmd(self._eng, self._cd, self._targ,
  211. ContextData(), *(ObjectType(oid) for oid in woids)))
  212. if errorInd: # pragma: no cover
  213. raise ValueError(errorIndication)
  214. elif errorStatus:
  215. if str(errorStatus) == 'tooBig' and len(oids) > 1:
  216. # split the request in two
  217. pivot = len(oids) / 2
  218. a = self._getmany(*oids[:pivot])
  219. b = self._getmany(*oids[pivot:])
  220. return a + b
  221. raise ValueError('%s at %s' %
  222. (errorStatus.prettyPrint(), errorIndex and
  223. varBinds[int(errorIndex)-1][0] or '?'))
  224. else:
  225. if len(varBinds) != len(oids): # pragma: no cover
  226. raise ValueError('too many return values')
  227. return varBinds
  228. def _get(self, oid):
  229. varBinds = self._getmany(oid)
  230. varBind = varBinds[0]
  231. return varBind[1]
  232. def _set(self, oid, value):
  233. oid = ObjectIdentity(*oid)
  234. oid.resolveWithMib(_mvc)
  235. if isinstance(value, (int, long)):
  236. value = Integer(value)
  237. elif isinstance(value, str):
  238. value = OctetString(value)
  239. errorInd, errorStatus, errorIndex, varBinds = \
  240. next(setCmd(self._eng, self._cd, self._targ,
  241. ContextData(), ObjectType(oid, value)))
  242. if errorInd: # pragma: no cover
  243. raise ValueError(errorIndication)
  244. elif errorStatus: # pragma: no cover
  245. raise ValueError('%s at %s' %
  246. (errorStatus.prettyPrint(), errorIndex and
  247. varBinds[int(errorIndex)-1][0] or '?'))
  248. else:
  249. for varBind in varBinds:
  250. if varBind[1] != value: # pragma: no cover
  251. raise RuntimeError('failed to set: %s' % ' = '.join([x.prettyPrint() for x in varBind]))
  252. def _walk(self, *oid):
  253. oid = ObjectIdentity(*oid)
  254. # XXX - keep these, this might stop working, no clue what managed to magically make things work
  255. # ref: http://snmplabs.com/pysnmp/examples/smi/manager/browsing-mib-tree.html#mib-objects-to-pdu-var-binds
  256. # mibdump.py --mib-source '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs' --mib-source /usr/share/snmp/mibs --rebuild rfc1212 pbridge vlan
  257. #oid.addAsn1MibSource('/usr/share/snmp/mibs', '/Users/jmg/Nextcloud/Documents/user manuals/netgear/gs7xxt-v6.3.1.19-mibs')
  258. oid.resolveWithMib(_mvc)
  259. for (errorInd, errorStatus, errorIndex, varBinds) in nextCmd(
  260. self._eng, self._cd, self._targ, ContextData(),
  261. ObjectType(oid),
  262. lexicographicMode=False):
  263. if errorInd: # pragma: no cover
  264. raise ValueError(errorInd)
  265. elif errorStatus: # pragma: no cover
  266. raise ValueError('%s at %s' %
  267. (errorStatus.prettyPrint(), errorIndex and
  268. varBinds[int(errorIndex)-1][0] or '?'))
  269. else:
  270. for varBind in varBinds:
  271. yield varBind
  272. def getportmapping(self):
  273. '''Return a port name mapping. Keys are the port index
  274. and the value is the name from the IF-MIB::ifName entry.'''
  275. return { x[0][-1]: str(x[1]) for x in self._walk('IF-MIB',
  276. 'ifName') }
  277. def findport(self, name):
  278. '''Look up a port name and return it's port index. This
  279. looks up via the ifName table in IF-MIB.'''
  280. return [ x[0][-1] for x in self._walk('IF-MIB', 'ifName') if
  281. str(x[1]) == name ][0]
  282. def getvlanname(self, vlan):
  283. '''Return the name for the vlan. This returns the value in
  284. Q-BRIDGE-MIB:dot1qVlanStaticName.'''
  285. v = self._get(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', vlan))
  286. return str(v).decode('utf-8')
  287. def createvlan(self, vlan, name):
  288. # createAndGo(4)
  289. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
  290. int(vlan)), 4)
  291. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticName', int(vlan)),
  292. name)
  293. def deletevlan(self, vlan):
  294. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticRowStatus',
  295. int(vlan)), 6) # destroy(6)
  296. def getvlans(self):
  297. '''Return an iterator with all the vlan ids.'''
  298. return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB',
  299. 'dot1qVlanStatus'))
  300. def staticvlans(self):
  301. '''Return an iterator of the staticly defined/configured
  302. vlans. This sometimes excludes special built in vlans,
  303. like vlan 1.'''
  304. return (x[0][-1] for x in self._walk('Q-BRIDGE-MIB',
  305. 'dot1qVlanStaticName'))
  306. def getpvid(self):
  307. '''Returns a dictionary w/ the interface index as the key,
  308. and the pvid of the interface.'''
  309. return { x[0][-1]: int(x[1]) for x in self._walk('Q-BRIDGE-MIB',
  310. 'dot1qPvid') }
  311. def setpvid(self, port, vlan):
  312. '''Set the port's Pvid to vlan. This means that any packet
  313. received by the port that is untagged, will be routed the
  314. the vlan.'''
  315. self._set(('Q-BRIDGE-MIB', 'dot1qPvid', int(port)), Gauge32(vlan))
  316. def getegress(self, *vlans):
  317. '''Get a dictionary keyed by the specified VLANs, where each
  318. value is a bit string that preresents what ports that
  319. particular VLAN will be transmitted on.'''
  320. r = { x[-1]: _octstrtobits(y) for x, y in
  321. self._getmany(*(('Q-BRIDGE-MIB',
  322. 'dot1qVlanStaticEgressPorts', x) for x in vlans)) }
  323. return r
  324. def setegress(self, vlan, ports):
  325. '''Set the ports which the specified VLAN will have packets
  326. transmitted as either tagged, if unset in untagged, or
  327. untagged, if set in untagged, to bit bit string specified
  328. by ports.'''
  329. value = OctetString.fromBinaryString(ports)
  330. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticEgressPorts',
  331. int(vlan)), value)
  332. def getuntagged(self, *vlans):
  333. '''Get a dictionary keyed by the specified VLANs, where each
  334. value is a bit string that preresents what ports that
  335. particular VLAN will be transmitted on as an untagged
  336. packet.'''
  337. r = { x[-1]: _octstrtobits(y) for x, y in
  338. self._getmany(*(('Q-BRIDGE-MIB',
  339. 'dot1qVlanStaticUntaggedPorts', x) for x in vlans)) }
  340. return r
  341. def setuntagged(self, vlan, ports):
  342. '''Set the ports which the specified VLAN will have packets
  343. transmitted as untagged to the bit string specified by ports.'''
  344. value = OctetString.fromBinaryString(ports)
  345. self._set(('Q-BRIDGE-MIB', 'dot1qVlanStaticUntaggedPorts',
  346. int(vlan)), value)
  347. if __name__ == '__main__': # pragma: no cover
  348. import pprint
  349. import sys
  350. changes, switch = checkchanges('data')
  351. if not changes:
  352. print 'No changes to apply.'
  353. sys.exit(0)
  354. pprint.pprint(changes)
  355. res = raw_input('Apply the changes? (type yes to apply): ')
  356. if res != 'yes':
  357. print 'not applying changes.'
  358. sys.exit(1)
  359. print 'applying...'
  360. failed = []
  361. for verb, arg1, arg2, oldarg in changes:
  362. print '%s: %s %s' % (verb, arg1, `arg2`)
  363. try:
  364. fun = getattr(switch, verb)
  365. fun(arg1, arg2)
  366. pass
  367. except Exception as e:
  368. print 'failed'
  369. failed.append((verb, arg1, arg2, e))
  370. if failed:
  371. print '%d failed to apply, they are:' % len(failed)
  372. for verb, arg1, arg2, e in failed:
  373. print '%s: %s %s: %s' % (verb, arg1, arg2, `e`)
  374. class _TestMisc(unittest.TestCase):
  375. def setUp(self):
  376. import test_data
  377. self._test_data = test_data
  378. def test_intstobits(self):
  379. self.assertEqual(_intstobits(1, 5, 10), '1000100001')
  380. self.assertEqual(_intstobits(3, 4, 9), '001100001')
  381. def test_octstrtobits(self):
  382. self.assertEqual(_octstrtobits('\x00'), '0' * 8)
  383. self.assertEqual(_octstrtobits('\xff'), '1' * 8)
  384. self.assertEqual(_octstrtobits('\xf0'), '1' * 4 + '0' * 4)
  385. self.assertEqual(_octstrtobits('\x0f'), '0' * 4 + '1' * 4)
  386. def test_cmpbits(self):
  387. self.assertTrue(_cmpbits('111000', '111'))
  388. self.assertTrue(_cmpbits('000111000', '000111'))
  389. self.assertTrue(_cmpbits('11', '11'))
  390. self.assertTrue(_cmpbits('0', '000'))
  391. self.assertFalse(_cmpbits('0011', '11'))
  392. self.assertFalse(_cmpbits('11', '0011'))
  393. self.assertFalse(_cmpbits('10', '000'))
  394. self.assertFalse(_cmpbits('0', '1000'))
  395. self.assertFalse(_cmpbits('00010', '000'))
  396. self.assertFalse(_cmpbits('0', '001000'))
  397. def test_pvidegressuntagged(self):
  398. data = {
  399. 1: {
  400. 'u': [ 1, 5, 10 ] + range(13, 20),
  401. 't': [ 'lag2', 6, 7 ],
  402. },
  403. 10: {
  404. 'u': [ 2, 3, 6, 7, 8, 'lag2' ],
  405. },
  406. 13: {
  407. 'u': [ 4, 9 ],
  408. 't': [ 'lag2', 6, 7 ],
  409. },
  410. 14: {
  411. 't': [ 'lag2' ],
  412. },
  413. }
  414. swconf = SwitchConfig('', '', data, [ 'lag3' ])
  415. lookup = {
  416. 'lag2': 30,
  417. 'lag3': 31,
  418. }
  419. lufun = lookup.__getitem__
  420. check = dict(itertools.chain(enumerate([ 1, 10, 10, 13, 1, 10,
  421. 10, 10, 13, 1 ], 1), enumerate([ 1 ] * 7, 13),
  422. [ (30, 10) ]))
  423. # That a pvid mapping
  424. res = getpvidmapping(data, lufun)
  425. # is correct
  426. self.assertEqual(res, check)
  427. self.assertEqual(swconf.getportlist(lufun),
  428. set(xrange(1, 11)) | set(xrange(13, 20)) |
  429. set(lookup.values()))
  430. checkegress = {
  431. 1: '1000111001001111111' + '0' * (30 - 20) + '1',
  432. 10: '01100111' + '0' * (30 - 9) + '1',
  433. 13: '000101101' + '0' * (30 - 10) + '1',
  434. 14: '0' * (30 - 1) + '1',
  435. }
  436. self.assertEqual(getegress(data, lufun), checkegress)
  437. checkuntagged = {
  438. 1: '1000100001001111111',
  439. 10: '01100111' + '0' * (30 - 9) + '1',
  440. 13: '000100001',
  441. 14: '',
  442. }
  443. self.assertEqual(getuntagged(data, lufun), checkuntagged)
  444. #@unittest.skip('foo')
  445. @mock.patch('vlanmang.SNMPSwitch.getuntagged')
  446. @mock.patch('vlanmang.SNMPSwitch.getegress')
  447. @mock.patch('vlanmang.SNMPSwitch.getpvid')
  448. @mock.patch('vlanmang.SNMPSwitch.getportmapping')
  449. @mock.patch('importlib.import_module')
  450. def test_checkchanges(self, imprt, portmapping, gpvid, gegress, guntagged):
  451. # that import returns the test data
  452. imprt.side_effect = itertools.repeat(self._test_data)
  453. # that getportmapping returns the following dict
  454. ports = { x: 'g%d' % x for x in xrange(1, 24) }
  455. ports[30] = 'lag1'
  456. ports[31] = 'lag2'
  457. ports[32] = 'lag3'
  458. portmapping.side_effect = itertools.repeat(ports)
  459. # that the switch's pvid returns
  460. spvid = { x: 283 for x in xrange(1, 24) }
  461. spvid[30] = 5
  462. gpvid.side_effect = itertools.repeat(spvid)
  463. # the the extra port is caught
  464. self.assertRaises(ValueError, checkchanges, 'data')
  465. # that the functions were called
  466. imprt.assert_called_with('data')
  467. portmapping.assert_called()
  468. # XXX - check that an ignore statement is honored
  469. # delete the extra port
  470. del ports[32]
  471. # that the egress data provided
  472. gegress.side_effect = [ {
  473. 1: '1' * 10,
  474. 5: '1' * 10,
  475. 283: '00000000111111111110011000000100000',
  476. } ]
  477. # that the untagged data provided
  478. guntagged.side_effect = [ {
  479. 1: '1' * 10,
  480. 5: '1' * 8 + '0' * 10,
  481. 283: '00000000111111111110011',
  482. } ]
  483. res, switch = checkchanges('data')
  484. self.assertIsInstance(switch, SNMPSwitch)
  485. validres = [ ('setpvid', x, 5, 283) for x in xrange(1, 9) ] + \
  486. [ ('setpvid', 20, 1, 283),
  487. ('setpvid', 21, 1, 283),
  488. ('setpvid', 30, 1, 5),
  489. ('setegress', 1, '0' * 19 + '11' + '0' * 8 + '1',
  490. '1' * 10),
  491. ('setuntagged', 1, '0' * 19 + '11' + '0' * 8 + '1',
  492. '1' * 10),
  493. ('setegress', 5, '1' * 8 + '0' * 11 + '11' + '0' * 8 +
  494. '1', '1' * 10),
  495. ]
  496. self.assertEqual(set(res), set(validres))
  497. class _TestSNMPSwitch(unittest.TestCase):
  498. def test_splitmany(self):
  499. # make sure that if we get a tooBig error that we split the
  500. # _getmany request
  501. switch = SNMPSwitch(None, None)
  502. @mock.patch('vlanmang.SNMPSwitch._getmany')
  503. def test_get(self, gm):
  504. # that a switch
  505. switch = SNMPSwitch(None, None)
  506. # when _getmany returns this structure
  507. retval = object()
  508. gm.side_effect = [[[ None, retval ]]]
  509. arg = object()
  510. # will return the correct value
  511. self.assertIs(switch._get(arg), retval)
  512. # and call _getmany w/ the correct arg
  513. gm.assert_called_with(arg)
  514. @mock.patch('pysnmp.hlapi.ContextData')
  515. @mock.patch('vlanmang.getCmd')
  516. def test_getmany(self, gc, cd):
  517. # that a switch
  518. switch = SNMPSwitch(None, None)
  519. lookup = { x: chr(x) for x in xrange(1, 10) }
  520. # when getCmd returns tooBig when too many oids are asked for
  521. def custgetcmd(eng, cd, targ, contextdata, *oids):
  522. # induce a too big error
  523. if len(oids) > 3:
  524. res = ( None, 'tooBig', None, None )
  525. else:
  526. #import pdb; pdb.set_trace()
  527. [ oid.resolveWithMib(_mvc) for oid in oids ]
  528. oids = [ ObjectType(x[0],
  529. OctetString(lookup[x[0][-1]])) for x in oids ]
  530. [ oid.resolveWithMib(_mvc) for oid in oids ]
  531. res = ( None, None, None, oids )
  532. return iter([res])
  533. gc.side_effect = custgetcmd
  534. #import pdb; pdb.set_trace()
  535. res = switch.getegress(*xrange(1, 10))
  536. # will still return the complete set of results
  537. self.assertEqual(res, { x: _octstrtobits(lookup[x]) for x in
  538. xrange(1, 10) })
  539. _skipSwitchTests = True
  540. class _TestSwitch(unittest.TestCase):
  541. def setUp(self):
  542. # If we don't have it, pretend it's true for now and
  543. # we'll recheck it later
  544. model = 'GS108T smartSwitch'
  545. if getattr(self, 'switchmodel', model) != model or \
  546. _skipSwitchTests: # pragma: no cover
  547. self.skipTest('Need a GS108T switch to run these tests')
  548. args = open('test.creds').read().split()
  549. self.switch = SNMPSwitch(*args)
  550. self.switchmodel = self.switch._get(('ENTITY-MIB',
  551. 'entPhysicalModelName', 1))
  552. if self.switchmodel != model: # pragma: no cover
  553. self.skipTest('Need a GS108T switch to run these tests')
  554. def test_misc(self):
  555. switch = self.switch
  556. self.assertEqual(switch.findport('g1'), 1)
  557. self.assertEqual(switch.findport('l1'), 14)
  558. def test_portnames(self):
  559. switch = self.switch
  560. resp = dict((x, 'g%d' % x) for x in xrange(1, 9))
  561. resp.update({ 13: 'cpu' })
  562. resp.update((x, 'l%d' % (x - 13)) for x in xrange(14, 18))
  563. self.assertEqual(switch.getportmapping(), resp)
  564. def test_egress(self):
  565. switch = self.switch
  566. egress = switch.getegress(1, 2, 3)
  567. checkegress = {
  568. 1: '1' * 8 + '0' * 5 + '1' * 4 + '0' * 23,
  569. 2: '0' * 8 * 5,
  570. 3: '0' * 8 * 5,
  571. }
  572. self.assertEqual(egress, checkegress)
  573. def test_untagged(self):
  574. switch = self.switch
  575. untagged = switch.getuntagged(1, 2, 3)
  576. checkuntagged = {
  577. 1: '1' * 8 * 5,
  578. 2: '1' * 8 * 5,
  579. 3: '1' * 8 * 5,
  580. }
  581. self.assertEqual(untagged, checkuntagged)
  582. def test_vlan(self):
  583. switch = self.switch
  584. existingvlans = set(switch.getvlans())
  585. while True:
  586. testvlan = random.randint(1,4095)
  587. if testvlan not in existingvlans:
  588. break
  589. # Test that getting a non-existant vlans raises an exception
  590. self.assertRaises(ValueError, switch.getvlanname, testvlan)
  591. self.assertTrue(set(switch.staticvlans()).issubset(existingvlans))
  592. pvidres = { x: 1 for x in xrange(1, 9) }
  593. pvidres.update({ x: 1 for x in xrange(14, 18) })
  594. self.assertEqual(switch.getpvid(), pvidres)
  595. testname = 'Sometestname'
  596. # Create test vlan
  597. switch.createvlan(testvlan, testname)
  598. testport = None
  599. try:
  600. # make sure the test vlan was created
  601. self.assertIn(testvlan, set(switch.staticvlans()))
  602. self.assertEqual(testname, switch.getvlanname(testvlan))
  603. switch.setegress(testvlan, '00100')
  604. pvidmap = switch.getpvid()
  605. testport = 3
  606. egressports = switch.getegress(testvlan)
  607. self.assertEqual(egressports[testvlan], '00100000' +
  608. '0' * 8 * 4)
  609. switch.setuntagged(testvlan, '00100')
  610. untaggedports = switch.getuntagged(testvlan)
  611. self.assertEqual(untaggedports[testvlan], '00100000' +
  612. '0' * 8 * 4)
  613. switch.setpvid(testport, testvlan)
  614. self.assertEqual(switch.getpvid()[testport], testvlan)
  615. finally:
  616. if testport:
  617. switch.setpvid(testport, pvidmap[3])
  618. switch.deletevlan(testvlan)