diff --git a/ament_black/ament_black/main.py b/ament_black/ament_black/main.py
index 5af6523..d842a03 100755
--- a/ament_black/ament_black/main.py
+++ b/ament_black/ament_black/main.py
@@ -23,15 +23,19 @@
import time
from xml.sax.saxutils import escape, quoteattr
+import click
+from black import get_sources
+from black import main as black
+from black import re_compile_maybe_verbose
+from black.concurrency import maybe_install_uvloop
+from black.const import DEFAULT_EXCLUDES, DEFAULT_INCLUDES
+from black.report import Report
from unidiff import PatchSet
def patched_black(*args, **kwargs) -> None:
from multiprocessing import freeze_support
- from black import main as black
- from black.concurrency import maybe_install_uvloop
-
maybe_install_uvloop()
freeze_support()
black(*args, **kwargs)
@@ -46,7 +50,7 @@ def main(argv=sys.argv[1:]):
"paths",
nargs="*",
default=[os.curdir],
- help="The files or directories to check. this argument is directly passed to black",
+ help="The files or directories to check. this argument is directly passed to black",
)
parser.add_argument(
"--config",
@@ -56,11 +60,14 @@ def main(argv=sys.argv[1:]):
help="The config file",
)
parser.add_argument(
- "--reformat", action="store_true", help="Reformat the files in place"
+ "--reformat",
+ action="store_true",
+ help="Reformat the files in place",
+ )
+ parser.add_argument(
+ "--xunit-file",
+ help="Generate a xunit compliant XML file",
)
- # not using a file handle directly
- # in order to prevent leaving an empty file when something fails early
- parser.add_argument("--xunit-file", help="Generate a xunit compliant XML file")
args = parser.parse_args(argv)
# if we have specified a config file, make sure it exists and abort if not
@@ -68,9 +75,23 @@ def main(argv=sys.argv[1:]):
print("Could not find config file '%s'" % args.config_file, file=sys.stderr)
return 1
+ # TODO(Nacho): Inject the config file results into the ctx (use read_pyproject_toml)
+ sources = get_sources(
+ ctx=click.Context(black),
+ src=tuple(args.paths),
+ quiet=True,
+ verbose=False,
+ include=re_compile_maybe_verbose(DEFAULT_INCLUDES),
+ exclude=re_compile_maybe_verbose(DEFAULT_EXCLUDES),
+ extend_exclude=None,
+ force_exclude=None,
+ report=Report(),
+ stdin_filename="",
+ )
+ checked_files = [str(path) for path in sources]
+
if args.xunit_file:
start_time = time.time()
- report = []
# invoke black
black_args_withouth_path = []
@@ -124,7 +145,9 @@ def main(argv=sys.argv[1:]):
file_name = file_name.split(suffix)[0]
testname = "%s.%s" % (folder_name, file_name)
- xml = get_xunit_content(report, testname, time.time() - start_time)
+ xml = get_xunit_content(
+ report, testname, time.time() - start_time, checked_files
+ )
path = os.path.dirname(os.path.abspath(args.xunit_file))
if not os.path.exists(path):
os.makedirs(path)
@@ -154,7 +177,7 @@ def get_line_number(data, offset):
return data[0:offset].count("\n") + data[0:offset].count("\r") + 1
-def get_xunit_content(report, testname, elapsed):
+def get_xunit_content(report, testname, elapsed, checked_files):
test_count = sum(max(len(r), 1) for r in report.values())
error_count = sum(len(r) for r in report.values())
data = {
@@ -212,10 +235,10 @@ def get_xunit_content(report, testname, elapsed):
# output list of checked files
data = {
- "escaped_files": escape("".join(["\n* %s" % r for r in sorted(report.keys())]))
+ "checked_files": escape("".join(["\n* %s" % r for r in sorted(checked_files)]))
}
xml += (
- """ Checked files:%(escaped_files)s
+ """ Checked files:%(checked_files)s
"""
% data
)