You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

338 lines
12 KiB

  1. ############################################################################
  2. # David W. Robertson, LBNL
  3. # See Copyright for copyright notice!
  4. ###########################################################################
  5. import sys, os.path, pickle
  6. import StringIO, copy, re
  7. import unittest, ConfigParser
  8. from ZSI.wstools.WSDLTools import WSDLReader
  9. """
  10. utils:
  11. This module contains utility functions for use by test case modules, a
  12. class facilitating the use of ConfigParser with multiple test cases, a
  13. class encapsulating comparisons against a test file, and a test loader
  14. class with a different loading strategy than the default
  15. unittest.TestLoader.
  16. """
  17. thisFileName = sys.modules[__name__].__file__
  18. class ConfigHandler(ConfigParser.ConfigParser):
  19. def __init__(self, name="config.py"):
  20. ConfigParser.ConfigParser.__init__(self)
  21. # first, look for one in this directory
  22. try:
  23. self.read(name)
  24. except IOError:
  25. self.read(os.path.dirname(thisFileName) + os.sep + name)
  26. def getConfigNames(self, sections, numMethods, valueFunc=None):
  27. """A generator which returns one value from a given config
  28. file section at a time. It also optionally calls a
  29. passed-function for that value, and yields the result as well.
  30. """
  31. result = None
  32. for section in sections:
  33. for name, value in self.items(section):
  34. for i in range(0, numMethods):
  35. yield value # indicate which test in all cases
  36. if i == 0:
  37. result = None
  38. if valueFunc:
  39. try:
  40. result = valueFunc(value)
  41. except KeyboardInterrupt:
  42. sys.exit(-1) # for now
  43. except: # don't care, test will be skipped
  44. pass
  45. if valueFunc:
  46. yield result
  47. def length(self, sections):
  48. """Determines the total number of items in all the
  49. chosen sections from a config file.
  50. """
  51. total = 0
  52. for section in sections:
  53. total += len(self.options(section))
  54. return total
  55. def setUpWsdl(path):
  56. """Load a WSDL given a file path or a URL.
  57. """
  58. if path[:7] == 'http://':
  59. wsdl = WSDLReader().loadFromURL(path)
  60. else:
  61. wsdl = WSDLReader().loadFromFile(path)
  62. return wsdl
  63. def loadPickledObj(fname):
  64. """Not currently used.
  65. """
  66. fname = os.path.dirname(thisFileName) + os.sep + fname + ".obj"
  67. f = open(fname, "r")
  68. obj = pickle.load(f)
  69. f.close()
  70. return obj
  71. def dumpPickledObj(obj, fname):
  72. """Not currently used"""
  73. fname = os.path.dirname(thisFileName) + os.sep + fname + ".obj"
  74. f = open(fname, "w")
  75. pickle.dump(obj, f)
  76. f.close()
  77. class TestDiff:
  78. """TestDiff encapsulates comparing a string or StringIO object
  79. against text in a test file. Test files are expected to
  80. be located in a subdirectory of the current directory,
  81. named data (if one doesn't exist, it will be created).
  82. If used in a test case, this should be instantiated in setUp and
  83. closed in tearDown. The calling unittest.TestCase instance is passed
  84. in on object creation. Optional compiled regular expressions
  85. can also be passed in, which are used to ignore strings
  86. that one knows in advance will be different, for example
  87. id="<hex digits>" .
  88. The initial running of the test will create the test
  89. files. When the tests are run again, the new output
  90. is compared against the old, line by line. To generate
  91. a new test file, remove the old one from data.
  92. """
  93. def __init__(self, testInst, *ignoreList):
  94. self.dataFile = None
  95. self.testInst = testInst
  96. self.origStrFile = None
  97. # used to divide separate test blocks within the same
  98. # test file.
  99. self.divider = "#" + ">" * 75 + "\n"
  100. self.expectedFailures = copy.copy(ignoreList)
  101. self.testFilePath = "data" + os.sep
  102. if not os.path.exists(self.testFilePath):
  103. os.mkdir(self.testFilePath)
  104. def setDiffFile(self, fname):
  105. """setDiffFile attempts to open the test file with the
  106. given name, and read it into a StringIO instance.
  107. If the file does not exist, it opens the file for
  108. writing.
  109. """
  110. filename = fname
  111. if self.dataFile and not self.dataFile.closed:
  112. self.dataFile.close()
  113. try:
  114. self.dataFile = open(self.testFilePath + filename, "r")
  115. self.origStrFile = StringIO.StringIO(self.dataFile.read())
  116. except IOError:
  117. try:
  118. self.dataFile = open(self.testFilePath + filename, "w")
  119. except IOError:
  120. print "exception"
  121. def failUnlessEqual(self, buffer):
  122. """failUnlessEqual takes either a string or a StringIO
  123. instance as input, and compares it against the original
  124. output from the test file.
  125. """
  126. # if not already a string IO
  127. if not isinstance(buffer, StringIO.StringIO):
  128. testStrFile = StringIO.StringIO(buffer)
  129. else:
  130. testStrFile = buffer
  131. testStrFile.seek(0)
  132. if self.dataFile.mode == "r":
  133. for testLine in testStrFile:
  134. origLine = self.origStrFile.readline()
  135. # skip divider
  136. if origLine == self.divider:
  137. origLine = self.origStrFile.readline()
  138. # take out expected failure strings before
  139. # comparing original against new output
  140. for cexpr in self.expectedFailures:
  141. origLine = cexpr.sub('', origLine)
  142. testLine = cexpr.sub('', testLine)
  143. self.testInst.failUnlessEqual(origLine, testLine)
  144. else: # write new test file
  145. for line in testStrFile:
  146. self.dataFile.write(line)
  147. self.dataFile.write(self.divider)
  148. testStrFile.close()
  149. def close(self):
  150. """Closes handle to original test file.
  151. """
  152. if self.dataFile and not self.dataFile.closed:
  153. self.dataFile.close()
  154. class MatchTestLoader(unittest.TestLoader):
  155. """Overrides unittest.TestLoader.loadTestsFromNames to provide a
  156. simpler and less verbose way to select a subset of tests to run.
  157. If all tests will always be run, use unittest.TestLoader instead.
  158. If a top-level test invokes test cases in other modules,
  159. MatchTestLoader should be created with topLevel set to True
  160. to get the correct results. For example,
  161. def main():
  162. loader = utils.MatchTestLoader(True, None, "makeTestSuite")
  163. unittest.main(defaultTest="makeTestSuite", testLoader=loader)
  164. The defaultTest argument in the constructor indicates the test to run
  165. if no additional arguments beyond the test script name are provided.
  166. """
  167. def __init__(self, topLevel, configName, defaultTest):
  168. unittest.TestLoader.__init__(self)
  169. self.testMethodPrefix = "test"
  170. self.defaultTest = defaultTest
  171. self.topLevel = topLevel
  172. if configName:
  173. self.config = ConfigHandler(configName)
  174. self.sections = []
  175. self.nameGenerator = None
  176. def setUpArgs(self):
  177. """Sets up the use of arguments from the command-line to select
  178. tests to run. There can be multiple names, both in full or as
  179. a substring, on the command-line.
  180. """
  181. sectionList = self.config.sections()
  182. self.testArgs = []
  183. argv = []
  184. # ignore section names in determining what to
  185. # load (sys.argv can be passed into setSection,
  186. # where any section names are extracted)
  187. for name in sys.argv:
  188. if name not in sectionList:
  189. argv.append(name)
  190. if not self.topLevel or (len(argv) != 1):
  191. for arg in argv[1:]:
  192. if arg.find("-") != 0:
  193. self.testArgs.append(arg)
  194. # has the effect of loading all tests
  195. if not self.testArgs:
  196. self.testArgs = [None]
  197. def loadTestsFromNames(self, unused, module=None):
  198. """Hard-wires using the default test. It ignores the names
  199. passed into it from unittest.TestProgram, because the
  200. default loader would fail on substrings or section names.
  201. """
  202. suites = unittest.TestLoader.loadTestsFromNames(self,
  203. (self.defaultTest,), module)
  204. return suites
  205. def setSection(self, args):
  206. """Sets section(s) of config file to read.
  207. """
  208. sectionList = self.config.sections()
  209. if ((type(args) is list) or
  210. (type(args) is tuple)):
  211. for arg in args:
  212. if arg in sectionList:
  213. self.sections.append(arg)
  214. if self.sections:
  215. return True
  216. elif type(args) is str:
  217. if args in sectionList:
  218. self.sections.append(args)
  219. return True
  220. return False
  221. def loadTestsFromConfig(self, testCaseClass, valueFunc=None):
  222. """Loads n number of instances of testCaseClass, where
  223. n is the number of items in the config file section(s).
  224. getConfigNames is a generator which is used to parcel
  225. out the values in the section(s) to the testCaseClass
  226. instances.
  227. """
  228. self.setUpArgs()
  229. numTestCases = self.getTestCaseNumber(testCaseClass)
  230. self.nameGenerator = self.config.getConfigNames(self.sections,
  231. numTestCases, valueFunc)
  232. configLen = self.config.length(self.sections)
  233. suite = unittest.TestSuite()
  234. for i in range(0, configLen):
  235. suite.addTest(self.loadTestsFromTestCase(testCaseClass))
  236. return suite
  237. def getTestCaseNumber(self, testCaseClass):
  238. """Looks for any test methods whose name contains testStr, checking
  239. if a test method has already been added. If there is not a match,
  240. it checks for an exact match with the test case name, and
  241. returns the number of test cases.
  242. """
  243. methods = self.getTestCaseNames(testCaseClass)
  244. prevAdded = []
  245. counter = 0
  246. for testStr in self.testArgs:
  247. if testStr:
  248. for m in methods:
  249. if m.find(testStr) >= 0 and m not in prevAdded:
  250. counter = counter + 1
  251. prevAdded.append(m)
  252. if counter:
  253. return counter
  254. if (not testStr) or (testCaseClass.__name__ == testStr):
  255. for m in methods:
  256. counter = counter + 1
  257. prevAdded.append(m)
  258. # print "found %d cases" % counter
  259. return counter
  260. def loadTestsFromTestCase(self, testCaseClass):
  261. """looks for any test methods whose name contains testStr, checking
  262. if a test method has already been added. If there is not a match,
  263. it checks for an exact match with the test case name, and loads
  264. all methods if so.
  265. """
  266. methods = self.getTestCaseNames(testCaseClass)
  267. prevAdded = []
  268. suites = unittest.TestSuite()
  269. for testStr in self.testArgs:
  270. # print testStr
  271. if testStr:
  272. for m in methods:
  273. if m.find(testStr) >= 0 and m not in prevAdded:
  274. suites.addTest(testCaseClass(m))
  275. prevAdded.append(m)
  276. if suites.countTestCases():
  277. return suites
  278. for testStr in self.testArgs:
  279. if (not testStr) or (testCaseClass.__name__ == testStr):
  280. for m in methods:
  281. suites.addTest(testCaseClass(m))
  282. prevAdded.append(m)
  283. if suites.countTestCases():
  284. return suites
  285. return suites