_test_utils.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from __future__ import print_function
  2. import filecmp
  3. import glob
  4. import itertools
  5. import os
  6. import sys
  7. import sysconfig
  8. import tempfile
  9. import unittest
  10. project_dir = os.path.abspath(os.path.join(__file__, '..', '..', '..'))
  11. test_dir = os.getenv("BROTLI_TESTS_PATH")
  12. BRO_ARGS = [os.getenv("BROTLI_WRAPPER")]
  13. # Fallbacks
  14. if test_dir is None:
  15. test_dir = os.path.join(project_dir, 'tests')
  16. if BRO_ARGS[0] is None:
  17. python_exe = sys.executable or 'python'
  18. bro_path = os.path.join(project_dir, 'python', 'bro.py')
  19. BRO_ARGS = [python_exe, bro_path]
  20. # Get the platform/version-specific build folder.
  21. # By default, the distutils build base is in the same location as setup.py.
  22. platform_lib_name = 'lib.{platform}-{version[0]}.{version[1]}'.format(
  23. platform=sysconfig.get_platform(), version=sys.version_info)
  24. build_dir = os.path.join(project_dir, 'bin', platform_lib_name)
  25. # Prepend the build folder to sys.path and the PYTHONPATH environment variable.
  26. if build_dir not in sys.path:
  27. sys.path.insert(0, build_dir)
  28. TEST_ENV = os.environ.copy()
  29. if 'PYTHONPATH' not in TEST_ENV:
  30. TEST_ENV['PYTHONPATH'] = build_dir
  31. else:
  32. TEST_ENV['PYTHONPATH'] = build_dir + os.pathsep + TEST_ENV['PYTHONPATH']
  33. TESTDATA_DIR = os.path.join(test_dir, 'testdata')
  34. TESTDATA_FILES = [
  35. 'empty', # Empty file
  36. '10x10y', # Small text
  37. 'alice29.txt', # Large text
  38. 'random_org_10k.bin', # Small data
  39. 'mapsdatazrh', # Large data
  40. 'ukkonooa', # Poem
  41. 'cp1251-utf16le', # Codepage 1251 table saved in UTF16-LE encoding
  42. 'cp852-utf8', # Codepage 852 table saved in UTF8 encoding
  43. ]
  44. # Some files might be missing in a lightweight sources pack.
  45. TESTDATA_PATH_CANDIDATES = [
  46. os.path.join(TESTDATA_DIR, f) for f in TESTDATA_FILES
  47. ]
  48. TESTDATA_PATHS = [
  49. path for path in TESTDATA_PATH_CANDIDATES if os.path.isfile(path)
  50. ]
  51. TESTDATA_PATHS_FOR_DECOMPRESSION = glob.glob(
  52. os.path.join(TESTDATA_DIR, '*.compressed'))
  53. TEMP_DIR = tempfile.mkdtemp()
  54. def get_temp_compressed_name(filename):
  55. return os.path.join(TEMP_DIR, os.path.basename(filename + '.bro'))
  56. def get_temp_uncompressed_name(filename):
  57. return os.path.join(TEMP_DIR, os.path.basename(filename + '.unbro'))
  58. def bind_method_args(method, *args, **kwargs):
  59. return lambda self: method(self, *args, **kwargs)
  60. def generate_test_methods(test_case_class,
  61. for_decompression=False,
  62. variants=None):
  63. # Add test methods for each test data file. This makes identifying problems
  64. # with specific compression scenarios easier.
  65. if for_decompression:
  66. paths = TESTDATA_PATHS_FOR_DECOMPRESSION
  67. else:
  68. paths = TESTDATA_PATHS
  69. opts = []
  70. if variants:
  71. opts_list = []
  72. for k, v in variants.items():
  73. opts_list.append([r for r in itertools.product([k], v)])
  74. for o in itertools.product(*opts_list):
  75. opts_name = '_'.join([str(i) for i in itertools.chain(*o)])
  76. opts_dict = dict(o)
  77. opts.append([opts_name, opts_dict])
  78. else:
  79. opts.append(['', {}])
  80. for method in [m for m in dir(test_case_class) if m.startswith('_test')]:
  81. for testdata in paths:
  82. for (opts_name, opts_dict) in opts:
  83. f = os.path.splitext(os.path.basename(testdata))[0]
  84. name = 'test_{method}_{options}_{file}'.format(
  85. method=method, options=opts_name, file=f)
  86. func = bind_method_args(
  87. getattr(test_case_class, method), testdata, **opts_dict)
  88. setattr(test_case_class, name, func)
  89. class TestCase(unittest.TestCase):
  90. def tearDown(self):
  91. for f in TESTDATA_PATHS:
  92. try:
  93. os.unlink(get_temp_compressed_name(f))
  94. except OSError:
  95. pass
  96. try:
  97. os.unlink(get_temp_uncompressed_name(f))
  98. except OSError:
  99. pass
  100. def assertFilesMatch(self, first, second):
  101. self.assertTrue(
  102. filecmp.cmp(first, second, shallow=False),
  103. 'File {} differs from {}'.format(first, second))