diff --git a/ament_black/ament_black/main.py b/ament_black/ament_black/main.py index 5af6523..d842a03 100755 --- a/ament_black/ament_black/main.py +++ b/ament_black/ament_black/main.py @@ -23,15 +23,19 @@ import time from xml.sax.saxutils import escape, quoteattr +import click +from black import get_sources +from black import main as black +from black import re_compile_maybe_verbose +from black.concurrency import maybe_install_uvloop +from black.const import DEFAULT_EXCLUDES, DEFAULT_INCLUDES +from black.report import Report from unidiff import PatchSet def patched_black(*args, **kwargs) -> None: from multiprocessing import freeze_support - from black import main as black - from black.concurrency import maybe_install_uvloop - maybe_install_uvloop() freeze_support() black(*args, **kwargs) @@ -46,7 +50,7 @@ def main(argv=sys.argv[1:]): "paths", nargs="*", default=[os.curdir], - help="The files or directories to check. this argument is directly passed to black", + help="The files or directories to check. this argument is directly passed to black", ) parser.add_argument( "--config", @@ -56,11 +60,14 @@ def main(argv=sys.argv[1:]): help="The config file", ) parser.add_argument( - "--reformat", action="store_true", help="Reformat the files in place" + "--reformat", + action="store_true", + help="Reformat the files in place", + ) + parser.add_argument( + "--xunit-file", + help="Generate a xunit compliant XML file", ) - # not using a file handle directly - # in order to prevent leaving an empty file when something fails early - parser.add_argument("--xunit-file", help="Generate a xunit compliant XML file") args = parser.parse_args(argv) # if we have specified a config file, make sure it exists and abort if not @@ -68,9 +75,23 @@ def main(argv=sys.argv[1:]): print("Could not find config file '%s'" % args.config_file, file=sys.stderr) return 1 + # TODO(Nacho): Inject the config file results into the ctx (use read_pyproject_toml) + sources = get_sources( + ctx=click.Context(black), + src=tuple(args.paths), + quiet=True, + verbose=False, + include=re_compile_maybe_verbose(DEFAULT_INCLUDES), + exclude=re_compile_maybe_verbose(DEFAULT_EXCLUDES), + extend_exclude=None, + force_exclude=None, + report=Report(), + stdin_filename="", + ) + checked_files = [str(path) for path in sources] + if args.xunit_file: start_time = time.time() - report = [] # invoke black black_args_withouth_path = [] @@ -124,7 +145,9 @@ def main(argv=sys.argv[1:]): file_name = file_name.split(suffix)[0] testname = "%s.%s" % (folder_name, file_name) - xml = get_xunit_content(report, testname, time.time() - start_time) + xml = get_xunit_content( + report, testname, time.time() - start_time, checked_files + ) path = os.path.dirname(os.path.abspath(args.xunit_file)) if not os.path.exists(path): os.makedirs(path) @@ -154,7 +177,7 @@ def get_line_number(data, offset): return data[0:offset].count("\n") + data[0:offset].count("\r") + 1 -def get_xunit_content(report, testname, elapsed): +def get_xunit_content(report, testname, elapsed, checked_files): test_count = sum(max(len(r), 1) for r in report.values()) error_count = sum(len(r) for r in report.values()) data = { @@ -212,10 +235,10 @@ def get_xunit_content(report, testname, elapsed): # output list of checked files data = { - "escaped_files": escape("".join(["\n* %s" % r for r in sorted(report.keys())])) + "checked_files": escape("".join(["\n* %s" % r for r in sorted(checked_files)])) } xml += ( - """ Checked files:%(escaped_files)s + """ Checked files:%(checked_files)s """ % data )