You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

337 lines
12 KiB

  1. ############################################################################
  2. # David W. Robertson, LBNL
  3. # See Copyright for copyright notice!
  4. ###########################################################################
  5. import sys, os.path, pickle
  6. import StringIO, copy, re
  7. import unittest, ConfigParser
  8. from ZSI.wstools.WSDLTools import WSDLReader
  9. """
  10. utils:
  11. This module contains utility functions for use by test case modules, a
  12. class facilitating the use of ConfigParser with multiple test cases, a
  13. class encapsulating comparisons against a test file, and a test loader
  14. class with a different loading strategy than the default
  15. unittest.TestLoader.
  16. """
  17. thisFileName = sys.modules[__name__].__file__
  18. class ConfigHandler(ConfigParser.ConfigParser):
  19. def __init__(self, name="config.py"):
  20. ConfigParser.ConfigParser.__init__(self)
  21. # first, look for one in this directory
  22. try:
  23. self.read(name)
  24. except IOError:
  25. self.read(os.path.dirname(thisFileName) + os.sep + name)
  26. def getConfigNames(self, sections, numMethods, valueFunc=None):
  27. """A generator which returns one value from a given config
  28. file section at a time. It also optionally calls a
  29. passed-function for that value, and yields the result as well.
  30. """
  31. result = None
  32. for section in sections:
  33. for name, value in self.items(section):
  34. for i in range(0, numMethods):
  35. yield value # indicate which test in all cases
  36. if i == 0:
  37. result = None
  38. if valueFunc:
  39. try:
  40. result = valueFunc(value)
  41. except KeyboardInterrupt:
  42. sys.exit(-1) # for now
  43. except: # don't care, test will be skipped
  44. pass
  45. if valueFunc:
  46. yield result
  47. def length(self, sections):
  48. """Determines the total number of items in all the
  49. chosen sections from a config file.
  50. """
  51. total = 0
  52. for section in sections:
  53. total += len(self.options(section))
  54. return total
  55. def setUpWsdl(path):
  56. """Load a WSDL given a file path or a URL.
  57. """
  58. if path[:7] == 'http://':
  59. wsdl = WSDLReader().loadFromURL(path)
  60. else:
  61. wsdl = WSDLReader().loadFromFile(path)
  62. return wsdl
  63. def loadPickledObj(fname):
  64. """Not currently used.
  65. """
  66. fname = os.path.dirname(thisFileName) + os.sep + fname + ".obj"
  67. f = open(fname, "r")
  68. obj = pickle.load(f)
  69. f.close()
  70. return obj
  71. def dumpPickledObj(obj, fname):
  72. """Not currently used"""
  73. fname = os.path.dirname(thisFileName) + os.sep + fname + ".obj"
  74. f = open(fname, "w")
  75. pickle.dump(obj, f)
  76. f.close()
  77. class TestDiff:
  78. """TestDiff encapsulates comparing a string or StringIO object
  79. against text in a test file. Test files are expected to
  80. be located in a subdirectory of the current directory,
  81. named data (if one doesn't exist, it will be created).
  82. If used in a test case, this should be instantiated in setUp and
  83. closed in tearDown. The calling unittest.TestCase instance is passed
  84. in on object creation. Optional compiled regular expressions
  85. can also be passed in, which are used to ignore strings
  86. that one knows in advance will be different, for example
  87. id="<hex digits>" .
  88. The initial running of the test will create the test
  89. files. When the tests are run again, the new output
  90. is compared against the old, line by line. To generate
  91. a new test file, remove the old one from data.
  92. """
  93. def __init__(self, testInst, *ignoreList):
  94. self.dataFile = None
  95. self.testInst = testInst
  96. self.origStrFile = None
  97. # used to divide separate test blocks within the same
  98. # test file.
  99. self.divider = "#" + ">" * 75 + "\n"
  100. self.expectedFailures = copy.copy(ignoreList)
  101. self.testFilePath = "data" + os.sep
  102. if not os.path.exists(self.testFilePath):
  103. os.mkdir(self.testFilePath)
  104. def setDiffFile(self, fname):
  105. """setDiffFile attempts to open the test file with the
  106. given name, and read it into a StringIO instance.
  107. If the file does not exist, it opens the file for
  108. writing.
  109. """
  110. filename = fname
  111. if self.dataFile and not self.dataFile.closed:
  112. self.dataFile.close()
  113. try:
  114. self.dataFile = open(self.testFilePath + filename, "r")
  115. self.origStrFile = StringIO.StringIO(self.dataFile.read())
  116. except IOError:
  117. try:
  118. self.dataFile = open(self.testFilePath + filename, "w")
  119. except IOError:
  120. print "exception"
  121. def failUnlessEqual(self, buffer):
  122. """failUnlessEqual takes either a string or a StringIO
  123. instance as input, and compares it against the original
  124. output from the test file.
  125. """
  126. # if not already a string IO
  127. if not isinstance(buffer, StringIO.StringIO):
  128. testStrFile = StringIO.StringIO(buffer)
  129. else:
  130. testStrFile = buffer
  131. testStrFile.seek(0)
  132. if self.dataFile.mode == "r":
  133. for testLine in testStrFile:
  134. origLine = self.origStrFile.readline()
  135. # skip divider
  136. if origLine == self.divider:
  137. origLine = self.origStrFile.readline()
  138. # take out expected failure strings before
  139. # comparing original against new output
  140. for cexpr in self.expectedFailures:
  141. origLine = cexpr.sub('', origLine)
  142. testLine = cexpr.sub('', testLine)
  143. self.testInst.failUnlessEqual(origLine, testLine)
  144. else: # write new test file
  145. for line in testStrFile:
  146. self.dataFile.write(line)
  147. self.dataFile.write(self.divider)
  148. def close(self):
  149. """Closes handle to original test file.
  150. """
  151. if self.dataFile and not self.dataFile.closed:
  152. self.dataFile.close()
  153. class MatchTestLoader(unittest.TestLoader):
  154. """Overrides unittest.TestLoader.loadTestsFromNames to provide a
  155. simpler and less verbose way to select a subset of tests to run.
  156. If all tests will always be run, use unittest.TestLoader instead.
  157. If a top-level test invokes test cases in other modules,
  158. MatchTestLoader should be created with topLevel set to True
  159. to get the correct results. For example,
  160. def main():
  161. loader = utils.MatchTestLoader(True, None, "makeTestSuite")
  162. unittest.main(defaultTest="makeTestSuite", testLoader=loader)
  163. The defaultTest argument in the constructor indicates the test to run
  164. if no additional arguments beyond the test script name are provided.
  165. """
  166. def __init__(self, topLevel, configName, defaultTest):
  167. unittest.TestLoader.__init__(self)
  168. self.testMethodPrefix = "test"
  169. self.defaultTest = defaultTest
  170. self.topLevel = topLevel
  171. if configName:
  172. self.config = ConfigHandler(configName)
  173. self.sections = []
  174. self.nameGenerator = None
  175. def setUpArgs(self):
  176. """Sets up the use of arguments from the command-line to select
  177. tests to run. There can be multiple names, both in full or as
  178. a substring, on the command-line.
  179. """
  180. sectionList = self.config.sections()
  181. self.testArgs = []
  182. argv = []
  183. # ignore section names in determining what to
  184. # load (sys.argv can be passed into setSection,
  185. # where any section names are extracted)
  186. for name in sys.argv:
  187. if name not in sectionList:
  188. argv.append(name)
  189. if not self.topLevel or (len(argv) != 1):
  190. for arg in argv[1:]:
  191. if arg.find("-") != 0:
  192. self.testArgs.append(arg)
  193. # has the effect of loading all tests
  194. if not self.testArgs:
  195. self.testArgs = [None]
  196. def loadTestsFromNames(self, unused, module=None):
  197. """Hard-wires using the default test. It ignores the names
  198. passed into it from unittest.TestProgram, because the
  199. default loader would fail on substrings or section names.
  200. """
  201. suites = unittest.TestLoader.loadTestsFromNames(self,
  202. (self.defaultTest,), module)
  203. return suites
  204. def setSection(self, args):
  205. """Sets section(s) of config file to read.
  206. """
  207. sectionList = self.config.sections()
  208. if ((type(args) is list) or
  209. (type(args) is tuple)):
  210. for arg in args:
  211. if arg in sectionList:
  212. self.sections.append(arg)
  213. if self.sections:
  214. return True
  215. elif type(args) is str:
  216. if args in sectionList:
  217. self.sections.append(args)
  218. return True
  219. return False
  220. def loadTestsFromConfig(self, testCaseClass, valueFunc=None):
  221. """Loads n number of instances of testCaseClass, where
  222. n is the number of items in the config file section(s).
  223. getConfigNames is a generator which is used to parcel
  224. out the values in the section(s) to the testCaseClass
  225. instances.
  226. """
  227. self.setUpArgs()
  228. numTestCases = self.getTestCaseNumber(testCaseClass)
  229. self.nameGenerator = self.config.getConfigNames(self.sections,
  230. numTestCases, valueFunc)
  231. configLen = self.config.length(self.sections)
  232. suite = unittest.TestSuite()
  233. for i in range(0, configLen):
  234. suite.addTest(self.loadTestsFromTestCase(testCaseClass))
  235. return suite
  236. def getTestCaseNumber(self, testCaseClass):
  237. """Looks for any test methods whose name contains testStr, checking
  238. if a test method has already been added. If there is not a match,
  239. it checks for an exact match with the test case name, and
  240. returns the number of test cases.
  241. """
  242. methods = self.getTestCaseNames(testCaseClass)
  243. prevAdded = []
  244. counter = 0
  245. for testStr in self.testArgs:
  246. if testStr:
  247. for m in methods:
  248. if m.find(testStr) >= 0 and m not in prevAdded:
  249. counter = counter + 1
  250. prevAdded.append(m)
  251. if counter:
  252. return counter
  253. if (not testStr) or (testCaseClass.__name__ == testStr):
  254. for m in methods:
  255. counter = counter + 1
  256. prevAdded.append(m)
  257. # print "found %d cases" % counter
  258. return counter
  259. def loadTestsFromTestCase(self, testCaseClass):
  260. """looks for any test methods whose name contains testStr, checking
  261. if a test method has already been added. If there is not a match,
  262. it checks for an exact match with the test case name, and loads
  263. all methods if so.
  264. """
  265. methods = self.getTestCaseNames(testCaseClass)
  266. prevAdded = []
  267. suites = unittest.TestSuite()
  268. for testStr in self.testArgs:
  269. # print testStr
  270. if testStr:
  271. for m in methods:
  272. if m.find(testStr) >= 0 and m not in prevAdded:
  273. suites.addTest(testCaseClass(m))
  274. prevAdded.append(m)
  275. if suites.countTestCases():
  276. return suites
  277. for testStr in self.testArgs:
  278. if (not testStr) or (testCaseClass.__name__ == testStr):
  279. for m in methods:
  280. suites.addTest(testCaseClass(m))
  281. prevAdded.append(m)
  282. if suites.countTestCases():
  283. return suites
  284. return suites