diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e5cbbd9..9f502a18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,34 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and versions match the minimum IPA version required to use functionality. +## [v7.2.2] - 2025-10-14 + +### Added + +- Parse table spans from ETL Output as `Table.spans`. +- `NULL_CELL`, `NULL_RANGE`, `NULL_TABLE`, and `NULL_TOKEN` constants. +- Document Extraction attributes for assigning tokens, tables, and cells from OCR: + - `DocumentExtraction.tokens`, `DocumentExtraction.tables`, `DocumentExtraction.cells` +- Document Extraction convenience properties for singular token, table, and cell access: + - `DocumentExtraction.token`, `DocumentExtraction.table`, `DocumentExtraction.cell` +- `PredictionList.assign_ocr(etl_outputs, tokens=True, tables=True)` method. +- Custom `__hash__` methods for tables and cells to speed up `.groupby(...)`. +- Prediction `.copy()` methods that only copy mutable state. + +### Changed + +- Move `Box` and `Span` from results to etloutput to avoid circular imports. + (Both can still be imported from either module.) +- Return `NULL_TOKEN` instead of raising an exception from `EtlOutput.token_for(span)`. +- Rewrite table cell lookup `EtlOutput.table_cells_for(span)` using a fast, span-based, + binary search algorithm that can return multiple overlapped table cells. + +### Removed + +- Custom `results` and `etloutput` error classes that are nearly never used. + (Replaced with idiomatic Python error classes.) + + ## [v7.2.1] - 2025-09-09 ### Fixed @@ -265,6 +293,7 @@ This is the first major version release tested to work on Indico 6.X. - Row Association now also sorting on 'bbtop'. +[v7.2.1]: https://github.com/IndicoDataSolutions/indico-toolkit-python/compare/v7.2.1...v7.2.2 [v7.2.1]: https://github.com/IndicoDataSolutions/indico-toolkit-python/compare/v7.2.0...v7.2.1 [v7.2.0]: https://github.com/IndicoDataSolutions/indico-toolkit-python/compare/v6.14.2...v7.2.0 [v6.14.2]: https://github.com/IndicoDataSolutions/indico-toolkit-python/compare/v6.14.1...v6.14.2 diff --git a/indico_toolkit/__init__.py b/indico_toolkit/__init__.py index 57b848f9..a10b5a2b 100644 --- a/indico_toolkit/__init__.py +++ b/indico_toolkit/__init__.py @@ -21,4 +21,4 @@ "ToolkitStaggeredLoopError", "ToolkitStatusError", ) -__version__ = "7.2.1" +__version__ = "7.2.2" diff --git a/indico_toolkit/etloutput/__init__.py b/indico_toolkit/etloutput/__init__.py index 9c7f3fef..5d2329b6 100644 --- a/indico_toolkit/etloutput/__init__.py +++ b/indico_toolkit/etloutput/__init__.py @@ -1,13 +1,13 @@ from typing import TYPE_CHECKING, TypeAlias, TypeVar -from ..results import NULL_BOX, NULL_SPAN, Box, Span -from ..results.utils import get, has, json_loaded, str_decoded -from .cell import Cell, CellType -from .errors import EtlOutputError, TableCellNotFoundError, TokenNotFoundError +from .box import NULL_BOX, Box +from .cell import NULL_CELL, Cell, CellType from .etloutput import EtlOutput -from .range import Range -from .table import Table -from .token import Token +from .range import NULL_RANGE, Range +from .span import NULL_SPAN, Span +from .table import NULL_TABLE, Table +from .token import NULL_TOKEN, Token +from .utils import get, has, json_loaded, str_decoded if TYPE_CHECKING: from collections.abc import Awaitable, Callable @@ -17,17 +17,18 @@ "Cell", "CellType", "EtlOutput", - "EtlOutputError", "load", "load_async", "NULL_BOX", + "NULL_CELL", + "NULL_RANGE", "NULL_SPAN", + "NULL_TABLE", + "NULL_TOKEN", "Range", "Span", "Table", - "TableCellNotFoundError", "Token", - "TokenNotFoundError", ) Loadable: TypeAlias = "dict[str, object] | list[object] | str | bytes" diff --git a/indico_toolkit/results/predictions/box.py b/indico_toolkit/etloutput/box.py similarity index 68% rename from indico_toolkit/results/predictions/box.py rename to indico_toolkit/etloutput/box.py index fb781474..8c01adfb 100644 --- a/indico_toolkit/results/predictions/box.py +++ b/indico_toolkit/etloutput/box.py @@ -1,10 +1,7 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Final -from ..utils import get - -if TYPE_CHECKING: - from typing import Final +from .utils import get @dataclass(frozen=True) @@ -18,6 +15,30 @@ class Box: def __bool__(self) -> bool: return self != NULL_BOX + def __and__(self, other: "Box") -> "Box": + """ + Return a new `Box` for the overlap between `self` and `other` + or `NULL_BOX` if they don't overlap. + + Supports set-like `extraction.box & cell.box` syntax. + """ + if ( + self.page != other.page + or self.bottom <= other.top # `self` is above `other` + or self.top >= other.bottom # `self` is below `other` + or self.right <= other.left # `self` is to the left of `other` + or self.left >= other.right # `self` is to the right of `other` + ): + return NULL_BOX + else: + return Box( + page=self.page, + top=max(self.top, other.top), + left=max(self.left, other.left), + right=min(self.right, other.right), + bottom=min(self.bottom, other.bottom), + ) + def __lt__(self, other: "Box") -> bool: """ Bounding boxes are sorted with vertical hysteresis. Those on the same line are @@ -58,4 +79,4 @@ def from_dict(box: object) -> "Box": # object rather than using `None` or raising an error. This lets you e.g. sort by the # `box` attribute without having to constantly check for `None`, while still allowing # you do a "None check" with `bool(extraction.box)` or `extraction.box == NULL_BOX`. -NULL_BOX: "Final" = Box(page=0, top=0, left=0, right=0, bottom=0) +NULL_BOX: Final = Box(page=0, top=0, left=0, right=0, bottom=0) diff --git a/indico_toolkit/etloutput/cell.py b/indico_toolkit/etloutput/cell.py index 292a78eb..3476644a 100644 --- a/indico_toolkit/etloutput/cell.py +++ b/indico_toolkit/etloutput/cell.py @@ -1,9 +1,11 @@ from dataclasses import dataclass from enum import Enum +from typing import Final -from ..results import NULL_SPAN, Box, Span -from ..results.utils import get -from .range import Range +from .box import NULL_BOX, Box +from .range import NULL_RANGE, Range +from .span import NULL_SPAN, Span +from .utils import get class CellType(Enum): @@ -19,6 +21,18 @@ class Cell: range: Range spans: "tuple[Span, ...]" + def __bool__(self) -> bool: + return self != NULL_CELL + + def __hash__(self) -> int: + """ + Uniquely identify cells by hashing their bounding box and spans. + + This is small speedup for `.groupby(attrgetter("cell"))` compared to + dataclasses's default __hash__ implementation. + """ + return hash((self.box, self.spans)) + @property def span(self) -> Span: """ @@ -45,3 +59,16 @@ def from_dict(cell: object, page: int) -> "Cell": range=Range.from_dict(cell), spans=tuple(map(Span.from_dict, get(cell, list, "doc_offsets"))), ) + + +# It's more ergonomic to represent the lack of cells with a special null cell object +# rather than using `None` or raising an error. This lets you e.g. sort by the `cell` +# attribute without having to constantly check for `None`, while still allowing you do +# a "None check" with `bool(extraction.cell)` or `extraction.cell == NULL_CELL`. +NULL_CELL: Final = Cell( + type=CellType.CONTENT, + text="", + box=NULL_BOX, + range=NULL_RANGE, + spans=tuple(), +) diff --git a/indico_toolkit/etloutput/errors.py b/indico_toolkit/etloutput/errors.py deleted file mode 100644 index d4688034..00000000 --- a/indico_toolkit/etloutput/errors.py +++ /dev/null @@ -1,16 +0,0 @@ -class EtlOutputError(Exception): - """ - Raised when an error occurs accessing `EtlOutput` values. - """ - - -class TokenNotFoundError(EtlOutputError): - """ - Raised when a `Token` can't be found for a `Span`. - """ - - -class TableCellNotFoundError(EtlOutputError): - """ - Raised when a `Table` and `Cell` can't be found for a `Token`. - """ diff --git a/indico_toolkit/etloutput/etloutput.py b/indico_toolkit/etloutput/etloutput.py index 899dadeb..da7e5460 100644 --- a/indico_toolkit/etloutput/etloutput.py +++ b/indico_toolkit/etloutput/etloutput.py @@ -1,18 +1,20 @@ import itertools from bisect import bisect_left, bisect_right +from collections import namedtuple from dataclasses import dataclass +from functools import cached_property from operator import attrgetter from typing import TYPE_CHECKING -from ..results import Box, Span -from .errors import TableCellNotFoundError, TokenNotFoundError +from .box import Box from .table import Table -from .token import Token +from .token import NULL_TOKEN, Token if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Iterator from .cell import Cell + from .span import Span @dataclass(frozen=True) @@ -54,18 +56,19 @@ def from_pages( tables_on_page=table_pages, ) - def token_for(self, span: Span) -> Token: + def token_for(self, span: "Span") -> Token: """ - Return a `Token` that contains every character from `span`. - Raise `TokenNotFoundError` if one can't be produced. + Return a `Token` that contains every character from `span` + or `NULL_TOKEN` if one doesn't exist. """ try: tokens = self.tokens_on_page[span.page] first = bisect_right(tokens, span.start, key=attrgetter("span.end")) last = bisect_left(tokens, span.end, lo=first, key=attrgetter("span.start")) tokens = tokens[first:last] - except (IndexError, ValueError) as error: - raise TokenNotFoundError(f"no token contains {span!r}") from error + assert tokens + except (AssertionError, IndexError, ValueError): + return NULL_TOKEN return Token( text=self.text[span.slice], @@ -79,28 +82,53 @@ def token_for(self, span: Span) -> Token: span=span, ) - def table_cell_for(self, token: Token) -> "tuple[Table, Cell]": + _TableCellSpan = namedtuple("_TableCellSpan", ["table", "cell", "span"]) + + @cached_property + def _table_cell_spans_on_page(self) -> "tuple[tuple[_TableCellSpan, ...], ...]": + """ + Order table cells on each page by their spans such that they can be bisected. + """ + return tuple( + tuple( + sorted( + ( + self._TableCellSpan(table, cell, span) + for table in page_tables + for cell in table.cells + for span in cell.spans + if span + ), + key=attrgetter("span"), + ) + ) + for page_tables in self.tables_on_page + ) + + def table_cells_for(self, span: "Span") -> "Iterator[tuple[Table, Cell]]": """ - Return the `Table` and `Cell` that contain the midpoint of `token`. - Raise `TableCellNotFoundError` if it's not inside a table cell. + Yield the table cells that overlap with `span`. + + Note: a single span may overlap the same cell multiple times causing it to be + yielded multiple times. Deduplication in `DocumentExtraction.table_cells` + accounts for this when OCR is assigned with `PredictionList.assign_ocr()`. """ - token_vmid = (token.box.top + token.box.bottom) // 2 - token_hmid = (token.box.left + token.box.right) // 2 - - for table in self.tables_on_page[token.box.page]: - if ( - (table.box.top <= token_vmid <= table.box.bottom) and - (table.box.left <= token_hmid <= table.box.right) - ): # fmt: skip - break - else: - raise TableCellNotFoundError(f"no table contains {token!r}") - - for cell in table.cells: - if ( - (cell.box.top <= token_vmid <= cell.box.bottom) and - (cell.box.left <= token_hmid <= cell.box.right) - ): # fmt: skip - return table, cell - else: - raise TableCellNotFoundError(f"no cell contains {token!r}") + try: + page_table_cell_spans = self._table_cell_spans_on_page[span.page] + first = bisect_right( + page_table_cell_spans, + span.start, + key=attrgetter("span.end"), + ) + last = bisect_left( + page_table_cell_spans, + span.end, + lo=first, + key=attrgetter("span.start"), + ) + table_cell_spans = page_table_cell_spans[first:last] + except (IndexError, ValueError): + table_cell_spans = tuple() + + for table, cell, span in table_cell_spans: + yield table, cell diff --git a/indico_toolkit/etloutput/range.py b/indico_toolkit/etloutput/range.py index 7d26431b..0d0f8eff 100644 --- a/indico_toolkit/etloutput/range.py +++ b/indico_toolkit/etloutput/range.py @@ -1,6 +1,7 @@ from dataclasses import dataclass +from typing import Final -from ..results.utils import get +from .utils import get @dataclass(order=True, frozen=True) @@ -12,6 +13,9 @@ class Range: rows: "tuple[int, ...]" columns: "tuple[int, ...]" + def __bool__(self) -> bool: + return self != NULL_RANGE + @staticmethod def from_dict(cell: object) -> "Range": """ @@ -28,3 +32,17 @@ def from_dict(cell: object) -> "Range": rows=tuple(rows), columns=tuple(columns), ) + + +# It's more ergonomic to represent the lack of ranges with a special null range object +# rather than using `None` or raising an error. This lets you e.g. sort by the `range` +# attribute without having to constantly check for `None`, while still allowing you do +# a "None check" with `bool(cell.range)` or `cell.range == NULL_RANGE`. +NULL_RANGE: Final = Range( + row=0, + column=0, + rowspan=0, + columnspan=0, + rows=tuple(), + columns=tuple(), +) diff --git a/indico_toolkit/results/predictions/span.py b/indico_toolkit/etloutput/span.py similarity index 56% rename from indico_toolkit/results/predictions/span.py rename to indico_toolkit/etloutput/span.py index 8ca2c310..4e3d2502 100644 --- a/indico_toolkit/results/predictions/span.py +++ b/indico_toolkit/etloutput/span.py @@ -1,10 +1,7 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Any, Final -from ..utils import get - -if TYPE_CHECKING: - from typing import Any, Final +from .utils import get @dataclass(order=True, frozen=True) @@ -20,6 +17,26 @@ def slice(self) -> slice: def __bool__(self) -> bool: return self != NULL_SPAN + def __and__(self, other: "Span") -> "Span": + """ + Return a new `Span` for the overlap between `self` and `other` + or `NULL_SPAN` if they don't overlap. + + Supports set-like `extraction.span & cell.span` syntax. + """ + if ( + self.page != other.page + or self.end <= other.start # `self` is to the left of `other` + or self.start >= other.end # `self` is to the right of `other` + ): + return NULL_SPAN + else: + return Span( + page=self.page, + start=max(self.start, other.start), + end=min(self.end, other.end), + ) + @staticmethod def from_dict(span: object) -> "Span": return Span( @@ -40,4 +57,4 @@ def to_dict(self) -> "dict[str, Any]": # rather than using `None` or raising an error. This lets you e.g. sort by the `span` # attribute without having to constantly check for `None`, while still allowing you do # a "None check" with `bool(extraction.span)` or `extraction.span == NULL_SPAN`. -NULL_SPAN: "Final" = Span(page=0, start=0, end=0) +NULL_SPAN: Final = Span(page=0, start=0, end=0) diff --git a/indico_toolkit/etloutput/table.py b/indico_toolkit/etloutput/table.py index 807f7ea5..1b278503 100644 --- a/indico_toolkit/etloutput/table.py +++ b/indico_toolkit/etloutput/table.py @@ -1,18 +1,40 @@ from dataclasses import dataclass from operator import attrgetter +from typing import Final -from ..results import Box -from ..results.utils import get +from .box import NULL_BOX, Box from .cell import Cell +from .span import NULL_SPAN, Span +from .utils import get @dataclass(frozen=True) class Table: box: Box + spans: "tuple[Span, ...]" cells: "tuple[Cell, ...]" rows: "tuple[tuple[Cell, ...], ...]" columns: "tuple[tuple[Cell, ...], ...]" + def __bool__(self) -> bool: + return self != NULL_TABLE + + def __hash__(self) -> int: + """ + Uniquely identify tables by hashing their bounding box and spans. + + This is an order of magnitude speedup for `.groupby(attrgetter("table"))` + compared to dataclasses's default __hash__ implementation. + """ + return hash((self.box, self.spans)) + + @property + def span(self) -> Span: + """ + Return the first `Span` the table covers or `NULL_SPAN` otherwise. + """ + return self.spans[0] if self.spans else NULL_SPAN + @staticmethod def from_dict(table: object) -> "Table": """ @@ -50,9 +72,26 @@ def from_dict(table: object) -> "Table": for column in range(column_count) ) # fmt: skip + for doc_offset in get(table, list, "doc_offsets"): + doc_offset["page_num"] = page + return Table( box=Box.from_dict(get(table, dict, "position")), + spans=tuple(map(Span.from_dict, get(table, list, "doc_offsets"))), cells=cells, rows=rows, columns=columns, ) + + +# It's more ergonomic to represent the lack of tables with a special null table object +# rather than using `None` or raising an error. This lets you e.g. group by the `table` +# attribute without having to constantly check for `None`, while still allowing you do +# a "None check" with `bool(extraction.table)` or `extraction.table == NULL_TABLE`. +NULL_TABLE: Final = Table( + box=NULL_BOX, + spans=tuple(), + cells=tuple(), + rows=tuple(), + columns=tuple(), +) diff --git a/indico_toolkit/etloutput/token.py b/indico_toolkit/etloutput/token.py index 9478d395..8feb8e59 100644 --- a/indico_toolkit/etloutput/token.py +++ b/indico_toolkit/etloutput/token.py @@ -1,7 +1,9 @@ from dataclasses import dataclass +from typing import Final -from ..results import Box, Span -from ..results.utils import get +from .box import NULL_BOX, Box +from .span import NULL_SPAN, Span +from .utils import get @dataclass(frozen=True) @@ -10,6 +12,9 @@ class Token: box: Box span: Span + def __bool__(self) -> bool: + return self != NULL_TOKEN + @staticmethod def from_dict(token: object) -> "Token": """ @@ -23,3 +28,14 @@ def from_dict(token: object) -> "Token": box=Box.from_dict(get(token, dict, "position")), span=Span.from_dict(get(token, dict, "doc_offset")), ) + + +# It's more ergonomic to represent the lack of tokens with a special null token object +# rather than using `None` or raising an error. This lets you e.g. sort by the `token` +# attribute without having to constantly check for `None`, while still allowing you do +# a "None check" with `bool(extraction.token)` or `extraction.token == NULL_TOKEN`. +NULL_TOKEN: Final = Token( + text="", + box=NULL_BOX, + span=NULL_SPAN, +) diff --git a/indico_toolkit/etloutput/utils.py b/indico_toolkit/etloutput/utils.py new file mode 100644 index 00000000..0f7c3e5c --- /dev/null +++ b/indico_toolkit/etloutput/utils.py @@ -0,0 +1,69 @@ +import json +from typing import Any, TypeVar + +Value = TypeVar("Value") + + +def get(value: object, value_type: "type[Value]", *keys: "str | int") -> Value: + """ + Return the value of type `value_type` obtained by traversing `value` using `keys`. + Raise an error if a key doesn't exist or the value has the wrong type. + """ + for key in keys: + if isinstance(value, dict): + if key in value: + value = value[key] + else: + raise KeyError(f"{key!r} not in {value.keys()!r}") + elif isinstance(value, list): + if isinstance(key, int): + if 0 <= key < len(value): + value = value[key] + else: + raise IndexError(f"{key} out of range [0,{len(value)})") + else: + raise TypeError(f"list can't be indexed with {key!r}") + else: + raise TypeError(f"{type(value)} can't be traversed") + + if isinstance(value, value_type): + return value + else: + raise TypeError(f"value `{value!r}` doesn't have type {value_type}") + + +def has(value: object, value_type: "type[Value]", *keys: "str | int") -> bool: + """ + Check if `value` can be traversed using `keys` to a value of type `value_type`. + """ + for key in keys: + if isinstance(value, dict) and key in value: + value = value[key] + elif isinstance(value, list) and isinstance(key, int) and 0 <= key < len(value): # fmt: skip # noqa: E501 + value = value[key] + else: + return False + + return isinstance(value, value_type) + + +def json_loaded(value: "Any") -> "Any": + """ + Ensure `value` has been loaded as JSON. + """ + value = str_decoded(value) + + if isinstance(value, str): + value = json.loads(value) + + return value + + +def str_decoded(value: str | bytes) -> str: + """ + Ensure `value` has been decoded to a string. + """ + if isinstance(value, bytes): + value = value.decode() + + return value diff --git a/indico_toolkit/results/__init__.py b/indico_toolkit/results/__init__.py index ea9a4e8b..53b9811f 100644 --- a/indico_toolkit/results/__init__.py +++ b/indico_toolkit/results/__init__.py @@ -1,13 +1,10 @@ from typing import TYPE_CHECKING, TypeAlias, TypeVar, overload +from ..etloutput import NULL_BOX, NULL_SPAN, Box, Span from .document import Document -from .errors import ResultError from .predictionlist import PredictionList from .predictions import ( - NULL_BOX, NULL_CITATION, - NULL_SPAN, - Box, Classification, DocumentExtraction, Extraction, @@ -15,7 +12,6 @@ FormExtractionType, Group, Prediction, - Span, Summarization, Unbundling, ) @@ -45,7 +41,6 @@ "Prediction", "PredictionList", "Result", - "ResultError", "Review", "ReviewType", "Span", diff --git a/indico_toolkit/results/errors.py b/indico_toolkit/results/errors.py deleted file mode 100644 index abebb540..00000000 --- a/indico_toolkit/results/errors.py +++ /dev/null @@ -1,4 +0,0 @@ -class ResultError(Exception): - """ - Raised when an error occurs while loading or dumping a result file. - """ diff --git a/indico_toolkit/results/normalization.py b/indico_toolkit/results/normalization.py index 44f4da16..72cec52a 100644 --- a/indico_toolkit/results/normalization.py +++ b/indico_toolkit/results/normalization.py @@ -1,12 +1,9 @@ import re -from typing import TYPE_CHECKING +from typing import Any from .task import TaskType from .utils import get, has -if TYPE_CHECKING: - from typing import Any - def normalize_result_dict(result: "Any") -> None: """ diff --git a/indico_toolkit/results/predictionlist.py b/indico_toolkit/results/predictionlist.py index 312fd9e5..b85b4a43 100644 --- a/indico_toolkit/results/predictionlist.py +++ b/indico_toolkit/results/predictionlist.py @@ -1,6 +1,7 @@ from collections import defaultdict +from itertools import chain from operator import attrgetter -from typing import TYPE_CHECKING, List, TypeVar, overload +from typing import TYPE_CHECKING, Any, Final, List, SupportsIndex, TypeVar, overload from .predictions import ( Classification, @@ -13,25 +14,24 @@ Unbundling, ) from .review import Review, ReviewType -from .task import TaskType from .utils import nfilter if TYPE_CHECKING: - from collections.abc import Callable, Collection, Container, Iterable - from typing import Any, Final, SupportsIndex + from collections.abc import Callable, Collection, Container, Iterable, Mapping from typing_extensions import Self + from ..etloutput import EtlOutput from .document import Document from .result import Result - from .task import Task + from .task import Task, TaskType PredictionType = TypeVar("PredictionType", bound=Prediction) OfType = TypeVar("OfType", bound=Prediction) KeyType = TypeVar("KeyType") # Non-None sentinel value to support `PredictionList.where(review=None)`. -REVIEW_UNSPECIFIED: "Final" = Review( +REVIEW_UNSPECIFIED: Final = Review( id=None, reviewer_id=None, notes=None, rejected=None, type=None # type: ignore[arg-type] ) @@ -84,6 +84,43 @@ def apply(self, function: "Callable[[PredictionType], None]") -> "Self": return self + def assign_ocr( + self, + etl_outputs: "Mapping[Document, EtlOutput]", + *, + tokens: bool = True, + tables: bool = True, + ) -> "Self": + """ + Assign OCR tokens, tables, and/or cells using `etl_outputs`. + + Use `tokens` or `tables` to skip lookup and assignment of those attributes. + """ + extractions_by_document = self.oftype( + DocumentExtraction, + ).groupby( + attrgetter("document"), + ) + + for document, extractions in extractions_by_document.items(): + etl_output = etl_outputs[document] + + for extraction in extractions: + if tokens: + extraction.tokens = list( + filter( + None, + map(etl_output.token_for, extraction.spans), + ) + ) + + if tables: + extraction.table_cells = chain.from_iterable( + map(etl_output.table_cells_for, extraction.spans) + ) + + return self + def groupby( self, key: "Callable[[PredictionType], KeyType]" ) -> "dict[KeyType, Self]": diff --git a/indico_toolkit/results/predictions/__init__.py b/indico_toolkit/results/predictions/__init__.py index 5a29151f..dff7d51a 100644 --- a/indico_toolkit/results/predictions/__init__.py +++ b/indico_toolkit/results/predictions/__init__.py @@ -1,9 +1,7 @@ from typing import TYPE_CHECKING -from ..errors import ResultError from ..normalization import normalize_prediction_dict from ..task import TaskType -from .box import NULL_BOX, Box from .citation import NULL_CITATION, Citation from .classification import Classification from .documentextraction import DocumentExtraction @@ -11,7 +9,6 @@ from .formextraction import FormExtraction, FormExtractionType from .group import Group from .prediction import Prediction -from .span import NULL_SPAN, Span from .summarization import Summarization from .unbundling import Unbundling @@ -21,7 +18,6 @@ from ..task import Task __all__ = ( - "Box", "Citation", "Classification", "DocumentExtraction", @@ -29,11 +25,8 @@ "FormExtraction", "FormExtractionType", "Group", - "NULL_BOX", "NULL_CITATION", - "NULL_SPAN", "Prediction", - "Span", "Summarization", "Unbundling", ) @@ -61,4 +54,4 @@ def from_dict( elif task.type == TaskType.UNBUNDLING: return Unbundling.from_dict(document, task, review, prediction) else: - raise ResultError(f"unsupported task type {task.type!r}") + raise ValueError(f"unsupported task type {task.type!r}") diff --git a/indico_toolkit/results/predictions/citation.py b/indico_toolkit/results/predictions/citation.py index 177cf60c..7e765ec7 100644 --- a/indico_toolkit/results/predictions/citation.py +++ b/indico_toolkit/results/predictions/citation.py @@ -1,11 +1,8 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Any, Final +from ...etloutput import NULL_SPAN, Span from ..utils import get -from .span import NULL_SPAN, Span - -if TYPE_CHECKING: - from typing import Any, Final @dataclass(order=True, frozen=True) @@ -44,4 +41,4 @@ def to_dict(self) -> "dict[str, Any]": # `citation` attribute without having to constantly check for `None`, while still # allowing you do a "None check" with `bool(summarization.citation)` or # `summarization.citation == NULL_CITATION`. -NULL_CITATION: "Final" = Citation(start=0, end=0, span=NULL_SPAN) +NULL_CITATION: Final = Citation(start=0, end=0, span=NULL_SPAN) diff --git a/indico_toolkit/results/predictions/classification.py b/indico_toolkit/results/predictions/classification.py index 8ef33ea9..b76b17ee 100644 --- a/indico_toolkit/results/predictions/classification.py +++ b/indico_toolkit/results/predictions/classification.py @@ -1,14 +1,12 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from ..review import Review from ..utils import get, omit from .prediction import Prediction if TYPE_CHECKING: - from typing import Any - from ..document import Document + from ..review import Review from ..task import Task diff --git a/indico_toolkit/results/predictions/documentextraction.py b/indico_toolkit/results/predictions/documentextraction.py index 5a112d86..c942ab7b 100644 --- a/indico_toolkit/results/predictions/documentextraction.py +++ b/indico_toolkit/results/predictions/documentextraction.py @@ -1,16 +1,28 @@ -from dataclasses import dataclass -from typing import TYPE_CHECKING +from copy import copy, deepcopy +from dataclasses import dataclass, field, replace +from typing import TYPE_CHECKING, Any -from ..review import Review +from ...etloutput import ( + NULL_CELL, + NULL_SPAN, + NULL_TABLE, + NULL_TOKEN, + Cell, + Span, + Table, + Token, +) from ..utils import get, has, omit from .extraction import Extraction from .group import Group -from .span import NULL_SPAN, Span if TYPE_CHECKING: - from typing import Any + from collections.abc import Iterable, Iterator + + from typing_extensions import Self from ..document import Document + from ..review import Review from ..task import Task @@ -19,12 +31,16 @@ class DocumentExtraction(Extraction): groups: "set[Group]" spans: "list[Span]" + tokens: "list[Token]" = field(default_factory=list) + tables: "list[Table]" = field(default_factory=list) + cells: "list[Cell]" = field(default_factory=list) + @property def span(self) -> Span: """ Return the first `Span` the document extraction covers else `NULL_SPAN`. - Post-review, document extractions have no spans. + Predictions added in review may not have spans. """ return self.spans[0] if self.spans else NULL_SPAN @@ -33,12 +49,88 @@ def span(self, span: Span) -> None: """ Overwrite all spans with the one provided, handling `NULL_SPAN`. - This is implemented under the assumption that if you're setting a single span, - you want it to be the only one. And if you're working in a context that's - multiple-span sensetive, you'll set `extraction.spans` instead. + This is assumes if you're setting a single span you want it to be the only one. + Multiple-span sensitive contexts should work with `extraction.spans` instead. """ self.spans = [span] if span else [] + @property + def token(self) -> Token: + """ + Return the first `Token` the document extraction covers + or `NULL_TOKEN` if it doesn't cover a token or OCR hasn't been assigned yet. + """ + return self.tokens[0] if self.tokens else NULL_TOKEN + + @token.setter + def token(self, token: Token) -> None: + """ + Overwrite all tokens with the one provided, handling `NULL_TOKEN`. + + This is assumes if you're setting a single token you want it to be the only one. + Multiple-token sensitive contexts should work with `extraction.tokens` instead. + """ + self.tokens = [token] if token else [] + + @property + def table(self) -> Table: + """ + Return the first `Table` the document extraction is in + or `NULL_TABLE` if it's not in a table or OCR hasn't been assigned yet. + """ + return self.tables[0] if self.tables else NULL_TABLE + + @table.setter + def table(self, table: Table) -> None: + """ + Overwrite all tables with the one provided, handling `NULL_TABLE`. + + This is assumes if you're setting a single table you want it to be the only one. + Multiple-table sensitive contexts should work with `extraction.tables` instead. + """ + self.tables = [table] if table else [] + + @property + def cell(self) -> Cell: + """ + Return the first `Cell` the document extraction is in + or `NULL_CELL` if it's not in a cell or OCR hasn't been assigned yet. + """ + return self.cells[0] if self.cells else NULL_CELL + + @cell.setter + def cell(self, cell: Cell) -> None: + """ + Overwrite all cells with the one provided, handling `NULL_CELL`. + + This is assumes if you're setting a single cell you want it to be the only one. + Multiple-cell sensitive contexts should work with `extraction.cells` instead. + """ + self.cells = [cell] if cell else [] + + @property + def table_cells(self) -> "Iterator[tuple[Table, Cell]]": + """ + Yield the table cells the document extraction is in. + """ + yield from zip(self.tables, self.cells) + + @table_cells.setter + def table_cells(self, table_cells: "Iterable[tuple[Table, Cell]]") -> None: + """ + Set the tables cells the document extraction is in. + + Deduplicate cells to handle the case where multiple + spans are contained within the same cell. + """ + self.tables = [] + self.cells = [] + + for table, cell in table_cells: + if cell not in self.cells: + self.tables.append(table) + self.cells.append(cell) + @staticmethod def from_dict( document: "Document", @@ -98,3 +190,15 @@ def to_dict(self) -> "dict[str, Any]": prediction["rejected"] = True return prediction + + def copy(self) -> "Self": + return replace( + self, + groups=copy(self.groups), + spans=copy(self.spans), + tokens=copy(self.tokens), + tables=copy(self.tables), + cells=copy(self.cells), + confidences=copy(self.confidences), + extras=deepcopy(self.extras), + ) diff --git a/indico_toolkit/results/predictions/formextraction.py b/indico_toolkit/results/predictions/formextraction.py index 15a01a93..daa41245 100644 --- a/indico_toolkit/results/predictions/formextraction.py +++ b/indico_toolkit/results/predictions/formextraction.py @@ -1,16 +1,14 @@ from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from ..review import Review +from ...etloutput import Box from ..utils import get, has, omit -from .box import Box from .extraction import Extraction if TYPE_CHECKING: - from typing import Any - from ..document import Document + from ..review import Review from ..task import Task diff --git a/indico_toolkit/results/predictions/group.py b/indico_toolkit/results/predictions/group.py index e9e000d0..3f72bdaa 100644 --- a/indico_toolkit/results/predictions/group.py +++ b/indico_toolkit/results/predictions/group.py @@ -1,11 +1,8 @@ from dataclasses import dataclass, replace -from typing import TYPE_CHECKING +from typing import Any from ..utils import get -if TYPE_CHECKING: - from typing import Any - @dataclass(frozen=True, order=True) class Group: diff --git a/indico_toolkit/results/predictions/prediction.py b/indico_toolkit/results/predictions/prediction.py index 216e0700..6edf03c0 100644 --- a/indico_toolkit/results/predictions/prediction.py +++ b/indico_toolkit/results/predictions/prediction.py @@ -1,12 +1,12 @@ -from dataclasses import dataclass -from typing import TYPE_CHECKING - -from ..review import Review +from copy import copy, deepcopy +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from typing import Any + from typing_extensions import Self from ..document import Document + from ..review import Review from ..task import Task @@ -33,3 +33,10 @@ def to_dict(self) -> "dict[str, Any]": Create a prediction dictionary for auto review changes. """ raise NotImplementedError() + + def copy(self) -> "Self": + return replace( + self, + confidences=copy(self.confidences), + extras=deepcopy(self.extras), + ) diff --git a/indico_toolkit/results/predictions/summarization.py b/indico_toolkit/results/predictions/summarization.py index 57fe56ac..6f2fa3e7 100644 --- a/indico_toolkit/results/predictions/summarization.py +++ b/indico_toolkit/results/predictions/summarization.py @@ -1,17 +1,18 @@ +from copy import copy, deepcopy from dataclasses import dataclass, replace -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from ..review import Review from ..utils import get, has, omit from .citation import NULL_CITATION, Citation from .extraction import Extraction if TYPE_CHECKING: - from typing import Any + from typing_extensions import Self + from ...etloutput import Span from ..document import Document + from ..review import Review from ..task import Task - from .span import Span @dataclass @@ -23,7 +24,7 @@ def citation(self) -> Citation: """ Return the first `Citation` the summarization covers else `NULL_CITATION`. - Post-review, summarizations have no citations. + Predictions added in review may not have citations. """ return self.citations[0] if self.citations else NULL_CITATION @@ -32,17 +33,15 @@ def citation(self, citation: Citation) -> None: """ Overwrite all citations with the one provided, handling `NULL_CITATION`. - This is implemented under the assumption that if you're setting a single - citation, you want it to be the only one. And if you're working in a context - that's multiple-citation sensetive, you'll set `summarization.citations` - instead. + This is assumes if you're setting a single citation it should be the only one. + Multiple-citation sensitive contexts should work with `summarization.citations`. """ self.citations = [citation] if citation else [] @property def spans(self) -> "tuple[Span, ...]": """ - Return the spans covered by `self.citations`. + Return the `Span`s covered by `self.citations`. """ return tuple(citation.span for citation in self.citations) @@ -51,7 +50,7 @@ def span(self) -> "Span": """ Return the `Span` the first citation covers else `NULL_SPAN`. - Post-review, summarizations have no citations/spans. + Predictions added in review may not have citations. """ return self.citation.span @@ -64,9 +63,9 @@ def span(self, span: "Span") -> None: Using `NULL_SPAN` for a citation is not explicitly handled, and should be considered undefined behavior. - This is implemented under the assumption that if you're setting a single span, - there's only one citation and you want to update its span. And if you're - working in a context that's multiple-citation/span sensetive, you'll set + This is assumes if you're setting a single span, + there's only one citation and you want it to update its span. + Multiple-context/span sensitive contexts should work with `summarization.citations` instead. """ self.citation = replace(self.citation, span=span) @@ -126,3 +125,11 @@ def to_dict(self) -> "dict[str, Any]": prediction["rejected"] = True return prediction + + def copy(self) -> "Self": + return replace( + self, + citations=copy(self.citations), + confidences=copy(self.confidences), + extras=deepcopy(self.extras), + ) diff --git a/indico_toolkit/results/predictions/unbundling.py b/indico_toolkit/results/predictions/unbundling.py index 4e15042c..7e998dda 100644 --- a/indico_toolkit/results/predictions/unbundling.py +++ b/indico_toolkit/results/predictions/unbundling.py @@ -1,15 +1,16 @@ -from dataclasses import dataclass -from typing import TYPE_CHECKING +from copy import copy, deepcopy +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING, Any -from ..review import Review +from ...etloutput import Span from ..utils import get, omit from .prediction import Prediction -from .span import Span if TYPE_CHECKING: - from typing import Any + from typing_extensions import Self from ..document import Document + from ..review import Review from ..task import Task @@ -54,3 +55,11 @@ def to_dict(self) -> "dict[str, Any]": "confidence": self.confidences, "spans": [span.to_dict() for span in self.spans], } + + def copy(self) -> "Self": + return replace( + self, + spans=copy(self.spans), + confidences=copy(self.confidences), + extras=deepcopy(self.extras), + ) diff --git a/indico_toolkit/results/result.py b/indico_toolkit/results/result.py index 0508cd98..cb552593 100644 --- a/indico_toolkit/results/result.py +++ b/indico_toolkit/results/result.py @@ -4,7 +4,6 @@ from . import predictions as prediction from .document import Document -from .errors import ResultError from .normalization import normalize_result_dict from .predictionlist import PredictionList from .predictions import Prediction @@ -53,7 +52,7 @@ def from_dict(result: object) -> "Result": file_version = get(result, int, "file_version") if file_version != 3: - raise ResultError(f"unsupported file version `{file_version}`") + raise ValueError(f"unsupported result file version `{file_version}`") normalize_result_dict(result) diff --git a/indico_toolkit/results/utils.py b/indico_toolkit/results/utils.py index 9c8e5f0a..32ce4ea4 100644 --- a/indico_toolkit/results/utils.py +++ b/indico_toolkit/results/utils.py @@ -1,70 +1,21 @@ -import json from collections.abc import Iterable, Iterator -from typing import TYPE_CHECKING, TypeVar +from typing import Callable -if TYPE_CHECKING: - from typing import Any, Callable +from ..etloutput.utils import Value, get, has, json_loaded, str_decoded -Value = TypeVar("Value") - - -def get(result: object, value_type: "type[Value]", *keys: "str | int") -> Value: - """ - Return the value of type `value_type` obtained by traversing `result` using `keys`. - Raise an error if a key doesn't exist or the value has the wrong type. - """ - for key in keys: - if isinstance(result, dict): - if key in result: - result = result[key] - else: - raise KeyError(f"{key!r} not in {result.keys()!r}") - elif isinstance(result, list): - if isinstance(key, int): - if 0 <= key < len(result): - result = result[key] - else: - raise IndexError(f"{key} out of range [0,{len(result)})") - else: - raise TypeError(f"list can't be indexed with {key!r}") - else: - raise TypeError(f"{type(result)} can't be traversed") - - if isinstance(result, value_type): - return result - else: - raise TypeError(f"value `{result!r}` doesn't have type {value_type}") - - -def has(result: object, value_type: "type[Value]", *keys: "str | int") -> bool: - """ - Check if `result` can be traversed using `keys` to a value of type `value_type`. - """ - for key in keys: - if isinstance(result, dict) and key in result: - result = result[key] - elif isinstance(result, list) and isinstance(key, int) and 0 <= key < len(result): # fmt: skip # noqa: E501 - result = result[key] - else: - return False - - return isinstance(result, value_type) - - -def json_loaded(value: "Any") -> "Any": - """ - Ensure `value` has been loaded as JSON. - """ - value = str_decoded(value) - - if isinstance(value, str): - value = json.loads(value) - - return value +__all__ = ( + "get", + "has", + "json_loaded", + "nfilter", + "omit", + "str_decoded", +) def nfilter( - predicates: "Iterable[Callable[[Value], bool]]", values: "Iterable[Value]" + predicates: "Iterable[Callable[[Value], bool]]", + values: "Iterable[Value]", ) -> "Iterator[Value]": """ Apply multiple filter predicates to an iterable of values. @@ -89,13 +40,3 @@ def omit(dictionary: object, *keys: str) -> "dict[str, Value]": for key, value in dictionary.items() if key not in keys } # fmt: skip - - -def str_decoded(value: str | bytes) -> str: - """ - Ensure `value` has been decoded to a string. - """ - if isinstance(value, bytes): - value = value.decode() - - return value diff --git a/pyproject.toml b/pyproject.toml index 7b1bfe59..046fd46b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ authors = [ readme = "README.md" urls = { source = "https://github.com/IndicoDataSolutions/Indico-Solutions-Toolkit" } requires-python = ">=3.10" -version = "7.2.1" +version = "7.2.2" dependencies = ["indico-client (>=6.14.0,<7.0.0)"] [project.optional-dependencies] diff --git a/tests/data/etloutput/4725/112731/112257/tables_0.json b/tests/data/etloutput/4725/112731/112257/tables_0.json index 8800fc7f..014ffc2b 100644 --- a/tests/data/etloutput/4725/112731/112257/tables_0.json +++ b/tests/data/etloutput/4725/112731/112257/tables_0.json @@ -1 +1 @@ -[{"cells":[{"cell_type":"header","columns":[0],"doc_offsets":[{"end":29,"start":25}],"page_offsets":[{"end":29,"start":25}],"position":{"bottom":562,"left":713,"right":1052,"top":435},"rows":[0],"text":"Alfa"},{"cell_type":"header","columns":[1],"doc_offsets":[{"end":35,"start":30}],"page_offsets":[{"end":35,"start":30}],"position":{"bottom":561,"left":1051,"right":1301,"top":434},"rows":[0],"text":"Bravo"},{"cell_type":"header","columns":[2],"doc_offsets":[{"end":43,"start":36}],"page_offsets":[{"end":43,"start":36}],"position":{"bottom":560,"left":1300,"right":1578,"top":433},"rows":[0],"text":"Charlie"},{"cell_type":"header","columns":[3],"doc_offsets":[{"end":49,"start":44}],"page_offsets":[{"end":49,"start":44}],"position":{"bottom":561,"left":1580,"right":1821,"top":430},"rows":[0],"text":"Delta"},{"cell_type":"content","columns":[0],"doc_offsets":[{"end":54,"start":50}],"page_offsets":[{"end":54,"start":50}],"position":{"bottom":778,"left":712,"right":1052,"top":561},"rows":[1,2],"text":"Echo"},{"cell_type":"content","columns":[1,2],"doc_offsets":[{"end":62,"start":55}],"page_offsets":[{"end":62,"start":55}],"position":{"bottom":670,"left":1052,"right":1580,"top":561},"rows":[1],"text":"Foxtrot"},{"cell_type":"content","columns":[3],"doc_offsets":[{"end":68,"start":64}],"page_offsets":[{"end":68,"start":64}],"position":{"bottom":669,"left":1580,"right":1821,"top":560},"rows":[1],"text":"Golf"},{"cell_type":"content","columns":[1],"doc_offsets":[{"end":75,"start":70}],"page_offsets":[{"end":75,"start":70}],"position":{"bottom":778,"left":1052,"right":1301,"top":670},"rows":[2],"text":"Hotel"},{"cell_type":"content","columns":[2],"doc_offsets":[{"end":81,"start":76}],"page_offsets":[{"end":81,"start":76}],"position":{"bottom":777,"left":1301,"right":1581,"top":669},"rows":[2],"text":"India"},{"cell_type":"content","columns":[3],"doc_offsets":[{"end":89,"start":82}],"page_offsets":[{"end":89,"start":82}],"position":{"bottom":776,"left":1580,"right":1822,"top":668},"rows":[2],"text":"Juliett"},{"cell_type":"content","columns":[0],"doc_offsets":[{"end":94,"start":90}],"page_offsets":[{"end":94,"start":90}],"position":{"bottom":891,"left":712,"right":1053,"top":778},"rows":[3],"text":"Kilo"},{"cell_type":"content","columns":[1,2],"doc_offsets":[],"page_offsets":[],"position":{"bottom":997,"left":1052,"right":1582,"top":776},"rows":[3,4],"text":"Lima"},{"cell_type":"content","columns":[3],"doc_offsets":[{"end":101,"start":97}],"page_offsets":[{"end":101,"start":97}],"position":{"bottom":889,"left":1581,"right":1822,"top":775},"rows":[3],"text":"Mike"},{"cell_type":"content","columns":[0],"doc_offsets":[{"end":110,"start":102}],"page_offsets":[{"end":110,"start":102}],"position":{"bottom":998,"left":713,"right":1053,"top":889},"rows":[4],"text":"November"},{"cell_type":"content","columns":[3],"doc_offsets":[{"end":122,"start":117}],"page_offsets":[{"end":122,"start":117}],"position":{"bottom":995,"left":1581,"right":1824,"top":888},"rows":[4],"text":"Oscar"}],"doc_offsets":[{"end":94,"start":25},{"end":122,"start":97}],"num_columns":4,"num_rows":5,"page_num":0,"page_offsets":[{"end":94,"start":25},{"end":122,"start":97}],"position":{"bottom":998,"left":711,"right":1824,"top":430},"table_id":0,"table_offset":{"column":0,"row":0}}] +[{"cells":[{"cell_type":"header","columns":[0],"doc_offsets":[{"end":29,"start":25}],"page_offsets":[{"end":29,"start":25}],"position":{"bottom":562,"left":713,"right":1052,"top":435},"rows":[0],"text":"Alfa"},{"cell_type":"header","columns":[1],"doc_offsets":[{"end":35,"start":30}],"page_offsets":[{"end":35,"start":30}],"position":{"bottom":561,"left":1051,"right":1301,"top":434},"rows":[0],"text":"Bravo"},{"cell_type":"header","columns":[2],"doc_offsets":[{"end":43,"start":36}],"page_offsets":[{"end":43,"start":36}],"position":{"bottom":560,"left":1300,"right":1578,"top":433},"rows":[0],"text":"Charlie"},{"cell_type":"header","columns":[3],"doc_offsets":[{"end":49,"start":44}],"page_offsets":[{"end":49,"start":44}],"position":{"bottom":561,"left":1580,"right":1821,"top":430},"rows":[0],"text":"Delta"},{"cell_type":"content","columns":[0],"doc_offsets":[{"end":54,"start":50}],"page_offsets":[{"end":54,"start":50}],"position":{"bottom":778,"left":712,"right":1052,"top":561},"rows":[1,2],"text":"Echo"},{"cell_type":"content","columns":[1,2],"doc_offsets":[{"end":62,"start":55}],"page_offsets":[{"end":62,"start":55}],"position":{"bottom":670,"left":1052,"right":1580,"top":561},"rows":[1],"text":"Foxtrot"},{"cell_type":"content","columns":[3],"doc_offsets":[{"end":68,"start":64}],"page_offsets":[{"end":68,"start":64}],"position":{"bottom":669,"left":1580,"right":1821,"top":560},"rows":[1],"text":"Golf"},{"cell_type":"content","columns":[1],"doc_offsets":[{"end":75,"start":70}],"page_offsets":[{"end":75,"start":70}],"position":{"bottom":778,"left":1052,"right":1301,"top":670},"rows":[2],"text":"Hotel"},{"cell_type":"content","columns":[2],"doc_offsets":[{"end":81,"start":76}],"page_offsets":[{"end":81,"start":76}],"position":{"bottom":777,"left":1301,"right":1581,"top":669},"rows":[2],"text":"India"},{"cell_type":"content","columns":[3],"doc_offsets":[{"end":89,"start":82}],"page_offsets":[{"end":89,"start":82}],"position":{"bottom":776,"left":1580,"right":1822,"top":668},"rows":[2],"text":"Juliett"},{"cell_type":"content","columns":[0],"doc_offsets":[{"end":94,"start":90}],"page_offsets":[{"end":94,"start":90}],"position":{"bottom":891,"left":712,"right":1053,"top":778},"rows":[3],"text":"Kilo"},{"cell_type":"content","columns":[1,2],"doc_offsets":[{"end":115,"start":111}],"page_offsets":[{"end":115, "start":111}],"position":{"bottom":997,"left":1052,"right":1582,"top":776},"rows":[3,4],"text":"Lima"},{"cell_type":"content","columns":[3],"doc_offsets":[{"end":101,"start":97}],"page_offsets":[{"end":101,"start":97}],"position":{"bottom":889,"left":1581,"right":1822,"top":775},"rows":[3],"text":"Mike"},{"cell_type":"content","columns":[0],"doc_offsets":[{"end":110,"start":102}],"page_offsets":[{"end":110,"start":102}],"position":{"bottom":998,"left":713,"right":1053,"top":889},"rows":[4],"text":"November"},{"cell_type":"content","columns":[3],"doc_offsets":[{"end":122,"start":117}],"page_offsets":[{"end":122,"start":117}],"position":{"bottom":995,"left":1581,"right":1824,"top":888},"rows":[4],"text":"Oscar"}],"doc_offsets":[{"end":94,"start":25},{"end":122,"start":97}],"num_columns":4,"num_rows":5,"page_num":0,"page_offsets":[{"end":94,"start":25},{"end":122,"start":97}],"position":{"bottom":998,"left":711,"right":1824,"top":430},"table_id":0,"table_offset":{"column":0,"row":0}}] diff --git a/tests/etloutput/test_rowspan_colspan.py b/tests/etloutput/test_rowspan_colspan.py index b21409c3..58bde1c4 100644 --- a/tests/etloutput/test_rowspan_colspan.py +++ b/tests/etloutput/test_rowspan_colspan.py @@ -3,8 +3,7 @@ import pytest from indico_toolkit import etloutput -from indico_toolkit.etloutput import EtlOutput, Table -from indico_toolkit.results import Span +from indico_toolkit.etloutput import EtlOutput, Span, Table data_folder = Path(__file__).parent.parent / "data" / "etloutput" etl_output_file = data_folder / "4725" / "112731" / "112257" / "etl_output_rs_cs.json" @@ -121,6 +120,5 @@ def test_columns(table: Table) -> None: ], ) def test_table_cell_for(etl_output: EtlOutput, span: Span, expected_text: str) -> None: - token = etl_output.token_for(span) - table, cell = etl_output.table_cell_for(token) + (table, cell), *_ = etl_output.table_cells_for(span) assert cell.text == expected_text diff --git a/tests/etloutput/test_token_table_cell.py b/tests/etloutput/test_token_table_cell.py index b994d1a4..e9f56947 100644 --- a/tests/etloutput/test_token_table_cell.py +++ b/tests/etloutput/test_token_table_cell.py @@ -4,14 +4,7 @@ import pytest from indico_toolkit import etloutput -from indico_toolkit.etloutput import ( - NULL_SPAN, - CellType, - EtlOutput, - Span, - TableCellNotFoundError, - TokenNotFoundError, -) +from indico_toolkit.etloutput import NULL_SPAN, NULL_TOKEN, CellType, EtlOutput, Span data_folder = Path(__file__).parent.parent / "data" / "etloutput" etl_output_file = data_folder / "4725" / "111924" / "110239" / "etl_output.json" @@ -29,6 +22,11 @@ def etl_output() -> EtlOutput: return etloutput.load(etl_output_file, reader=read_uri) +@pytest.fixture(scope="module") +def etl_output_no_tokens_tables() -> EtlOutput: + return etloutput.load(etl_output_file, reader=read_uri, tokens=False, tables=False) + + @pytest.fixture def header_span() -> Span: return Span(page=1, start=1281, end=1285) @@ -39,6 +37,21 @@ def content_span() -> Span: return Span(page=1, start=1343, end=1349) +@pytest.fixture +def line_item_span() -> Span: + return Span(page=1, start=1311, end=1244) + + +@pytest.fixture +def mulitple_table_span() -> Span: + return Span(page=1, start=1217, end=1299) + + +@pytest.fixture +def outside_table_span() -> Span: + return Span(page=1, start=1056, end=1067) + + def test_text_slice( etl_output: EtlOutput, header_span: Span, content_span: Span ) -> None: @@ -58,18 +71,20 @@ def test_token(etl_output: EtlOutput, header_span: Span, content_span: Span) -> def test_token_not_found(etl_output: EtlOutput, header_span: Span) -> None: - with pytest.raises(TokenNotFoundError): - etl_output.token_for(replace(header_span, page=3)) + assert etl_output.token_for(replace(header_span, page=3)) == NULL_TOKEN + assert etl_output.token_for(NULL_SPAN) == NULL_TOKEN + + +def test_no_tokens(etl_output_no_tokens_tables: EtlOutput, header_span: Span) -> None: + assert etl_output_no_tokens_tables.token_for(header_span) == NULL_TOKEN + assert etl_output_no_tokens_tables.token_for(NULL_SPAN) == NULL_TOKEN def test_table_cell( etl_output: EtlOutput, header_span: Span, content_span: Span ) -> None: - header_token = etl_output.token_for(header_span) - content_token = etl_output.token_for(content_span) - - header_table, header_cell = etl_output.table_cell_for(header_token) - content_table, content_cell = etl_output.table_cell_for(content_token) + (header_table, header_cell), *_ = etl_output.table_cells_for(header_span) + (content_table, content_cell), *_ = etl_output.table_cells_for(content_span) assert header_cell.span == header_span assert content_cell.span == content_span @@ -81,10 +96,34 @@ def test_table_cell( assert content_cell.text == "720.00" -def test_table_cell_not_found(etl_output: EtlOutput) -> None: - with pytest.raises(TableCellNotFoundError): - token = etl_output.token_for(Span(page=0, start=0, end=8)) - etl_output.table_cell_for(token) +def test_table_cells(etl_output: EtlOutput, line_item_span: Span) -> None: + table_cells = etl_output.table_cells_for(line_item_span) + correct_table = etl_output.tables[3] + correct_row = correct_table.rows[1] + correct_cells = correct_row[1:4] + + for (table, cell), correct_cell in zip(table_cells, correct_cells): + assert table == correct_table + assert cell == correct_cell + + +def test_multiple_tables(etl_output: EtlOutput, mulitple_table_span: Span) -> None: + table_cells = etl_output.table_cells_for(mulitple_table_span) + cells = [cell for (table, cell) in table_cells] + _correct_cells = etl_output.tables[2].rows[-1] + etl_output.tables[3].rows[0] + correct_cells = [cell for cell in _correct_cells if cell.text] + assert cells == correct_cells + + +def test_table_cell_not_found(etl_output: EtlOutput, outside_table_span: Span) -> None: + assert not tuple(etl_output.table_cells_for(outside_table_span)) + assert not tuple(etl_output.table_cells_for(NULL_SPAN)) + assert not tuple(etl_output.table_cells_for(Span(-1, -1, -1))) + + +def test_no_tables(etl_output_no_tokens_tables: EtlOutput, header_span: Span) -> None: + assert not tuple(etl_output_no_tokens_tables.table_cells_for(header_span)) + assert not tuple(etl_output_no_tokens_tables.table_cells_for(NULL_SPAN)) def test_empty_cell(etl_output: EtlOutput) -> None: diff --git a/tests/etloutput/test_utils.py b/tests/etloutput/test_utils.py new file mode 100644 index 00000000..4da4d0d6 --- /dev/null +++ b/tests/etloutput/test_utils.py @@ -0,0 +1,54 @@ +import pytest + +from indico_toolkit.etloutput.utils import get, has + + +@pytest.fixture +def cell() -> "dict[str, object]": + return { + "cell_type": "header", + "columns": [0], + "rows": [0], + "doc_offsets": [{"start": 285, "end": 289}], + "position": {"bottom": 1209, "left": 150, "right": 848, "top": 1107}, + "text": "Item", + } + + +def test_get_has(cell: "dict[str, object]") -> None: + assert has(cell, str, "text") + assert get(cell, str, "text") == "Item" + + assert has(cell, dict, "position") + assert has(cell, int, "position", "top") + assert get(cell, int, "position", "top") == 1107 + + assert has(cell, list, "doc_offsets") + assert has(cell, int, "doc_offsets", 0, "start") + assert get(cell, int, "doc_offsets", 0, "start") == 285 + + +def test_get_has_not(cell: object) -> None: + assert not has(cell, str, "missing") + with pytest.raises(KeyError): + get(cell, str, "missing") + + assert not has(cell, int, "text") + with pytest.raises(TypeError): + get(cell, int, "text") + + assert not has(cell, float, "position", "top", 0) + with pytest.raises(TypeError): + get(cell, float, "position", "top", 0) + + assert not has(cell, int, "doc_offsets", "0", "start") + with pytest.raises(TypeError): + get(cell, int, "doc_offsets", "0", "start") + + assert not has(cell, int, "doc_offsets", -1, "start") + with pytest.raises(IndexError): + get(cell, int, "doc_offsets", -1, "start") + + assert not has(cell, int, "doc_offsets", -1, "start") + with pytest.raises(IndexError): + get(cell, int, "doc_offsets", 1, "start") diff --git a/tests/results/test_files.py b/tests/results/test_files.py index 8d1513c4..04e3fc50 100644 --- a/tests/results/test_files.py +++ b/tests/results/test_files.py @@ -3,7 +3,6 @@ import pytest from indico_toolkit import results -from indico_toolkit.results import ResultError data_folder = Path(__file__).parent.parent / "data" / "results" @@ -26,5 +25,5 @@ async def path_read_bytes_async(path: Path) -> bytes: def test_usupported_version() -> None: - with pytest.raises(ResultError): + with pytest.raises(ValueError): results.load({"file_version": 1})