diff --git a/pytest_black.py b/pytest_black.py index 04c80cb..23b461f 100644 --- a/pytest_black.py +++ b/pytest_black.py @@ -20,13 +20,10 @@ def pytest_addoption(parser): ) -def pytest_collect_file(path, parent): +def pytest_collect_file(file_path, path, parent): config = parent.config if config.option.black and path.ext == ".py": - if hasattr(BlackItem, "from_parent"): - return BlackItem.from_parent(parent, fspath=path) - else: - return BlackItem(path, parent) + return BlackFile.from_parent(parent, path=file_path) def pytest_configure(config): @@ -42,10 +39,17 @@ def pytest_unconfigure(config): config.cache.set(HISTKEY, config._blackmtimes) -class BlackItem(pytest.Item, pytest.File): - def __init__(self, fspath, parent): - super(BlackItem, self).__init__(fspath, parent) - self._nodeid += "::BLACK" +class BlackFile(pytest.File): + def collect(self): + """ returns a list of children (items and collectors) + for this collection node. + """ + yield BlackItem.from_parent(self, name="black") + + +class BlackItem(pytest.Item): + def __init__(self, **kwargs): + super(BlackItem, self).__init__(**kwargs) self.add_marker("black") try: with open("pyproject.toml") as toml_file: @@ -61,8 +65,8 @@ def __init__(self, fspath, parent): def setup(self): pytest.importorskip("black") mtimes = getattr(self.config, "_blackmtimes", {}) - self._blackmtime = self.fspath.mtime() - old = mtimes.get(str(self.fspath), 0) + self._blackmtime = self.path.stat().st_mtime + old = mtimes.get(str(self.path), 0) if self._blackmtime == old: pytest.skip("file(s) previously passed black format checks") @@ -70,7 +74,7 @@ def setup(self): pytest.skip("file(s) excluded by pyproject.toml") def runtest(self): - cmd = [sys.executable, "-m", "black", "--check", "--diff", "--quiet", str(self.fspath)] + cmd = [sys.executable, "-m", "black", "--check", "--diff", "--quiet", str(self.path)] try: subprocess.run( cmd, check=True, stdout=subprocess.PIPE, universal_newlines=True @@ -79,7 +83,7 @@ def runtest(self): raise BlackError(e) mtimes = getattr(self.config, "_blackmtimes", {}) - mtimes[str(self.fspath)] = self._blackmtime + mtimes[str(self.path)] = self._blackmtime def repr_failure(self, excinfo): if excinfo.errisinstance(BlackError): @@ -87,7 +91,7 @@ def repr_failure(self, excinfo): return super(BlackItem, self).repr_failure(excinfo) def reportinfo(self): - return (self.fspath, -1, "Black format check") + return (self.path, -1, "Black format check") def _skip_test(self): return self._excluded() or (not self._included()) @@ -95,24 +99,18 @@ def _skip_test(self): def _included(self): if "include" not in self.pyproject: return True - return re.search(self.pyproject["include"], str(self.fspath)) + return re.search(self.pyproject["include"], str(self.path)) def _excluded(self): if "exclude" not in self.pyproject: return False - return re.search(self.pyproject["exclude"], str(self.fspath)) + return re.search(self.pyproject["exclude"], str(self.path)) def _re_fix_verbose(self, regex): if "\n" in regex: regex = "(?x)" + regex return re.compile(regex) - def collect(self): - """ returns a list of children (items and collectors) - for this collection node. - """ - return (self,) - class BlackError(Exception): pass diff --git a/setup.py b/setup.py index 527d23d..88c745f 100644 --- a/setup.py +++ b/setup.py @@ -25,8 +25,8 @@ def read(fname): py_modules=["pytest_black"], python_requires=">=3.5", install_requires=[ - "pytest>=3.5.0", - 'black>=22.1.0"', + "black>=22.1.0", + "pytest>=7.0.0", "toml", ], use_scm_version=True,