Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/integration/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import time
from typing import TYPE_CHECKING, Literal, TypeVar

from crawlee._utils.crypto import crypto_random_object_id
Expand Down Expand Up @@ -48,6 +49,32 @@ async def call_with_exp_backoff(
raise ValueError(f'Invalid rq_access_mode: {rq_access_mode}')


async def poll_until_condition(
fn: Callable[[], Awaitable[T]],
condition: Callable[[T], bool],
*,
timeout: float = 60,
poll_interval: float = 5,
) -> T:
"""Poll `fn` until `condition(result)` is True or the timeout expires.

Polls `fn` at `poll_interval`-second intervals until `condition` is satisfied or `timeout` seconds have elapsed.
Returns the last polled result regardless of whether the condition was met.

Use this instead of a fixed `asyncio.sleep` when waiting for eventually-consistent API state (e.g. request queue
stats) that may take a variable amount of time to propagate.
"""
deadline = time.monotonic() + timeout
result = await fn()
while not condition(result):
remaining = deadline - time.monotonic()
if remaining <= 0:
break
await asyncio.sleep(min(poll_interval, remaining))
result = await fn()
return result


def generate_unique_resource_name(label: str) -> str:
"""Generates a unique resource name, which will contain the given label."""
name_template = 'python-sdk-tests-{}-generated-{}'
Expand Down
88 changes: 58 additions & 30 deletions tests/integration/test_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from crawlee import service_locator
from crawlee.crawlers import BasicCrawler

from ._utils import call_with_exp_backoff, generate_unique_resource_name
from ._utils import call_with_exp_backoff, generate_unique_resource_name, poll_until_condition
from apify import Actor, Request
from apify.storage_clients import ApifyStorageClient
from apify.storage_clients._apify import ApifyRequestQueueClient
Expand Down Expand Up @@ -856,10 +856,9 @@ async def test_request_queue_metadata_another_client(
api_client = apify_client_async.request_queue(request_queue_id=rq.id, client_key=None)
await api_client.add_request(Request.from_url('http://example.com/1').model_dump(by_alias=True, exclude={'id'}))

# Wait to be sure that the API has updated the global metadata
await asyncio.sleep(10)

assert (await rq.get_metadata()).total_request_count == 1
# Poll until the API has propagated the metadata change.
metadata = await poll_until_condition(rq.get_metadata, lambda m: m.total_request_count >= 1)
assert metadata.total_request_count == 1


async def test_request_queue_had_multiple_clients(
Expand Down Expand Up @@ -950,12 +949,18 @@ async def default_handler(context: BasicCrawlingContext) -> None:
assert crawler.statistics.state.requests_finished == requests

try:
# Check the request queue stats
await asyncio.sleep(10) # Wait to be sure that metadata are updated
# Poll until request queue stats are propagated by the API.
expected_write_count = requests * expected_write_count_per_request

async def _get_rq_metadata() -> ApifyRequestQueueMetadata:
return cast('ApifyRequestQueueMetadata', await rq.get_metadata())

metadata = cast('ApifyRequestQueueMetadata', await rq.get_metadata())
metadata = await poll_until_condition(
_get_rq_metadata,
lambda m: m.stats.write_count >= expected_write_count,
)
Actor.log.info(f'{metadata.stats=}')
assert metadata.stats.write_count == requests * expected_write_count_per_request
assert metadata.stats.write_count == expected_write_count

finally:
await rq.drop()
Expand Down Expand Up @@ -1009,13 +1014,16 @@ async def test_request_queue_has_stats(request_queue_apify: RequestQueue) -> Non

await rq.add_requests([Request.from_url(f'http://example.com/{i}') for i in range(add_request_count)])

# Wait for stats to become stable
await asyncio.sleep(10)
# Poll until stats are propagated by the API.
async def _get_rq_metadata() -> ApifyRequestQueueMetadata:
return cast('ApifyRequestQueueMetadata', await rq.get_metadata())

metadata = await rq.get_metadata()
apify_metadata = await poll_until_condition(
_get_rq_metadata,
lambda m: m.stats.write_count >= add_request_count,
)

assert hasattr(metadata, 'stats')
apify_metadata = cast('ApifyRequestQueueMetadata', metadata)
assert hasattr(apify_metadata, 'stats')
assert apify_metadata.stats.write_count == add_request_count


Expand Down Expand Up @@ -1153,10 +1161,15 @@ def return_unprocessed_requests(requests: list[dict], *_: Any, **__: Any) -> dic
# This will succeed.
await request_queue_apify.add_requests(['http://example.com/1'])

await asyncio.sleep(10) # Wait to be sure that metadata are updated
_rq = await rq_client.get()
assert _rq
stats_after = _rq.get('stats', {})
# Poll until stats reflect the successful write.
async def _get_rq_stats() -> dict:
result = await rq_client.get()
return (result or {}).get('stats', {})

stats_after = await poll_until_condition(
_get_rq_stats,
lambda s: s.get('writeCount', 0) - stats_before.get('writeCount', 0) >= 1,
)
Actor.log.info(stats_after)

assert (stats_after['writeCount'] - stats_before['writeCount']) == 1
Expand Down Expand Up @@ -1256,10 +1269,15 @@ async def test_request_queue_deduplication(
await rq.add_request(request1)
await rq.add_request(request2)

await asyncio.sleep(10) # Wait to be sure that metadata are updated
_rq = await rq_client.get()
assert _rq
stats_after = _rq.get('stats', {})
# Poll until stats reflect the write.
async def _get_rq_stats() -> dict:
result = await rq_client.get()
return (result or {}).get('stats', {})

stats_after = await poll_until_condition(
_get_rq_stats,
lambda s: s.get('writeCount', 0) - stats_before.get('writeCount', 0) >= 1,
)

assert (stats_after['writeCount'] - stats_before['writeCount']) == 1

Expand All @@ -1283,10 +1301,15 @@ async def test_request_queue_deduplication_use_extended_unique_key(
await rq.add_request(request1)
await rq.add_request(request2)

await asyncio.sleep(10) # Wait to be sure that metadata are updated
_rq = await rq_client.get()
assert _rq
stats_after = _rq.get('stats', {})
# Poll until stats reflect both writes.
async def _get_rq_stats() -> dict:
result = await rq_client.get()
return (result or {}).get('stats', {})

stats_after = await poll_until_condition(
_get_rq_stats,
lambda s: s.get('writeCount', 0) - stats_before.get('writeCount', 0) >= 2,
)

assert (stats_after['writeCount'] - stats_before['writeCount']) == 2

Expand Down Expand Up @@ -1316,10 +1339,15 @@ async def add_requests_worker() -> None:
add_requests_workers = [asyncio.create_task(add_requests_worker()) for _ in range(worker_count)]
await asyncio.gather(*add_requests_workers)

await asyncio.sleep(10) # Wait to be sure that metadata are updated
_rq = await rq_client.get()
assert _rq
stats_after = _rq.get('stats', {})
# Poll until stats reflect all written requests.
async def _get_rq_stats() -> dict:
result = await rq_client.get()
return (result or {}).get('stats', {})

stats_after = await poll_until_condition(
_get_rq_stats,
lambda s: s.get('writeCount', 0) - stats_before.get('writeCount', 0) >= len(requests),
)

assert (stats_after['writeCount'] - stats_before['writeCount']) == len(requests)

Expand Down
Loading