Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import socket
import statistics
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

import psutil

Expand All @@ -28,7 +28,7 @@ def publish_benchmark_extra_info(
benchmark_group: str = "read",
true_times: List[float] = [],
download_bytes_list: Optional[List[int]] = None,
duration: Optional[int] = None,
duration: Optional[Union[float, List[float]]] = None,
) -> None:
"""
Helper function to publish benchmark parameters to the extra_info property.
Expand All @@ -50,9 +50,19 @@ def publish_benchmark_extra_info(

if download_bytes_list is not None:
assert duration is not None, (
"Duration must be provided if total_bytes_transferred is provided."
"Duration must be provided if download_bytes_list is provided."
)
throughputs_list = [x / duration / (1024 * 1024) for x in download_bytes_list]
if isinstance(duration, list):
assert len(download_bytes_list) == len(duration), (
"Download bytes and duration lists must have the same length."
)
throughputs_list = [
x / d / (1024 * 1024) for x, d in zip(download_bytes_list, duration)
]
else:
throughputs_list = [
x / duration / (1024 * 1024) for x in download_bytes_list
]
min_throughput = min(throughputs_list)
max_throughput = max(throughputs_list)
mean_throughput = statistics.mean(throughputs_list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import random
import time
from io import BytesIO
from typing import List, NamedTuple, Optional

import pytest

Expand All @@ -39,11 +40,68 @@
all_params = config._get_params()


class DownloadResult(NamedTuple):
total_bytes: int
measured_start_time: float
measured_end_time: float


async def create_client():
"""Initializes async client and gets the current event loop."""
return AsyncGrpcClient()


def _aggregate_download_results(results: List[DownloadResult]) -> DownloadResult:
if not results:
raise ValueError("At least one download result is required.")

total_bytes = sum(result.total_bytes for result in results)
measured_start_time = min(result.measured_start_time for result in results)
measured_end_time = max(result.measured_end_time for result in results)
if measured_end_time <= measured_start_time:
raise ValueError("Measured elapsed time must be positive.")

return DownloadResult(
total_bytes=total_bytes,
measured_start_time=measured_start_time,
measured_end_time=measured_end_time,
)


def _calculate_average_throughput_mib_s(
download_bytes_list: List[int], download_elapsed_times: List[float]
) -> float:
total_bytes_downloaded = sum(download_bytes_list)
total_elapsed_time = sum(download_elapsed_times)
if total_elapsed_time <= 0:
raise ValueError("Total measured elapsed time must be positive.")

return (total_bytes_downloaded / total_elapsed_time) / (1024 * 1024)


def _record_measured_start(
measured_start_time: Optional[float], current_time: float
) -> float:
if measured_start_time is None:
return current_time
return measured_start_time


def _build_download_result(
total_bytes_downloaded: int,
measured_start_time: Optional[float],
measured_end_time: Optional[float],
) -> DownloadResult:
if measured_start_time is None or measured_end_time is None:
raise ValueError("No downloads completed during the measured interval.")

return DownloadResult(
total_bytes=total_bytes_downloaded,
measured_start_time=measured_start_time,
measured_end_time=measured_end_time,
)
Comment thread
zhixiangli marked this conversation as resolved.


# --- Global Variables for Worker Process ---
worker_loop = None
worker_client = None
Expand Down Expand Up @@ -78,15 +136,20 @@ def _download_time_based_json(client, filename, params):

offset = 0
is_warming_up = True
start_time = time.monotonic()
start_time = time.perf_counter()
warmup_end_time = start_time + params.warmup_duration
test_end_time = warmup_end_time + params.duration
measured_start_time = None
measured_end_time = None

while time.monotonic() < test_end_time:
current_time = time.monotonic()
while time.perf_counter() < test_end_time:
current_time = time.perf_counter()
if is_warming_up and current_time >= warmup_end_time:
is_warming_up = False
total_bytes_downloaded = 0 # Reset counter after warmup
measured_start_time = _record_measured_start(
measured_start_time, current_time
)

bytes_in_iteration = 0
# For JSON, we can't batch ranges like gRPC, so we download one by one
Expand All @@ -110,8 +173,11 @@ def _download_time_based_json(client, filename, params):

if not is_warming_up:
total_bytes_downloaded += bytes_in_iteration
measured_end_time = time.perf_counter()

return total_bytes_downloaded
return _build_download_result(
total_bytes_downloaded, measured_start_time, measured_end_time
)


async def _download_time_based_async(client, filename, params):
Expand All @@ -122,15 +188,20 @@ async def _worker_coro():
total_bytes_downloaded = 0
offset = 0
is_warming_up = True
start_time = time.monotonic()
start_time = time.perf_counter()
warmup_end_time = start_time + params.warmup_duration
test_end_time = warmup_end_time + params.duration
measured_start_time = None
measured_end_time = None

while time.monotonic() < test_end_time:
current_time = time.monotonic()
while time.perf_counter() < test_end_time:
current_time = time.perf_counter()
if is_warming_up and current_time >= warmup_end_time:
is_warming_up = False
total_bytes_downloaded = 0 # Reset counter after warmup
measured_start_time = _record_measured_start(
measured_start_time, current_time
)

ranges = []
if params.pattern == "rand":
Expand All @@ -153,13 +224,16 @@ async def _worker_coro():

if not is_warming_up:
total_bytes_downloaded += params.chunk_size_bytes * params.num_ranges
return total_bytes_downloaded
measured_end_time = time.perf_counter()
return _build_download_result(
total_bytes_downloaded, measured_start_time, measured_end_time
)

tasks = [asyncio.create_task(_worker_coro()) for _ in range(params.num_coros)]
results = await asyncio.gather(*tasks)

await mrd.close()
return sum(results)
return _aggregate_download_results(results)


def _download_files_worker(process_idx, filename, params, bucket_type):
Expand All @@ -175,7 +249,8 @@ def download_files_mp_mc_wrapper(pool, files_names, params, bucket_type):
args = [(i, files_names[i], params, bucket_type) for i in range(len(files_names))]

results = pool.starmap(_download_files_worker, args)
return sum(results)
agg_res = _aggregate_download_results(results)
return agg_res.total_bytes, agg_res.measured_end_time - agg_res.measured_start_time


@pytest.mark.parametrize(
Expand All @@ -198,9 +273,14 @@ def test_downloads_multi_proc_multi_coro(
)

download_bytes_list = []
download_elapsed_times = []

def target_wrapper(*args, **kwargs):
download_bytes_list.append(download_files_mp_mc_wrapper(pool, *args, **kwargs))
total_bytes, measured_elapsed_time = download_files_mp_mc_wrapper(
pool, *args, **kwargs
)
download_bytes_list.append(total_bytes)
download_elapsed_times.append(measured_elapsed_time)
return

try:
Expand All @@ -214,10 +294,9 @@ def target_wrapper(*args, **kwargs):
finally:
pool.close()
pool.join()
total_bytes_downloaded = sum(download_bytes_list)
throughput_mib_s = (
total_bytes_downloaded / params.duration / params.rounds
) / (1024 * 1024)
throughput_mib_s = _calculate_average_throughput_mib_s(
download_bytes_list, download_elapsed_times
)
benchmark.extra_info["avg_throughput_mib_s"] = f"{throughput_mib_s:.2f}"
print(
f"Avg Throughput of {params.rounds} round(s): {throughput_mib_s:.2f} MiB/s"
Expand All @@ -226,6 +305,6 @@ def target_wrapper(*args, **kwargs):
benchmark,
params,
download_bytes_list=download_bytes_list,
duration=params.duration,
duration=download_elapsed_times,
)
publish_resource_metrics(benchmark, m)
Loading