Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 20 additions & 22 deletions pytest_black.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@ def pytest_addoption(parser):
)


def pytest_collect_file(path, parent):
def pytest_collect_file(file_path, path, parent):
config = parent.config
if config.option.black and path.ext == ".py":
if hasattr(BlackItem, "from_parent"):
return BlackItem.from_parent(parent, fspath=path)
else:
return BlackItem(path, parent)
return BlackFile.from_parent(parent, path=file_path)


def pytest_configure(config):
Expand All @@ -42,10 +39,17 @@ def pytest_unconfigure(config):
config.cache.set(HISTKEY, config._blackmtimes)


class BlackItem(pytest.Item, pytest.File):
def __init__(self, fspath, parent):
super(BlackItem, self).__init__(fspath, parent)
self._nodeid += "::BLACK"
class BlackFile(pytest.File):
def collect(self):
""" returns a list of children (items and collectors)
for this collection node.
"""
yield BlackItem.from_parent(self, name="black")


class BlackItem(pytest.Item):
def __init__(self, **kwargs):
super(BlackItem, self).__init__(**kwargs)
self.add_marker("black")
try:
with open("pyproject.toml") as toml_file:
Expand All @@ -61,16 +65,16 @@ def __init__(self, fspath, parent):
def setup(self):
pytest.importorskip("black")
mtimes = getattr(self.config, "_blackmtimes", {})
self._blackmtime = self.fspath.mtime()
old = mtimes.get(str(self.fspath), 0)
self._blackmtime = self.path.stat().st_mtime
old = mtimes.get(str(self.path), 0)
if self._blackmtime == old:
pytest.skip("file(s) previously passed black format checks")

if self._skip_test():
pytest.skip("file(s) excluded by pyproject.toml")

def runtest(self):
cmd = [sys.executable, "-m", "black", "--check", "--diff", "--quiet", str(self.fspath)]
cmd = [sys.executable, "-m", "black", "--check", "--diff", "--quiet", str(self.path)]
try:
subprocess.run(
cmd, check=True, stdout=subprocess.PIPE, universal_newlines=True
Expand All @@ -79,40 +83,34 @@ def runtest(self):
raise BlackError(e)

mtimes = getattr(self.config, "_blackmtimes", {})
mtimes[str(self.fspath)] = self._blackmtime
mtimes[str(self.path)] = self._blackmtime

def repr_failure(self, excinfo):
if excinfo.errisinstance(BlackError):
return excinfo.value.args[0].stdout
return super(BlackItem, self).repr_failure(excinfo)

def reportinfo(self):
return (self.fspath, -1, "Black format check")
return (self.path, -1, "Black format check")

def _skip_test(self):
return self._excluded() or (not self._included())

def _included(self):
if "include" not in self.pyproject:
return True
return re.search(self.pyproject["include"], str(self.fspath))
return re.search(self.pyproject["include"], str(self.path))

def _excluded(self):
if "exclude" not in self.pyproject:
return False
return re.search(self.pyproject["exclude"], str(self.fspath))
return re.search(self.pyproject["exclude"], str(self.path))

def _re_fix_verbose(self, regex):
if "\n" in regex:
regex = "(?x)" + regex
return re.compile(regex)

def collect(self):
""" returns a list of children (items and collectors)
for this collection node.
"""
return (self,)


class BlackError(Exception):
pass
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def read(fname):
py_modules=["pytest_black"],
python_requires=">=3.5",
install_requires=[
"pytest>=3.5.0",
'black>=22.1.0"',
"black>=22.1.0",
"pytest>=7.0.0",
"toml",
],
use_scm_version=True,
Expand Down