Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 69 additions & 14 deletions xrspatial/geotiff/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import threading
import urllib.request
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor

import numpy as np

Expand Down Expand Up @@ -236,6 +237,39 @@ def read_range(self, start: int, length: int) -> bytes:
with urllib.request.urlopen(req) as resp:
return resp.read()

def read_ranges(
self,
ranges: list[tuple[int, int]],
max_workers: int = 8,
) -> list[bytes]:
"""Fetch multiple ranges concurrently using a thread pool.

Each ``(start, length)`` pair is fetched with its own range request,
but requests run in parallel so total wall time is bounded by the
slowest worker rather than ``len(ranges) * RTT``.

Returns the bytes for each range in input order.
"""
if not ranges:
return []
if len(ranges) == 1:
start, length = ranges[0]
return [self.read_range(start, length)]

workers = min(max_workers, len(ranges))
results: list[bytes | None] = [None] * len(ranges)

with ThreadPoolExecutor(max_workers=workers) as ex:
future_to_idx = {
ex.submit(self.read_range, start, length): i
for i, (start, length) in enumerate(ranges)
}
for fut in future_to_idx:
idx = future_to_idx[fut]
results[idx] = fut.result()

return results # type: ignore[return-value]

def read_all(self) -> bytes:
if self._pool is not None:
resp = self._pool.request('GET', self._url)
Expand Down Expand Up @@ -690,6 +724,11 @@ def _read_cog_http(url: str, overview_level: int | None = None,
) -> tuple[np.ndarray, GeoInfo]:
"""Read a COG via HTTP range requests.

Tile fetches run concurrently through a small thread pool so that the
total wall time is bounded by the slowest tile request rather than
``num_tiles * RTT``. The pool size can be overridden with the
``XRSPATIAL_COG_HTTP_WORKERS`` environment variable (default 8).

Parameters
----------
url : str
Expand Down Expand Up @@ -774,31 +813,47 @@ def _read_cog_http(url: str, overview_level: int | None = None,
else:
result = np.empty((height, width), dtype=dtype)

# Pass 1: collect every tile's range and where it lands in the output.
# Empty tiles (byte_count == 0) and any tile_idx beyond the offsets
# array are skipped here so the fetch list stays exactly aligned with
# the placements list.
fetch_ranges: list[tuple[int, int]] = []
placements: list[tuple[int, int]] = [] # (tr, tc) per fetched tile
for tr in range(tiles_down):
for tc in range(tiles_across):
tile_idx = tr * tiles_across + tc
if tile_idx >= len(offsets):
continue

off = offsets[tile_idx]
bc = byte_counts[tile_idx]
if bc == 0:
continue
fetch_ranges.append((off, bc))
placements.append((tr, tc))

tile_data = source.read_range(off, bc)
tile_pixels = _decode_strip_or_tile(
tile_data, compression, tw, th, samples,
bps, bytes_per_sample, is_sub_byte, dtype, pred,
byte_order=header.byte_order)
# Pass 2: fetch all tile bytes in parallel. Worker pool size is tunable
# via XRSPATIAL_COG_HTTP_WORKERS so users on very slow links can dial
# it up without code changes.
try:
workers = max(1, int(_os_module.environ.get('XRSPATIAL_COG_HTTP_WORKERS', '8')))
except ValueError:
workers = 8
tile_bytes_list = source.read_ranges(fetch_ranges, max_workers=workers)

# Pass 3: decode each tile and place it.
for (tr, tc), tile_data in zip(placements, tile_bytes_list):
tile_pixels = _decode_strip_or_tile(
tile_data, compression, tw, th, samples,
bps, bytes_per_sample, is_sub_byte, dtype, pred,
byte_order=header.byte_order)

# Place tile
y0 = tr * th
x0 = tc * tw
y1 = min(y0 + th, height)
x1 = min(x0 + tw, width)
actual_h = y1 - y0
actual_w = x1 - x0
result[y0:y1, x0:x1] = tile_pixels[:actual_h, :actual_w]
y0 = tr * th
x0 = tc * tw
y1 = min(y0 + th, height)
x1 = min(x0 + tw, width)
actual_h = y1 - y0
actual_w = x1 - x0
result[y0:y1, x0:x1] = tile_pixels[:actual_h, :actual_w]

source.close()
return result, geo_info
Expand Down
180 changes: 180 additions & 0 deletions xrspatial/geotiff/tests/test_cog_http_concurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""Tests for concurrent tile fetching in _read_cog_http (issue #1480)."""
from __future__ import annotations

import http.server
import socketserver
import threading
import time

import numpy as np
import pytest

from xrspatial.geotiff._reader import (
_HTTPSource,
_read_cog_http,
read_to_array,
)
from xrspatial.geotiff._writer import write


# ---------------------------------------------------------------------------
# read_ranges: ordering and concurrency
# ---------------------------------------------------------------------------

class _FakeHTTPSource(_HTTPSource):
"""_HTTPSource that fakes read_range with a configurable sleep.

Tracks both total call count and the maximum observed in-flight
concurrency so tests can verify the threadpool dispatch directly
rather than relying on wall-clock timing (which is flaky on busy
CI runners).
"""

def __init__(self, per_request_sleep: float = 0.05):
# Skip super().__init__ -- we're not making real HTTP calls.
self._url = 'fake://test'
self._size = None
self._pool = None
self._per_request_sleep = per_request_sleep
self.call_count = 0
self.in_flight = 0
self.max_in_flight = 0
self._lock = threading.Lock()

def read_range(self, start: int, length: int) -> bytes:
with self._lock:
self.call_count += 1
self.in_flight += 1
if self.in_flight > self.max_in_flight:
self.max_in_flight = self.in_flight
try:
time.sleep(self._per_request_sleep)
return f'{start}:{length}'.encode('ascii')
finally:
with self._lock:
self.in_flight -= 1


def test_read_ranges_returns_results_in_input_order():
src = _FakeHTTPSource(per_request_sleep=0.0)
ranges = [(0, 10), (100, 5), (50, 20), (200, 7)]
out = src.read_ranges(ranges, max_workers=4)
assert len(out) == len(ranges)
for (start, length), data in zip(ranges, out):
assert data == f'{start}:{length}'.encode('ascii')


def test_read_ranges_empty_list():
src = _FakeHTTPSource(per_request_sleep=0.0)
assert src.read_ranges([]) == []


def test_read_ranges_single_request_skips_pool():
src = _FakeHTTPSource(per_request_sleep=0.0)
out = src.read_ranges([(42, 8)], max_workers=8)
assert out == [b'42:8']
assert src.call_count == 1


def test_read_ranges_dispatches_concurrently():
"""The threadpool should run multiple requests in flight at once.

Asserting on observed in-flight concurrency is robust to CI scheduler
jitter; a wall-clock assertion of the same effect is flaky on busy
runners (the previous version of this test was a 50 ms per-request
× 20-request setup that occasionally exceeded its 0.5 s budget by a
few ms on macOS).
"""
n = 20
workers = 8
src = _FakeHTTPSource(per_request_sleep=0.02)
ranges = [(i * 100, 10) for i in range(n)]

out = src.read_ranges(ranges, max_workers=workers)

assert src.call_count == n
assert len(out) == n
# Sequential dispatch would peak at 1 in flight. The pool should
# run several in parallel; require at least 2 (very loose) to keep
# the test robust on heavily oversubscribed CI runners.
assert src.max_in_flight >= 2, (
f'expected >=2 concurrent in-flight calls, '
f'got max_in_flight={src.max_in_flight}'
)


# ---------------------------------------------------------------------------
# _read_cog_http: correctness via local http.server
# ---------------------------------------------------------------------------

class _RangeHandler(http.server.BaseHTTPRequestHandler):
"""Serve a single in-memory bytes payload with HTTP Range support."""

payload: bytes = b''

def do_GET(self): # noqa: N802
rng = self.headers.get('Range')
if rng and rng.startswith('bytes='):
spec = rng[len('bytes='):]
# Single range only -- matches what _HTTPSource sends.
start_s, _, end_s = spec.partition('-')
start = int(start_s)
end = int(end_s) if end_s else len(self.payload) - 1
chunk = self.payload[start:end + 1]
self.send_response(206)
self.send_header('Content-Type', 'application/octet-stream')
self.send_header(
'Content-Range',
f'bytes {start}-{start + len(chunk) - 1}/{len(self.payload)}',
)
self.send_header('Content-Length', str(len(chunk)))
self.end_headers()
self.wfile.write(chunk)
return
self.send_response(200)
self.send_header('Content-Type', 'application/octet-stream')
self.send_header('Content-Length', str(len(self.payload)))
self.end_headers()
self.wfile.write(self.payload)

def log_message(self, *_args, **_kwargs):
# Silence the default access log during tests.
pass


@pytest.fixture
def cog_http_server(tmp_path):
"""Spin up a local http.server serving a tiled COG, yield (url, arr)."""
arr = np.arange(64 * 64, dtype=np.float32).reshape(64, 64)
path = str(tmp_path / 'tmp_1480_cog.tif')
write(arr, path, compression='deflate', tiled=True, tile_size=16,
cog=True, overview_levels=[1])

with open(path, 'rb') as f:
payload = f.read()

handler_cls = type(
'RangeHandler1480', (_RangeHandler,), {'payload': payload}
)
httpd = socketserver.TCPServer(('127.0.0.1', 0), handler_cls)
port = httpd.server_address[1]
thread = threading.Thread(target=httpd.serve_forever, daemon=True)
thread.start()

try:
yield f'http://127.0.0.1:{port}/cog.tif', arr
finally:
httpd.shutdown()
httpd.server_close()


def test_cog_http_round_trip_matches_local_read(cog_http_server):
url, expected = cog_http_server
result, _ = _read_cog_http(url)
np.testing.assert_array_equal(result, expected)


def test_read_to_array_dispatches_to_http(cog_http_server):
url, expected = cog_http_server
result, _ = read_to_array(url)
np.testing.assert_array_equal(result, expected)
Loading