2016-02-06 03:12:25 +01:00

319 lines
8.8 KiB
Python

# This module contains some code copied from unittest2/loader.py and other
# code developed in reference to that module and others within unittest2.
# unittest2 is Copyright (c) 2001-2010 Python Software Foundation; All
# Rights Reserved. See: http://docs.python.org/license.html
import logging
import os
import re
import sys
import traceback
import platform
try:
from inspect import isgeneratorfunction # new in 2.6
except ImportError:
import inspect
try:
from compiler.consts import CO_GENERATOR
except ImportError:
# IronPython doesn't have a complier module
CO_GENERATOR = 0x20
# backported from Python 2.6
def isgeneratorfunction(func):
return bool((inspect.isfunction(func) or inspect.ismethod(func)) and
func.func_code.co_flags & CO_GENERATOR)
import six
__unittest = True
IDENT_RE = re.compile(r'^[_a-zA-Z]\w*$', re.UNICODE)
VALID_MODULE_RE = re.compile(r'[_a-zA-Z]\w*\.py$', re.UNICODE)
def ln(label, char='-', width=70):
"""Draw a divider, with label in the middle.
>>> ln('hello there')
'---------------------------- hello there -----------------------------'
Width and divider char may be specified. Defaults are 70 and '-'
respectively.
"""
label_len = len(label) + 2
chunk = (width - label_len) // 2
out = '%s %s %s' % (char * chunk, label, char * chunk)
pad = width - len(out)
if pad > 0:
out = out + (char * pad)
return out
def valid_module_name(path):
"""Is path a valid module name?"""
return VALID_MODULE_RE.search(path)
def name_from_path(path):
"""Translate path into module name"""
# back up to find module root
parts = []
path = os.path.normpath(path)
base = os.path.splitext(path)[0]
candidate, top = os.path.split(base)
parts.append(top)
while candidate:
if ispackage(candidate):
candidate, top = os.path.split(candidate)
parts.append(top)
else:
break
return '.'.join(reversed(parts))
def module_from_name(name):
"""Import module from name"""
__import__(name)
return sys.modules[name]
def test_from_name(name, module):
"""Import test from name"""
pos = name.find(':')
index = None
if pos != -1:
real_name, digits = name[:pos], name[pos + 1:]
try:
index = int(digits)
except ValueError:
pass
else:
name = real_name
parent, obj = object_from_name(name, module)
return parent, obj, name, index
def object_from_name(name, module=None):
"""Import object from name"""
parts = name.split('.')
if module is None:
parts_copy = parts[:]
while parts_copy:
try:
module = __import__('.'.join(parts_copy))
break
except ImportError:
del parts_copy[-1]
if not parts_copy:
raise
parts = parts[1:]
parent = None
obj = module
for part in parts:
parent, obj = obj, getattr(obj, part)
return parent, obj
def name_from_args(name, index, args):
"""Create test name from test args"""
summary = ', '.join(repr(arg) for arg in args)
return '%s:%s\n%s' % (name, index + 1, summary[:79])
def test_name(test):
# XXX does not work for test funcs, test.id() lacks module
if hasattr(test, '_funcName'):
tid = test._funcName
elif hasattr(test, '_testFunc'):
tid = "%s.%s" % (test._testFunc.__module__, test._testFunc.__name__)
else:
tid = test.id()
if '\n' in tid:
tid = tid.split('\n')[0]
return tid
def ispackage(path):
"""Is this path a package directory?"""
if os.path.isdir(path):
# at least the end of the path must be a legal python identifier
# and __init__.py[co] must exist
end = os.path.basename(path)
if IDENT_RE.match(end):
for init in ('__init__.py', '__init__.pyc', '__init__.pyo'):
if os.path.isfile(os.path.join(path, init)):
return True
if sys.platform.startswith('java') and \
os.path.isfile(os.path.join(path, '__init__$py.class')):
return True
return False
def ensure_importable(dirname):
"""Ensure a directory is on sys.path"""
if not dirname in sys.path:
sys.path.insert(0, dirname)
def isgenerator(obj):
"""is this object a generator?"""
return (isgeneratorfunction(obj)
or getattr(obj, 'testGenerator', None) is not None)
def has_module_fixtures(test):
"""Does this test live in a module with module fixtures?"""
modname = test.__class__.__module__
try:
mod = sys.modules[modname]
except KeyError:
return
return hasattr(mod, 'setUpModule') or hasattr(mod, 'tearDownModule')
def has_class_fixtures(test):
# hasattr would be the obvious thing to use here, unfortunately all tests
# inherit from unittest2.case.TestCase and that *always* has setUpClass and
# tearDownClass methods. Therefore will have the following (ugly) solution:
ver = platform.python_version_tuple()
if float('{0}.{1}'.format(*ver[:2])) >= 2.7:
name = 'unittest.case'
else:
name = 'unittest2.case'
has_class_setups = any(
'setUpClass' in c.__dict__ for c in test.__class__.__mro__ if c.__module__.find(name) == -1)
has_class_teardowns = any(
'tearDownClass' in c.__dict__ for c in test.__class__.__mro__ if c.__module__.find(name) == -1)
return has_class_setups or has_class_teardowns
def safe_decode(string):
"""Safely decode a byte string into unicode"""
if string is None:
return string
try:
return string.decode()
except AttributeError:
return string
except UnicodeDecodeError:
pass
try:
return string.decode('utf-8')
except UnicodeDecodeError:
return six.u('<unable to decode>')
def exc_info_to_string(err, test):
"""Format exception info for output"""
formatTraceback = getattr(test, 'formatTraceback', None)
if formatTraceback is not None:
return test.formatTraceback(err)
else:
return format_traceback(test, err)
def format_traceback(test, err):
"""Converts a sys.exc_info()-style tuple of values into a string."""
exctype, value, tb = err
if not hasattr(tb, 'tb_next'):
msgLines = tb
else:
# Skip test runner traceback levels
while tb and _is_relevant_tb_level(tb):
tb = tb.tb_next
failure = getattr(test, 'failureException', AssertionError)
if exctype is failure:
# Skip assert*() traceback levels
length = _count_relevant_tb_levels(tb)
msgLines = traceback.format_exception(exctype, value, tb, length)
else:
msgLines = traceback.format_exception(exctype, value, tb)
return ''.join(msgLines)
def transplant_class(cls, module):
"""Make class appear to reside in ``module``.
:param cls: A class
:param module: A module name
:returns: A subclass of ``cls`` that appears to have been defined in ``module``.
The returned class's ``__name__`` will be equal to
``cls.__name__``, and its ``__module__`` equal to ``module``.
"""
class C(cls):
pass
C.__module__ = module
C.__name__ = cls.__name__
return C
def parse_log_level(lvl):
"""Return numeric log level given a string"""
try:
return int(lvl)
except ValueError:
pass
return getattr(logging, lvl.upper(), logging.WARN)
def _is_relevant_tb_level(tb):
return '__unittest' in tb.tb_frame.f_globals
def _count_relevant_tb_levels(tb):
length = 0
while tb and not _is_relevant_tb_level(tb):
length += 1
tb = tb.tb_next
return length
class _WritelnDecorator(object):
"""Used to decorate file-like objects with a handy 'writeln' method"""
def __init__(self, stream):
self.stream = stream
def __getattr__(self, attr):
if attr in ('stream', '__getstate__'):
raise AttributeError(attr)
return getattr(self.stream, attr)
def write(self, arg):
self.stream.write(arg)
def writeln(self, arg=None):
if arg:
self.stream.write(arg)
self.stream.write('\n')
# text-mode streams translate to \r\n if needed
def ancestry(layer):
layers = [[layer]]
bases = [base for base in bases_and_mixins(layer)
if base is not object]
while bases:
layers.append(bases)
newbases = []
for b in bases:
for bb in bases_and_mixins(b):
if bb is not object:
newbases.append(bb)
bases = newbases
layers.reverse()
return layers
def bases_and_mixins(layer):
return (layer.__bases__ + getattr(layer, 'mixins', ()))