[英]How to patch method ''.join using the mock library
要创建给定函数的单元测试,我需要修补''.join(...)
。
我尝试了很多方法(使用模拟库),但即使我有一些使用该库创建单元测试的经验,我也无法让它工作。
出现的第一个问题是str
是一个内置类,因此它不能被嘲笑。 William John Bert的一篇文章展示了如何处理这个问题(在他的案例中是datetime.date
)。 图书馆官方文档的“部分模拟”部分也有可能的解决方案。
第二个问题是str
并没有真正直接使用。 相反,调用文字''
方法join
。 那么,补丁的路径应该是什么?
这些选项都不起作用:
patch('__builtin__.str', 'join')
patch('string.join')
patch('__builtin__.str', FakeStr)
(其中FakeStr
是str
的子类) 任何帮助将不胜感激。
你不能,因为无法在内置类中设置属性:
>>> str.join = lambda x: None
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: can't set attributes of built-in/extension type 'str'
并且你不能修补str
,因为''.join
使用文字,因此无论你如何尝试在__builtin__
替换str
,解释器总是会创建一个str
。
如果您读取生成的字节码,您可以看到这个:
>>> import dis
>>> def test():
... ''.join([1,2,3])
...
>>> dis.dis(test)
2 0 LOAD_CONST 1 ('')
3 LOAD_ATTR 0 (join)
6 LOAD_CONST 2 (1)
9 LOAD_CONST 3 (2)
12 LOAD_CONST 4 (3)
15 BUILD_LIST 3
18 CALL_FUNCTION 1
21 POP_TOP
22 LOAD_CONST 0 (None)
25 RETURN_VALUE
字节码是在编译时生成的,正如您所看到的,无论您在运行时如何更改str
的值,第一个LOAD_CONST
加载''
,这是一个str
。
你可以做的是使用一个可以模拟的包装函数,或者避免使用文字。 例如,使用str()
而不是''
允许您使用实现join
方法的子类来mock
str
类(尽管这可能会影响太多代码,并且根据您使用的模块可能不可行)。
如果您感到非常幸运,可以检查并修补代码对象中的字符串consts:
def patch_strings(fun, cls):
new_consts = tuple(
cls(c) if type(c) is str else c
for c in fun.func_code.co_consts)
code = type(fun.func_code)
fun.func_code = code(
fun.func_code.co_argcount,
fun.func_code.co_nlocals,
fun.func_code.co_stacksize,
fun.func_code.co_flags,
fun.func_code.co_code,
new_consts,
fun.func_code.co_names,
fun.func_code.co_varnames,
fun.func_code.co_filename,
fun.func_code.co_name,
fun.func_code.co_firstlineno,
fun.func_code.co_lnotab,
fun.func_code.co_freevars,
fun.func_code.co_cellvars)
def a():
return ''.join(['a', 'b'])
class mystr(str):
def join(self, s):
print 'join called!'
return super(mystr, self).join(s)
patch_strings(a, mystr)
print a() # prints "join called!\nab"
Python3版本:
def patch_strings(fun, cls):
new_consts = tuple(
cls(c) if type(c) is str else c
for c in fun.__code__.co_consts)
code = type(fun.__code__)
fun.__code__ = code(
fun.__code__.co_argcount,
fun.__code__.co_kwonlyargcount,
fun.__code__.co_nlocals,
fun.__code__.co_stacksize,
fun.__code__.co_flags,
fun.__code__.co_code,
new_consts,
fun.__code__.co_names,
fun.__code__.co_varnames,
fun.__code__.co_filename,
fun.__code__.co_name,
fun.__code__.co_firstlineno,
fun.__code__.co_lnotab,
fun.__code__.co_freevars,
fun.__code__.co_cellvars)
实际上没有办法可以使用字符串文字,因为它们总是使用内置的str
类,正如您所发现的那样,它不能以这种方式进行修补。
当然,您可以编写一个函数join(seq, sep='')
来代替''.join()
和patch,或者一个str
子类class Separator
,它总是用来显式构造将成为的字符串用于join
操作(例如Separator('').join(....)
)。 这些变通办法有点难看,但是你不能修补方法。
在这里,我正在修改我正在测试的模块中的变量。 我不喜欢这个想法,因为我正在改变我的代码以适应测试,但它有效。
import mock
from main import func
@mock.patch('main.patched_str')
def test(patched_str):
patched_str.join.return_value = "hello"
result = func('1', '2')
assert patched_str.join.called_with('1', '2')
assert result == "hello"
if __name__ == '__main__':
test()
patched_str = ''
def func(*args):
return patched_str.join(args)
我的解决方案有点棘手,但它适用于大多数情况。 它不使用模拟库BTW。 我的解决方案的优点是你继续使用''.join
而不会进行丑陋的修改。
当我必须在Python3.2中运行为Python3.3编写的代码时,我找到了这种方法(它用str(...).casefold
替换了str(...).casefold
str(...).lower
)
假设你有这个模块:
# my_module.py
def my_func():
"""Print some joined text"""
print('_'.join(str(n) for n in range(5)))
有一个测试它的单元测试示例。 请注意,它是为Python 2.7编写的,但可以很容易地修改为Python 3(请参阅注释):
import re
from imp import reload # for Python 3
import my_module
class CustomJoinTets(unittest.TestCase):
"""Test case using custom str(...).join method"""
def setUp(self):
"""Replace the join method with a custom function"""
with open(my_module.__file__.replace('.pyc', '.py')) as f:
# Replace `separator.join(` with `custom_join(separator)(`
contents = re.sub(r"""(?P<q>["'])(?P<sep>.*?)(?P=q)[.]join\(""",
r"custom_join(\g<q>\g<sep>\g<q>)(",
f.read())
# Replace the code in the module
# For Python 3 do `exec(contents, my_module.__dict__)`
exec contents in my_module.__dict__
# Create `custom_join` object in the module
my_module.custom_join = self._custom_join
def tearDown(self):
"""Reload the module"""
reload(my_module)
def _custom_join(self, separator):
"""A factory for a custom join"""
separator = '+{}+'.format(separator)
return separator.join
def test_smoke(self):
"""Do something"""
my_module.my_func()
if __name__ == '__main__':
unittest.main()
如果你真的想要mock
库,你可以让_custom_join
方法返回MagicMock对象:
def _custom_join(self, separator):
"""A factory for a custom join"""
import mock
return mock.MagicMock(name="{!r}.join".format(separator))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.