简体   繁体   中英

how to run a tf.test.TestCase from jupyter notebook - UnrecognizedFlagError: Unknown command line flag 'f'

My test case runs fine when run from the command line with: py foo_test.py (please see below).

When i run the following notebook:

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import tensorflow_datasets as tfds
from importlib import reload
from tensorflow.python.framework import test_util
import foo_test
foo_test=reload(foo_test)
foo_test.main()

i get:

in: foo_test.py::main()
Running tests under Python 3.8.3: d:\pyvenvs\tf2.4\scripts\python.exe
-----------------------------------------------------------------
UnrecognizedFlagError           Traceback (most recent call last)
<ipython-input-1-8f6431fbb106> in <module>
      9 import foo_test
     10 foo_test=reload(foo_test)
---> 11 foo_test.main()

D:\ray\ml\newdlaicourse\foo_test.py in main()
     23   import os
     24   print(f'in: {os.path.basename(__file__)}::main()')
---> 25   tf.test.main()
     26 if __name__ == '__main__':
     27   main()

d:\pyvenvs\tf2.4\lib\site-packages\tensorflow\python\platform\test.py in main(argv)
     56   """Runs all unit tests."""
     57   _test_util.InstallStackTraceHandler()
---> 58   return _googletest.main(argv)
     59 
     60 

d:\pyvenvs\tf2.4\lib\site-packages\tensorflow\python\platform\googletest.py in main(argv)
     64       args = sys.argv
     65     return app.run(main=g_main, argv=args)
---> 66   benchmark.benchmarks_main(true_main=main_wrapper)
     67 
     68 

d:\pyvenvs\tf2.4\lib\site-packages\tensorflow\python\platform\benchmark.py in benchmarks_main(true_main, argv)
    484     app.run(lambda _: _run_benchmarks(regex), argv=argv)
    485   else:
--> 486     true_main()

d:\pyvenvs\tf2.4\lib\site-packages\tensorflow\python\platform\googletest.py in main_wrapper()
     63     if args is None:
     64       args = sys.argv
---> 65     return app.run(main=g_main, argv=args)
     66   benchmark.benchmarks_main(true_main=main_wrapper)
     67 

d:\pyvenvs\tf2.4\lib\site-packages\tensorflow\python\platform\app.py in run(main, argv)
     38   main = main or _sys.modules['__main__'].main
     39 
---> 40   _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)

d:\pyvenvs\tf2.4\lib\site-packages\absl\app.py in run(main, argv, flags_parser)
    301       callback()
    302     try:
--> 303       _run_main(main, args)
    304     except UsageError as error:
    305       usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)

d:\pyvenvs\tf2.4\lib\site-packages\absl\app.py in _run_main(main, argv)
    249     sys.exit(retval)
    250   else:
--> 251     sys.exit(main(argv))
    252 
    253 

d:\pyvenvs\tf2.4\lib\site-packages\tensorflow\python\platform\googletest.py in g_main(argv)
     54   """Delegate to absltest.main."""
     55 
---> 56   absltest_main(argv=argv)
     57 
     58 

d:\pyvenvs\tf2.4\lib\site-packages\absl\testing\absltest.py in main(*args, **kwargs)
   2000   """
   2001   print_python_version()
-> 2002   _run_in_app(run_tests, args, kwargs)
   2003 
   2004 

d:\pyvenvs\tf2.4\lib\site-packages\absl\testing\absltest.py in _run_in_app(function, args, kwargs)
   2103     # after the command-line has been parsed. So we have the for loop below
   2104     # to change back flags to their old values.
-> 2105     argv = FLAGS(sys.argv)
   2106     for saved_flag in six.itervalues(saved_flags):
   2107       saved_flag.restore_flag()

d:\pyvenvs\tf2.4\lib\site-packages\absl\flags\_flagvalues.py in __call__(self, argv, known_only)
    652     for name, value in unknown_flags:
    653       suggestions = _helpers.get_flag_suggestions(name, list(self))
--> 654       raise _exceptions.UnrecognizedFlagError(
    655           name, value, suggestions=suggestions)
    656 

UnrecognizedFlagError: Unknown command line flag 'f'

Edit 1: Trying DorElias' suggestion, I get:

['d:\\pyvenvs\\tf2.4\\lib\\site-packages\\ipykernel_launcher.py']
in: foo_test.py::main()
Running tests under Python 3.8.3: d:\pyvenvs\tf2.4\scripts\python.exe
----------------------------------------------------------------------
Ran 0 tests in 0.000s

OK
An exception has occurred, use %tb to see the full traceback.

SystemExit: False


d:\pyvenvs\tf2.4\lib\site-packages\IPython\core\interactiveshell.py:3426: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)

Edit 2:

sys.argv = sys.argv[:1] # the first arg in argv is the name of the script and maybe we want to keep it
old_sysexit = sys.exit
try:
    sys.exit = lambda *args: None
    foo_test.main()
finally:
    sys.exit = old_sysexit

gets:

['d:\\pyvenvs\\tf2.4\\lib\\site-packages\\ipykernel_launcher.py']
in: foo_test.py::main()
after: tf.test.main()
Running tests under Python 3.8.3: d:\pyvenvs\tf2.4\scripts\python.exe
----------------------------------------------------------------------
Ran 0 tests in 0.000s

OK

This code (foo_test.py) seems to work:

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.python.framework import test_util
class MyTestCase(tf.test.TestCase):
  #@run_in_graph_and_eager_modes
  def setUp(self):
    super(MyTestCase,self).setUp()
  def tearDown(self):
    super(MyTestCase,self).tearDown()
  def a_test(self):
    print("a test")
    x = tf.constant([1, 2])
    y = tf.constant([3, 4])
    z = tf.add(x, y)
    self.assertAllEqual([4, 6], self.evaluate(z))
  def test2(self):
    print("test 2")
    self.assertEqual(3,3)
  def test3(self):
    print("test 3")
    self.assertEqual(3,4)
def main():
  import os
  print(f'in: {os.path.basename(__file__)}::main()')    
  tf.test.main()
if __name__ == '__main__':
  main()


(tf2.4) D:\ray\ml\newdlaicourse>py foo_test.py
in: foo_test.py::main()
Running tests under Python 3.8.3: D:\pyvenvs\tf2.4\Scripts\python.exe
[ RUN      ] MyTestCase.test2
test 2
INFO:tensorflow:time(__main__.MyTestCase.test2): 0.0s
I1110 18:35:10.862683  9316 test_util.py:2075] time(__main__.MyTestCase.test2): 0.0s
[       OK ] MyTestCase.test2
[ RUN      ] MyTestCase.test3
test 3
INFO:tensorflow:time(__main__.MyTestCase.test3): 0.0s
I1110 18:35:10.863683  9316 test_util.py:2075] time(__main__.MyTestCase.test3): 0.0s
[  FAILED  ] MyTestCase.test3
[ RUN      ] MyTestCase.test_session
[  SKIPPED ] MyTestCase.test_session
======================================================================
FAIL: test3 (__main__.MyTestCase)
MyTestCase.test3
----------------------------------------------------------------------
Traceback (most recent call last):
  File "foo_test.py", line 21, in test3
    self.assertEqual(3,4)
AssertionError: 3 != 4

----------------------------------------------------------------------
Ran 3 tests in 0.003s

FAILED (failures=1, skipped=1)

it seems that when you run with notebook, there is an argument to the command that "Stays" when you run tf.test.main().

have you tried to simply remove all arguments from sys.argv before calling main()?

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import tensorflow_datasets as tfds
from importlib import reload
from tensorflow.python.framework import test_util
import foo_test
foo_test=reload(foo_test)

import sys
sys.argv = sys.argv[:1] # the first arg in argv is the name of the script and maybe we want to keep it

foo_test.main()

EDIT: apperently the unittest library has a option (that is on by default) to exit the program when the test run is finished. and the notebook overrrides this method so that you wont exit the whole notebook, and it insteads raises an exception when you try to exit to stop the code.

it seems that tensorflow doesnt pass the exit parameter to unittest, so it will always exit.

you can do this workaround (but its kind of patchy):

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import tensorflow_datasets as tfds
from importlib import reload
from tensorflow.python.framework import test_util
import foo_test
foo_test=reload(foo_test)

import sys
sys.argv = sys.argv[:1] # the first arg in argv is the name of the script and maybe we want to keep it
old_sysexit = sys.exit
try:
    sys.exit = lambda *args: None
    foo_test.main()
finally:
    sys.exit = old_sysexit

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM