Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 71 additions & 13 deletions indico_toolkit/results/predictionlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,42 @@ def groupby(
self, key: "Callable[[PredictionType], KeyType]"
) -> "dict[KeyType, Self]":
"""
Group predictions into a dictionary using `key`.
E.g. `key=attrgetter("label")` or `key=attrgetter("model")`
Group predictions into a dictionary using `key` to derive each prediction's key.
E.g. `key=attrgetter("label")` or `key=attrgetter("model")`.

If a derived key is an unhashable mutable collection (like set),
it's automatically converted to its hashable immutable variant (like frozenset).
This makes it easy to group by linked labels or unbundling pages.
"""
grouped = defaultdict(type(self)) # type: ignore[var-annotated]

for prediction in self:
grouped[key(prediction)].append(prediction)
derived_key = key(prediction)

if isinstance(derived_key, list):
derived_key = tuple(derived_key) # type: ignore[assignment]
elif isinstance(derived_key, set):
derived_key = frozenset(derived_key) # type: ignore[assignment]

grouped[derived_key].append(prediction)

return grouped

def groupbyiter(
self, keys: "Callable[[PredictionType], Iterable[KeyType]]"
) -> "dict[KeyType, Self]":
"""
Group predictions into a dictionary using `key` to derive an iterable of keys.
E.g. `key=attrgetter("groups")` or `key=attrgetter("pages")`.

Each prediction is associated with every key in the iterable individually.
If the iterable is empty, the prediction is not included in any group.
"""
grouped = defaultdict(type(self)) # type: ignore[var-annotated]

for prediction in self:
for derived_key in keys(prediction):
grouped[derived_key].append(prediction)

return grouped

Expand All @@ -112,14 +141,17 @@ def where(
predicate: "Callable[[PredictionType], bool] | None" = None,
*,
document: "Document | None" = None,
document_in: "Container[Document] | None" = None,
model: "ModelGroup | TaskType | str | None" = None,
model_in: "Container[ModelGroup | TaskType | str] | None" = None,
review: "Review | ReviewType | None" = ReviewUnspecified,
review_in: "Container[Review | ReviewType | None]" = {ReviewUnspecified},
label: "str | None" = None,
label_in: "Container[str] | None" = None,
min_confidence: "float | None" = None,
max_confidence: "float | None" = None,
page: "int | None" = None,
page_in: "Collection[int] | None" = None,
min_confidence: "float | None" = None,
max_confidence: "float | None" = None,
accepted: "bool | None" = None,
rejected: "bool | None" = None,
checked: "bool | None" = None,
Expand All @@ -129,20 +161,23 @@ def where(
Return a new prediction list containing predictions that match
all of the specified filters.

predicate: predictions for which this function returns True.
predicate: predictions for which this function returns True,
document: predictions from this document,
document_in: predictions from these documents,
model: predictions from this model, task type, or name,
model_in: predictions from these models, task types, or names,
review: predictions from this review or review type,
review_in: predictions from these reviews or review types,
label: predictions with this label,
label_in: predictions with one of these labels,
label_in: predictions with these labels,
page: extractions/unbundlings on this page,
page_in: extractions/unbundlings on these pages,
min_confidence: predictions with confidence >= this threshold,
max_confidence: predictions with confidence <= this threshold,
page: extractions/unbundlings on this page,
page_in: extractions/unbundlings on one of these pages,
accepted: extractions that have been accepted,
rejected: extractions that have been rejected,
checked: form extractions that are checked,
signed: form extractions that are signed,
accepted: extractions that have or haven't been accepted,
rejected: extractions that have or haven't been rejected,
checked: form extractions that are or aren't checked,
signed: form extractions that are or aren't signed.
"""
predicates = []

Expand All @@ -152,6 +187,9 @@ def where(
if document is not None:
predicates.append(lambda prediction: prediction.document == document)

if document_in is not None:
predicates.append(lambda prediction: prediction.document in document_in)

if model is not None:
predicates.append(
lambda prediction: (
Expand All @@ -161,6 +199,15 @@ def where(
)
)

if model_in is not None:
predicates.append(
lambda prediction: (
prediction.model in model_in
or prediction.model.task_type in model_in
or prediction.model.name in model_in
)
)

if review is not ReviewUnspecified:
predicates.append(
lambda prediction: (
Expand All @@ -172,6 +219,17 @@ def where(
)
)

if review_in != {ReviewUnspecified}:
predicates.append(
lambda prediction: (
prediction.review in review_in
or (
prediction.review is not None
and prediction.review.type in review_in
)
)
)

if label is not None:
predicates.append(lambda prediction: prediction.label == label)

Expand Down
Loading
Loading