Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ public static Optional<Long> parseContentRangeForTotalSize(String contentRange)
* @param partSize size of each part in bytes
* @return the number of parts
*/
public static int calculateTotalParts(long contentLength, long partSize) {
return (int) Math.ceil((double) contentLength / partSize);
public static long calculateTotalParts(long contentLength, long partSize) {
return (contentLength / partSize) + (contentLength % partSize == 0 ? 0 : 1);

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkInternalApi;
Expand Down Expand Up @@ -62,23 +62,23 @@ public class ParallelPresignedUrlMultipartDownloaderSubscriber
private final CompletableFuture<GetObjectResponse> resultFuture;
private final int maxInFlightParts;

private final AtomicInteger partNumber = new AtomicInteger(0);
private final AtomicInteger completedParts = new AtomicInteger(0);
private final AtomicLong partNumber = new AtomicLong(0);
private final AtomicLong completedParts = new AtomicLong(0);
private final Semaphore inFlightPermits;
/**
* CAS gate ensuring only the first part failure triggers error handling and cancellation.
*/
private final AtomicBoolean downloadFailed = new AtomicBoolean(false);
private final AtomicBoolean processingPending = new AtomicBoolean(false);
private final Map<Integer, CompletableFuture<GetObjectResponse>> inFlightRequests = new ConcurrentHashMap<>();
private final Queue<Pair<Integer, AsyncResponseTransformer<GetObjectResponse, GetObjectResponse>>> pendingTransformers =
private final Map<Long, CompletableFuture<GetObjectResponse>> inFlightRequests = new ConcurrentHashMap<>();
private final Queue<Pair<Long, AsyncResponseTransformer<GetObjectResponse, GetObjectResponse>>> pendingTransformers =
new ConcurrentLinkedQueue<>();

private final Object subscriptionLock = new Object();
private Subscription subscription;

private volatile Long totalContentLength;
private volatile Integer totalParts;
private volatile Long totalParts;
private volatile String eTag;
private volatile GetObjectResponse firstResponse;

Expand Down Expand Up @@ -112,7 +112,7 @@ public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse
throw new NullPointerException("onNext must not be called with null asyncResponseTransformer");
}

int currentPart = partNumber.getAndIncrement();
long currentPart = partNumber.getAndIncrement();

if (currentPart == 0) {
sendFirstRequest(asyncResponseTransformer);
Expand All @@ -129,7 +129,7 @@ public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse
}

private void sendFirstRequest(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> transformer) {
PresignedUrlDownloadRequest partRequest = createRangedGetRequest(0);
PresignedUrlDownloadRequest partRequest = createRangedGetRequest(0L);
log.debug(() -> "Sending first range request with range=" + partRequest.range());

if (!inFlightPermits.tryAcquire()) {
Expand All @@ -139,11 +139,11 @@ private void sendFirstRequest(AsyncResponseTransformer<GetObjectResponse, GetObj
CompletableFuture<GetObjectResponse> response =
s3AsyncClient.presignedUrlExtension().getObject(partRequest, transformer);

inFlightRequests.put(0, response);
inFlightRequests.put(0L, response);
CompletableFutureUtils.forwardExceptionTo(resultFuture, response);

response.whenComplete((res, error) -> {
inFlightRequests.remove(0);
inFlightRequests.remove(0L);
inFlightPermits.release();

if (error != null) {
Expand All @@ -154,7 +154,7 @@ private void sendFirstRequest(AsyncResponseTransformer<GetObjectResponse, GetObj
subscription.cancel();
}
} else {
handlePartError(error, 0);
handlePartError(error, 0L);
}
return;
}
Expand All @@ -170,23 +170,23 @@ private void sendFirstRequest(AsyncResponseTransformer<GetObjectResponse, GetObj

String contentRange = res.contentRange();
if (contentRange == null) {
handlePartError(PresignedUrlDownloadHelper.missingContentRangeHeader(), 0);
handlePartError(PresignedUrlDownloadHelper.missingContentRangeHeader(), 0L);
return;
}

Optional<Long> parsedTotal = MultipartDownloadUtils.parseContentRangeForTotalSize(contentRange);
if (!parsedTotal.isPresent()) {
handlePartError(PresignedUrlDownloadHelper.invalidContentRangeHeader(contentRange), 0);
handlePartError(PresignedUrlDownloadHelper.invalidContentRangeHeader(contentRange), 0L);
return;
}

this.totalContentLength = parsedTotal.get();
this.totalParts = MultipartDownloadUtils.calculateTotalParts(totalContentLength, configuredPartSizeInBytes);
log.debug(() -> String.format("Total content length: %d, Total parts: %d", totalContentLength, totalParts));

Optional<SdkClientException> validationError = validatePartResponse(res, 0);
Optional<SdkClientException> validationError = validatePartResponse(res, 0L);
if (validationError.isPresent()) {
handlePartError(validationError.get(), 0);
handlePartError(validationError.get(), 0L);
return;
}

Expand All @@ -200,16 +200,16 @@ private void sendFirstRequest(AsyncResponseTransformer<GetObjectResponse, GetObj

processPendingTransformers();

int remainingParts = totalParts - 1;
int toRequest = Math.min(remainingParts, maxInFlightParts);
long remainingParts = totalParts - 1;
long toRequest = Math.min(remainingParts, maxInFlightParts);
synchronized (subscriptionLock) {
subscription.request(toRequest);
}
});
}

private void processRequest(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> transformer,
int currentPart) {
long currentPart) {
if (currentPart >= totalParts) {
return;
}
Expand All @@ -224,7 +224,7 @@ private void processRequest(AsyncResponseTransformer<GetObjectResponse, GetObjec
}

private void sendPartRequest(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> transformer,
int partIndex) {
long partIndex) {
if (downloadFailed.get()) {
inFlightPermits.release();
return;
Expand Down Expand Up @@ -259,7 +259,7 @@ private void sendPartRequest(AsyncResponseTransformer<GetObjectResponse, GetObje
}

log.debug(() -> "Completed part: " + partIndex);
int totalComplete = completedParts.incrementAndGet();
long totalComplete = completedParts.incrementAndGet();

if (totalComplete == totalParts) {
resultFuture.complete(firstResponse);
Expand All @@ -285,7 +285,7 @@ private void processPendingTransformers() {
try {
// Drain pending queue while permits are available
while (!pendingTransformers.isEmpty() && inFlightPermits.tryAcquire()) {
Pair<Integer, AsyncResponseTransformer<GetObjectResponse, GetObjectResponse>> pendingPart =
Pair<Long, AsyncResponseTransformer<GetObjectResponse, GetObjectResponse>> pendingPart =
pendingTransformers.poll();
if (pendingPart != null && pendingPart.left() < totalParts) {
sendPartRequest(pendingPart.right(), pendingPart.left());
Expand All @@ -299,12 +299,12 @@ private void processPendingTransformers() {
} while (!pendingTransformers.isEmpty() && inFlightPermits.availablePermits() > 0);
}

private Optional<SdkClientException> validatePartResponse(GetObjectResponse response, int partIndex) {
private Optional<SdkClientException> validatePartResponse(GetObjectResponse response, long partIndex) {
return PresignedUrlDownloadHelper.validatePartResponse(
response, partIndex, configuredPartSizeInBytes, totalContentLength, totalParts);
}

private void handlePartError(Throwable error, int partIndex) {
private void handlePartError(Throwable error, long partIndex) {
if (downloadFailed.compareAndSet(false, true)) {
log.debug(() -> "Error on part " + partIndex, error);
resultFuture.completeExceptionally(error);
Expand All @@ -317,7 +317,7 @@ private void handlePartError(Throwable error, int partIndex) {
}
}

private PresignedUrlDownloadRequest createRangedGetRequest(int partIndex) {
private PresignedUrlDownloadRequest createRangedGetRequest(long partIndex) {
return PresignedUrlDownloadHelper.createRangedGetRequest(
presignedUrlDownloadRequest, partIndex, configuredPartSizeInBytes, totalContentLength, eTag);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ static SdkClientException invalidContentLength() {
* @return empty if valid, or an SdkClientException describing the mismatch
*/
static Optional<SdkClientException> validatePartResponse(GetObjectResponse response,
int partIndex,
long partIndex,
long partSizeInBytes,
Long totalContentLength,
Integer totalParts) {
Long totalParts) {
String contentRange = response.contentRange();
if (contentRange == null) {
return Optional.of(missingContentRangeHeader());
Expand Down Expand Up @@ -214,7 +214,7 @@ static Optional<SdkClientException> validatePartResponse(GetObjectResponse respo
* @return a new PresignedUrlDownloadRequest with the appropriate Range and If-Match headers
*/
static PresignedUrlDownloadRequest createRangedGetRequest(PresignedUrlDownloadRequest originalRequest,
int partIndex,
long partIndex,
long partSizeInBytes,
Long totalContentLength,
String eTag) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import java.util.Queue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkInternalApi;
Expand Down Expand Up @@ -67,16 +67,16 @@ public class PresignedUrlMultipartDownloaderSubscriber
*/
private final CompletableFuture<?> resultFuture;
private final Object lock = new Object();
private final AtomicInteger nextPartIndex;
private final AtomicInteger requestsSent;
private final AtomicLong nextPartIndex;
private final AtomicLong requestsSent;

/**
* Store the GetObject futures so we can cancel them if onError() is invoked.
*/
private final Queue<CompletableFuture<GetObjectResponse>> getObjectFutures = new ConcurrentLinkedQueue<>();

private volatile Long totalContentLength;
private volatile Integer totalParts;
private volatile Long totalParts;
private volatile String eTag;
private Subscription subscription;

Expand All @@ -88,8 +88,8 @@ public PresignedUrlMultipartDownloaderSubscriber(
this.s3AsyncClient = s3AsyncClient;
this.presignedUrlDownloadRequest = presignedUrlDownloadRequest;
this.configuredPartSizeInBytes = configuredPartSizeInBytes;
this.nextPartIndex = new AtomicInteger(0);
this.requestsSent = new AtomicInteger(0);
this.nextPartIndex = new AtomicLong(0);
this.requestsSent = new AtomicLong(0);
this.future = new CompletableFuture<>();
this.resultFuture = resultFuture;
}
Expand All @@ -110,7 +110,7 @@ public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse
throw new NullPointerException("onNext must not be called with null asyncResponseTransformer");
}

int currentPartIndex;
long currentPartIndex;
synchronized (lock) {
currentPartIndex = nextPartIndex.get();
if (totalParts != null && currentPartIndex >= totalParts) {
Expand All @@ -123,7 +123,7 @@ public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse
makeRangeRequest(currentPartIndex, asyncResponseTransformer);
}

private void makeRangeRequest(int partIndex,
private void makeRangeRequest(long partIndex,
AsyncResponseTransformer<GetObjectResponse,
GetObjectResponse> asyncResponseTransformer) {
PresignedUrlDownloadRequest partRequest = createRangedGetRequest(partIndex);
Expand Down Expand Up @@ -153,9 +153,9 @@ private void makeRangeRequest(int partIndex,
});
}

private boolean validatePart(GetObjectResponse response, int partIndex,
private boolean validatePart(GetObjectResponse response, long partIndex,
AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> asyncResponseTransformer) {
int dispatched = nextPartIndex.get();
long dispatched = nextPartIndex.get();
log.debug(() -> String.format("Dispatched %d parts so far", dispatched));

String responseETag = response.eTag();
Expand Down Expand Up @@ -191,7 +191,7 @@ private boolean validatePart(GetObjectResponse response, int partIndex,
return true;
}

private void requestMoreIfNeeded(int dispatched) {
private void requestMoreIfNeeded(long dispatched) {
synchronized (lock) {
if (hasMoreParts(dispatched)) {
subscription.request(1);
Expand All @@ -207,16 +207,16 @@ private void requestMoreIfNeeded(int dispatched) {
}
}

private Optional<SdkClientException> validateResponse(GetObjectResponse response, int partIndex) {
private Optional<SdkClientException> validateResponse(GetObjectResponse response, long partIndex) {
return PresignedUrlDownloadHelper.validatePartResponse(
response, partIndex, configuredPartSizeInBytes, totalContentLength, totalParts);
}

private boolean hasMoreParts(int dispatched) {
private boolean hasMoreParts(long dispatched) {
return totalParts != null && dispatched < totalParts;
}

private PresignedUrlDownloadRequest createRangedGetRequest(int partIndex) {
private PresignedUrlDownloadRequest createRangedGetRequest(long partIndex) {
return PresignedUrlDownloadHelper.createRangedGetRequest(
presignedUrlDownloadRequest, partIndex, configuredPartSizeInBytes, totalContentLength, eTag);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,12 @@ void calculateTotalParts_shouldCalculateCorrectly() {
assertThat(MultipartDownloadUtils.calculateTotalParts(1, 16)).isEqualTo(1); // smaller than part size
assertThat(MultipartDownloadUtils.calculateTotalParts(16, 16)).isEqualTo(1); // exactly one part
assertThat(MultipartDownloadUtils.calculateTotalParts(0, 16)).isEqualTo(0); // empty object
// 5 GiB / 1 byte = 5_368_709_120 parts (exceeds Integer.MAX_VALUE)
long fiveGiB = 5L * 1024L * 1024L * 1024L;
assertThat(MultipartDownloadUtils.calculateTotalParts(fiveGiB, 1L)).isEqualTo(fiveGiB);
assertThat(MultipartDownloadUtils.calculateTotalParts(Long.MAX_VALUE - 1, 1L))
.isEqualTo(Long.MAX_VALUE - 1);
assertThat(MultipartDownloadUtils.calculateTotalParts(Long.MAX_VALUE, 2))
.isEqualTo((Long.MAX_VALUE / 2) + 1);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void validatePartResponse_validResponse_shouldReturnEmpty() {
.build();

Optional<SdkClientException> result = PresignedUrlDownloadHelper.validatePartResponse(
response, 0, 16L, 32L, 2);
response, 0, 16L, 32L, 2L);

assertThat(result).isEmpty();
}
Expand All @@ -47,7 +47,7 @@ void validatePartResponse_missingContentRange_shouldReturnError() {
.build();

Optional<SdkClientException> result = PresignedUrlDownloadHelper.validatePartResponse(
response, 0, 16L, 32L, 2);
response, 0, 16L, 32L, 2L);

assertThat(result).isPresent();
assertThat(result.get().getMessage()).contains("No Content-Range header");
Expand All @@ -61,7 +61,7 @@ void validatePartResponse_invalidContentLength_shouldReturnError() {
.build();

Optional<SdkClientException> result = PresignedUrlDownloadHelper.validatePartResponse(
response, 0, 16L, 32L, 2);
response, 0, 16L, 32L, 2L);

assertThat(result).isPresent();
assertThat(result.get().getMessage()).contains("Invalid or missing Content-Length");
Expand All @@ -75,7 +75,7 @@ void validatePartResponse_contentRangeMismatch_shouldReturnError() {
.build();

Optional<SdkClientException> result = PresignedUrlDownloadHelper.validatePartResponse(
response, 0, 16L, 32L, 2);
response, 0, 16L, 32L, 2L);

assertThat(result).isPresent();
assertThat(result.get().getMessage()).contains("Content-Range mismatch for part 0");
Expand All @@ -89,7 +89,7 @@ void validatePartResponse_contentLengthMismatch_shouldReturnError() {
.build();

Optional<SdkClientException> result = PresignedUrlDownloadHelper.validatePartResponse(
response, 0, 16L, 32L, 2);
response, 0, 16L, 32L, 2L);

assertThat(result).isPresent();
assertThat(result.get().getMessage()).contains("content length validation failed");
Expand All @@ -104,7 +104,7 @@ void validatePartResponse_lastPartSmallerSize_shouldPass() {
.build();

Optional<SdkClientException> result = PresignedUrlDownloadHelper.validatePartResponse(
response, 1, 16L, 30L, 2);
response, 1, 16L, 30L, 2L);

assertThat(result).isEmpty();
}
Expand Down
Loading