Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions xrspatial/geotiff/_geotags.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,18 +278,40 @@ def _extract_transform(ifd: IFD) -> GeoTransform:
"""Extract affine transform from ModelTransformation, or
ModelTiepoint + ModelPixelScale tags."""

# Try ModelTransformationTag (4x4 matrix)
# Try ModelTransformationTag (4x4 row-major matrix, 16 doubles).
# Per the GeoTIFF spec this tag wins over ModelPixelScale + ModelTiepoint
# when present.
#
# x = M[0]*col + M[1]*row + M[2]*z + M[3]
# y = M[4]*col + M[5]*row + M[6]*z + M[7]
#
# GeoTransform only carries the axis-aligned case. For rotated, sheared,
# or z-coupled transforms we raise NotImplementedError instead of silently
# dropping the off-diagonal terms.
transform_tag = ifd.get_value(TAG_MODEL_TRANSFORMATION)
if transform_tag is not None:
if isinstance(transform_tag, tuple) and len(transform_tag) >= 12:
# 4x4 row-major matrix
# x = M[0]*col + M[1]*row + M[3]
# y = M[4]*col + M[5]*row + M[7]
m = transform_tag
# Off-diagonal terms (rotation/skew) and z-coupling. Use a small
# tolerance scaled to the diagonal to absorb floating-point noise.
scale = max(abs(m[0]), abs(m[5]), 1.0)
tol = 1e-12 * scale
rotation_terms = (m[1], m[4])
z_terms = (m[2], m[6]) if len(m) >= 8 else (0.0, 0.0)
if any(abs(t) > tol for t in rotation_terms + z_terms):
raise NotImplementedError(
"ModelTransformationTag (34264) contains rotation, "
"skew, or z-coupling terms "
f"(M[1]={m[1]!r}, M[4]={m[4]!r}, "
f"M[2]={m[2] if len(m) > 2 else 0.0!r}, "
f"M[6]={m[6] if len(m) > 6 else 0.0!r}). "
"Only axis-aligned affine transforms are supported."
)
return GeoTransform(
origin_x=transform_tag[3],
origin_y=transform_tag[7],
pixel_width=transform_tag[0],
pixel_height=transform_tag[5],
origin_x=m[3],
origin_y=m[7],
pixel_width=m[0],
pixel_height=m[5],
)

# Try ModelTiepoint + ModelPixelScale
Expand Down
15 changes: 15 additions & 0 deletions xrspatial/geotiff/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,21 @@ def _decode_strip_or_tile(data_slice, compression, width, height, samples,
chunk = decompress(data_slice, compression, expected,
width=width, height=height, samples=samples)

# Validate the decompressed byte count. A truncated deflate stream or a
# buggy compressor can produce fewer or more bytes than expected. Without
# this check the downstream reshape raises an opaque "cannot reshape array
# of size N into shape (h, w)" that hides which tile/strip broke. Edge
# tiles in a valid TIFF still decompress to the full tile_height x
# tile_width (the caller slices the top-left region), so this only fires
# on genuine corruption.
if chunk.size != expected:
raise ValueError(
f"Decompressed tile/strip size mismatch: expected {expected} "
f"bytes for a {width} x {height} x {samples} block "
f"(bps={bps}, compression={compression}), got {chunk.size}. "
f"The TIFF data is likely truncated or corrupt."
)

if pred in (2, 3) and not is_sub_byte:
if not chunk.flags.writeable:
chunk = chunk.copy()
Expand Down
146 changes: 146 additions & 0 deletions xrspatial/geotiff/tests/test_geotags.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,149 @@ def test_projected_crs_geokey(self):
geokeys = tags[TAG_GEO_KEY_DIRECTORY]
# Flatten and check that ProjectedCSType is present
assert 3072 in geokeys # GEOKEY_PROJECTED_CS_TYPE


def _build_tiff_with_transformation_tag(matrix_16: tuple) -> bytes:
"""Build a tiny single-strip TIFF carrying a 4x4 ModelTransformationTag.

No ModelPixelScale or ModelTiepoint -- the reader has to use the
transformation tag.
"""
import struct

bo = '<'
width, height = 2, 2
pixels = np.zeros((height, width), dtype=np.uint8)

tag_list = []

def add_short(tag, val):
tag_list.append((tag, 3, 1, struct.pack(f'{bo}H', val)))

def add_long(tag, val):
tag_list.append((tag, 4, 1, struct.pack(f'{bo}I', val)))

def add_doubles(tag, vals):
tag_list.append(
(tag, 12, len(vals), struct.pack(f'{bo}{len(vals)}d', *vals)))

add_short(256, width) # ImageWidth
add_short(257, height) # ImageLength
add_short(258, 8) # BitsPerSample
add_short(259, 1) # Compression: none
add_short(262, 1) # PhotometricInterpretation: BlackIsZero
add_short(277, 1) # SamplesPerPixel
add_short(278, height) # RowsPerStrip
add_long(273, 0) # StripOffsets (placeholder)
add_long(279, len(pixels.tobytes())) # StripByteCounts
add_short(339, 1) # SampleFormat
# ModelTransformationTag (34264): 16 doubles, row-major 4x4.
add_doubles(34264, list(matrix_16))

tag_list.sort(key=lambda t: t[0])

num_entries = len(tag_list)
ifd_start = 8
ifd_size = 2 + 12 * num_entries + 4

# Build overflow buffer for tags whose value exceeds 4 bytes.
overflow = bytearray()
overflow_offsets = {}
for tag, _typ, _count, raw in tag_list:
if len(raw) > 4:
overflow_offsets[tag] = ifd_start + ifd_size + len(overflow)
overflow.extend(raw)
if len(overflow) % 2:
overflow.append(0)

pixel_start = ifd_start + ifd_size + len(overflow)

# Patch StripOffsets
patched = []
for tag, typ, count, raw in tag_list:
if tag == 273:
patched.append((tag, typ, count, struct.pack(f'{bo}I', pixel_start)))
else:
patched.append((tag, typ, count, raw))
tag_list = patched

out = bytearray()
out.extend(b'II')
out.extend(struct.pack(f'{bo}H', 42))
out.extend(struct.pack(f'{bo}I', ifd_start))
out.extend(struct.pack(f'{bo}H', num_entries))
for tag, typ, count, raw in tag_list:
out.extend(struct.pack(f'{bo}HHI', tag, typ, count))
if len(raw) <= 4:
out.extend(raw.ljust(4, b'\x00'))
else:
out.extend(struct.pack(f'{bo}I', overflow_offsets[tag]))
out.extend(struct.pack(f'{bo}I', 0))
out.extend(overflow)
out.extend(pixels.tobytes())
return bytes(out)


class TestModelTransformationTag_1486:
"""Issue #1486: handle ModelTransformationTag (34264) explicitly.

The reader previously read the matrix but silently discarded any
rotation, skew, or z-coupling. The fix raises NotImplementedError
instead of returning a corrupted transform.
"""

def test_axis_aligned_extracts_correctly(self, tmp_path):
# M = [[sx, 0, 0, ox], [0, sy, 0, oy], [0, 0, 1, 0], [0, 0, 0, 1]]
# Note pixel_height is M[5] verbatim -- the writer encodes a negative
# value here so it stays negative in the read transform.
ox, oy, sx, sy = 500000.0, 4500000.0, 30.0, -30.0
matrix = (
sx, 0.0, 0.0, ox,
0.0, sy, 0.0, oy,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0,
)
data = _build_tiff_with_transformation_tag(matrix)
path = tmp_path / 'transform_axis_aligned_1486.tif'
path.write_bytes(data)

from xrspatial.geotiff._reader import read_to_array
_, geo_info = read_to_array(str(path))
assert geo_info.transform.origin_x == pytest.approx(ox)
assert geo_info.transform.origin_y == pytest.approx(oy)
assert geo_info.transform.pixel_width == pytest.approx(sx)
assert geo_info.transform.pixel_height == pytest.approx(sy)

def test_rotation_raises(self, tmp_path):
# Non-zero rotation in M[1] / M[4]
matrix = (
30.0, 5.0, 0.0, 500000.0,
5.0, -30.0, 0.0, 4500000.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0,
)
data = _build_tiff_with_transformation_tag(matrix)
path = tmp_path / 'transform_rotated_1486.tif'
path.write_bytes(data)

from xrspatial.geotiff._reader import read_to_array
with pytest.raises(NotImplementedError) as exc:
read_to_array(str(path))
assert 'ModelTransformationTag' in str(exc.value)
assert 'rotation' in str(exc.value).lower() or 'skew' in str(exc.value).lower()

def test_z_coupling_raises(self, tmp_path):
# Non-zero z-coupling in M[2] / M[6]
matrix = (
30.0, 0.0, 0.5, 500000.0,
0.0, -30.0, 0.5, 4500000.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0,
)
data = _build_tiff_with_transformation_tag(matrix)
path = tmp_path / 'transform_z_coupled_1486.tif'
path.write_bytes(data)

from xrspatial.geotiff._reader import read_to_array
with pytest.raises(NotImplementedError):
read_to_array(str(path))
68 changes: 68 additions & 0 deletions xrspatial/geotiff/tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,71 @@ def test_geo_info(self, tmp_path):
arr, geo_info = read_to_array(path)
assert geo_info.crs_epsg == 4326
assert geo_info.transform.origin_x == pytest.approx(-120.0)


class TestPartialTileValidation_1486:
"""Issue #1486: corrupt tile/strip data should raise a clear error.

Without validation a truncated deflate stream causes numpy.reshape to
raise an opaque "cannot reshape array of size N" with no hint of which
tile is at fault. These tests pin the new behaviour: a clear ValueError
naming the size mismatch.
"""

def _zero_out_last_tile(self, path):
"""Replace the last tile's compressed bytes with zeros so deflate
decodes a short stream."""
from xrspatial.geotiff._header import parse_all_ifds, parse_header
with open(path, 'rb') as f:
data = bytearray(f.read())
header = parse_header(bytes(data))
ifds = parse_all_ifds(bytes(data), header)
ifd = ifds[0]
if ifd.tile_offsets is not None:
offsets = ifd.tile_offsets
counts = ifd.tile_byte_counts
else:
offsets = ifd.strip_offsets
counts = ifd.strip_byte_counts
last_off = offsets[-1]
last_count = counts[-1]
# Zero deflate stream: header 0x78 0x9C followed by an empty stored
# block. zlib will return 0 bytes -- a clear undersized result.
zero_stream = b'\x78\x9c\x03\x00\x00\x00\x00\x01'
# Pad with zeros to original length so file layout stays intact.
padded = zero_stream + b'\x00' * max(0, last_count - len(zero_stream))
for i, b in enumerate(padded[:last_count]):
data[last_off + i] = b
with open(path, 'wb') as f:
f.write(bytes(data))

def test_truncated_tile_raises_clear_error(self, tmp_path):
from xrspatial.geotiff._writer import write

pixels = np.arange(256, dtype=np.float32).reshape(16, 16)
path = str(tmp_path / 'truncated_1486.tif')
write(pixels, path, compression='deflate', tiled=True, tile_size=8)

self._zero_out_last_tile(path)

with pytest.raises(ValueError) as exc:
read_to_array(path)
msg = str(exc.value)
assert 'size mismatch' in msg
assert 'expected' in msg
assert 'truncated or corrupt' in msg

def test_valid_edge_tile_still_works(self, tmp_path):
"""Edge tiles in a valid file decompress to full tile size; the
validation should not flag this as corrupt."""
from xrspatial.geotiff._writer import write

# 9 x 9 with tile_size=4 -> a 3x3 tile grid where the right and
# bottom tiles are partial. These exercise the legitimate
# "decompress full tile, slice top-left actual_h x actual_w" path.
pixels = np.arange(81, dtype=np.float32).reshape(9, 9)
path = str(tmp_path / 'edge_tiles_1486.tif')
write(pixels, path, compression='deflate', tiled=True, tile_size=4)

arr, _ = read_to_array(path)
np.testing.assert_array_equal(arr, pixels)
Loading