diff --git a/indico_toolkit/results/predictionlist.py b/indico_toolkit/results/predictionlist.py index abf2c9c..14882a2 100644 --- a/indico_toolkit/results/predictionlist.py +++ b/indico_toolkit/results/predictionlist.py @@ -80,13 +80,42 @@ def groupby( self, key: "Callable[[PredictionType], KeyType]" ) -> "dict[KeyType, Self]": """ - Group predictions into a dictionary using `key`. - E.g. `key=attrgetter("label")` or `key=attrgetter("model")` + Group predictions into a dictionary using `key` to derive each prediction's key. + E.g. `key=attrgetter("label")` or `key=attrgetter("model")`. + + If a derived key is an unhashable mutable collection (like set), + it's automatically converted to its hashable immutable variant (like frozenset). + This makes it easy to group by linked labels or unbundling pages. """ grouped = defaultdict(type(self)) # type: ignore[var-annotated] for prediction in self: - grouped[key(prediction)].append(prediction) + derived_key = key(prediction) + + if isinstance(derived_key, list): + derived_key = tuple(derived_key) # type: ignore[assignment] + elif isinstance(derived_key, set): + derived_key = frozenset(derived_key) # type: ignore[assignment] + + grouped[derived_key].append(prediction) + + return grouped + + def groupbyiter( + self, keys: "Callable[[PredictionType], Iterable[KeyType]]" + ) -> "dict[KeyType, Self]": + """ + Group predictions into a dictionary using `key` to derive an iterable of keys. + E.g. `key=attrgetter("groups")` or `key=attrgetter("pages")`. + + Each prediction is associated with every key in the iterable individually. + If the iterable is empty, the prediction is not included in any group. + """ + grouped = defaultdict(type(self)) # type: ignore[var-annotated] + + for prediction in self: + for derived_key in keys(prediction): + grouped[derived_key].append(prediction) return grouped @@ -112,14 +141,17 @@ def where( predicate: "Callable[[PredictionType], bool] | None" = None, *, document: "Document | None" = None, + document_in: "Container[Document] | None" = None, model: "ModelGroup | TaskType | str | None" = None, + model_in: "Container[ModelGroup | TaskType | str] | None" = None, review: "Review | ReviewType | None" = ReviewUnspecified, + review_in: "Container[Review | ReviewType | None]" = {ReviewUnspecified}, label: "str | None" = None, label_in: "Container[str] | None" = None, - min_confidence: "float | None" = None, - max_confidence: "float | None" = None, page: "int | None" = None, page_in: "Collection[int] | None" = None, + min_confidence: "float | None" = None, + max_confidence: "float | None" = None, accepted: "bool | None" = None, rejected: "bool | None" = None, checked: "bool | None" = None, @@ -129,20 +161,23 @@ def where( Return a new prediction list containing predictions that match all of the specified filters. - predicate: predictions for which this function returns True. + predicate: predictions for which this function returns True, document: predictions from this document, + document_in: predictions from these documents, model: predictions from this model, task type, or name, + model_in: predictions from these models, task types, or names, review: predictions from this review or review type, + review_in: predictions from these reviews or review types, label: predictions with this label, - label_in: predictions with one of these labels, + label_in: predictions with these labels, + page: extractions/unbundlings on this page, + page_in: extractions/unbundlings on these pages, min_confidence: predictions with confidence >= this threshold, max_confidence: predictions with confidence <= this threshold, - page: extractions/unbundlings on this page, - page_in: extractions/unbundlings on one of these pages, - accepted: extractions that have been accepted, - rejected: extractions that have been rejected, - checked: form extractions that are checked, - signed: form extractions that are signed, + accepted: extractions that have or haven't been accepted, + rejected: extractions that have or haven't been rejected, + checked: form extractions that are or aren't checked, + signed: form extractions that are or aren't signed. """ predicates = [] @@ -152,6 +187,9 @@ def where( if document is not None: predicates.append(lambda prediction: prediction.document == document) + if document_in is not None: + predicates.append(lambda prediction: prediction.document in document_in) + if model is not None: predicates.append( lambda prediction: ( @@ -161,6 +199,15 @@ def where( ) ) + if model_in is not None: + predicates.append( + lambda prediction: ( + prediction.model in model_in + or prediction.model.task_type in model_in + or prediction.model.name in model_in + ) + ) + if review is not ReviewUnspecified: predicates.append( lambda prediction: ( @@ -172,6 +219,17 @@ def where( ) ) + if review_in != {ReviewUnspecified}: + predicates.append( + lambda prediction: ( + prediction.review in review_in + or ( + prediction.review is not None + and prediction.review.type in review_in + ) + ) + ) + if label is not None: predicates.append(lambda prediction: prediction.label == label) diff --git a/tests/results/test_predictionlist.py b/tests/results/test_predictionlist.py index 7dd37e2..bbf9b53 100644 --- a/tests/results/test_predictionlist.py +++ b/tests/results/test_predictionlist.py @@ -6,6 +6,7 @@ Classification, Document, DocumentExtraction, + Group, ModelGroup, Prediction, PredictionList, @@ -53,6 +54,16 @@ def manual_review() -> Review: ) +@pytest.fixture +def group_alpha() -> Group: + return Group(id=12345, name="Alpha", index=0) + + +@pytest.fixture +def group_bravo() -> Group: + return Group(id=12345, name="Bravo", index=0) + + @pytest.fixture def predictions( document: Document, @@ -60,6 +71,8 @@ def predictions( extraction_model: ModelGroup, auto_review: Review, manual_review: Review, + group_alpha: Group, + group_bravo: Group, ) -> "PredictionList[Prediction]": return PredictionList( [ @@ -84,7 +97,7 @@ def predictions( start=352, end=356, page=0, - groups=set(), + groups={group_alpha}, ), DocumentExtraction( document=document, @@ -99,7 +112,7 @@ def predictions( start=357, end=360, page=1, - groups=set(), + groups={group_alpha, group_bravo}, ), ] ) @@ -111,7 +124,7 @@ def test_classifications(predictions: "PredictionList[Prediction]") -> None: def test_extractions(predictions: "PredictionList[Prediction]") -> None: - (first_extraction, second_extraction) = predictions.document_extractions + first_extraction, second_extraction = predictions.extractions assert isinstance(first_extraction, DocumentExtraction) assert isinstance(second_extraction, DocumentExtraction) @@ -122,9 +135,32 @@ def test_slice_is_prediction_list(predictions: "PredictionList[Prediction]") -> assert isinstance(predictions, PredictionList) +def test_groupby( + predictions: "PredictionList[Prediction]", group_alpha: Group, group_bravo: Group +) -> None: + first_name, last_name = predictions.extractions + predictions_by_groups = predictions.extractions.groupby(attrgetter("groups")) + assert predictions_by_groups == { + frozenset({group_alpha}): [first_name], + frozenset({group_alpha, group_bravo}): [last_name], + } + + +def test_groupbyiter( + predictions: "PredictionList[Prediction]", group_alpha: Group, group_bravo: Group +) -> None: + first_name, last_name = predictions.extractions + predictions_by_group = predictions.extractions.groupbyiter(attrgetter("groups")) + assert predictions_by_group == { + group_alpha: [first_name, last_name], + group_bravo: [last_name], + } + + def test_orderby(predictions: "PredictionList[Prediction]") -> None: + classification, first_name, last_name = predictions predictions = predictions.orderby(attrgetter("confidence"), reverse=True) - assert predictions[0].confidence == 0.9 + assert predictions == [last_name, first_name, classification] def test_where_document( @@ -133,115 +169,109 @@ def test_where_document( assert predictions.where(document=document) == predictions +def test_where_document_in( + predictions: "PredictionList[Prediction]", document: Document +) -> None: + assert predictions.where(document_in={document}) == predictions + assert predictions.where(document_in={}) == [] + + def test_where_model( predictions: "PredictionList[Prediction]", classification_model: ModelGroup ) -> None: - classification, first_name, last_name = predictions + (classification,) = predictions.classifications + assert predictions.where(model=classification_model) == [classification] + assert predictions.where(model=TaskType.CLASSIFICATION) == [classification] + assert predictions.where(model="Tax Classification") == [classification] - filtered = predictions.where(model=classification_model) - assert classification in filtered - assert first_name not in filtered - assert last_name not in filtered - filtered = predictions.where(model=TaskType.CLASSIFICATION) - assert classification in filtered - assert first_name not in filtered - assert last_name not in filtered +def test_where_model_in( + predictions: "PredictionList[Prediction]", classification_model: ModelGroup +) -> None: + classification, first_name, last_name = predictions + assert predictions.where(model_in={classification_model}) == [classification] + assert predictions.where(model_in={TaskType.CLASSIFICATION}) == [classification] + assert predictions.where( + model_in={TaskType.CLASSIFICATION, TaskType.DOCUMENT_EXTRACTION} + ) == [classification, first_name, last_name] + assert predictions.where(model_in={"Tax Classification"}) == [classification] + assert predictions.where( + model_in={"Tax Classification", "1040 Document Extraction"} + ) == [classification, first_name, last_name] + assert predictions.where(model_in={}) == [] + + +def test_where_review( + predictions: "PredictionList[Prediction]", auto_review: Review +) -> None: + classification, first_name, last_name = predictions + assert predictions.where(review=None) == [classification] + assert predictions.where(review=auto_review) == [first_name] + assert predictions.where(review=ReviewType.MANUAL) == [last_name] - filtered = predictions.where(model="Tax Classification") - assert classification in filtered - assert first_name not in filtered - assert last_name not in filtered +def test_where_review_in( + predictions: "PredictionList[Prediction]", auto_review: Review +) -> None: + classification, first_name, last_name = predictions + assert predictions.where(review_in={None}) == [classification] + assert predictions.where( + review_in={None, auto_review} + ) == [classification, first_name] + assert predictions.where( + review_in={auto_review, ReviewType.MANUAL} + ) == [first_name, last_name] + assert predictions.where(review_in={}) == [] def test_where_label(predictions: "PredictionList[Prediction]") -> None: - classification, first_name, last_name = predictions + first_name, _ = predictions.extractions + assert predictions.where(label="First Name") == [first_name] - filtered = predictions.where(label="First Name") - assert classification not in filtered - assert first_name in filtered - assert last_name not in filtered - filtered = predictions.where(label_in=("First Name", "Last Name")) - assert classification not in filtered - assert first_name in filtered - assert last_name in filtered +def test_where_label_in(predictions: "PredictionList[Prediction]") -> None: + first_name, last_name = predictions.extractions + assert predictions.where( + label_in=("First Name", "Last Name") + ) == [first_name, last_name] def test_where_confidence(predictions: "PredictionList[Prediction]") -> None: conf_70, conf_80, conf_90 = predictions - - filtered = predictions.where(min_confidence=0.9) - assert conf_70 not in filtered - assert conf_80 not in filtered - assert conf_90 in filtered - - filtered = predictions.where(min_confidence=0.75, max_confidence=0.85) - assert conf_70 not in filtered - assert conf_80 in filtered - assert conf_90 not in filtered - - filtered = predictions.where(max_confidence=0.7) - assert conf_70 in filtered - assert conf_80 not in filtered - assert conf_90 not in filtered + assert predictions.where(min_confidence=0.9) == [conf_90] + assert predictions.where(min_confidence=0.75, max_confidence=0.85) == [conf_80] + assert predictions.where(max_confidence=0.7) == [conf_70] def test_where_page(predictions: "PredictionList[Prediction]") -> None: - classification, first_name, last_name = predictions + first_name, _ = predictions.extractions + assert predictions.where(page=0) == [first_name] - filtered = predictions.where(page=0) - assert classification not in filtered - assert first_name in filtered - assert last_name not in filtered - filtered = predictions.where(page_in=(0, 1)) - assert classification not in filtered - assert first_name in filtered - assert last_name in filtered +def test_where_page_in(predictions: "PredictionList[Prediction]") -> None: + first_name, last_name = predictions.extractions + assert predictions.where(page_in=(0, 1)) == [first_name, last_name] def test_where_accepted(predictions: "PredictionList[Prediction]") -> None: - _, first_name, last_name = predictions + first_name, last_name = predictions.extractions predictions.unaccept() - filtered = predictions.where(accepted=True) - assert first_name not in filtered - assert last_name not in filtered - - filtered = predictions.where(accepted=False) - assert first_name in filtered - assert last_name in filtered + assert predictions.where(accepted=True) == [] + assert predictions.where(accepted=False) == [first_name, last_name] predictions.accept() - filtered = predictions.where(accepted=False) - assert first_name not in filtered - assert last_name not in filtered - - filtered = predictions.where(accepted=True) - assert first_name in filtered - assert last_name in filtered - + assert predictions.where(accepted=False) == [] + assert predictions.where(accepted=True) == [first_name, last_name] def test_where_rejected(predictions: "PredictionList[Prediction]") -> None: - _, first_name, last_name = predictions + first_name, last_name = predictions.extractions predictions.unreject() - filtered = predictions.where(rejected=True) - assert first_name not in filtered - assert last_name not in filtered - - filtered = predictions.where(rejected=False) - assert first_name in filtered - assert last_name in filtered + assert predictions.where(rejected=True) == [] + assert predictions.where(rejected=False) == [first_name, last_name] predictions.reject() - filtered = predictions.where(rejected=False) - assert first_name not in filtered - assert last_name not in filtered - - filtered = predictions.where(rejected=True) - assert first_name in filtered - assert last_name in filtered + assert predictions.where(rejected=False) == [] + assert predictions.where(rejected=True) == [first_name, last_name]