Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions algobattle/battle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This module contains the :class:`Battle` class, which speciefies how each type of battle is fought and scored,
some basic battle types, and related classed.
"""

from abc import abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass, field
Expand Down Expand Up @@ -201,8 +202,7 @@ async def run(
*,
with_results: Literal[False] = False,
**kwargs: Unpack[RunKwargs],
) -> Fight:
...
) -> Fight: ...

@overload
async def run(
Expand All @@ -211,8 +211,7 @@ async def run(
*,
with_results: Literal[True],
**kwargs: Unpack[RunKwargs],
) -> tuple[Fight, GeneratorResult, SolverResult | None]:
...
) -> tuple[Fight, GeneratorResult, SolverResult | None]: ...

async def run(
self,
Expand Down Expand Up @@ -250,12 +249,7 @@ async def run(
The resulting info about the executed fight, and the results if the flag has been set.
"""
gen_result, sol_result = await self.run_raw(max_size=max_size, **kwargs)
if gen_result.instance is None or gen_result.solution is None:
score = 1
elif sol_result is None or sol_result.solution is None:
score = 0
else:
score = self.calculate_score(gen_result, sol_result)
score = self.calculate_score(gen_result, sol_result)
fight = Fight.from_results(
score=score,
max_size=max_size,
Expand Down Expand Up @@ -321,7 +315,7 @@ async def run_raw(
)
return gen_result, sol_result

def calculate_score(self, gen_result: GeneratorResult, sol_result: SolverResult) -> float:
def calculate_score(self, gen_result: GeneratorResult, sol_result: SolverResult | None) -> float:
"""Calculates the score achieved by the solver in this fight.

Both results need to contain all instance and/or solution data required.
Expand All @@ -333,9 +327,25 @@ def calculate_score(self, gen_result: GeneratorResult, sol_result: SolverResult)
Returns:
A number in [0, 1] with higher numbers meaning the solver performed better.
"""
assert gen_result.instance is not None
assert sol_result.solution is not None
# We first need to check whether the generator somehow failed. This can happen in three cases:
# it doesn't produce an instance, it creates an invalid instance or it doesn't create a needed solution
# note that gen_result.instance and gen_result.error will both contain data if the generator outputted an
# invalid instance and/or solution
if (
gen_result.instance is None
or gen_result.error is not None
or (self.problem.with_solution and gen_result.solution is None)
):
return 1
# sol_result is None only if the generator failed, which needs to be caught by the above check
assert sol_result is not None
# The solver failed if it didn't output a solution or the solution contains some error
if sol_result.solution is None or sol_result.error is not None:
return 0

if self.problem.with_solution:
# we need this assert since type checkers can't see the dependency between self.problem.with_solution
# and gen_result.solution that is created by the last check in the first "if"
assert gen_result.solution is not None
score = self.problem.score(
gen_result.instance, solver_solution=sol_result.solution, generator_solution=gen_result.solution
Expand Down Expand Up @@ -449,8 +459,7 @@ class FallbackConfig(Config):

if TYPE_CHECKING:
# to hint that we're gonna fill this with arbitrary data belonging to some supposed battle type
def __getattr__(self, attr: str, /) -> Any:
...
def __getattr__(self, attr: str, /) -> Any: ...

class UiData(BaseModel):
"""Object containing custom diplay data.
Expand Down Expand Up @@ -705,13 +714,13 @@ class Config(Battle.Config):
"""Number of fights that will be fought."""
weighting: Annotated[float, Ge(0)] = 1.1
"""How much each successive fight should be weighted more than the previous."""
scores: set[Role] = {Role.generator, Role.solver} # noqa: RUF012
scores: set[Role] = {Role.generator, Role.solver} # noqa: RUF012
"""Who to show each fight's scores to."""
instances: set[Role] = {Role.generator, Role.solver} # noqa: RUF012
instances: set[Role] = {Role.generator, Role.solver} # noqa: RUF012
"""Who to show the instances to."""
generator_solutions: set[Role] = {Role.generator} # noqa: RUF012
generator_solutions: set[Role] = {Role.generator} # noqa: RUF012
"""Who to show the generator's solutions to, if the problem requires them."""
solver_solutions: set[Role] = {Role.solver} # noqa: RUF012
solver_solutions: set[Role] = {Role.solver} # noqa: RUF012
"""Who to show the solver's solutions to."""

class UiData(Battle.UiData):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "pdm.backend"

[project]
name = "algobattle-base"
version = "4.3.3"
version = "4.3.4"
description = "The Algobattle lab course package."
readme = "README.md"
requires-python = ">=3.11"
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.