You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

343 lines
12 KiB

  1. ############################################################################
  2. # David W. Robertson, LBNL
  3. # See Copyright for copyright notice!
  4. ###########################################################################
  5. import sys, os.path, pickle
  6. import StringIO, copy, re
  7. import unittest, ConfigParser
  8. from ZSI.wstools.WSDLTools import WSDLReader
  9. """
  10. utils:
  11. This module contains utility functions for use by test case modules, a
  12. class facilitating the use of ConfigParser with multiple test cases, a
  13. class encapsulating comparisons against a test file, and a test loader
  14. class with a different loading strategy than the default
  15. unittest.TestLoader.
  16. """
  17. thisFileName = sys.modules[__name__].__file__
  18. class ConfigHandler(ConfigParser.ConfigParser):
  19. def __init__(self, name="config.py"):
  20. ConfigParser.ConfigParser.__init__(self)
  21. # first, look for one in this directory
  22. try:
  23. self.read(name)
  24. except IOError:
  25. self.read(os.path.dirname(thisFileName) + os.sep + name)
  26. def getConfigNames(self, sections, numMethods, valueFunc=None):
  27. """A generator which returns one value from a given config
  28. file section at a time. It also optionally calls a
  29. passed-function for that value, and yields the result as well.
  30. """
  31. result = None
  32. for section in sections:
  33. for name, value in self.items(section):
  34. for i in range(0, numMethods):
  35. yield value # indicate which test in all cases
  36. if i == 0:
  37. result = None
  38. if valueFunc:
  39. try:
  40. result = valueFunc(value)
  41. except KeyboardInterrupt:
  42. sys.exit(-1) # for now
  43. except: # don't care, test will be skipped
  44. pass
  45. if valueFunc:
  46. yield result
  47. def length(self, sections):
  48. """Determines the total number of items in all the
  49. chosen sections from a config file.
  50. """
  51. total = 0
  52. for section in sections:
  53. total += len(self.options(section))
  54. return total
  55. def setUpWsdl(path):
  56. """Load a WSDL given a file path or a URL.
  57. """
  58. if path[:7] == 'http://':
  59. wsdl = WSDLReader().loadFromURL(path)
  60. else:
  61. wsdl = WSDLReader().loadFromFile(path)
  62. return wsdl
  63. def loadPickledObj(fname):
  64. """Not currently used.
  65. """
  66. fname = os.path.dirname(thisFileName) + os.sep + fname + ".obj"
  67. f = open(fname, "r")
  68. obj = pickle.load(f)
  69. f.close()
  70. return obj
  71. def dumpPickledObj(obj, fname):
  72. """Not currently used"""
  73. fname = os.path.dirname(thisFileName) + os.sep + fname + ".obj"
  74. f = open(fname, "w")
  75. pickle.dump(obj, f)
  76. f.close()
  77. class TestDiff:
  78. """TestDiff encapsulates comparing a string or StringIO object
  79. against text in a test file. Test files are expected to
  80. be located in a subdirectory of the current directory,
  81. named data (if one doesn't exist, it will be created).
  82. If used in a test case, this should be instantiated in setUp and
  83. closed in tearDown. The calling unittest.TestCase instance is passed
  84. in on object creation. Optional compiled regular expressions
  85. can also be passed in, which are used to ignore strings
  86. that one knows in advance will be different, for example
  87. id="<hex digits>" .
  88. The initial running of the test will create the test
  89. files. When the tests are run again, the new output
  90. is compared against the old, line by line. To generate
  91. a new test file, remove the old one from data.
  92. """
  93. def __init__(self, testInst, *ignoreList):
  94. self.dataFile = None
  95. self.testInst = testInst
  96. self.origStrFile = None
  97. # used to divide separate test blocks within the same
  98. # test file.
  99. self.divider = "#" + ">" * 75 + "\n"
  100. self.expectedFailures = copy.copy(ignoreList)
  101. self.testFilePath = "data" + os.sep
  102. if not os.path.exists(self.testFilePath):
  103. os.mkdir(self.testFilePath)
  104. def setDiffFile(self, fname):
  105. """setDiffFile attempts to open the test file with the
  106. given name, and read it into a StringIO instance.
  107. If the file does not exist, it opens the file for
  108. writing.
  109. """
  110. filename = fname
  111. if self.dataFile and not self.dataFile.closed:
  112. self.dataFile.close()
  113. try:
  114. self.dataFile = open(self.testFilePath + filename, "r")
  115. self.origStrFile = StringIO.StringIO(self.dataFile.read())
  116. except IOError:
  117. try:
  118. self.dataFile = open(self.testFilePath + filename, "w")
  119. except IOError:
  120. print "exception"
  121. def failUnlessEqual(self, buffer):
  122. """failUnlessEqual takes either a string or a StringIO
  123. instance as input, and compares it against the original
  124. output from the test file.
  125. """
  126. # if not already a string IO
  127. if not isinstance(buffer, StringIO.StringIO):
  128. testStrFile = StringIO.StringIO(buffer)
  129. else:
  130. testStrFile = buffer
  131. testStrFile.seek(0)
  132. if self.dataFile.mode == "r":
  133. for testLine in testStrFile:
  134. origLine = self.origStrFile.readline()
  135. # skip divider
  136. if origLine == self.divider:
  137. origLine = self.origStrFile.readline()
  138. # take out expected failure strings before
  139. # comparing original against new output
  140. for cexpr in self.expectedFailures:
  141. origLine = cexpr.sub('', origLine)
  142. testLine = cexpr.sub('', testLine)
  143. if origLine != testLine: # fails
  144. # advance test file to next test
  145. line = origLine
  146. while line and line != self.divider:
  147. line = self.origStrFile.readline()
  148. self.testInst.failUnlessEqual(origLine, testLine)
  149. else: # write new test file
  150. for line in testStrFile:
  151. self.dataFile.write(line)
  152. self.dataFile.write(self.divider)
  153. def close(self):
  154. """Closes handle to original test file.
  155. """
  156. if self.dataFile and not self.dataFile.closed:
  157. self.dataFile.close()
  158. class MatchTestLoader(unittest.TestLoader):
  159. """Overrides unittest.TestLoader.loadTestsFromNames to provide a
  160. simpler and less verbose way to select a subset of tests to run.
  161. If all tests will always be run, use unittest.TestLoader instead.
  162. If a top-level test invokes test cases in other modules,
  163. MatchTestLoader should be created with topLevel set to True
  164. to get the correct results. For example,
  165. def main():
  166. loader = utils.MatchTestLoader(True, None, "makeTestSuite")
  167. unittest.main(defaultTest="makeTestSuite", testLoader=loader)
  168. The defaultTest argument in the constructor indicates the test to run
  169. if no additional arguments beyond the test script name are provided.
  170. """
  171. def __init__(self, topLevel, configName, defaultTest):
  172. unittest.TestLoader.__init__(self)
  173. self.testMethodPrefix = "test"
  174. self.defaultTest = defaultTest
  175. self.topLevel = topLevel
  176. if configName:
  177. self.config = ConfigHandler(configName)
  178. self.sections = []
  179. self.nameGenerator = None
  180. def setUpArgs(self):
  181. """Sets up the use of arguments from the command-line to select
  182. tests to run. There can be multiple names, both in full or as
  183. a substring, on the command-line.
  184. """
  185. sectionList = self.config.sections()
  186. self.testArgs = []
  187. argv = []
  188. # ignore section names in determining what to
  189. # load (sys.argv can be passed into setSection,
  190. # where any section names are extracted)
  191. for name in sys.argv:
  192. if name not in sectionList:
  193. argv.append(name)
  194. if not self.topLevel or (len(argv) != 1):
  195. for arg in argv[1:]:
  196. if arg.find("-") != 0:
  197. self.testArgs.append(arg)
  198. # has the effect of loading all tests
  199. if not self.testArgs:
  200. self.testArgs = [None]
  201. def loadTestsFromNames(self, unused, module=None):
  202. """Hard-wires using the default test. It ignores the names
  203. passed into it from unittest.TestProgram, because the
  204. default loader would fail on substrings or section names.
  205. """
  206. suites = unittest.TestLoader.loadTestsFromNames(self,
  207. (self.defaultTest,), module)
  208. return suites
  209. def setSection(self, args):
  210. """Sets section(s) of config file to read.
  211. """
  212. sectionList = self.config.sections()
  213. if ((type(args) is list) or
  214. (type(args) is tuple)):
  215. for arg in args:
  216. if arg in sectionList:
  217. self.sections.append(arg)
  218. if self.sections:
  219. return True
  220. elif type(args) is str:
  221. if args in sectionList:
  222. self.sections.append(args)
  223. return True
  224. return False
  225. def loadTestsFromConfig(self, testCaseClass, valueFunc=None):
  226. """Loads n number of instances of testCaseClass, where
  227. n is the number of items in the config file section(s).
  228. getConfigNames is a generator which is used to parcel
  229. out the values in the section(s) to the testCaseClass
  230. instances.
  231. """
  232. self.setUpArgs()
  233. numTestCases = self.getTestCaseNumber(testCaseClass)
  234. self.nameGenerator = self.config.getConfigNames(self.sections,
  235. numTestCases, valueFunc)
  236. configLen = self.config.length(self.sections)
  237. suite = unittest.TestSuite()
  238. for i in range(0, configLen):
  239. suite.addTest(self.loadTestsFromTestCase(testCaseClass))
  240. return suite
  241. def getTestCaseNumber(self, testCaseClass):
  242. """Looks for any test methods whose name contains testStr, checking
  243. if a test method has already been added. If there is not a match,
  244. it checks for an exact match with the test case name, and
  245. returns the number of test cases.
  246. """
  247. methods = self.getTestCaseNames(testCaseClass)
  248. prevAdded = []
  249. counter = 0
  250. for testStr in self.testArgs:
  251. if testStr:
  252. for m in methods:
  253. if m.find(testStr) >= 0 and m not in prevAdded:
  254. counter = counter + 1
  255. prevAdded.append(m)
  256. if counter:
  257. return counter
  258. if (not testStr) or (testCaseClass.__name__ == testStr):
  259. for m in methods:
  260. counter = counter + 1
  261. prevAdded.append(m)
  262. # print "found %d cases" % counter
  263. return counter
  264. def loadTestsFromTestCase(self, testCaseClass):
  265. """looks for any test methods whose name contains testStr, checking
  266. if a test method has already been added. If there is not a match,
  267. it checks for an exact match with the test case name, and loads
  268. all methods if so.
  269. """
  270. methods = self.getTestCaseNames(testCaseClass)
  271. prevAdded = []
  272. suites = unittest.TestSuite()
  273. for testStr in self.testArgs:
  274. # print testStr
  275. if testStr:
  276. for m in methods:
  277. if m.find(testStr) >= 0 and m not in prevAdded:
  278. suites.addTest(testCaseClass(m))
  279. prevAdded.append(m)
  280. if suites.countTestCases():
  281. return suites
  282. for testStr in self.testArgs:
  283. if (not testStr) or (testCaseClass.__name__ == testStr):
  284. for m in methods:
  285. suites.addTest(testCaseClass(m))
  286. prevAdded.append(m)
  287. if suites.countTestCases():
  288. return suites
  289. return suites