From bf32902973ef08359073242eda4320eee523fc76 Mon Sep 17 00:00:00 2001 From: Vitaliy Mysak Date: Wed, 18 Feb 2026 23:00:33 +0000 Subject: [PATCH 1/4] Replace SubTextIO with SubBinaryIO --- src/multicsv/exceptions.py | 37 --- src/multicsv/file.py | 236 +++++++++++------- src/multicsv/open.py | 15 +- src/multicsv/subbinaryio.py | 279 +++++++++++++++++++++ src/multicsv/subtextio.py | 368 ---------------------------- tests/test_examples.py | 25 +- tests/test_file.py | 72 +++--- tests/test_open.py | 4 +- tests/test_subtextio.py | 470 ------------------------------------ 9 files changed, 474 insertions(+), 1032 deletions(-) create mode 100644 src/multicsv/subbinaryio.py delete mode 100644 src/multicsv/subtextio.py delete mode 100644 tests/test_subtextio.py diff --git a/src/multicsv/exceptions.py b/src/multicsv/exceptions.py index 39999f0..41dc26f 100644 --- a/src/multicsv/exceptions.py +++ b/src/multicsv/exceptions.py @@ -1,38 +1,5 @@ -class SubTextIOErrror(Exception): - """Base class for all SubTextIO custom exceptions.""" - pass - - -class OpOnClosedError(SubTextIOErrror, ValueError): - pass - - -class InvalidWhenceError(SubTextIOErrror, ValueError): - pass - - -class InvalidSubtextCoordinates(SubTextIOErrror, ValueError): - pass - - -class BaseMustBeSeekable(SubTextIOErrror, ValueError): - pass - - -class BaseMustBeReadable(SubTextIOErrror, ValueError): - pass - - -class EndsBeyondBaseContent(SubTextIOErrror, ValueError): - pass - - -class BaseIOClosed(SubTextIOErrror, ValueError): - pass - - class MultiCSVFileError(Exception): """Base class for all MultiCSVFile custom exceptions.""" pass @@ -48,7 +15,3 @@ class CSVFileBaseIOClosed(MultiCSVFileError, ValueError): class SectionNotFound(MultiCSVFileError, KeyError): pass - - -class BrokenTell(MultiCSVFileError, IOError): - pass diff --git a/src/multicsv/file.py b/src/multicsv/file.py index f3b158d..1f1aa26 100644 --- a/src/multicsv/file.py +++ b/src/multicsv/file.py @@ -1,12 +1,11 @@ -from typing import TextIO, Optional, Type, List, MutableMapping, Iterator +from typing import BinaryIO, TextIO, Optional, Type, List, MutableMapping, \ + Iterator import csv -import shutil -import os import io -from .subtextio import SubTextIO +from .subbinaryio import SubBinaryIO from .exceptions import OpOnClosedCSVFileError, CSVFileBaseIOClosed, \ - SectionNotFound, BrokenTell + SectionNotFound from .section import MultiCSVSection @@ -28,15 +27,17 @@ class MultiCSVFile(MutableMapping[str, TextIO]): Structure: ---------- - - The class initializes by reading the CSV file and identifying sections - encapsulated within bracketed headers (e.g., [section_name]). - - Each section is represented by a MultiCSVSection that holds a - descriptor to a SubTextIO object, allowing isolated operations - within that section. + - The class initialises by reading the CSV file in binary mode and + identifying sections encapsulated within bracketed headers + (e.g. [section_name]). + - Each section is represented by a MultiCSVSection whose descriptor is + an io.TextIOWrapper wrapping a SubBinaryIO. SubBinaryIO is a live view + into the base file keyed by byte offsets, so tell() always returns a + plain integer offset regardless of the file's encoding. - Operations like reading, writing, iterating, and deleting sections are - supported. + supported. - Changes in sections are committed back to the base CSV file when - the `flush` or `close` method is invoked. + the `flush` or `close` method is invoked. Use Cases: ---------- @@ -45,14 +46,14 @@ class MultiCSVFile(MutableMapping[str, TextIO]): - Structured log files where each segment represents a distinct log category. - Processing large CSV files by logically splitting them into independent - sections for easier manipulation. + sections for easier manipulation. Interface Functions: -------------------- - `__getitem__(key: str) -> TextIO`: Retrieves the TextIO object for the - specified section. + specified section. - `__setitem__(key: str, value: TextIO) -> None`: Sets the TextIO - object for the specified section. + object for the specified section. - `__delitem__(key: str) -> None`: Deletes the specified section. - `__iter__() -> Iterator[str]`: Iterates over the section names. - `__len__() -> int`: Returns the number of sections. @@ -61,58 +62,49 @@ class MultiCSVFile(MutableMapping[str, TextIO]): uncommitted changes. - `flush() -> None`: Commits changes in sections back to the base CSV file. - Context Management Support: Allows for usage with `with` statement for - automatic resource management. + automatic resource management. Examples: --------- ```python import io - from multicsv.multicsvfile import MultiCSVFile + from multicsv.file import MultiCSVFile - # Initialize the MultiCSVFile with a base CSV string - csv_content = io.StringIO("[section1]\na,b,c\n1,2,3\n" - "[section2]\nd,e,f\n4,5,6\n") + # Initialize the MultiCSVFile with a base CSV byte stream + csv_content = io.BytesIO( + b"[section1]\\na,b,c\\n1,2,3\\n[section2]\\nd,e,f\\n4,5,6\\n") csv_file = MultiCSVFile(csv_content) - # Accessing a section + # Accessing a section (returns TextIO decoded with the given encoding) section1 = csv_file["section1"] - print(section1.read()) # Should output 'a,b,c\n1,2,3\n' + print(section1.read()) # Should output 'a,b,c\\n1,2,3\\n' # Adding a new section - new_section = io.StringIO("g,h,i\n7,8,9\n") + new_section = io.StringIO("g,h,i\\n7,8,9\\n") csv_file["section3"] = new_section csv_file.flush() # Verify the new section is added csv_content.seek(0) print(csv_content.read()) - # Outputs - # [section1] - # a,b,c - # 1,2,3 - # [section2] - # d,e,f - # 4,5,6 - # [section3] - # g,h,i - # 7,8,9 ``` Caveats: -------- - - Writing to and reading from the base TextIO when it is used in - MultiCSVFile can lead to unexpected results. - - Always ensure to call `flush` or use context management to commit changes - back to the base CSV file. - - Mixing reading/writing operations from MultiCSVFile and the base TextIO - directly may cause inconsistencies. + - The base BinaryIO must remain open for the lifetime of MultiCSVFile. + - Always ensure to call `flush` or use context management to commit + changes back to the base CSV file. + - Mixing reads/writes on MultiCSVFile and the base BinaryIO directly + may cause inconsistencies. """ - def __init__(self, file: TextIO, own: bool = False): + def __init__(self, file: BinaryIO, own: bool = False, + encoding: str = 'utf-8') -> None: self._initialized = False self._need_flush = False self._own_file = own self._file = file + self._encoding = encoding self._closed = self._file.closed self._sections: List[MultiCSVSection] = [] self._initialize_sections() @@ -197,22 +189,27 @@ def close(self) -> None: else: self._closed = True - def _write_section(self, section: MultiCSVSection) -> None: - self._file.write(f"[{section.name}]\n") - - initial_section_pos = section.descriptor.tell() - try: - section.descriptor.seek(0) - shutil.copyfileobj(section.descriptor, self._file) - finally: - section.descriptor.seek(initial_section_pos) - def _write_file(self) -> None: + # Collect all section text BEFORE touching the base file. Each + # descriptor is a TextIOWrapper over a SubBinaryIO that reads directly + # from `self._file`; once we seek(0) + truncate() below those bytes + # would be gone. Pre-reading here makes the subsequent rewrite safe. + sections_data: List[tuple[str, str]] = [] + for section in self._sections: + saved_pos = section.descriptor.tell() + try: + section.descriptor.seek(0) + text: str = section.descriptor.read() + finally: + section.descriptor.seek(saved_pos) + sections_data.append((section.name, text)) + self._file.seek(0) self._file.truncate() - for section in self._sections: - self._write_section(section) + for name, text in sections_data: + self._file.write(f"[{name}]\n".encode(self._encoding)) + self._file.write(text.encode(self._encoding)) def flush(self) -> None: if self._file.closed: @@ -221,12 +218,12 @@ def flush(self) -> None: if not self._need_flush: return - initial_file_pos = self._file.tell() + saved = self._file.tell() try: self._write_file() self._need_flush = False finally: - self._file.seek(initial_file_pos) + self._file.seek(saved) def __enter__(self) -> 'MultiCSVFile': return self @@ -238,37 +235,90 @@ def __exit__(self, self.close() def _initialize_sections_wrapped(self) -> None: - current_section: Optional[str] = None - section_start = 0 - previous_position = 0 + # Binary readline() gives exact byte offsets from tell() with no + # opaque codec-state cookies, but only when the encoding maps \n to + # the single byte 0x0a. EBCDIC encodings use 0x25 for \n, and + # UTF-16/32 encode \n as a multi-byte sequence; in those cases we + # fall back to a whole-file text-decode approach. + try: + newline_byte = '\n'.encode(self._encoding) + except (LookupError, UnicodeEncodeError): + newline_byte = b'\n' - def end_section() -> None: - if current_section is not None: - descriptor = SubTextIO(self._file, - start=section_start, - end=previous_position) - section = MultiCSVSection(name=current_section, - descriptor=descriptor) - self._sections.append(section) - - self._file.seek(0, os.SEEK_END) - final_position = self._file.tell() - if final_position == 0: - return + if newline_byte == b'\n': + self._initialize_sections_binary() + else: + self._initialize_sections_text() + def _initialize_sections_binary(self) -> None: + """Section detection via binary readline + byte-offset SubBinaryIO.""" self._file.seek(0) + current_section: Optional[str] = None + section_start = 0 # byte offset where current section's data begins + while True: - line = self._file.readline() - if not line: + line_start: int = self._file.tell() + line_bytes: bytes = self._file.readline() + + if not line_bytes: + # EOF – close out the last section. + if current_section is not None: + self._sections.append(MultiCSVSection( + name=current_section, + descriptor=io.TextIOWrapper( + SubBinaryIO(self._file, section_start, line_start), + encoding=self._encoding, + ), + )) break - current_position = self._file.tell() - if current_position > final_position: - raise BrokenTell("Base file has a broken tell() function.") + line_text = line_bytes.decode(self._encoding, + errors='replace').strip() + if line_text: + row = next(csv.reader([line_text])) + if len(row) == 0: + break + + first = row[0].strip() + rest = row[1:] + + if first.startswith("[") and \ + first.endswith("]") and \ + all(not x for x in rest): + + # Close the previous section (data ran from + # section_start up to – but not including – this + # header line). + if current_section is not None: + self._sections.append(MultiCSVSection( + name=current_section, + descriptor=io.TextIOWrapper( + SubBinaryIO(self._file, + section_start, line_start), + encoding=self._encoding, + ), + )) + current_section = first[1:-1] + # Section data starts right after the header line. + section_start = self._file.tell() + + def _initialize_sections_text(self) -> None: + """Fallback for encodings whose newline is not the single byte 0x0a + (EBCDIC, UTF-16, UTF-32, …). Reads the whole file, decodes to text, + and stores each section as an io.StringIO.""" + self._file.seek(0) + raw = self._file.read() + if not raw: + return + + full_text = raw.decode(self._encoding, errors='replace') + current_section: Optional[str] = None + section_lines: List[str] = [] - line = line.strip() - if line: - row = next(csv.reader([line])) + for line in full_text.splitlines(keepends=True): + stripped = line.strip() + if stripped: + row = next(csv.reader([stripped])) if len(row) == 0: break @@ -279,26 +329,36 @@ def end_section() -> None: first.endswith("]") and \ all(not x for x in rest): - end_section() + if current_section is not None: + self._sections.append(MultiCSVSection( + name=current_section, + descriptor=io.StringIO( + "".join(section_lines)), + )) current_section = first[1:-1] - section_start = current_position + section_lines = [] + continue - previous_position = current_position + if current_section is not None: + section_lines.append(line) - end_section() + if current_section is not None: + self._sections.append(MultiCSVSection( + name=current_section, + descriptor=io.StringIO("".join(section_lines)), + )) def _initialize_sections(self) -> None: - initial_file_pos = self._file.tell() + if not self._file.readable(): + return + + saved = self._file.tell() try: self._initialize_sections_wrapped() finally: - self._file.seek(initial_file_pos) + self._file.seek(saved) def _check_closed(self) -> None: - """ - Helper method to verify if the IO object is closed. - """ - if self._closed: raise OpOnClosedCSVFileError("I/O operation on closed file.") diff --git a/src/multicsv/open.py b/src/multicsv/open.py index e9410b3..b977672 100644 --- a/src/multicsv/open.py +++ b/src/multicsv/open.py @@ -1,5 +1,5 @@ -from typing import Union, Literal, TextIO +from typing import Union, Literal, BinaryIO from pathlib import Path from .file import MultiCSVFile @@ -9,14 +9,13 @@ def multicsv_open(path: OpenPath, - mode: Literal["r", "w", "a", "x", "r+", "w+", "a+", "x+", - "rt", "wt", "at", "xt", "r+t", "w+t", "a+t", - "x+t"] = "rt") \ - -> MultiCSVFile: + mode: Literal["rb", "wb", "ab", "xb", + "r+b", "w+b", "a+b", "x+b"] = "rb", + encoding: str = 'utf-8') -> MultiCSVFile: file = open(path, mode=mode) - return MultiCSVFile(file, own=True) + return MultiCSVFile(file, own=True, encoding=encoding) -def multicsv_wrap(file: TextIO) -> MultiCSVFile: - return MultiCSVFile(file) +def multicsv_wrap(file: BinaryIO, encoding: str = 'utf-8') -> MultiCSVFile: + return MultiCSVFile(file, encoding=encoding) diff --git a/src/multicsv/subbinaryio.py b/src/multicsv/subbinaryio.py new file mode 100644 index 0000000..0e687b8 --- /dev/null +++ b/src/multicsv/subbinaryio.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, BinaryIO +import io +import os + +if TYPE_CHECKING: + from _typeshed import ReadableBuffer, WriteableBuffer + + +class SubBinaryIO(io.BufferedIOBase): + """ + A seekable read/write view of a contiguous byte range ``[start, end)`` + within a *base_io* binary file. + + SubBinaryIO does **not** copy bytes into memory. Every read, write, and + seek operation is forwarded directly to *base_io* after translating the + local offset into an absolute file position. This means: + + * The view is cheaply created even for multi-gigabyte files. + * Writes to the view are immediately visible through *base_io* and vice- + versa (they share the same underlying file descriptor). + * ``tell()`` always returns a plain ``int`` byte offset (never an opaque + codec-state cookie), making it safe to use as a boundary marker when + slicing a file into sections. + + Coordinates + ----------- + *start* and *end* are **absolute** byte positions in *base_io* measured + from the beginning of the file (as returned by ``base_io.tell()`` after + ``base_io.seek(0)``). The view covers the half-open interval + ``[start, end)``. An empty view (``start == end``) is allowed. + + The internal *position* is always **relative** to *start* and runs in + ``[0, end - start]``. In other words, ``tell()`` returns ``0`` just + after construction and ``end - start`` once all bytes have been read. + + Closing + ------- + ``close()`` marks this view as closed but does **not** close *base_io*. + The owner is responsible for the lifetime of *base_io*. + + Compatibility with ``io.TextIOWrapper`` + ---------------------------------------- + SubBinaryIO satisfies every requirement of ``io.BufferedIOBase`` so it + can be wrapped directly: + + >>> import io + >>> raw = io.BytesIO(b"hello\\nworld\\n") + >>> sub = SubBinaryIO(raw, start=6, end=12) # b"world\\n" + >>> text = io.TextIOWrapper(sub, encoding="utf-8") + >>> text.read() + 'world\\n' + + Because ``tell()`` returns a plain integer, wrapping with + ``TextIOWrapper`` is safe regardless of the encoding of *base_io* – + there are no opaque codec-state cookies. + + Read / write example + -------------------- + >>> raw = io.BytesIO(b"AAAABBBBCCCC") + >>> sub = SubBinaryIO(raw, start=4, end=8) # b"BBBB" + >>> sub.read(2) + b'BB' + >>> sub.write(b"XX") # overwrites the 3rd and 4th B + 2 + >>> raw.seek(0); raw.read() + b'AAAABBXXCCCC' + """ + + def __init__(self, base_io: BinaryIO, start: int, end: int) -> None: + super().__init__() + self._base_io = base_io + self._start = start + self._end = end + self._position = 0 # relative to _start + + # ------------------------------------------------------------------ # + # IO[bytes] protocol attributes # + # ------------------------------------------------------------------ # + + @property + def name(self) -> str: + """Nominal name of this byte-range view. + + Returns a fixed string since the view has no file-system path of its + own. The property exists to satisfy the ``typing.IO[bytes]`` protocol + required by ``io.TextIOWrapper``. + """ + return "" + + @property + def mode(self) -> str: + """Access mode string (always ``'rb+'``). + + SubBinaryIO supports both reading and writing so ``'rb+'`` is the + appropriate description. The property exists to satisfy the + ``typing.IO[bytes]`` protocol required by ``io.TextIOWrapper``. + """ + return "rb+" + + # ------------------------------------------------------------------ # + # Capability flags # + # ------------------------------------------------------------------ # + + def readable(self) -> bool: + return True + + def writable(self) -> bool: + return True + + def seekable(self) -> bool: + return True + + # ------------------------------------------------------------------ # + # Internal helpers # + # ------------------------------------------------------------------ # + + @property + def _abs_pos(self) -> int: + """Absolute position in *base_io* corresponding to the current local position.""" + return self._start + self._position + + @property + def _region_size(self) -> int: + """Length of the byte region this view covers.""" + return self._end - self._start + + def _remaining(self) -> int: + """Number of bytes left between the current position and the end of the region.""" + return max(0, self._region_size - self._position) + + # ------------------------------------------------------------------ # + # Read operations # + # ------------------------------------------------------------------ # + + def readinto(self, b: WriteableBuffer) -> int: + """Read up to ``len(b)`` bytes into *b*, capped to the region boundary.""" + self._checkClosed() + mv = memoryview(b) + want = min(len(mv), self._remaining()) + if want == 0: + return 0 + self._base_io.seek(self._abs_pos) + chunk = self._base_io.read(want) + n = len(chunk) + mv[:n] = chunk + self._position += n + return n + + def read(self, size: int | None = -1) -> bytes: + """Read at most *size* bytes from the region (all remaining if *size* is ``None`` or negative).""" + self._checkClosed() + effective = -1 if size is None else size + want = self._remaining() if effective < 0 else min(effective, self._remaining()) + if want == 0: + return b"" + self._base_io.seek(self._abs_pos) + chunk = self._base_io.read(want) + self._position += len(chunk) + return chunk + + def read1(self, size: int | None = -1) -> bytes: + """Read up to *size* bytes in a single underlying call. + + Mirrors ``io.BufferedIOBase.read1()``, required by ``io.TextIOWrapper``. + For a stream with no internal read-ahead buffer this is identical to + :meth:`read`. + """ + return self.read(size) + + def readline(self, size: int | None = -1) -> bytes: + """Read up to the next ``\\n`` (inclusive) or EOF, capped by *size* and the region end. + + The scan is performed by reading the region incrementally in small + chunks so that only a single line's worth of data is held in memory at + any time. + """ + self._checkClosed() + effective = -1 if size is None else size + limit = self._remaining() if effective < 0 else min(effective, self._remaining()) + if limit == 0: + return b"" + # Scan for a newline inside the region without loading the whole region. + self._base_io.seek(self._abs_pos) + line = self._base_io.read(limit) + nl = line.find(b"\n") + if nl != -1: + line = line[: nl + 1] + self._position += len(line) + return line + + # ------------------------------------------------------------------ # + # Write operations # + # ------------------------------------------------------------------ # + + def write(self, b: ReadableBuffer) -> int: + """Write *b* into the region at the current position. + + Data is written directly to *base_io* at ``start + position``. + Writing beyond the current region end is **not** allowed; any + excess bytes are silently dropped and only the bytes that fit + within ``[start, end)`` are written. Returns the number of bytes + actually written. + """ + self._checkClosed() + data = bytes(b) + space = self._remaining() + if space == 0: + return 0 + chunk = data[:space] + self._base_io.seek(self._abs_pos) + n = self._base_io.write(chunk) + self._position += n + return n + + def truncate(self, size: int | None = None) -> int: + """Truncate the *logical* view to *size* bytes from the start of the region. + + This updates the ``end`` boundary of the view; it does **not** + truncate *base_io*. + """ + self._checkClosed() + if size is None: + size = self._position + self._end = self._start + max(0, size) + if self._position > size: + self._position = size + return self._region_size + + # ------------------------------------------------------------------ # + # Position # + # ------------------------------------------------------------------ # + + def seek(self, pos: int, whence: int = os.SEEK_SET) -> int: + """Seek to *pos* within the region. + + All three standard *whence* values are supported: + ``SEEK_SET`` (0), ``SEEK_CUR`` (1), and ``SEEK_END`` (2). + The position is clamped to ``[0, region_size]``; seeking before + the start of the region is silently clamped to 0. + """ + self._checkClosed() + if whence == os.SEEK_SET: + target = pos + elif whence == os.SEEK_CUR: + target = self._position + pos + elif whence == os.SEEK_END: + target = self._region_size + pos + else: + raise OSError(f"Invalid whence value: {whence!r}") + self._position = max(0, min(target, self._region_size)) + return self._position + + def tell(self) -> int: + """Return the current position as a byte offset from the start of the region. + + Unlike ``io.TextIOWrapper.tell()`` the value is always a plain + integer and never an opaque codec-state cookie. + """ + self._checkClosed() + return self._position + + # ------------------------------------------------------------------ # + # Flush / close # + # ------------------------------------------------------------------ # + + def flush(self) -> None: + """Flush the underlying *base_io* without closing it.""" + if not self.closed: + self._base_io.flush() + + def close(self) -> None: + """Mark this view as closed. + + Does **not** close *base_io*; the owner is responsible for the + lifetime of the underlying file object. + """ + super().close() diff --git a/src/multicsv/subtextio.py b/src/multicsv/subtextio.py deleted file mode 100644 index bf40408..0000000 --- a/src/multicsv/subtextio.py +++ /dev/null @@ -1,368 +0,0 @@ -from typing import TextIO, List, Optional, Type, Iterable -import io -import os -from .exceptions import OpOnClosedError, \ - InvalidWhenceError, InvalidSubtextCoordinates, \ - BaseMustBeReadable, BaseMustBeSeekable, \ - EndsBeyondBaseContent, BaseIOClosed - - -class SubTextIO(TextIO): - """ - SubTextIO provides an interface for reading, writing, and - manipulating a specified subsection of a base TextIO object. This - class allows for convenient and isolated operations within a given - range of the base TextIO buffer, while efficiently buffering - content to minimize repeated seeks. - - Purpose: - -------- - The primary aim of SubTextIO is to allow for editing, reading, and - writing operations on a specific segment of a TextIO object - without affecting other parts. This can be particularly useful in - scenarios where parts of a large text file need to be updated or - read independently. - - Structure: - ---------- - - The class initializes by reading and storing the relevant - segment of the base TextIO into an in-memory buffer. - - Operations (read, write, seek, etc.) are done on this buffer. - - Changes are committed back to the base TextIO when the `flush` - or `close` method is called. - - Use Cases: - ---------- - - Editing a specific section of a configuration file or log - without loading the entire file. - - Concurrent processing on different segments of a large file. - - Efficiently managing memory and I/O operations for large-scale - text processing tasks. - - Interface Functions: - -------------------- - - `read(size: int = -1) -> str`: Reads a specified number of - characters from the buffer. - - `readline(limit: int = -1) -> str`: Reads and returns one line - from the buffer. - - `readlines(hint: int = -1) -> List[str]`: Reads and returns all - remaining lines from the buffer. - - `write(s: str) -> int`: Writes a string to the buffer. - - `writelines(lines: List[str]) -> None`: Writes a list of lines - to the buffer. - - `truncate(size: int) -> int`: Resizes the section. - - `seek(offset: int, whence: int = 0) -> int`: Moves the buffer's - current position. - - `tell() -> int`: Returns the current position in the buffer. - - `flush() -> None`: Writes the buffer content back to the base - TextIO object. - - `close() -> None`: Flushes the buffer and closes this IO object. - - Context Management Support: Allows for usage with `with` - statement for automatic resource management. - - Examples: - --------- - ```python - import io - from subtextio import SubTextIO - - base_text = io.StringIO("Hello\nWorld\nThis\nIs\nA\nTest\n") - sub_text = SubTextIO(base_text, start=6, end=21) - - # Should output 'World\n' - print("Reading first line:", sub_text.readline()) - # Should output 'This\nIs\n' - print("Reading rest within range:", sub_text.read()) - - sub_text.seek(0) - sub_text.write("NewContent") # Overwrites 'World\nThis\n' - sub_text.seek(0) - # Should output 'NewContentIs\n' - print("After write operation:", sub_text.read()) - - sub_text.delete_section() - sub_text.write("Overwritten") - sub_text.seek(0) - # Should output 'Overwritten' - print("After delete and write operation:", sub_text.read()) - - # Make sure changes are committed to the base TextIO - sub_text.flush() - # Should reflect changes in original buffer - print("Base IO after SubTextIO flush:", base_text.getvalue()) - ``` - - Caveats: - -------- - - Writing to and reading from the base TextIO when it is used in SubTextIO - can lead to unexpected results. - - SubTextIO loads the subsection into memory. Thus be cautious - of buffer size when working with very large files. - - Always ensure to call `flush` or use context management to - commit changes back to the base TextIO. - """ - - def __init__(self, base_io: TextIO, start: int, end: int): - self._initialized = False - self._need_flush = False - self._base_io = base_io - self._start = start - self._end = end - self._position = 0 # Position within the SubTextIO - self._closed = base_io.closed - self._buffer = "" - - if end < start or start < 0: - raise InvalidSubtextCoordinates( - f"Invalid range [{start},{end}] passed to SubTextIO.") - - if not base_io.seekable(): - raise BaseMustBeSeekable("Base io must be seekable.") - - if end > start and not base_io.readable(): - # TODO: losen this requirement because if we override by - # the same length, then we dont need to read. - raise BaseMustBeReadable("Base io must be readable" - " if existing content is to be modified.") - - self._load() - self._initial_length = self.buffer_length - self._initialized = True - - def _load(self) -> None: - """ - Load the relevant part of the base_io into the buffer. - """ - - base_initial_position = self._base_io.tell() - - # - # Below we try to avoid calling `base_io.read()` as much as possible. - # - try: - self._base_io.seek(0, os.SEEK_END) - base_last_position = self._base_io.tell() - - if self.end > base_last_position: - raise EndsBeyondBaseContent( - "End position is greater than base TextIO length.") - - if self.end > self.start: - self._base_io.seek(self.end) - base_final_position = self._base_io.tell() - self.is_at_end = base_final_position == base_last_position - self._base_io.seek(self.start) - self._buffer = self._base_io.read(self.end - self.start) - else: - base_final_position = self.start - self._base_io.seek(0, os.SEEK_END) - self.is_at_end = base_final_position == self._base_io.tell() - finally: - self._base_io.seek(base_initial_position) - - @property - def start(self) -> int: - return self._start - - @property - def end(self) -> int: - return self._end - - @property - def buffer_length(self) -> int: - return len(self._buffer) - - @property - def mode(self) -> str: - return self._base_io.mode - - @property - def closed(self) -> bool: - return self._closed - - @property - def encoding(self) -> str: - return self._base_io.encoding - - def read(self, size: int = -1) -> str: - self._check_closed() - - if size < 0 or size > self.buffer_length - self._position: - size = self.buffer_length - self._position - - result = self._buffer[self._position:self._position + size] - self._position += len(result) - return result - - def readline(self, limit: int = -1) -> str: - self._check_closed() - - if self._position >= self.buffer_length: - return '' - - newline_pos = self._buffer.find('\n', self._position) - if newline_pos == -1 or newline_pos >= self.buffer_length: - newline_pos = self.buffer_length - - if limit < 0 or limit > newline_pos - self._position: - limit = newline_pos - self._position + 1 - - result = self._buffer[self._position:self._position + limit] - self._position += len(result) - return result - - def readlines(self, hint: int = -1) -> List[str]: - """ - The `hint` argument in the `readlines` method of the `TextIO` - interface serves as a performance hint rather than a strict - limit. It indicates the approximate number of bytes to read - from the file. If the hint is positive, the implementation may - read more than the hint value to complete a line but will aim - to read at least as many bytes as specified by the hint. - """ - - self._check_closed() - - remaining_buffer = self._buffer[self._position:] - lines = remaining_buffer.splitlines(keepends=True) - read_size = 0 - result = [] - - for line in lines: - result.append(line) - read_size += len(line) - if 0 <= hint <= read_size: - break - - self._position += read_size - return result - - def write(self, s: str) -> int: - self._check_closed() - - pre = self._buffer[:self._position] - post = self._buffer[self._position + len(s):] - written = len(s) - - self._buffer = pre + s + post - self._position += written - self._need_flush = True - - return written - - def writelines(self, lines: Iterable[str]) -> None: - for line in lines: - self.write(line) - - def truncate(self, size: Optional[int] = None) -> int: - self._check_closed() - - if size is None: - end = self._position - else: - end = size - - self._buffer = self._buffer[:end] - self._need_flush = True - return self.buffer_length - - def close(self) -> None: - if not self._closed: - try: - self.flush() - finally: - self._closed = True - - def seek(self, offset: int, whence: int = os.SEEK_SET) -> int: - self._check_closed() - - if whence == os.SEEK_SET: # Absolute positioning - target = offset - elif whence == os.SEEK_CUR: # Relative to current position - target = self._position + offset - elif whence == os.SEEK_END: # Relative to the end - target = self.buffer_length + offset - else: - raise InvalidWhenceError( - f"Invalid value for whence: {repr(whence)}") - - self._position = max(0, min(target, self.buffer_length)) - return self._position - - def tell(self) -> int: - self._check_closed() - return self._position - - def flush(self) -> None: - if self._base_io.closed: - raise BaseIOClosed("Base io is closed in flush.") - - if not self._need_flush: - return - - if not self._closed: - base_initial_position = self._base_io.tell() - try: - if self.buffer_length == self._initial_length \ - or self.is_at_end: - self._base_io.seek(self.start) - self._base_io.write(self._buffer) - else: - self._base_io.seek(self.end) - content_after = self._base_io.read() - - self._base_io.seek(self.start) - self._base_io.write(self._buffer + content_after) - - self._base_io.flush() - self._need_flush = False - finally: - self._base_io.seek(base_initial_position) - - def isatty(self) -> bool: - return False - - def fileno(self) -> int: - raise io.UnsupportedOperation("Not a filesystem descriptor.") - - def readable(self) -> bool: - return self._base_io.readable() - - def writable(self) -> bool: - return self._base_io.writable() - - def seekable(self) -> bool: - return True - - def _check_closed(self) -> None: - """ - Helper method to verify if the IO object is closed. - """ - - if self._closed: - raise OpOnClosedError("I/O operation on closed file.") - - def __iter__(self) -> 'SubTextIO': - return self - - def __next__(self) -> str: - if self._position < self.buffer_length: - return self.readline() - else: - raise StopIteration - - def __enter__(self) -> 'SubTextIO': - return self - - def __exit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[object]) -> None: - self.close() - - def __del__(self) -> None: - if self._initialized: - try: - self.close() - except BaseIOClosed: - pass diff --git a/tests/test_examples.py b/tests/test_examples.py index 0bc47fd..3e8c1aa 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -44,7 +44,7 @@ def test_read_csv(example_file_1: Path) -> None: def test_write_csv(example_file_1): - with multicsv.open(example_file_1, mode='w+') as csv_file: + with multicsv.open(example_file_1, mode='w+b') as csv_file: # Write the CSV content to the file csv_file['section1'] = io.StringIO("header1,header2,header3\nvalue1,value2,value3\n") csv_file['section2'] = io.StringIO("header4,header5,header6\nvalue4,value5,value6\n") @@ -61,7 +61,7 @@ def test_write_csv(example_file_1): def test_write_csv_easier(example_file_1): - with multicsv.open(example_file_1, mode='w+') as csv_file: + with multicsv.open(example_file_1, mode='w+b') as csv_file: # Write the CSV content to the file csv_file.section('section1').write("header1,header2,header3\nvalue1,value2,value3\n") csv_file.section('section2').write("header4,header5,header6\nvalue4,value5,value6\n") @@ -167,8 +167,9 @@ def test_dict_read_csv2(example_file_2: Path) -> None: def test_open_csv(): - # Initialize the MultiCSVFile with a base CSV string - csv_content = io.StringIO("[section1]\na,b,c\n1,2,3\n[section2]\nd,e,f\n4,5,6\n") + # Initialize the MultiCSVFile with a base CSV byte stream + csv_content = io.BytesIO( + b"[section1]\na,b,c\n1,2,3\n[section2]\nd,e,f\n4,5,6\n") csv_file = multicsv.wrap(csv_content) # Accessing a section @@ -182,14 +183,8 @@ def test_open_csv(): # Verify the new section is added csv_content.seek(0) - assert csv_content.read() == """\ -[section1] -a,b,c -1,2,3 -[section2] -d,e,f -4,5,6 -[section3] -g,h,i -7,8,9 -""" + assert csv_content.read() == ( + b"[section1]\na,b,c\n1,2,3\n" + b"[section2]\nd,e,f\n4,5,6\n" + b"[section3]\ng,h,i\n7,8,9\n" + ) diff --git a/tests/test_file.py b/tests/test_file.py index 21ca327..7b02468 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -3,44 +3,41 @@ import pytest import csv from pathlib import Path -from typing import TextIO +from typing import BinaryIO from multicsv.file import MultiCSVFile from multicsv.exceptions import SectionNotFound, CSVFileBaseIOClosed, \ - OpOnClosedCSVFileError, BrokenTell + OpOnClosedCSVFileError @pytest.fixture -def simple_csv() -> TextIO: - content = """\ -[section1] +def simple_csv() -> BinaryIO: + content = b"""[section1] a,b,c 1,2,3 [section2] d,e,f 4,5,6 """ - return io.StringIO(content) + return io.BytesIO(content) @pytest.fixture -def empty_csv() -> TextIO: - return io.StringIO("") +def empty_csv() -> BinaryIO: + return io.BytesIO(b"") @pytest.fixture -def no_section_csv() -> TextIO: - content = """\ -a,b,c +def no_section_csv() -> BinaryIO: + content = b"""a,b,c 1,2,3 d,e,f 4,5,6 """ - return io.StringIO(content) + return io.BytesIO(content) -def make_encoded_csv(encoding: str) -> TextIO: - content = """\ -[section1] +def make_encoded_csv(encoding: str) -> BinaryIO: + content = """[section1] a,b,c 1,2,3 [section2] @@ -52,8 +49,7 @@ def make_encoded_csv(encoding: str) -> TextIO: """ binary = content.encode(encoding) - # Return as a TextIO, but not StringIO because that one only supports UTF-8. - return io.TextIOWrapper(io.BytesIO(binary), encoding=encoding) + return io.BytesIO(binary) # Get the list of currently supported encodings from the encodings module, and test them all. @@ -68,12 +64,11 @@ def try_encode(encoding: str) -> bool: TEXT_ENCODINGS = tuple(encoding for encoding in ENCODINGS if try_encode(encoding)) -@pytest.mark.skip(reason="Currently fails") @pytest.mark.parametrize("encoding", TEXT_ENCODINGS) def test_encoding_whole_content(encoding: str) -> None: content = "" - with MultiCSVFile(make_encoded_csv(encoding)) as csv_file: + with MultiCSVFile(make_encoded_csv(encoding), encoding=encoding) as csv_file: sections = list(csv_file) assert sections == ['section1', 'section2', 'some third\tsection'] for section in csv_file: @@ -102,13 +97,13 @@ def test_read_section(simple_csv): def test_read_section_from_file(simple_csv, tmp_path): path = tmp_path / "file1.txt" initial_content = simple_csv.read() - with open(path, "w") as writer: + with open(path, "wb") as writer: writer.write(initial_content) - with open(path, "r") as fd: + with open(path, "rb") as fd: assert fd.read() == initial_content - with open(path, "r") as fd: + with open(path, "rb") as fd: csv_file = MultiCSVFile(fd) section1 = csv_file["section1"] assert section1.read() == "a,b,c\n1,2,3\n" @@ -124,7 +119,7 @@ def test_write_section(simple_csv): csv_file.flush() simple_csv.seek(0) - assert simple_csv.read() == "[section1]\na,b,c\n1,2,3\n[section2]\nd,e,f\n4,5,6\n[section3]\ng,h,i\n7,8,9\n" + assert simple_csv.read() == b"[section1]\na,b,c\n1,2,3\n[section2]\nd,e,f\n4,5,6\n[section3]\ng,h,i\n7,8,9\n" def test_delete_section(simple_csv): @@ -133,7 +128,7 @@ def test_delete_section(simple_csv): csv_file.flush() simple_csv.seek(0) - assert simple_csv.read() == "[section2]\nd,e,f\n4,5,6\n" + assert simple_csv.read() == b"[section2]\nd,e,f\n4,5,6\n" def test_iterate_sections(simple_csv): @@ -200,7 +195,7 @@ def test_update_existing_section(simple_csv): csv_file.flush() simple_csv.seek(0) - expected_content = "[section1]\nnew,data\n[section2]\nd,e,f\n4,5,6\n" + expected_content = b"[section1]\nnew,data\n[section2]\nd,e,f\n4,5,6\n" assert simple_csv.read() == expected_content @@ -222,7 +217,7 @@ def test_multiple_writes_with_flush(simple_csv): csv_file.flush() simple_csv.seek(0) - expected_content = "[section1]\na,b,c\n1,2,3\n[section2]\nd,e,f\n4,5,6\n[section3]\nx,y,z\n10,11,12\n[section4]\np,q,r\n13,14,15\n" + expected_content = b"[section1]\na,b,c\n1,2,3\n[section2]\nd,e,f\n4,5,6\n[section3]\nx,y,z\n10,11,12\n[section4]\np,q,r\n13,14,15\n" assert simple_csv.read() == expected_content @@ -254,13 +249,13 @@ def test_section_not_found_for_deleted_section(simple_csv): @pytest.mark.parametrize("initial_content, expected_sections", [ - ("[first_section]\na,b,c\n1,2,3\n[second_section]\nd,e,f\n4,5,6\n", + (b"[first_section]\na,b,c\n1,2,3\n[second_section]\nd,e,f\n4,5,6\n", ["first_section", "second_section"]), - ("", []), - ("[lonely_section]\ng,h,i\n7,8,9\n", ["lonely_section"]), + (b"", []), + (b"[lonely_section]\ng,h,i\n7,8,9\n", ["lonely_section"]), ]) def test_various_initial_contents(initial_content, expected_sections): - file = io.StringIO(initial_content) + file = io.BytesIO(initial_content) csv_file = MultiCSVFile(file) assert list(iter(csv_file)) == expected_sections @@ -297,28 +292,17 @@ def test_open_nonpython_encoding(tmp_path: Path) -> None: with open(temp_file, "wb") as fd: fd.write(csv_content) - with MultiCSVFile(temp_file.open()) as csv_file: + with MultiCSVFile(temp_file.open("rb")) as csv_file: datasection = csv_file["section2"] csvdatasection = csv.DictReader(datasection) assert csvdatasection.fieldnames == ['d', 'e', 'f'] def test_no_newline_at_the_end(): - simple_csv = io.StringIO( - "[section1]\na,b,c\n1,2,3\n[section2]\nd,e,f\n4,5,6") + simple_csv = io.BytesIO( + b"[section1]\na,b,c\n1,2,3\n[section2]\nd,e,f\n4,5,6") with MultiCSVFile(simple_csv) as csv_file: datasection = csv_file["section2"] csvdatasection = csv.DictReader(datasection) assert csvdatasection.fieldnames == ['d', 'e', 'f'] - - -def test_broken_tell(tmp_path: Path) -> None: - csv_content = b'[section1]\ra,b,c\r1,2,3\r[section2]\r4,5,6\r' - - temp_file = tmp_path / "file1.csv" - with open(temp_file, "wb") as fd: - fd.write(csv_content) - - with pytest.raises(BrokenTell): - MultiCSVFile(temp_file.open()) diff --git a/tests/test_open.py b/tests/test_open.py index 9756551..9a2eebe 100644 --- a/tests/test_open.py +++ b/tests/test_open.py @@ -51,7 +51,7 @@ def test_open_write(tmp_path): path = tmp_path / "file2.txt" # Writing sections using multicsv_open - with multicsv_open(path, "wt") as csv_file: + with multicsv_open(path, "wb") as csv_file: csv_file["section1"] = io.StringIO("a,b,c\n1,2,3\n") csv_file["section2"] = io.StringIO("d,e,f\n4,5,6\n") # csv_file.flush() @@ -76,7 +76,7 @@ def test_open_append(tmp_path): writer.write(initial_content) # Appending new sections using multicsv_open - with multicsv_open(path, "a+t") as csv_file: + with multicsv_open(path, "a+b") as csv_file: csv_file["section2"] = io.StringIO("d,e,f\n4,5,6\n") # Validate the appended content diff --git a/tests/test_subtextio.py b/tests/test_subtextio.py deleted file mode 100644 index 52f3a19..0000000 --- a/tests/test_subtextio.py +++ /dev/null @@ -1,470 +0,0 @@ - -import io -from typing import TextIO -import pytest -import os -from multicsv.subtextio import SubTextIO -from multicsv.exceptions import OpOnClosedError, InvalidWhenceError, InvalidSubtextCoordinates, EndsBeyondBaseContent, BaseMustBeSeekable, BaseMustBeReadable - - -@pytest.fixture -def base_textio() -> TextIO: - return io.StringIO("""\ -Hello World, -this is a -test -""") - -@pytest.fixture -def large_textio() -> TextIO: - content = "a" * 10000 + "0123456789\n" + "a" * 5000 + "Python is Great\n" + "a" * 10000 - return io.StringIO(content) - -def test_read_1(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=15) - assert sub_text.read() == "World,\nth" - assert sub_text.read() == "" - assert sub_text.read() == "" - assert sub_text.read() == "" - -def test_read_2(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.read() == "World,\nthis is " - assert sub_text.read() == "" - assert sub_text.read() == "" - assert sub_text.read() == "" - -def test_read_when_ended(base_textio): - assert base_textio.read() - assert not base_textio.read() - assert not base_textio.read() - - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.read() == "World,\nthis is " - assert sub_text.read() == "" - assert sub_text.read() == "" - assert sub_text.read() == "" - -def test_read_at_end(base_textio): - initial_content = base_textio.read() - assert initial_content - end = base_textio.tell() - base_textio.seek(0) - - sub_text = SubTextIO(base_textio, start=6, end=end) - assert sub_text.read() == initial_content[6:] - -def test_readline_1(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=15) - assert sub_text.readline() == "World,\n" - assert sub_text.readline() == "th" - assert sub_text.readline() == "" # Ensure end of the segment is reached - assert sub_text.readline() == "" # Ensure end of the segment is reached - assert sub_text.readline() == "" # Ensure end of the segment is reached - -def test_readline_2(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.readline() == "World,\n" - assert sub_text.readline() == "this is " - assert sub_text.readline() == "" # Ensure end of the segment is reached - assert sub_text.readline() == "" # Ensure end of the segment is reached - assert sub_text.readline() == "" # Ensure end of the segment is reached - -def test_readlines(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.readlines() == ["World,\n", "this is "] - -def test_write(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.write("NewContent") - sub_text.seek(0) - assert sub_text.read() == "NewContents is " - - sub_text.flush() - base_textio.seek(0) - assert base_textio.read() == "Hello NewContents is a\ntest\n" - -def test_write_past_end(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - longstring = 'NewContent' * 20 - sub_text.write(longstring) - sub_text.seek(0) - assert sub_text.read() == longstring - - sub_text.flush() - base_textio.seek(0) - assert base_textio.read() == f"Hello {longstring}a\ntest\n" - -def test_seek_values(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - - sub_text.seek(-3, whence=os.SEEK_END) - assert sub_text.readline() == "is " - - sub_text.seek(+3, whence=os.SEEK_SET) - assert sub_text.readline() == "ld,\n" - - sub_text.seek(+3, whence=os.SEEK_CUR) - assert sub_text.readline() == "s is " - -def test_writelines(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.writelines(["Line1\n", "Line2\n"]) - sub_text.seek(0) - assert sub_text.read() == "Line1\nLine2\nis " - -def test_truncate(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.truncate() - sub_text.write("CleanSlate") - sub_text.seek(0) - assert sub_text.read() == "CleanSlate" - -def test_truncate_with_arg(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.truncate(4) - assert sub_text.read() == "Worl" - sub_text.flush() - - base_textio.seek(0) - assert base_textio.read() == """\ -Hello Worla -test - is a -test -""" - -def test_truncate_past_end(base_textio): - initial_content = base_textio.read() - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.truncate(999) - assert sub_text.buffer_length == 15 == 21 - 6 - - base_textio.seek(0) - assert base_textio.read() == initial_content - -def test_seek_tell(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.seek(5) - assert sub_text.tell() == 5 - assert sub_text.read() == ",\nthis is " - -def test_flush(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.write("Overwritten") - sub_text.flush() - base_textio.seek(0) - assert base_textio.read() == "Hello Overwritten is a\ntest\n" - -def test_iter(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - lines = list(sub_text) - assert lines == ["World,\n", "this is "] - -def test_context_manager(base_textio): - with SubTextIO(base_textio, start=6, end=21) as sub_text: - sub_text.write("ContextWrite") - sub_text.seek(0) - assert sub_text.read() == "ContextWriteis " - base_textio.seek(0) - assert base_textio.read() == "Hello ContextWriteis a\ntest\n" - -# Edge case for empty subsections -def test_empty_section(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=6) - assert sub_text.read() == "" - assert sub_text.readline() == "" - assert sub_text.readlines() == [] - assert sub_text.tell() == 0 - -# Test for large buffer -def test_large_buffer(large_textio): - sub_text = SubTextIO(large_textio, start=10000, end=20000) - sub_text.seek(9999) - assert sub_text.tell() == 9999 - sub_text.seek(0) - assert sub_text.read(5) == "01234" - sub_text.seek(5, os.SEEK_CUR) - sub_text.seek(1, os.SEEK_CUR) - sub_text.seek(5000, os.SEEK_CUR) - assert sub_text.readline() == "Python is Great\n" - -# Test long text write -def test_long_write(base_textio): - long_text = "x" * 1000 - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.write(long_text) - sub_text.seek(0) - assert sub_text.read(len(long_text)) == long_text - -def test_closed(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.closed is False - sub_text.close() - assert sub_text.closed is True - -def test_closed_after_context(base_textio): - with SubTextIO(base_textio, start=6, end=21) as sub_text: - assert sub_text.closed is False - assert sub_text.closed is True - -# Test seek and read after close -def test_operations_after_close(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.close() - with pytest.raises(OpOnClosedError): - sub_text.read() - with pytest.raises(OpOnClosedError): - sub_text.readline() - with pytest.raises(OpOnClosedError): - sub_text.readlines() - with pytest.raises(OpOnClosedError): - sub_text.truncate() - with pytest.raises(OpOnClosedError): - sub_text.write("Test") - with pytest.raises(OpOnClosedError): - sub_text.seek(os.SEEK_CUR) - with pytest.raises(OpOnClosedError): - sub_text.tell() - -# Test unexpected `whence` value in `seek` -def test_invalid_seek_whence(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - with pytest.raises(InvalidWhenceError): - sub_text.seek(0, whence=3) - -# Test `flush` without any open/close -def test_flush_without_open_close(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.write("FlushTest") - sub_text.flush() - base_textio.seek(0) - assert base_textio.read() == "Hello FlushTestis is a\ntest\n" - -# Test read after delete -def test_read_after_truncate(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.truncate() - assert sub_text.read() == "" - -# Test write after delete -def test_write_truncate(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.truncate() - assert sub_text.read() == "" - sub_text.write("AfterDelete") - sub_text.seek(0) - assert sub_text.read() == "AfterDelete" - -# Test multiple writes without flushing -def test_multiple_writes(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.write("First") - sub_text.write("Second") - sub_text.seek(0) - assert sub_text.read() == "FirstSecond is " - -# Test read and write interleaving -def test_read_write_interleaving(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.write("Interleave") - sub_text.seek(0) - assert sub_text.read(5) == "Inter" - sub_text.write("Change") - sub_text.seek(0) - assert sub_text.read() == "InterChange is " - -# Test context management with exception -def test_context_manager_with_exception(base_textio): - try: - with SubTextIO(base_textio, start=6, end=21) as sub_text: - sub_text.write("ExceptionTest") - raise IndexError("Testing Exception") - except IndexError: - base_textio.seek(0) - assert 'Hello ExceptionTests a\ntest\n' == base_textio.read() - -# Test `seek` past the end of buffer -def test_seek_past_end(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=15) - sub_text.seek(20) - assert sub_text.tell() == 9 - -# Test `read` with exact buffer end condition -def test_read_exact_buffer_end(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=15) - assert sub_text.read(size=9) == "World,\nth" - -# Test `read` after flushing and re-seeking -def test_read_after_flush_and_seek(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.write("AfterFlush") - sub_text.flush() - sub_text.seek(0) - assert sub_text.read() == "AfterFlushs is " - -# Test `close` without `flush` -def test_close_without_flush(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - sub_text.write("CloseWithoutFlush") - sub_text.close() - base_textio.seek(0) - assert base_textio.read() == "Hello CloseWithoutFlusha\ntest\n" - -def test_iterate(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - actual_lines = iter(["World,\n", "this is "]) - for line in sub_text: - expected_line = next(actual_lines) - assert line == expected_line - -def test_readline_with_limit(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.readline(limit=3) == "Wor" - assert sub_text.readline(limit=4) == "ld,\n" - assert sub_text.readline(limit=4) == "this" - assert sub_text.readline(limit=4) == " is " - assert sub_text.readline(limit=4) == "" - assert sub_text.readline(limit=4) == "" - -def test_readline_with_huge_limit(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.readline(limit=100) == "World,\n" - assert sub_text.readline(limit=100) == "this is " - assert sub_text.readline(limit=100) == "" - -def test_readline_with_zero_limit(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.readline(limit=0) == "" - -def test_readlines_with_hint(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.readlines(hint=3) == ["World,\n"] - assert sub_text.readlines(hint=3) == ["this is "] - assert sub_text.readlines(hint=3) == [] - -def test_readlines_with_huge_hint(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.readlines(hint=100) == ["World,\n", "this is "] - assert sub_text.readlines(hint=100) == [] - -def test_readlines_with_zero_hint(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.readlines(hint=0) == ["World,\n"] - -def test_encoding(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.encoding == base_textio.encoding - -def test_errors(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.errors is None - -def test_line_buffering(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert not sub_text.line_buffering - -def test_newlines(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.newlines is None - -def test_isatty(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert not sub_text.isatty() - -def test_fileno(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - with pytest.raises(io.UnsupportedOperation): - sub_text.fileno() - -def test_readable(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.readable() - -def test_writable(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.writable() - -def test_seekable(base_textio): - sub_text = SubTextIO(base_textio, start=6, end=21) - assert sub_text.seekable() - -@pytest.mark.parametrize("mode", ["r", "w", "a", "x", "r+", "w+", "a+", "x+", "rt", "wt", "at", "xt", "r+t", "w+t", "a+t", "x+t"]) -def test_mode(mode): - import tempfile - - with tempfile.NamedTemporaryFile(mode=mode) as tmp: - if tmp.writable(): - tmp.truncate(30) - - if tmp.readable() and tmp.writable(): - sub_text = SubTextIO(tmp, start=0, end=21) - assert sub_text.mode == mode - - else: - sub_text = SubTextIO(tmp, start=0, end=0) - assert sub_text.mode == mode - -def test_invalid_range(base_textio): - with pytest.raises(InvalidSubtextCoordinates): - SubTextIO(base_textio, start=15, end=10) - -def test_invalid_range_past_initial(base_textio): - with pytest.raises(EndsBeyondBaseContent): - SubTextIO(base_textio, start=5, end=40) - -def test_no_readable_requirement(): - import tempfile - - initial_content = "start of file | end of file" - - with tempfile.NamedTemporaryFile(mode="w") as writer: - writer.write(initial_content) - writer.flush() - - with open(writer.name, "r") as reader: - assert reader.read() == initial_content - - assert not writer.readable() - current_pos = writer.tell() - with SubTextIO(writer, start=current_pos, end=current_pos) as sub_text: - sub_text.write(" | appendix") - - assert not writer.readable() - with open(writer.name, "r") as reader: - assert reader.read() == initial_content + " | appendix" - -def test_not_seekable(): - class NonSeekable: - @property - def closed(self): - return False - - def seekable(self): - return False - - file = NonSeekable() - - with pytest.raises(BaseMustBeSeekable): - SubTextIO(file, start=0, end=0) - -def test_not_readable(): - import tempfile - - with tempfile.NamedTemporaryFile(mode="w") as file: - with pytest.raises(BaseMustBeReadable): - SubTextIO(file, start=0, end=10) - -def test_not_writable(tmp_path): - path = tmp_path / "example.csv" - - with path.open("w") as writer: - writer.write("hello") - - with path.open("r") as file: - with SubTextIO(file, start=2, end=4) as sub_text: - assert sub_text.read() == "ll" - assert sub_text.read() == "" - assert sub_text.read() == "" - sub_text.flush() # should be a noop. - sub_text.flush() # should be a noop. From 2390a8860e090984900c8d15a8da359af10f2bb2 Mon Sep 17 00:00:00 2001 From: Vitaliy Mysak Date: Wed, 18 Feb 2026 23:06:51 +0000 Subject: [PATCH 2/4] Add tests for SubBinaryIO --- tests/test_subbinaryio.py | 864 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 864 insertions(+) create mode 100644 tests/test_subbinaryio.py diff --git a/tests/test_subbinaryio.py b/tests/test_subbinaryio.py new file mode 100644 index 0000000..68447e8 --- /dev/null +++ b/tests/test_subbinaryio.py @@ -0,0 +1,864 @@ +""" +Comprehensive tests for SubBinaryIO. + +SubBinaryIO is a seekable read/write window into a contiguous byte range +[start, end) of a base BinaryIO object. Tests are written to verify the +*specification*, not to assume the implementation is correct. +""" + +from __future__ import annotations + +import io +import os + +import pytest + +from multicsv.subbinaryio import SubBinaryIO + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make(data: bytes, start: int, end: int) -> tuple[io.BytesIO, SubBinaryIO]: + """Return a (base, sub) pair. *base* contains *data*; *sub* covers [start, end).""" + base = io.BytesIO(data) + return base, SubBinaryIO(base, start, end) + + +# =========================================================================== +# Construction +# =========================================================================== + + +class TestConstruction: + def test_initial_position_is_zero(self) -> None: + _, sub = make(b"hello", 0, 5) + assert sub.tell() == 0 + + def test_does_not_seek_base_on_construction(self) -> None: + """Base position must not be disturbed during __init__.""" + base = io.BytesIO(b"hello world") + base.seek(7) + sub = SubBinaryIO(base, 0, 5) + # SubBinaryIO was created but base position should be untouched + assert base.tell() == 7 + # sub is usable and covers the correct region + assert sub.read() == b"hello" + + def test_empty_region(self) -> None: + _, sub = make(b"hello", 3, 3) + assert sub.tell() == 0 + assert sub.read() == b"" + assert sub.readline() == b"" + + def test_region_at_start(self) -> None: + _, sub = make(b"abcdef", 0, 3) + assert sub.read() == b"abc" + + def test_region_at_end(self) -> None: + _, sub = make(b"abcdef", 3, 6) + assert sub.read() == b"def" + + def test_whole_file_region(self) -> None: + _, sub = make(b"abcdef", 0, 6) + assert sub.read() == b"abcdef" + + def test_single_byte_region(self) -> None: + _, sub = make(b"abcdef", 2, 3) + assert sub.read() == b"c" + + +# =========================================================================== +# Protocol / capability attributes +# =========================================================================== + + +class TestCapabilities: + def test_readable(self) -> None: + _, sub = make(b"x", 0, 1) + assert sub.readable() is True + + def test_writable(self) -> None: + _, sub = make(b"x", 0, 1) + assert sub.writable() is True + + def test_seekable(self) -> None: + _, sub = make(b"x", 0, 1) + assert sub.seekable() is True + + def test_name_property(self) -> None: + _, sub = make(b"x", 0, 1) + assert isinstance(sub.name, str) + + def test_mode_property(self) -> None: + _, sub = make(b"x", 0, 1) + assert isinstance(sub.mode, str) + + def test_is_buffered_io_base(self) -> None: + _, sub = make(b"x", 0, 1) + assert isinstance(sub, io.BufferedIOBase) + + def test_is_io_base(self) -> None: + _, sub = make(b"x", 0, 1) + assert isinstance(sub, io.IOBase) + + def test_not_closed_initially(self) -> None: + _, sub = make(b"x", 0, 1) + assert not sub.closed + + +# =========================================================================== +# read() +# =========================================================================== + + +class TestRead: + def test_read_all_negative_one(self) -> None: + _, sub = make(b"abcdef", 1, 5) + assert sub.read(-1) == b"bcde" + + def test_read_all_none(self) -> None: + _, sub = make(b"abcdef", 1, 5) + assert sub.read(None) == b"bcde" + + def test_read_all_no_arg(self) -> None: + _, sub = make(b"abcdef", 1, 5) + assert sub.read() == b"bcde" + + def test_read_zero(self) -> None: + _, sub = make(b"abcdef", 0, 6) + assert sub.read(0) == b"" + + def test_read_partial(self) -> None: + _, sub = make(b"abcdef", 0, 6) + assert sub.read(3) == b"abc" + + def test_read_advances_position(self) -> None: + _, sub = make(b"abcdef", 0, 6) + sub.read(2) + assert sub.tell() == 2 + + def test_read_sequential(self) -> None: + _, sub = make(b"abcdef", 0, 6) + assert sub.read(2) == b"ab" + assert sub.read(2) == b"cd" + assert sub.read(2) == b"ef" + + def test_read_at_eof_returns_empty(self) -> None: + _, sub = make(b"abc", 0, 3) + sub.read() + assert sub.read() == b"" + assert sub.read(5) == b"" + + def test_read_does_not_exceed_region_end(self) -> None: + """Reading more bytes than remaining must stop at the region boundary.""" + _, sub = make(b"BEFORE_content_AFTER", 7, 14) # b"content" + data = sub.read(9999) + assert data == b"content" + + def test_read_clamped_to_remaining(self) -> None: + _, sub = make(b"abcdef", 0, 4) + sub.read(2) # position = 2 + data = sub.read(100) # only 2 bytes remain + assert data == b"cd" + + def test_read_does_not_read_past_region_into_base(self) -> None: + base = io.BytesIO(b"AAABBBCCC") + sub = SubBinaryIO(base, 3, 6) # b"BBB" + assert sub.read() == b"BBB" + + def test_read_position_correct_after_partial_read(self) -> None: + _, sub = make(b"0123456789", 2, 8) # b"234567" + sub.read(3) + assert sub.tell() == 3 + assert sub.read() == b"567" + assert sub.tell() == 6 + + def test_base_position_is_modified_after_read(self) -> None: + """SubBinaryIO seeks base_io before each operation; after a read base is left + at abs_pos + bytes_read (not necessarily restored).""" + base, sub = make(b"abcdef", 2, 5) + sub.read(2) + # base is left at 4 (2 + 2) after reading 2 bytes from position 2 + assert base.tell() == 4 + + def test_read_empty_region(self) -> None: + _, sub = make(b"abc", 2, 2) + assert sub.read() == b"" + assert sub.read(0) == b"" + assert sub.read(5) == b"" + + def test_read_returns_bytes_not_bytearray(self) -> None: + _, sub = make(b"abc", 0, 3) + result = sub.read() + assert type(result) is bytes + + +# =========================================================================== +# read1() +# =========================================================================== + + +class TestRead1: + def test_read1_returns_bytes(self) -> None: + _, sub = make(b"abc", 0, 3) + assert sub.read1() == b"abc" + + def test_read1_partial(self) -> None: + _, sub = make(b"abcdef", 0, 6) + assert sub.read1(3) == b"abc" + + def test_read1_advances_position(self) -> None: + _, sub = make(b"abcdef", 0, 6) + sub.read1(2) + assert sub.tell() == 2 + + def test_read1_at_eof(self) -> None: + _, sub = make(b"abc", 0, 3) + sub.read() + assert sub.read1() == b"" + + def test_read1_none(self) -> None: + _, sub = make(b"abc", 0, 3) + assert sub.read1(None) == b"abc" + + +# =========================================================================== +# readinto() +# =========================================================================== + + +class TestReadInto: + def test_readinto_bytearray(self) -> None: + _, sub = make(b"abcdef", 0, 6) + buf = bytearray(6) + n = sub.readinto(buf) + assert n == 6 + assert buf == b"abcdef" + + def test_readinto_memoryview(self) -> None: + _, sub = make(b"abcdef", 0, 6) + buf = bytearray(6) + n = sub.readinto(memoryview(buf)) + assert n == 6 + assert buf == b"abcdef" + + def test_readinto_partial_region(self) -> None: + _, sub = make(b"XYZabcXYZ", 3, 6) # b"abc" + buf = bytearray(10) + n = sub.readinto(buf) + assert n == 3 + assert buf[:3] == b"abc" + + def test_readinto_buffer_larger_than_remaining(self) -> None: + _, sub = make(b"ab", 0, 2) + buf = bytearray(100) + n = sub.readinto(buf) + assert n == 2 + assert buf[:2] == b"ab" + + def test_readinto_empty_buffer(self) -> None: + _, sub = make(b"abc", 0, 3) + buf = bytearray(0) + n = sub.readinto(buf) + assert n == 0 + + def test_readinto_at_eof(self) -> None: + _, sub = make(b"abc", 0, 3) + sub.read() + buf = bytearray(10) + n = sub.readinto(buf) + assert n == 0 + + def test_readinto_advances_position(self) -> None: + _, sub = make(b"abcdef", 0, 6) + buf = bytearray(3) + sub.readinto(buf) + assert sub.tell() == 3 + + def test_readinto_sequential(self) -> None: + _, sub = make(b"abcdef", 0, 6) + buf = bytearray(2) + sub.readinto(buf) + assert buf == b"ab" + sub.readinto(buf) + assert buf == b"cd" + + def test_readinto_does_not_modify_bytes_beyond_n(self) -> None: + """Bytes in buf at positions >= n must be unchanged.""" + _, sub = make(b"XY", 0, 2) + buf = bytearray(b"\xff" * 5) + n = sub.readinto(buf) + assert n == 2 + assert buf[2:] == b"\xff\xff\xff" + + +# =========================================================================== +# readline() +# =========================================================================== + + +class TestReadline: + def test_readline_stops_at_newline(self) -> None: + _, sub = make(b"line1\nline2\n", 0, 12) + assert sub.readline() == b"line1\n" + + def test_readline_includes_newline(self) -> None: + _, sub = make(b"abc\ndef", 0, 7) + assert sub.readline() == b"abc\n" + + def test_readline_sequential(self) -> None: + _, sub = make(b"a\nb\nc\n", 0, 6) + assert sub.readline() == b"a\n" + assert sub.readline() == b"b\n" + assert sub.readline() == b"c\n" + assert sub.readline() == b"" + + def test_readline_no_trailing_newline(self) -> None: + _, sub = make(b"abc", 0, 3) + assert sub.readline() == b"abc" + + def test_readline_at_eof_returns_empty(self) -> None: + _, sub = make(b"abc\n", 0, 4) + sub.readline() + assert sub.readline() == b"" + + def test_readline_empty_region(self) -> None: + _, sub = make(b"abc", 1, 1) + assert sub.readline() == b"" + + def test_readline_newline_only(self) -> None: + _, sub = make(b"\n", 0, 1) + assert sub.readline() == b"\n" + + def test_readline_newline_at_first_position(self) -> None: + _, sub = make(b"\nabc", 0, 4) + assert sub.readline() == b"\n" + assert sub.readline() == b"abc" + + def test_readline_size_zero(self) -> None: + _, sub = make(b"abcdef", 0, 6) + assert sub.readline(0) == b"" + assert sub.tell() == 0 + + def test_readline_size_limits_result(self) -> None: + _, sub = make(b"abcde\nfgh", 0, 9) + result = sub.readline(3) + assert result == b"abc" + assert sub.tell() == 3 + + def test_readline_size_spans_newline(self) -> None: + _, sub = make(b"abc\nde", 0, 6) + result = sub.readline(10) + assert result == b"abc\n" + + def test_readline_size_none(self) -> None: + _, sub = make(b"abc\ndef", 0, 7) + assert sub.readline(None) == b"abc\n" + + def test_readline_region_ends_before_newline_in_base(self) -> None: + """Region [0,3] ends at base index 3; base has \n at index 5. + readline must NOT read past the region end.""" + _, sub = make(b"abcde\nfgh", 0, 5) # b"abcde", no \n inside + assert sub.readline() == b"abcde" + + def test_readline_region_ends_at_newline(self) -> None: + """Region ends exactly on a newline character.""" + _, sub = make(b"abc\ndef", 0, 4) # b"abc\n" + assert sub.readline() == b"abc\n" + + def test_readline_advances_position(self) -> None: + _, sub = make(b"line1\nline2\n", 0, 12) + sub.readline() + assert sub.tell() == 6 + + def test_readline_offset_region(self) -> None: + """Region starting in the middle of base.""" + _, sub = make(b"XYZline1\nline2\nXYZ", 3, 15) # b"line1\nline2\n" + assert sub.readline() == b"line1\n" + assert sub.readline() == b"line2\n" + assert sub.readline() == b"" + + +# =========================================================================== +# write() +# =========================================================================== + + +class TestWrite: + def test_write_bytes(self) -> None: + base, sub = make(b"AAAA", 0, 4) + n = sub.write(b"BB") + assert n == 2 + base.seek(0) + assert base.read() == b"BBAA" + + def test_write_bytearray(self) -> None: + base, sub = make(b"AAAA", 0, 4) + sub.write(bytearray(b"CC")) + base.seek(0) + assert base.read()[:2] == b"CC" + + def test_write_memoryview(self) -> None: + base, sub = make(b"AAAA", 0, 4) + sub.write(memoryview(b"DD")) + base.seek(0) + assert base.read()[:2] == b"DD" + + def test_write_advances_position(self) -> None: + _, sub = make(b"AAAA", 0, 4) + sub.write(b"XY") + assert sub.tell() == 2 + + def test_write_at_offset_region(self) -> None: + base, sub = make(b"AAABBBCCC", 3, 6) # covers BBB + sub.write(b"XY") + base.seek(0) + assert base.read() == b"AAAXYB" + b"CCC" + + def test_write_does_not_exceed_region_end(self) -> None: + """Writing past the region end must not touch bytes outside [start, end).""" + base, sub = make(b"AAABBBCCC", 3, 6) # covers BBB (3 bytes) + n = sub.write(b"XXXX") # 4 bytes, but only 3 fit + assert n == 3 + base.seek(0) + assert base.read() == b"AAAXXXCCC" + + def test_write_at_end_of_region_returns_zero(self) -> None: + _, sub = make(b"AAAA", 0, 4) + sub.seek(0, os.SEEK_END) + n = sub.write(b"XY") + assert n == 0 + + def test_write_visible_through_base(self) -> None: + base, sub = make(b"HELLO", 0, 5) + sub.write(b"WORLD") + base.seek(0) + assert base.read() == b"WORLD" + + def test_write_multiple_times(self) -> None: + base, sub = make(b"AAAAAAAA", 2, 7) # 5-byte region + sub.write(b"12") # writes at [2,4) + sub.write(b"345") # writes at [4,7) + base.seek(0) + assert base.read() == b"AA12345A" + + def test_write_does_not_affect_bytes_before_start(self) -> None: + base, sub = make(b"XYZAAA", 3, 6) + sub.write(b"BBB") + base.seek(0) + assert base.read()[:3] == b"XYZ" + + def test_write_does_not_affect_bytes_after_end(self) -> None: + base, sub = make(b"AAAXYZ", 0, 3) + sub.write(b"BBB") + base.seek(0) + assert base.read()[3:] == b"XYZ" + + def test_write_empty_bytes(self) -> None: + _, sub = make(b"AAAA", 0, 4) + n = sub.write(b"") + assert n == 0 + assert sub.tell() == 0 + + +# =========================================================================== +# truncate() +# =========================================================================== + + +class TestTruncate: + def test_truncate_explicit_zero(self) -> None: + _, sub = make(b"abcdef", 0, 6) + sub.truncate(0) + assert sub.read() == b"" + + def test_truncate_explicit_size(self) -> None: + _, sub = make(b"abcdef", 0, 6) + sub.truncate(3) + assert sub.read() == b"abc" + + def test_truncate_none_at_position(self) -> None: + """truncate(None) truncates at current position. Position is NOT reset.""" + _, sub = make(b"abcdef", 0, 6) + sub.seek(2) + sub.truncate() # logical end is now 2; position stays at 2 + sub.seek(0) # must seek explicitly to read from beginning + assert sub.read() == b"ab" + + def test_truncate_clamps_position(self) -> None: + """After truncate, position must not exceed new size.""" + _, sub = make(b"abcdef", 0, 6) + sub.seek(5) + sub.truncate(3) + assert sub.tell() <= 3 + + def test_truncate_returns_new_size(self) -> None: + _, sub = make(b"abcdef", 0, 6) + result = sub.truncate(4) + assert result == 4 + + def test_truncate_does_not_modify_base_content(self) -> None: + """Truncate adjusts logical end but must not write to base_io.""" + base, sub = make(b"abcdef", 0, 6) + sub.truncate(3) + # base still has all 6 bytes + base.seek(0) + assert base.read() == b"abcdef" + + def test_truncate_expand(self) -> None: + """truncate to a size larger than current region expands _end.""" + base = io.BytesIO(b"abcdefghij") + sub = SubBinaryIO(base, 0, 3) # initially b"abc" + sub.truncate(7) # expand end to 7 + sub.seek(0) + assert sub.read() == b"abcdefg" + + def test_truncate_at_region_start(self) -> None: + _, sub = make(b"abcdef", 2, 5) + sub.truncate(0) + assert sub.read() == b"" + assert sub.tell() == 0 + + +# =========================================================================== +# seek() and tell() +# =========================================================================== + + +class TestSeekTell: + def test_seek_set_zero(self) -> None: + _, sub = make(b"abcdef", 0, 6) + sub.read(3) + sub.seek(0) + assert sub.tell() == 0 + assert sub.read() == b"abcdef" + + def test_seek_set_middle(self) -> None: + _, sub = make(b"abcdef", 0, 6) + sub.seek(3) + assert sub.tell() == 3 + assert sub.read() == b"def" + + def test_seek_set_to_end(self) -> None: + _, sub = make(b"abcdef", 0, 6) + sub.seek(6) + assert sub.tell() == 6 + assert sub.read() == b"" + + def test_seek_set_past_end_clamped(self) -> None: + _, sub = make(b"abcdef", 0, 6) + result = sub.seek(100) + assert result == 6 + assert sub.tell() == 6 + + def test_seek_set_negative_clamped_to_zero(self) -> None: + _, sub = make(b"abcdef", 0, 6) + sub.seek(3) + result = sub.seek(-5) + assert result == 0 + assert sub.tell() == 0 + + def test_seek_cur_forward(self) -> None: + _, sub = make(b"abcdef", 0, 6) + sub.seek(2) + sub.seek(2, os.SEEK_CUR) + assert sub.tell() == 4 + + def test_seek_cur_backward(self) -> None: + _, sub = make(b"abcdef", 0, 6) + sub.seek(4) + sub.seek(-2, os.SEEK_CUR) + assert sub.tell() == 2 + + def test_seek_cur_past_end_clamped(self) -> None: + _, sub = make(b"abcdef", 0, 6) + sub.seek(4) + result = sub.seek(100, os.SEEK_CUR) + assert result == 6 + + def test_seek_cur_before_start_clamped(self) -> None: + _, sub = make(b"abcdef", 0, 6) + sub.seek(2) + result = sub.seek(-100, os.SEEK_CUR) + assert result == 0 + + def test_seek_end_zero_goes_to_end(self) -> None: + _, sub = make(b"abcdef", 0, 6) + result = sub.seek(0, os.SEEK_END) + assert result == 6 + assert sub.tell() == 6 + + def test_seek_end_negative_offset(self) -> None: + _, sub = make(b"abcdef", 0, 6) + sub.seek(-2, os.SEEK_END) + assert sub.tell() == 4 + assert sub.read() == b"ef" + + def test_seek_end_positive_clamped(self) -> None: + _, sub = make(b"abcdef", 0, 6) + result = sub.seek(5, os.SEEK_END) + assert result == 6 + + def test_seek_end_before_start_clamped(self) -> None: + _, sub = make(b"abcdef", 0, 6) + result = sub.seek(-100, os.SEEK_END) + assert result == 0 + + def test_seek_invalid_whence_raises(self) -> None: + _, sub = make(b"abcdef", 0, 6) + with pytest.raises(OSError): + sub.seek(0, 99) + + def test_seek_returns_new_position(self) -> None: + _, sub = make(b"abcdef", 0, 6) + assert sub.seek(3) == 3 + assert sub.seek(1, os.SEEK_CUR) == 4 + assert sub.seek(-1, os.SEEK_END) == 5 + + def test_seek_set_on_offset_region(self) -> None: + """seek() positions are relative to the start of the region, not base_io.""" + _, sub = make(b"XXXYYY", 3, 6) # covers b"YYY" + sub.seek(1) + assert sub.tell() == 1 + assert sub.read() == b"YY" + + def test_tell_always_positive(self) -> None: + _, sub = make(b"abc", 0, 3) + for pos in [0, 1, 2, 3]: + sub.seek(pos) + assert sub.tell() >= 0 + + def test_seek_end_empty_region(self) -> None: + _, sub = make(b"abc", 2, 2) + result = sub.seek(0, os.SEEK_END) + assert result == 0 + + +# =========================================================================== +# flush() and close() +# =========================================================================== + + +class TestFlushClose: + def test_flush_does_not_close_base(self) -> None: + base, sub = make(b"abc", 0, 3) + sub.flush() + assert not base.closed + + def test_flush_when_closed_does_not_raise(self) -> None: + _, sub = make(b"abc", 0, 3) + sub.close() + sub.flush() # must not raise + + def test_close_marks_closed(self) -> None: + _, sub = make(b"abc", 0, 3) + sub.close() + assert sub.closed + + def test_close_does_not_close_base(self) -> None: + base, sub = make(b"abc", 0, 3) + sub.close() + assert not base.closed + + def test_close_is_idempotent(self) -> None: + _, sub = make(b"abc", 0, 3) + sub.close() + sub.close() # must not raise + assert sub.closed + + def test_read_after_close_raises(self) -> None: + _, sub = make(b"abc", 0, 3) + sub.close() + with pytest.raises(ValueError): + sub.read() + + def test_write_after_close_raises(self) -> None: + _, sub = make(b"abc", 0, 3) + sub.close() + with pytest.raises(ValueError): + sub.write(b"x") + + def test_seek_after_close_raises(self) -> None: + _, sub = make(b"abc", 0, 3) + sub.close() + with pytest.raises(ValueError): + sub.seek(0) + + def test_tell_after_close_raises(self) -> None: + _, sub = make(b"abc", 0, 3) + sub.close() + with pytest.raises(ValueError): + sub.tell() + + def test_readline_after_close_raises(self) -> None: + _, sub = make(b"abc\n", 0, 4) + sub.close() + with pytest.raises(ValueError): + sub.readline() + + def test_context_manager_closes(self) -> None: + base = io.BytesIO(b"abc") + with SubBinaryIO(base, 0, 3) as sub: + sub.read() + assert sub.closed + assert not base.closed + + +# =========================================================================== +# Multiple views on same base_io +# =========================================================================== + + +class TestMultipleViews: + def test_two_views_read_independently(self) -> None: + base = io.BytesIO(b"AAABBBCCC") + sub1 = SubBinaryIO(base, 0, 3) # b"AAA" + sub2 = SubBinaryIO(base, 3, 6) # b"BBB" + assert sub1.read() == b"AAA" + assert sub2.read() == b"BBB" + + def test_two_views_interleaved_reads(self) -> None: + base = io.BytesIO(b"abcdef") + sub1 = SubBinaryIO(base, 0, 3) # b"abc" + sub2 = SubBinaryIO(base, 3, 6) # b"def" + assert sub1.read(1) == b"a" + assert sub2.read(1) == b"d" + assert sub1.read(1) == b"b" + assert sub2.read(1) == b"e" + assert sub1.read(1) == b"c" + assert sub2.read(1) == b"f" + + def test_write_in_one_view_visible_in_another(self) -> None: + base = io.BytesIO(b"AAABBB") + sub1 = SubBinaryIO(base, 0, 3) + sub2 = SubBinaryIO(base, 0, 3) + sub1.write(b"XYZ") + sub2.seek(0) + assert sub2.read() == b"XYZ" + + def test_first_view_write_not_visible_in_non_overlapping_view(self) -> None: + base = io.BytesIO(b"AAABBB") + sub1 = SubBinaryIO(base, 0, 3) # b"AAA" + sub2 = SubBinaryIO(base, 3, 6) # b"BBB" + sub1.write(b"ZZZ") + sub2.seek(0) + assert sub2.read() == b"BBB" + + +# =========================================================================== +# TextIOWrapper integration +# =========================================================================== + + +class TestTextIOWrapper: + def test_textiowrapper_read(self) -> None: + base = io.BytesIO("hello\nworld\n".encode("utf-8")) + sub = SubBinaryIO(base, 0, len(base.getvalue())) + wrapper = io.TextIOWrapper(sub, encoding="utf-8") + assert wrapper.read() == "hello\nworld\n" + + def test_textiowrapper_region(self) -> None: + raw = "HEADER\ncontent line\n" + data = raw.encode("utf-8") + base = io.BytesIO(data) + sub = SubBinaryIO(base, 7, len(data)) # skip "HEADER\n" + wrapper = io.TextIOWrapper(sub, encoding="utf-8") + assert wrapper.read() == "content line\n" + + def test_textiowrapper_readline(self) -> None: + base = io.BytesIO(b"line1\nline2\nline3\n") + sub = SubBinaryIO(base, 0, 18) + wrapper = io.TextIOWrapper(sub, encoding="utf-8") + assert wrapper.readline() == "line1\n" + assert wrapper.readline() == "line2\n" + assert wrapper.readline() == "line3\n" + assert wrapper.readline() == "" + + def test_textiowrapper_seek_and_reread(self) -> None: + base = io.BytesIO(b"hello world") + sub = SubBinaryIO(base, 0, 11) + wrapper = io.TextIOWrapper(sub, encoding="utf-8") + wrapper.read() + wrapper.seek(0) + assert wrapper.read() == "hello world" + + +# =========================================================================== +# Edge cases and interactions +# =========================================================================== + + +class TestEdgeCases: + def test_read_write_seek_roundtrip(self) -> None: + base = io.BytesIO(b"AAAAAAA") + sub = SubBinaryIO(base, 1, 6) # 5-byte region + sub.write(b"12345") + sub.seek(0) + assert sub.read() == b"12345" + + def test_read_after_write_sees_written_data(self) -> None: + base = io.BytesIO(b"AAAAAAA") + sub = SubBinaryIO(base, 0, 7) + sub.write(b"XY") + sub.seek(0) + assert sub.read(2) == b"XY" + + def test_region_size_zero_all_ops_safe(self) -> None: + _, sub = make(b"abc", 2, 2) + assert sub.read() == b"" + assert sub.readline() == b"" + buf = bytearray(10) + assert sub.readinto(buf) == 0 + assert sub.write(b"x") == 0 + assert sub.tell() == 0 + sub.seek(0) + assert sub.tell() == 0 + + def test_write_then_readline(self) -> None: + base = io.BytesIO(b"\x00" * 11) # 11 bytes: 6 for "hello\n" + 5 for "world" + sub = SubBinaryIO(base, 0, 11) + sub.write(b"hello\n") + sub.write(b"world") + sub.seek(0) + assert sub.readline() == b"hello\n" + assert sub.readline() == b"world" + + def test_truncate_then_write(self) -> None: + _, sub = make(b"AAAAAAA", 0, 7) + sub.truncate(3) + sub.seek(0) + sub.write(b"XY") + sub.seek(0) + assert sub.read() == b"XY" + b"A" + + def test_region_at_byte_offset_boundaries(self) -> None: + """Verify byte offset math on a non-zero region start.""" + base = io.BytesIO(bytes(range(256))) + sub = SubBinaryIO(base, 10, 20) + assert sub.read() == bytes(range(10, 20)) + + def test_large_region_no_copy(self) -> None: + """SubBinaryIO must not hold a copy of the region data. + Verify by mutating base after construction and seeing the change on read.""" + base = io.BytesIO(b"AAAA") + sub = SubBinaryIO(base, 0, 4) + # Mutate base after SubBinaryIO is created + base.seek(0) + base.write(b"ZZZZ") + # sub must reflect the mutation (no stale copy held) + sub.seek(0) + assert sub.read() == b"ZZZZ" + + def test_iterate_lines(self) -> None: + base = io.BytesIO(b"a\nb\nc\n") + sub = SubBinaryIO(base, 0, 6) + lines = list(sub) + assert lines == [b"a\n", b"b\n", b"c\n"] + + def test_readlines(self) -> None: + base = io.BytesIO(b"x\ny\nz\n") + sub = SubBinaryIO(base, 0, 6) + assert sub.readlines() == [b"x\n", b"y\n", b"z\n"] From 6a1b3a03f87cd0ed693173eb0209cae77f920de1 Mon Sep 17 00:00:00 2001 From: Vitaliy Mysak Date: Wed, 18 Feb 2026 23:17:19 +0000 Subject: [PATCH 3/4] Small improvements to typing and handling --- src/multicsv/file.py | 87 ++++++++++++++++++--------------- src/multicsv/open.py | 31 ++++++++++-- src/multicsv/subbinaryio.py | 55 ++++++++++++++++----- tests/test_examples.py | 4 +- tests/test_open.py | 4 +- tests/test_subbinaryio.py | 96 ++++++++++++++++++++++++++++++------- 6 files changed, 200 insertions(+), 77 deletions(-) diff --git a/src/multicsv/file.py b/src/multicsv/file.py index 1f1aa26..ce48c68 100644 --- a/src/multicsv/file.py +++ b/src/multicsv/file.py @@ -304,49 +304,58 @@ def _initialize_sections_binary(self) -> None: def _initialize_sections_text(self) -> None: """Fallback for encodings whose newline is not the single byte 0x0a - (EBCDIC, UTF-16, UTF-32, …). Reads the whole file, decodes to text, - and stores each section as an io.StringIO.""" - self._file.seek(0) - raw = self._file.read() - if not raw: - return - - full_text = raw.decode(self._encoding, errors='replace') - current_section: Optional[str] = None - section_lines: List[str] = [] - - for line in full_text.splitlines(keepends=True): - stripped = line.strip() - if stripped: - row = next(csv.reader([stripped])) - if len(row) == 0: - break - - first = row[0].strip() - rest = row[1:] + (EBCDIC, UTF-16, UTF-32, …). Wraps the file in a TextIOWrapper to + iterate line by line without loading everything into memory at once, + then stores each section as an io.StringIO. - if first.startswith("[") and \ - first.endswith("]") and \ - all(not x for x in rest): + TextIOWrapper is detached (not closed) at the end so that *self._file* + remains open for subsequent operations. + """ + self._file.seek(0) + wrapper = io.TextIOWrapper( + self._file, + encoding=self._encoding, + errors='replace', + line_buffering=False, + ) + try: + current_section: Optional[str] = None + section_lines: List[str] = [] + + for line in wrapper: + stripped = line.strip() + if stripped: + row = next(csv.reader([stripped])) + if len(row) == 0: + break + + first = row[0].strip() + rest = row[1:] + + if first.startswith("[") and \ + first.endswith("]") and \ + all(not x for x in rest): + + if current_section is not None: + self._sections.append(MultiCSVSection( + name=current_section, + descriptor=io.StringIO( + "".join(section_lines)), + )) + current_section = first[1:-1] + section_lines = [] + continue - if current_section is not None: - self._sections.append(MultiCSVSection( - name=current_section, - descriptor=io.StringIO( - "".join(section_lines)), - )) - current_section = first[1:-1] - section_lines = [] - continue + if current_section is not None: + section_lines.append(line) if current_section is not None: - section_lines.append(line) - - if current_section is not None: - self._sections.append(MultiCSVSection( - name=current_section, - descriptor=io.StringIO("".join(section_lines)), - )) + self._sections.append(MultiCSVSection( + name=current_section, + descriptor=io.StringIO("".join(section_lines)), + )) + finally: + wrapper.detach() # release self._file without closing it def _initialize_sections(self) -> None: if not self._file.readable(): diff --git a/src/multicsv/open.py b/src/multicsv/open.py index b977672..ed08019 100644 --- a/src/multicsv/open.py +++ b/src/multicsv/open.py @@ -1,5 +1,5 @@ -from typing import Union, Literal, BinaryIO +from typing import Union, Literal, BinaryIO, cast from pathlib import Path from .file import MultiCSVFile @@ -7,13 +7,34 @@ OpenPath = Union[str, int, bytes, Path] +# Textual mode strings accepted by multicsv_open. +OpenMode = Literal["r", "w", "a", "x", "r+", "w+", "a+", "x+"] + + +def _to_binary_mode(mode: str) -> str: + """Translate a textual mode string to its binary equivalent. + + Strips any ``'t'`` flag and appends ``'b'`` if not already present. + Examples: ``"r"`` \u2192 ``"rb"``, ``"w+"`` \u2192 ``"w+b"``, + ``"a+t"`` \u2192 ``"a+b"``. + """ + m = mode.replace('t', '') + if 'b' not in m: + m += 'b' + return m + def multicsv_open(path: OpenPath, - mode: Literal["rb", "wb", "ab", "xb", - "r+b", "w+b", "a+b", "x+b"] = "rb", + mode: OpenMode = "r", encoding: str = 'utf-8') -> MultiCSVFile: - - file = open(path, mode=mode) + """Open a multi-CSV file at *path*. + + *mode* is a standard text-mode string (``"r"``, ``"w"``, ``"a"``, + ``"x"``, or any of those with ``"+"``). The file is always opened in + binary mode internally; the text encoding is handled by + :class:`~multicsv.file.MultiCSVFile` using *encoding*. + """ + file = cast(BinaryIO, open(path, mode=_to_binary_mode(mode))) return MultiCSVFile(file, own=True, encoding=encoding) diff --git a/src/multicsv/subbinaryio.py b/src/multicsv/subbinaryio.py index 0e687b8..7ba20e4 100644 --- a/src/multicsv/subbinaryio.py +++ b/src/multicsv/subbinaryio.py @@ -81,36 +81,65 @@ def __init__(self, base_io: BinaryIO, start: int, end: int) -> None: @property def name(self) -> str: - """Nominal name of this byte-range view. + """Nominal name that includes the underlying *base_io* name. - Returns a fixed string since the view has no file-system path of its - own. The property exists to satisfy the ``typing.IO[bytes]`` protocol - required by ``io.TextIOWrapper``. + Formats as ``" [start:end]>"`` when *base_io* + exposes a ``name`` attribute (real files, named streams), or + ``""`` otherwise (e.g. ``BytesIO``). + The property satisfies the ``typing.IO[bytes]`` protocol required by + ``io.TextIOWrapper``. """ - return "" + base_name: object = getattr(self._base_io, 'name', None) + if base_name is not None: + return f"" + return f"" @property def mode(self) -> str: - """Access mode string (always ``'rb+'``). + """Access mode string, derived from *base_io*. - SubBinaryIO supports both reading and writing so ``'rb+'`` is the - appropriate description. The property exists to satisfy the - ``typing.IO[bytes]`` protocol required by ``io.TextIOWrapper``. + If *base_io* exposes a ``mode`` attribute the value is normalised to + a binary mode string (``'t'`` replaced by ``'b'``; ``'b'`` added if + absent). Otherwise the mode is inferred from the capabilities + reported by *base_io*: + + * readable **and** writable → ``"rb+"`` + * writable only → ``"wb"`` + * readable only → ``"rb"`` + + The property satisfies the ``typing.IO[bytes]`` protocol required by + ``io.TextIOWrapper``. """ - return "rb+" + base_mode: str | None = getattr(self._base_io, 'mode', None) + if base_mode is not None: + m = base_mode.replace('t', '') + if 'b' not in m: + m += 'b' + return m + # Fallback: derive from runtime capabilities. + r = self._base_io.readable() + w = self._base_io.writable() + if r and w: + return "rb+" + if w: + return "wb" + return "rb" # ------------------------------------------------------------------ # # Capability flags # # ------------------------------------------------------------------ # def readable(self) -> bool: - return True + """Return whether *base_io* is readable.""" + return self._base_io.readable() def writable(self) -> bool: - return True + """Return whether *base_io* is writable.""" + return self._base_io.writable() def seekable(self) -> bool: - return True + """Return whether *base_io* is seekable.""" + return self._base_io.seekable() # ------------------------------------------------------------------ # # Internal helpers # diff --git a/tests/test_examples.py b/tests/test_examples.py index 3e8c1aa..8268f17 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -44,7 +44,7 @@ def test_read_csv(example_file_1: Path) -> None: def test_write_csv(example_file_1): - with multicsv.open(example_file_1, mode='w+b') as csv_file: + with multicsv.open(example_file_1, mode='w+') as csv_file: # Write the CSV content to the file csv_file['section1'] = io.StringIO("header1,header2,header3\nvalue1,value2,value3\n") csv_file['section2'] = io.StringIO("header4,header5,header6\nvalue4,value5,value6\n") @@ -61,7 +61,7 @@ def test_write_csv(example_file_1): def test_write_csv_easier(example_file_1): - with multicsv.open(example_file_1, mode='w+b') as csv_file: + with multicsv.open(example_file_1, mode='w+') as csv_file: # Write the CSV content to the file csv_file.section('section1').write("header1,header2,header3\nvalue1,value2,value3\n") csv_file.section('section2').write("header4,header5,header6\nvalue4,value5,value6\n") diff --git a/tests/test_open.py b/tests/test_open.py index 9a2eebe..202c2d0 100644 --- a/tests/test_open.py +++ b/tests/test_open.py @@ -51,7 +51,7 @@ def test_open_write(tmp_path): path = tmp_path / "file2.txt" # Writing sections using multicsv_open - with multicsv_open(path, "wb") as csv_file: + with multicsv_open(path, "w") as csv_file: csv_file["section1"] = io.StringIO("a,b,c\n1,2,3\n") csv_file["section2"] = io.StringIO("d,e,f\n4,5,6\n") # csv_file.flush() @@ -76,7 +76,7 @@ def test_open_append(tmp_path): writer.write(initial_content) # Appending new sections using multicsv_open - with multicsv_open(path, "a+b") as csv_file: + with multicsv_open(path, "a+") as csv_file: csv_file["section2"] = io.StringIO("d,e,f\n4,5,6\n") # Validate the appended content diff --git a/tests/test_subbinaryio.py b/tests/test_subbinaryio.py index 68447e8..6d5a8a0 100644 --- a/tests/test_subbinaryio.py +++ b/tests/test_subbinaryio.py @@ -10,6 +10,7 @@ import io import os +from pathlib import Path import pytest @@ -76,25 +77,88 @@ def test_single_byte_region(self) -> None: class TestCapabilities: - def test_readable(self) -> None: - _, sub = make(b"x", 0, 1) - assert sub.readable() is True - - def test_writable(self) -> None: - _, sub = make(b"x", 0, 1) - assert sub.writable() is True - - def test_seekable(self) -> None: - _, sub = make(b"x", 0, 1) - assert sub.seekable() is True - - def test_name_property(self) -> None: + def test_readable_delegates_to_base(self) -> None: + """readable() must reflect base_io capability, not always True.""" + base = io.BytesIO(b"x") + sub = SubBinaryIO(base, 0, 1) + assert sub.readable() == base.readable() + + def test_writable_delegates_to_base(self) -> None: + """writable() must reflect base_io capability.""" + base = io.BytesIO(b"x") + sub = SubBinaryIO(base, 0, 1) + assert sub.writable() == base.writable() + + def test_seekable_delegates_to_base(self) -> None: + """seekable() must reflect base_io capability.""" + base = io.BytesIO(b"x") + sub = SubBinaryIO(base, 0, 1) + assert sub.seekable() == base.seekable() + + def test_readable_false_for_write_only(self, tmp_path: Path) -> None: + """A write-only file-backed SubBinaryIO must report readable=False.""" + p = tmp_path / "wo.bin" + p.write_bytes(b"hello") + with open(p, "wb") as f: + sub = SubBinaryIO(f, 0, 0) + assert sub.readable() is False + + def test_writable_false_for_read_only(self, tmp_path: Path) -> None: + """A read-only file-backed SubBinaryIO must report writable=False.""" + p = tmp_path / "ro.bin" + p.write_bytes(b"hello") + with open(p, "rb") as f: + sub = SubBinaryIO(f, 0, 5) + assert sub.writable() is False + + def test_name_without_base_name(self) -> None: + """BytesIO has no name attribute; format must still return a string.""" _, sub = make(b"x", 0, 1) assert isinstance(sub.name, str) - - def test_mode_property(self) -> None: + # Must include the byte offsets + assert "0" in sub.name + assert "1" in sub.name + + def test_name_includes_base_name(self, tmp_path: Path) -> None: + """For a real file, name must reflect the underlying file path.""" + p = tmp_path / "named.bin" + p.write_bytes(b"hello") + with open(p, "r+b") as f: + sub = SubBinaryIO(f, 1, 4) + assert str(p) in sub.name + # Offsets also present + assert "1" in sub.name + assert "4" in sub.name + + def test_mode_inherits_from_base_bytesio(self) -> None: + """BytesIO has no mode; mode must be derived from capabilities.""" _, sub = make(b"x", 0, 1) - assert isinstance(sub.mode, str) + m = sub.mode + assert isinstance(m, str) + assert 'b' in m + + def test_mode_inherits_from_base_file(self, tmp_path: Path) -> None: + """Real file mode must propagate (with 'b' added if necessary).""" + p = tmp_path / "mode.bin" + p.write_bytes(b"hello") + with open(p, "r+b") as f: + sub = SubBinaryIO(f, 0, 5) + assert 'b' in sub.mode + assert 'r' in sub.mode + + def test_mode_read_only_file(self, tmp_path: Path) -> None: + p = tmp_path / "ro.bin" + p.write_bytes(b"hello") + with open(p, "rb") as f: + sub = SubBinaryIO(f, 0, 5) + assert sub.mode == "rb" + + def test_mode_write_only_file(self, tmp_path: Path) -> None: + p = tmp_path / "wo.bin" + p.write_bytes(b"hello") + with open(p, "wb") as f: + sub = SubBinaryIO(f, 0, 0) + assert sub.mode == "wb" def test_is_buffered_io_base(self) -> None: _, sub = make(b"x", 0, 1) From e2d5e4e17b3ebe7fa99290815dc9323ac940a82c Mon Sep 17 00:00:00 2001 From: Vitaliy Mysak Date: Wed, 18 Feb 2026 23:19:56 +0000 Subject: [PATCH 4/4] Improve typing in open --- src/multicsv/open.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/src/multicsv/open.py b/src/multicsv/open.py index ed08019..a0f0f46 100644 --- a/src/multicsv/open.py +++ b/src/multicsv/open.py @@ -1,5 +1,5 @@ -from typing import Union, Literal, BinaryIO, cast +from typing import Union, Literal, BinaryIO, NoReturn from pathlib import Path from .file import MultiCSVFile @@ -9,19 +9,36 @@ # Textual mode strings accepted by multicsv_open. OpenMode = Literal["r", "w", "a", "x", "r+", "w+", "a+", "x+"] +OpenBinaryMode = Literal["rb", "wb", "ab", "xb", "r+b", "w+b", "a+b", "x+b"] -def _to_binary_mode(mode: str) -> str: +def _to_binary_mode(mode: OpenMode) -> OpenBinaryMode: """Translate a textual mode string to its binary equivalent. Strips any ``'t'`` flag and appends ``'b'`` if not already present. Examples: ``"r"`` \u2192 ``"rb"``, ``"w+"`` \u2192 ``"w+b"``, ``"a+t"`` \u2192 ``"a+b"``. """ - m = mode.replace('t', '') - if 'b' not in m: - m += 'b' - return m + + if mode == "r": + return "rb" + elif mode == "w": + return "wb" + elif mode == "a": + return "ab" + elif mode == "x": + return "xb" + elif mode == "r+": + return "r+b" + elif mode == "w+": + return "w+b" + elif mode == "a+": + return "a+b" + elif mode == "x+": + return "x+b" + else: + _other: NoReturn = mode + raise ValueError(f"Invalid mode: {mode!r}") def multicsv_open(path: OpenPath, @@ -34,7 +51,7 @@ def multicsv_open(path: OpenPath, binary mode internally; the text encoding is handled by :class:`~multicsv.file.MultiCSVFile` using *encoding*. """ - file = cast(BinaryIO, open(path, mode=_to_binary_mode(mode))) + file = open(path, mode=_to_binary_mode(mode)) return MultiCSVFile(file, own=True, encoding=encoding)