diff --git a/algobattle/battle.py b/algobattle/battle.py index 5fbb5b08..46f8b543 100644 --- a/algobattle/battle.py +++ b/algobattle/battle.py @@ -3,6 +3,7 @@ This module contains the :class:`Battle` class, which speciefies how each type of battle is fought and scored, some basic battle types, and related classed. """ + from abc import abstractmethod from collections.abc import Iterable from dataclasses import dataclass, field @@ -201,8 +202,7 @@ async def run( *, with_results: Literal[False] = False, **kwargs: Unpack[RunKwargs], - ) -> Fight: - ... + ) -> Fight: ... @overload async def run( @@ -211,8 +211,7 @@ async def run( *, with_results: Literal[True], **kwargs: Unpack[RunKwargs], - ) -> tuple[Fight, GeneratorResult, SolverResult | None]: - ... + ) -> tuple[Fight, GeneratorResult, SolverResult | None]: ... async def run( self, @@ -250,12 +249,7 @@ async def run( The resulting info about the executed fight, and the results if the flag has been set. """ gen_result, sol_result = await self.run_raw(max_size=max_size, **kwargs) - if gen_result.instance is None or gen_result.solution is None: - score = 1 - elif sol_result is None or sol_result.solution is None: - score = 0 - else: - score = self.calculate_score(gen_result, sol_result) + score = self.calculate_score(gen_result, sol_result) fight = Fight.from_results( score=score, max_size=max_size, @@ -321,7 +315,7 @@ async def run_raw( ) return gen_result, sol_result - def calculate_score(self, gen_result: GeneratorResult, sol_result: SolverResult) -> float: + def calculate_score(self, gen_result: GeneratorResult, sol_result: SolverResult | None) -> float: """Calculates the score achieved by the solver in this fight. Both results need to contain all instance and/or solution data required. @@ -333,9 +327,25 @@ def calculate_score(self, gen_result: GeneratorResult, sol_result: SolverResult) Returns: A number in [0, 1] with higher numbers meaning the solver performed better. """ - assert gen_result.instance is not None - assert sol_result.solution is not None + # We first need to check whether the generator somehow failed. This can happen in three cases: + # it doesn't produce an instance, it creates an invalid instance or it doesn't create a needed solution + # note that gen_result.instance and gen_result.error will both contain data if the generator outputted an + # invalid instance and/or solution + if ( + gen_result.instance is None + or gen_result.error is not None + or (self.problem.with_solution and gen_result.solution is None) + ): + return 1 + # sol_result is None only if the generator failed, which needs to be caught by the above check + assert sol_result is not None + # The solver failed if it didn't output a solution or the solution contains some error + if sol_result.solution is None or sol_result.error is not None: + return 0 + if self.problem.with_solution: + # we need this assert since type checkers can't see the dependency between self.problem.with_solution + # and gen_result.solution that is created by the last check in the first "if" assert gen_result.solution is not None score = self.problem.score( gen_result.instance, solver_solution=sol_result.solution, generator_solution=gen_result.solution @@ -449,8 +459,7 @@ class FallbackConfig(Config): if TYPE_CHECKING: # to hint that we're gonna fill this with arbitrary data belonging to some supposed battle type - def __getattr__(self, attr: str, /) -> Any: - ... + def __getattr__(self, attr: str, /) -> Any: ... class UiData(BaseModel): """Object containing custom diplay data. @@ -705,13 +714,13 @@ class Config(Battle.Config): """Number of fights that will be fought.""" weighting: Annotated[float, Ge(0)] = 1.1 """How much each successive fight should be weighted more than the previous.""" - scores: set[Role] = {Role.generator, Role.solver} # noqa: RUF012 + scores: set[Role] = {Role.generator, Role.solver} # noqa: RUF012 """Who to show each fight's scores to.""" - instances: set[Role] = {Role.generator, Role.solver} # noqa: RUF012 + instances: set[Role] = {Role.generator, Role.solver} # noqa: RUF012 """Who to show the instances to.""" - generator_solutions: set[Role] = {Role.generator} # noqa: RUF012 + generator_solutions: set[Role] = {Role.generator} # noqa: RUF012 """Who to show the generator's solutions to, if the problem requires them.""" - solver_solutions: set[Role] = {Role.solver} # noqa: RUF012 + solver_solutions: set[Role] = {Role.solver} # noqa: RUF012 """Who to show the solver's solutions to.""" class UiData(Battle.UiData): diff --git a/pyproject.toml b/pyproject.toml index 7ebbc91a..da89458d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "pdm.backend" [project] name = "algobattle-base" -version = "4.3.3" +version = "4.3.4" description = "The Algobattle lab course package." readme = "README.md" requires-python = ">=3.11" diff --git a/uv.lock b/uv.lock index 1b241d91..fd46aa9d 100644 --- a/uv.lock +++ b/uv.lock @@ -4,7 +4,7 @@ requires-python = ">=3.11" [[package]] name = "algobattle-base" -version = "4.3.2" +version = "4.3.4" source = { editable = "." } dependencies = [ { name = "anyio" },