import importlib import os import sys blacklist = {"test_wasserstein_distance", "test_wasserstein_barycenter"} testdir = sys.argv[1] sys.path.append(testdir) results = [] with os.scandir(testdir) as it: for entry in it: if entry.is_file() and entry.name.startswith("test_") and entry.name.endswith(".py"): name = entry.name.rstrip(".py") if name in blacklist: print("Skipping tests in %s due to blacklist." %(name)) else: print("Running tests from %s." %(name)) module = importlib.import_module(name) tests = [f for f in dir(module) if str(f).startswith("test_")] for t in tests: func = getattr(module, t) if callable(func): print(" ", t) ok = True try: func() except AssertionError: ok = False if ok: print(" OK!") else: print(" FAIL!") results.append(ok) if all(results): exit(0) else: exit(1)