diff --git a/topic/src/main/java/tech/ydb/topic/impl/TopicRetryableStream.java b/topic/src/main/java/tech/ydb/topic/impl/TopicRetryableStream.java index 50c873c03..18c3acf57 100644 --- a/topic/src/main/java/tech/ydb/topic/impl/TopicRetryableStream.java +++ b/topic/src/main/java/tech/ydb/topic/impl/TopicRetryableStream.java @@ -14,8 +14,8 @@ import tech.ydb.core.StatusCode; public abstract class TopicRetryableStream { + protected final String debugId; private final Logger logger; - private final String debugId; private final RetryConfig retryConfig; private final ScheduledExecutorService scheduler; @@ -32,11 +32,6 @@ public TopicRetryableStream(Logger logger, String debugId, RetryConfig config, S this.scheduler = scheduler; } - @Override - public String toString() { - return "Session[" + debugId + "]"; - } - protected abstract TopicStream createNewStream(String debugId); protected abstract W getInitRequest(); @@ -54,8 +49,7 @@ public void start() { TopicStream stream = createNewStream(streamID); if (!realStream.compareAndSet(null, stream)) { - logger.warn("{} double start of stream, skipping", this); - stream.close(); + logger.warn("[{}] double start of stream, skipping", debugId); return; } @@ -78,18 +72,21 @@ protected void resetRetries() { public void send(W msg) { TopicStream stream = realStream.get(); if (stream == null) { - logger.warn("{} send message before stream is ready", this); + logger.warn("[{}] send message before stream is ready", debugId); return; } stream.send(msg); } - public void close() { + public boolean close() { isClosed = true; TopicStream stream = realStream.getAndSet(null); - if (stream != null) { - stream.close(); + if (stream == null) { + return false; } + + stream.close(); + return true; } private void onStreamStop(Status status, RetryPolicy policy) { @@ -99,7 +96,7 @@ private void onStreamStop(Status status, RetryPolicy policy) { } if (policy == null) { - logger.warn("{} stopped by non-retryable status {}", this, status); + logger.warn("[{}] stopped by non-retryable status {}", debugId, status); onClose(status); return; } @@ -107,26 +104,26 @@ private void onStreamStop(Status status, RetryPolicy policy) { long nextRetryMs = state.nextRetryMs(policy); if (nextRetryMs < 0) { - logger.warn("{} stopped after retry policy evaluation for status {}", this, status); + logger.warn("[{}] stopped after retry policy evaluation for status {}", debugId, status); onClose(status); return; } if (nextRetryMs == 0) { // retry immediately - logger.warn("{} retry #{}. Retry immediately...", this, state.retryNumber()); + logger.warn("[{}] retry #{}. Retry immediately...", debugId, state.retryNumber()); onRetry(status); start(); return; } // retry scheduling - logger.warn("{} retry #{}. Scheduling reconnect in {}ms...", debugId, state.retryNumber(), nextRetryMs); + logger.warn("[{}] retry #{}. Scheduling reconnect in {}ms...", debugId, state.retryNumber(), nextRetryMs); onRetry(status); try { scheduler.schedule(this::start, nextRetryMs, TimeUnit.MILLISECONDS); } catch (Exception ex) { - logger.error("{} cannot schedule reconnect, stopping", debugId, ex); + logger.error("[{}] cannot schedule reconnect, stopping", debugId, ex); onClose(status); } } diff --git a/topic/src/main/java/tech/ydb/topic/settings/TopicRetryConfig.java b/topic/src/main/java/tech/ydb/topic/settings/TopicRetryConfig.java new file mode 100644 index 000000000..91edd854a --- /dev/null +++ b/topic/src/main/java/tech/ydb/topic/settings/TopicRetryConfig.java @@ -0,0 +1,43 @@ +package tech.ydb.topic.settings; + +import tech.ydb.common.retry.ExponentialBackoffRetry; +import tech.ydb.common.retry.RetryConfig; +import tech.ydb.common.retry.RetryPolicy; +import tech.ydb.core.Status; + +/** + * Predefined {@link RetryConfig} instances for topic writers and readers. + *

+ * Pass one of these constants (or a custom {@link RetryConfig}) to + * {@link WriterSettings.Builder#setRetryConfig} to control how the writer + * behaves when its underlying stream is interrupted. + * + * @author Aleksandr Gorshenin + */ +public class TopicRetryConfig { + // Max backoff will be random delay from 32.768s to 65.536s + private static final RetryPolicy DEFAULT_BACKOFF = new ExponentialBackoffRetry(32, 10); + + /** + * Retry any stream disconnection indefinitely with exponential backoff. + *

+ * Every status code, including {@link Status#SUCCESS}, is treated as retryable. + * The delay between reconnection attempts grows exponentially and is capped at a + * random value between 32 and 65 seconds. + *

+ * This is the default retry configuration for topic writers and readers. + */ + public static final RetryConfig FOREVER = status -> DEFAULT_BACKOFF; + + /** + * Disable retries entirely. + *

+ * Any stream disconnection is reported immediately as a terminal error through + * the errors handler configured via + * {@link WriterSettings.Builder#setErrorsHandler}. + * Use this when you need full control over reconnection logic in application code. + */ + public static final RetryConfig NEVER = status -> null; + + private TopicRetryConfig() { } +} diff --git a/topic/src/main/java/tech/ydb/topic/settings/WriterSettings.java b/topic/src/main/java/tech/ydb/topic/settings/WriterSettings.java index 061a2c067..b301a427a 100644 --- a/topic/src/main/java/tech/ydb/topic/settings/WriterSettings.java +++ b/topic/src/main/java/tech/ydb/topic/settings/WriterSettings.java @@ -2,6 +2,7 @@ import java.util.function.BiConsumer; +import tech.ydb.common.retry.RetryConfig; import tech.ydb.core.Status; import tech.ydb.topic.description.Codec; @@ -20,6 +21,7 @@ public class WriterSettings { private final int codec; private final long maxSendBufferMemorySize; private final int maxSendBufferMessagesCount; + private final RetryConfig retryConfig; private final BiConsumer errorsHandler; private WriterSettings(Builder builder) { @@ -31,6 +33,7 @@ private WriterSettings(Builder builder) { this.codec = builder.codec; this.maxSendBufferMemorySize = builder.maxSendBufferMemorySize; this.maxSendBufferMessagesCount = builder.maxSendBufferMessagesCount; + this.retryConfig = builder.retryConfig; this.errorsHandler = builder.errorsHandler; } @@ -58,6 +61,10 @@ public BiConsumer getErrorsHandler() { return errorsHandler; } + public RetryConfig getRetryConfig() { + return retryConfig; + } + public Long getPartitionId() { return partitionId; } @@ -86,6 +93,7 @@ public static class Builder { private int codec = Codec.GZIP; private long maxSendBufferMemorySize = MAX_MEMORY_USAGE_BYTES_DEFAULT; private int maxSendBufferMessagesCount = MAX_IN_FLIGHT_COUNT_DEFAULT; + private RetryConfig retryConfig = TopicRetryConfig.FOREVER; private BiConsumer errorsHandler = null; /** @@ -183,6 +191,30 @@ public Builder setErrorsHandler(BiConsumer handler) { return this; } + /** + * Set retry configuration for the writer's underlying stream connection. + * Controls how the writer reconnects when the stream is interrupted. + *

+ * The default value is {@link TopicRetryConfig#FOREVER}, which retries any disconnection + * indefinitely with exponential backoff (up to ~65 seconds between attempts). + *

+ * Use {@link TopicRetryConfig#NEVER} to disable retries and surface errors immediately + * via the errors handler set by {@link #setErrorsHandler}. + * + * @param config retry configuration, must not be {@code null} + * @return this builder + * @throws NullPointerException if {@code config} is {@code null} + * @see TopicRetryConfig#FOREVER + * @see TopicRetryConfig#NEVER + */ + public Builder setRetryConfig(RetryConfig config) { + if (config == null) { + throw new NullPointerException("RetryConfig must not be null"); + } + this.retryConfig = config; + return this; + } + public WriterSettings build() { return new WriterSettings(this); } diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/BufferManager.java b/topic/src/main/java/tech/ydb/topic/write/impl/BufferManager.java index 8b3216c4a..e3f056d9d 100644 --- a/topic/src/main/java/tech/ydb/topic/write/impl/BufferManager.java +++ b/topic/src/main/java/tech/ydb/topic/write/impl/BufferManager.java @@ -7,6 +7,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import tech.ydb.core.Status; import tech.ydb.topic.settings.WriterSettings; import tech.ydb.topic.write.QueueOverflowException; @@ -18,7 +19,7 @@ public class BufferManager { // use logger from WriterImpl private static final Logger logger = LoggerFactory.getLogger(WriterImpl.class); - private final String id; + private final String debugId; private final long bufferMaxSize; private final int maxCount; private final int blockBitsCount; @@ -26,8 +27,10 @@ public class BufferManager { private final Semaphore blocksAvailable; private final Semaphore countAvailable; + private volatile Status closed = null; + public BufferManager(String id, WriterSettings settings) { - this.id = id; + this.debugId = id; this.maxCount = settings.getMaxSendBufferMessagesCount(); this.bufferMaxSize = settings.getMaxSendBufferMemorySize(); @@ -42,23 +45,56 @@ public long getMaxSize() { return bufferMaxSize; } + public void close(Status status) { + this.closed = status; + // release all waiters + this.blocksAvailable.release(calculateBlocksCount(bufferMaxSize, blockBitsCount)); + this.countAvailable.release(maxCount); + } + public void acquire(long messageSize) throws InterruptedException, QueueOverflowException { + if (closed != null) { + throw new IllegalStateException("Writer was closed with status " + closed); + } + countAvailable.acquire(); + + if (closed != null) { + countAvailable.release(); + throw new IllegalStateException("Writer was closed with status " + closed); + } + + int messageBlocks = calculateBlocksCount(messageSize, blockBitsCount); + try { - int messageBlocks = calculateBlocksCount(messageSize, blockBitsCount); blocksAvailable.acquire(messageBlocks); } catch (InterruptedException ex) { countAvailable.release(); throw ex; } + + if (closed != null) { + blocksAvailable.release(messageBlocks); + countAvailable.release(); + throw new IllegalStateException("Writer was closed with status " + closed); + } } public void tryAcquire(long messageSize) throws QueueOverflowException { + if (closed != null) { + throw new IllegalStateException("Writer was closed with status " + closed); + } + if (!countAvailable.tryAcquire()) { - String errorMessage = "[" + id + "] Rejecting a message due to reaching message queue in-flight limit of " + String errorMsg = "[" + debugId + "] Rejecting a message due to reaching message queue in-flight limit of " + maxCount; - logger.warn(errorMessage); - throw new QueueOverflowException(errorMessage); + logger.warn(errorMsg); + throw new QueueOverflowException(errorMsg); + } + + if (closed != null) { + countAvailable.release(); + throw new IllegalStateException("Writer was closed with status " + closed); } int messageBlocks = calculateBlocksCount(messageSize, blockBitsCount); @@ -66,43 +102,64 @@ public void tryAcquire(long messageSize) throws QueueOverflowException { countAvailable.release(); int count = maxCount - countAvailable.availablePermits(); long size = ((long) blocksAvailable.availablePermits()) << blockBitsCount; - String errorMessage = "[" + id + "] Rejecting a message of " + messageSize + + String errorMsg = "[" + debugId + "] Rejecting a message of " + messageSize + " bytes: not enough space in message queue. Buffer currently has " + count + " messages with " + size + " / " + bufferMaxSize + " bytes available"; - logger.warn(errorMessage); - throw new QueueOverflowException(errorMessage); + logger.warn(errorMsg); + throw new QueueOverflowException(errorMsg); + } + + if (closed != null) { + blocksAvailable.release(messageBlocks); + countAvailable.release(); + throw new IllegalStateException("Writer was closed with status " + closed); } } public void tryAcquire(long messageSize, long timeout, TimeUnit unit) throws InterruptedException, QueueOverflowException, TimeoutException { + if (closed != null) { + throw new IllegalStateException("Writer was closed with status " + closed); + } + long expireAt = System.nanoTime() + unit.toNanos(timeout); if (!countAvailable.tryAcquire(timeout, unit)) { - String errorMessage = "[" + id + "] Rejecting a message due to reaching message queue in-flight limit of " + String errorMsg = "[" + debugId + "] Rejecting a message due to reaching message queue in-flight limit of " + maxCount; - logger.warn(errorMessage); - throw new TimeoutException(errorMessage); + logger.warn(errorMsg); + throw new TimeoutException(errorMsg); } + if (closed != null) { + countAvailable.release(); + throw new IllegalStateException("Writer was closed with status " + closed); + } + + int messageBlocks = calculateBlocksCount(messageSize, blockBitsCount); + try { // negative timeout is allowed for tryAcquire - long timeout2 = unit.convert(expireAt - System.nanoTime(), TimeUnit.NANOSECONDS); - int messageBlocks = calculateBlocksCount(messageSize, blockBitsCount); - if (!blocksAvailable.tryAcquire(messageBlocks, timeout2, unit)) { + long timeout2 = expireAt - System.nanoTime(); + if (!blocksAvailable.tryAcquire(messageBlocks, timeout2, TimeUnit.NANOSECONDS)) { countAvailable.release(); int count = maxCount - countAvailable.availablePermits(); long size = ((long) blocksAvailable.availablePermits()) << blockBitsCount; - String errorMessage = "[" + id + "] Rejecting a message of " + messageSize + + String errorMsg = "[" + debugId + "] Rejecting a message of " + messageSize + " bytes: not enough space in message queue. Buffer currently has " + count + " messages with " + size + " / " + bufferMaxSize + " bytes available"; - logger.warn(errorMessage); - throw new TimeoutException(errorMessage); + logger.warn(errorMsg); + throw new TimeoutException(errorMsg); } } catch (InterruptedException ex) { countAvailable.release(); throw ex; } + if (closed != null) { + blocksAvailable.release(messageBlocks); + countAvailable.release(); + throw new IllegalStateException("Writer was closed with status " + closed); + } } public void releaseMessage(long messageSize) { diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/EnqueuedMessage.java b/topic/src/main/java/tech/ydb/topic/write/impl/EnqueuedMessage.java index 80d1ed062..e5ae85c75 100644 --- a/topic/src/main/java/tech/ydb/topic/write/impl/EnqueuedMessage.java +++ b/topic/src/main/java/tech/ydb/topic/write/impl/EnqueuedMessage.java @@ -15,7 +15,7 @@ public class EnqueuedMessage { private final CompletableFuture ackFuture = new CompletableFuture<>(); private volatile ByteString data = null; - private volatile Throwable encodingProblem = null; + private volatile Throwable problem = null; private volatile long bufferSize; public EnqueuedMessage(MessageMeta meta, long bufferSize) { @@ -33,7 +33,7 @@ public ByteString getData() { } public Throwable getProblem() { - return encodingProblem; + return problem; } public CompletableFuture getAckFuture() { @@ -45,7 +45,7 @@ public boolean isReady() { } public boolean hasProblem() { - return encodingProblem != null; + return problem != null; } public long getBufferSize() { @@ -55,10 +55,10 @@ public long getBufferSize() { public void setData(ByteString data, long updatedSize) { this.bufferSize = updatedSize; this.data = data; - this.encodingProblem = null; + this.problem = null; } public void setError(Throwable ex) { - this.encodingProblem = ex; + this.problem = ex; } } diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/MessageSender.java b/topic/src/main/java/tech/ydb/topic/write/impl/MessageSender.java index ac873930d..d7f5cf7b4 100644 --- a/topic/src/main/java/tech/ydb/topic/write/impl/MessageSender.java +++ b/topic/src/main/java/tech/ydb/topic/write/impl/MessageSender.java @@ -49,6 +49,7 @@ public class MessageSender { REQUEST_OVERHEAD); } + private final String debugId; private final int codecCode; private final List messages = new ArrayList<>(); private final AtomicInteger messagesPbSize = new AtomicInteger(0); @@ -56,7 +57,8 @@ public class MessageSender { private volatile YdbTransaction currentTransaction = null; - public MessageSender(int codecCode, Consumer session) { + public MessageSender(String debugId, int codecCode, Consumer session) { + this.debugId = debugId; this.codecCode = codecCode; this.session = session; } @@ -65,7 +67,7 @@ public int getCurrentRequestSize() { return REQUEST_OVERHEAD + messagesPbSize.get() + MESSAGE_OVERHEAD * messages.size(); } - public void sendWriteRequest() { + private void sendWriteRequest() { YdbTopic.StreamWriteMessage.WriteRequest.Builder req = YdbTopic.StreamWriteMessage.WriteRequest.newBuilder(); if (currentTransaction != null) { req.setTx(YdbTopic.TransactionIdentity.newBuilder() @@ -80,13 +82,16 @@ public void sendWriteRequest() { .setWriteRequest(req.build()) .build(); - if (logger.isDebugEnabled()) { - logger.debug("Predicted request size: {} = {}(request overhead) + {}(all MessageData protos) " + + if (logger.isTraceEnabled()) { + logger.trace("Predicted request size: {} = {}(request overhead) + {}(all MessageData protos) " + "+ {}(message overheads) Actual request size: {} bytes", getCurrentRequestSize(), REQUEST_OVERHEAD, messagesPbSize, MESSAGE_OVERHEAD * messages.size(), fromClient.getSerializedSize()); } + logger.debug("[{}] write {} messages with seq numbers {}-{}", debugId, messages.size(), + messages.get(0).getSeqNo(), messages.get(messages.size() - 1).getSeqNo()); + session.accept(fromClient); messages.clear(); messagesPbSize.set(0); diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/WriteSession.java b/topic/src/main/java/tech/ydb/topic/write/impl/WriteSession.java index 16d31f3f3..a1fbe271b 100644 --- a/topic/src/main/java/tech/ydb/topic/write/impl/WriteSession.java +++ b/topic/src/main/java/tech/ydb/topic/write/impl/WriteSession.java @@ -1,66 +1,57 @@ package tech.ydb.topic.write.impl; import java.util.List; +import java.util.function.BiConsumer; import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import tech.ydb.core.Issue; import tech.ydb.core.Status; -import tech.ydb.core.StatusCode; import tech.ydb.core.utils.ProtobufUtils; -import tech.ydb.proto.StatusCodesProtos; import tech.ydb.proto.topic.YdbTopic; import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromClient; import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromServer; import tech.ydb.topic.TopicRpc; -import tech.ydb.topic.impl.SessionBase; +import tech.ydb.topic.impl.TopicRetryableStream; import tech.ydb.topic.settings.WriterSettings; import tech.ydb.topic.write.WriteAck; /** * @author Nikolay Perfilov */ -public final class WriteSession extends SessionBase { +public final class WriteSession extends TopicRetryableStream { private static final Logger logger = LoggerFactory.getLogger(WriteSession.class); - private final WriterImpl writer; - private final MessageSender sender; - private final YdbTopic.StreamWriteMessage.InitRequest initRequest; + public interface Listener { + void onAck(WriteAck ack); - private volatile String sessionId = null; - private volatile Status finishStatus = null; + void onStart(long lastSeqNo, String sessionId); + void onStop(Status status); - public WriteSession(WriterImpl writer, TopicRpc rpc, String streamId, WriterSettings settings) { - super(rpc.writeSession(streamId), streamId); - this.writer = writer; - this.initRequest = buildInitRequest(settings); - this.sender = new MessageSender(settings.getCodec(), this::safeSend); + void onClose(Status status); } - @Override - protected Logger getLogger() { - return logger; - } - - public boolean isStarted() { - return sessionId != null; + private final Listener listener; + private final StreamFactory streamFactory; + private final MessageSender sender; + private final BiConsumer errorsHandler; + + public WriteSession(String debugId, TopicRpc rpc, WriterSettings settings, Listener controller) { + super(logger, debugId, settings.getRetryConfig(), rpc.getScheduler()); + this.listener = controller; + this.streamFactory = new StreamFactory(rpc, settings); + this.sender = new MessageSender(debugId, settings.getCodec(), this::send); + this.errorsHandler = settings.getErrorsHandler(); } @Override - protected void sendUpdateTokenRequest(String token) { - streamConnection.sendNext(YdbTopic.StreamWriteMessage.FromClient.newBuilder() - .setUpdateTokenRequest(YdbTopic.UpdateTokenRequest.newBuilder() - .setToken(token) - .build()) - .build() - ); + protected WriteStream createNewStream(String id) { + return streamFactory.createNewStream(id); } - private void safeSend(YdbTopic.StreamWriteMessage.FromClient msg) { - if (finishStatus == null) { - send(msg); - } + @Override + protected FromClient getInitRequest() { + return streamFactory.initRequest(); } public void sendAll(Supplier generator) { @@ -72,26 +63,19 @@ public void sendAll(Supplier generator) { sender.flush(); } - @Override - public void startAndInitialize() { - logger.debug("[{}] Session startAndInitialize called", streamId); - start(this::processMessage).whenComplete(this::closeDueToError); - safeSend(YdbTopic.StreamWriteMessage.FromClient.newBuilder().setInitRequest(initRequest).build()); - } - private void onInitResponse(YdbTopic.StreamWriteMessage.InitResponse response) { long lastSeqNo = response.getLastSeqNo(); - writer.onInit(lastSeqNo); - sessionId = response.getSessionId(); + String sessionId = response.getSessionId(); + resetRetries(); logger.info("[{}] Session with id {} (partition {}) initialized for topic \"{}\", lastSeqNo {}", - streamId, sessionId, response.getPartitionId(), initRequest.getPath(), lastSeqNo); - writer.onStart(lastSeqNo); + debugId, sessionId, response.getPartitionId(), streamFactory.topicPath, lastSeqNo); + listener.onStart(lastSeqNo, sessionId); } // Shouldn't be called more than once at a time due to grpc guarantees private void onWriteResponse(YdbTopic.StreamWriteMessage.WriteResponse response) { List acks = response.getAcksList(); - logger.debug("[{}] Received WriteResponse with {} WriteAcks", streamId, acks.size()); + logger.debug("[{}] Received WriteResponse with {} WriteAcks", debugId, acks.size()); WriteAck.Statistics statistics = null; if (response.getWriteStatistics() != null) { YdbTopic.StreamWriteMessage.WriteResponse.WriteStatistics src = response.getWriteStatistics(); @@ -105,31 +89,12 @@ private void onWriteResponse(YdbTopic.StreamWriteMessage.WriteResponse response) } for (YdbTopic.StreamWriteMessage.WriteResponse.WriteAck ack : acks) { - writer.onAck(mapAck(statistics, ack)); - } - } - - private void processMessage(YdbTopic.StreamWriteMessage.FromServer message) { - if (message.getStatus() != StatusCodesProtos.StatusIds.StatusCode.SUCCESS) { - Status status = Status.of(StatusCode.fromProto(message.getStatus()), - Issue.fromPb(message.getIssuesList())); - logger.warn("[{}] Got non-success status in processMessage method: {}", streamId, status); - closeDueToError(status, null); - return; - } - if (message.hasInitResponse()) { - onInitResponse(message.getInitResponse()); - } else if (message.hasWriteResponse()) { - onWriteResponse(message.getWriteResponse()); - } else if (message.hasUpdateTokenResponse()) { - logger.debug("[{}] got update token response", streamId); - } else { - logger.warn("[{}] got unknown type message", streamId); + listener.onAck(mapAck(statistics, ack)); } } WriteAck mapAck(WriteAck.Statistics statistics, YdbTopic.StreamWriteMessage.WriteResponse.WriteAck ack) { - logger.debug("[{}] Received WriteAck with seqNo {} and status {}", streamId, ack.getSeqNo(), + logger.debug("[{}] Received WriteAck with seqNo {} and status {}", debugId, ack.getSeqNo(), ack.getMessageWriteStatusCase()); if (ack.hasSkipped()) { return new WriteAck(ack.getSeqNo(), WriteAck.State.ALREADY_WRITTEN, null, statistics); @@ -146,39 +111,75 @@ WriteAck mapAck(WriteAck.Statistics statistics, YdbTopic.StreamWriteMessage.Writ return new WriteAck(ack.getSeqNo(), null, null, statistics); } - private void closeDueToError(Status status, Throwable th) { - finishStatus = status != null ? status : Status.of(StatusCode.CLIENT_INTERNAL_ERROR, th); - logger.info("[{}] Session {} closeDueToError called", streamId, sessionId); - if (shutdown()) { - // Signal writer to retry - writer.onSessionClosed(status, th); + @Override + public void onRetry(Status status) { + logger.warn("[{}] Session onRetry with status {} called", debugId, status); + listener.onStop(status); + if (errorsHandler != null) { + errorsHandler.accept(status, null); } } @Override - protected void onStop() { - logger.debug("[{}] Session {} onStop called", streamId, sessionId); + public void onClose(Status status) { + logger.debug("[{}] Session onClose with status {} called", debugId, status); + listener.onClose(status); + if (errorsHandler != null && !status.isSuccess()) { + errorsHandler.accept(status, null); + } } - private static YdbTopic.StreamWriteMessage.InitRequest buildInitRequest(WriterSettings settings) { - YdbTopic.StreamWriteMessage.InitRequest.Builder req = YdbTopic.StreamWriteMessage.InitRequest - .newBuilder() - .setPath(settings.getTopicPath()); - String producerId = settings.getProducerId(); - if (producerId != null) { - req.setProducerId(producerId); + @Override + public void onNext(YdbTopic.StreamWriteMessage.FromServer message) { + if (message.hasInitResponse()) { + onInitResponse(message.getInitResponse()); + } else if (message.hasWriteResponse()) { + onWriteResponse(message.getWriteResponse()); + } else if (message.hasUpdateTokenResponse()) { + logger.debug("[{}] got update token response", debugId); + } else { + logger.warn("[{}] got unknown type message", debugId); } - String messageGroupId = settings.getMessageGroupId(); - Long partitionId = settings.getPartitionId(); - if (messageGroupId != null) { - if (partitionId != null) { - throw new IllegalArgumentException("Both MessageGroupId and PartitionId are set in WriterSettings"); + } + + private class StreamFactory { + private final String topicPath; + private final TopicRpc rpc; + private final YdbTopic.StreamWriteMessage.InitRequest initRequest; + + StreamFactory(TopicRpc rpc, WriterSettings settings) { + this.rpc = rpc; + this.topicPath = settings.getTopicPath(); + + YdbTopic.StreamWriteMessage.InitRequest.Builder req = YdbTopic.StreamWriteMessage.InitRequest + .newBuilder() + .setPath(topicPath); + String producerId = settings.getProducerId(); + if (producerId != null) { + req.setProducerId(producerId); + } + String messageGroupId = settings.getMessageGroupId(); + Long partitionId = settings.getPartitionId(); + if (messageGroupId != null) { + if (partitionId != null) { + throw new IllegalArgumentException("Both MessageGroupId and PartitionId are set in WriterSettings"); + } + req.setMessageGroupId(messageGroupId); + } else if (partitionId != null) { + req.setPartitionId(partitionId); } - req.setMessageGroupId(messageGroupId); - } else if (partitionId != null) { - req.setPartitionId(partitionId); + + this.initRequest = req.build(); + } + + public WriteStream createNewStream(String id) { + return new WriteStream(id, rpc); } - return req.build(); + public YdbTopic.StreamWriteMessage.FromClient initRequest() { + return YdbTopic.StreamWriteMessage.FromClient.newBuilder() + .setInitRequest(initRequest) + .build(); + } } } diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/WriteStream.java b/topic/src/main/java/tech/ydb/topic/write/impl/WriteStream.java new file mode 100644 index 000000000..1f431e3a4 --- /dev/null +++ b/topic/src/main/java/tech/ydb/topic/write/impl/WriteStream.java @@ -0,0 +1,37 @@ +package tech.ydb.topic.write.impl; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import tech.ydb.core.Issue; +import tech.ydb.core.Status; +import tech.ydb.core.StatusCode; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromClient; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromServer; +import tech.ydb.proto.topic.YdbTopic.UpdateTokenRequest; +import tech.ydb.topic.TopicRpc; +import tech.ydb.topic.impl.TopicStream; + +/** + * + * @author Aleksandr Gorshenin + */ +public class WriteStream extends TopicStream { + private static final Logger logger = LoggerFactory.getLogger(WriteStream.class); + + public WriteStream(String id, TopicRpc rpc) { + super(logger, id, rpc.writeSession(id)); + } + + @Override + protected FromClient updateTokenMessage(String token) { + return FromClient.newBuilder().setUpdateTokenRequest( + UpdateTokenRequest.newBuilder().setToken(token).build() + ).build(); + } + + @Override + protected Status parseMessageStatus(FromServer message) { + return Status.of(StatusCode.fromProto(message.getStatus()), Issue.fromPb(message.getIssuesList())); + } +} diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java b/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java index 85cf86012..255667313 100644 --- a/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java +++ b/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java @@ -5,8 +5,7 @@ import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicBoolean; import javax.annotation.Nonnull; @@ -14,9 +13,12 @@ import org.slf4j.LoggerFactory; import tech.ydb.common.transaction.YdbTransaction; +import tech.ydb.core.Issue; +import tech.ydb.core.Status; +import tech.ydb.core.UnexpectedResultException; import tech.ydb.topic.TopicRpc; import tech.ydb.topic.description.CodecRegistry; -import tech.ydb.topic.impl.GrpcStreamRetrier; +import tech.ydb.topic.impl.DebugTools; import tech.ydb.topic.impl.SerialRunnable; import tech.ydb.topic.settings.SendSettings; import tech.ydb.topic.settings.WriterSettings; @@ -28,58 +30,57 @@ /** * @author Nikolay Perfilov */ -public class WriterImpl extends GrpcStreamRetrier { +public class WriterImpl { private static final Logger logger = LoggerFactory.getLogger(WriterImpl.class); + private final String debugId; private final WriterQueue writeQueue; - private final WriteSessionFactory sessionFactory; - private final SerialRunnable sendTask = new SerialRunnable(new SendTaskImpl()); + private final WriteSession stream; + private final Runnable sendTask = new SerialRunnable(new SendTask()); - private final AtomicReference> initResultFutureRef = new AtomicReference<>(null); + private final CompletableFuture initFuture = new CompletableFuture<>(); + private final CompletableFuture shutdownFuture = new CompletableFuture<>(); + private final AtomicBoolean isClosed = new AtomicBoolean(false); + + private volatile boolean isReady = false; - private volatile WriteSession session = null; private Boolean isSeqNoProvided = null; public WriterImpl(TopicRpc topicRpc, WriterSettings settings, Executor compressionExecutor, @Nonnull CodecRegistry codecRegistry) { - super(settings.getLogPrefix(), topicRpc.getScheduler(), settings.getErrorsHandler()); - - this.writeQueue = new WriterQueue(id, settings, codecRegistry, compressionExecutor, sendTask); - this.sessionFactory = new WriteSessionFactory(topicRpc, settings); - - String message = "Writer" + - " (generated id " + id + ")" + - " created for topic \"" + settings.getTopicPath() + "\"" + - " with producerId \"" + settings.getProducerId() + "\"" + - " and messageGroupId \"" + settings.getMessageGroupId() + "\""; - logger.info(message); - } + this.debugId = DebugTools.createDebugId(settings.getLogPrefix()); + this.stream = new WriteSession(debugId, topicRpc, settings, new ListenerImpl()); + this.writeQueue = new WriterQueue(debugId, settings, codecRegistry, compressionExecutor, sendTask); - @Override - protected Logger getLogger() { - return logger; - } - - @Override - protected String getStreamName() { - return "Writer"; + logger.info("Writer with id {} created for topic \"{}\" with producerId \"{}\" and messageGroupId \"{}\"", + debugId, settings.getTopicPath(), settings.getProducerId(), settings.getMessageGroupId()); } public CompletableFuture init() { - logger.info("[{}] initImpl called", id); - if (initResultFutureRef.compareAndSet(null, new CompletableFuture<>())) { - session = sessionFactory.createNextSession(); - session.startAndInitialize(); - } else { - logger.warn("[{}] Init is called on this writer more than once. Nothing is done", id); + if (isClosed.get()) { + throw new IllegalStateException("Writer is already stopped"); } - return initResultFutureRef.get(); + logger.info("[{}] start called", debugId); + stream.start(); + return initFuture; } public CompletableFuture shutdown() { - return shutdownImpl(""); + if (!isClosed.compareAndSet(false, true)) { + return shutdownFuture; + } + + if (!stream.close()) { + // implicit closing because stream will never call onClose + Status status = Status.SUCCESS.withIssues(Issue.of("Closed by client", Issue.Severity.INFO)); + initFuture.completeExceptionally(new UnexpectedResultException("Cannot init write session", status)); + shutdownFuture.complete(null); + writeQueue.close(status); + } + + return shutdownFuture; } public CompletableFuture flush() { @@ -87,17 +88,17 @@ public CompletableFuture flush() { } private Message validate(Message message) { - if (isStopped.get()) { - throw new RuntimeException("Writer is already stopped"); + if (isClosed.get()) { + throw new IllegalStateException("Writer is already stopped"); } if (isSeqNoProvided != null) { if (message.getSeqNo() != null && !isSeqNoProvided) { - throw new RuntimeException( + throw new IllegalArgumentException( "SeqNo was provided for a message after it had not been provided for another message. " + "SeqNo should either be provided for all messages or none of them."); } if (message.getSeqNo() == null && isSeqNoProvided) { - throw new RuntimeException( + throw new IllegalArgumentException( "SeqNo was not provided for a message after it had been provided for another message. " + "SeqNo should either be provided for all messages or none of them."); } @@ -126,64 +127,46 @@ protected CompletableFuture nonblockingSend(Message msg, SendSettings return writeQueue.tryEnqueue(validate(msg), getTx(settings)); } - @Override - protected void onStreamReconnect() { - session = sessionFactory.createNextSession(); - session.startAndInitialize(); - } - - @Override - protected void onShutdown(String reason) { - if (session != null) { - session.shutdown(); - } - if (initResultFutureRef.get() != null && !initResultFutureRef.get().isDone()) { - initResultFutureRef.get().completeExceptionally(new RuntimeException(reason)); + private class ListenerImpl implements WriteSession.Listener { + @Override + public void onStart(long lastSeqNo, String sessionId) { + // resend all sent messages in writing queue + Iterator resend = writeQueue.updateSeqNo(lastSeqNo); + stream.sendAll(() -> resend.hasNext() ? resend.next() : null); + isReady = true; + initFuture.complete(new InitResult(lastSeqNo)); + sendTask.run(); } - } - void onInit(long lastSeqNo) { - reconnectCounter.set(0); - Iterator resend = writeQueue.updateSeqNo(lastSeqNo); - session.sendAll(() -> resend.hasNext() ? resend.next() : null); - } + @Override + public void onStop(Status status) { + isReady = false; + } - void onStart(long lastSeqNo) { - if (initResultFutureRef.get() != null) { - initResultFutureRef.get().complete(new InitResult(lastSeqNo)); + @Override + public void onAck(WriteAck ack) { + writeQueue.confirmAck(ack); } - sendTask.run(); - } - void onAck(WriteAck ack) { - writeQueue.confirmAck(ack); + @Override + public void onClose(Status status) { + isClosed.set(true); + isReady = false; + initFuture.completeExceptionally(new UnexpectedResultException("Cannot init write session", status)); + shutdownFuture.complete(null); + writeQueue.close(status); + } } - private class SendTaskImpl implements Runnable { + private class SendTask implements Runnable { @Override public void run() { - if (session == null || !session.isStarted()) { - logger.debug("[{}] Can't send data: current session is not yet initialized", id); + if (!isReady) { + logger.debug("[{}] Can't send data: current session is not ready yet", debugId); return; } - session.sendAll(writeQueue::nextMessageToSend); - } - } - - private class WriteSessionFactory { - private final TopicRpc rpc; - private final WriterSettings settings; - private final AtomicLong sessionCounter = new AtomicLong(0); - - WriteSessionFactory(TopicRpc rpc, WriterSettings settings) { - this.rpc = rpc; - this.settings = settings; - } - - public WriteSession createNextSession() { - String streamID = id + '.' + sessionCounter.incrementAndGet(); - return new WriteSession(WriterImpl.this, rpc, streamID, settings); + stream.sendAll(writeQueue::nextMessageToSend); } } } diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/WriterQueue.java b/topic/src/main/java/tech/ydb/topic/write/impl/WriterQueue.java index 7c2b90459..f7e42c71f 100644 --- a/topic/src/main/java/tech/ydb/topic/write/impl/WriterQueue.java +++ b/topic/src/main/java/tech/ydb/topic/write/impl/WriterQueue.java @@ -18,6 +18,7 @@ import org.slf4j.LoggerFactory; import tech.ydb.common.transaction.YdbTransaction; +import tech.ydb.core.Status; import tech.ydb.topic.description.Codec; import tech.ydb.topic.description.CodecRegistry; import tech.ydb.topic.settings.WriterSettings; @@ -32,7 +33,7 @@ public class WriterQueue { private static final Logger logger = LoggerFactory.getLogger(WriterImpl.class); - private final String id; + private final String debugId; private final BufferManager buffer; private final Codec codec; private final Executor compressionExecutor; @@ -48,10 +49,10 @@ public class WriterQueue { // Future for flush method private volatile EnqueuedMessage lastAcceptedMessage = null; - public WriterQueue(String id, WriterSettings settings, CodecRegistry codecRegistry, Executor compressionExecutor, - Runnable readyNotify) { - this.id = id; - this.buffer = new BufferManager(id, settings); + public WriterQueue(String debugId, WriterSettings settings, CodecRegistry codecRegistry, + Executor compressionExecutor, Runnable readyNotify) { + this.debugId = debugId; + this.buffer = new BufferManager(debugId, settings); this.codec = codecRegistry.getCodec(settings.getCodec()); if (codec == null) { @@ -96,7 +97,7 @@ SentMessage nextMessageToSend() { if (userSeqNo != null) { if (userSeqNo < seqNo) { buffer.releaseMessage(next.getBufferSize()); - String error = "[" + id + "] Message wasn't sent because seqNo " + userSeqNo + String error = "[" + debugId + "] Message wasn't sent because seqNo " + userSeqNo + " is less than current seqNo " + seqNo; logger.warn(error); next.getAckFuture().completeExceptionally(new IllegalArgumentException(error)); @@ -107,7 +108,7 @@ SentMessage nextMessageToSend() { } SentMessage sentMsg = new SentMessage(next, seqNo); - logger.trace("[{}] prepare sent message with seqNo {}", id, seqNo); + logger.debug("[{}] prepare sent message with seqNo {}", debugId, seqNo); sent.offer(sentMsg); return sentMsg; } @@ -128,11 +129,38 @@ void confirmAck(WriteAck ack) { } } + void close(Status status) { + buffer.close(status); + + while (!queue.isEmpty()) { + RuntimeException ex = new RuntimeException("Message sending was cancelled with status " + status); + Iterator it = queue.iterator(); + while (it.hasNext()) { + EnqueuedMessage next = it.next(); + next.setError(ex); + next.getAckFuture().completeExceptionally(ex); + it.remove(); + } + } + + while (!sent.isEmpty()) { + RuntimeException ex = new RuntimeException("Message had been sent but the writer was stopped with status " + + status); + Iterator it = sent.iterator(); + while (it.hasNext()) { + it.next().getAckFuture().completeExceptionally(ex); + it.remove(); + } + } + } + Iterator updateSeqNo(long newSeqNo) { - lastSeqNo.set(newSeqNo); + if (newSeqNo > lastSeqNo.get()) { + lastSeqNo.set(newSeqNo); + } - WriteAck lostAck = new WriteAck(newSeqNo, WriteAck.State.ALREADY_WRITTEN, null, null); // complete all messages with lost acks + WriteAck lostAck = new WriteAck(newSeqNo, WriteAck.State.ALREADY_WRITTEN, null, null); Iterator it = sent.iterator(); while (it.hasNext()) { SentMessage msg = it.next(); @@ -185,7 +213,7 @@ private CompletableFuture accept(Message message, YdbTransaction tx, l try { compressionExecutor.execute(() -> encode(message.getData(), msgSize, msg)); } catch (Throwable ex) { - logger.warn("[{}] Message wasn't sent because of processing error", id, ex); + logger.warn("[{}] Message wasn't sent because of processing error", debugId, ex); msg.setError(ex); readyNotify.run(); } @@ -194,13 +222,17 @@ private CompletableFuture accept(Message message, YdbTransaction tx, l } private void encode(byte[] data, long msgSize, EnqueuedMessage msg) { - logger.trace("[{}] Started encoding message", id); + if (msg.hasProblem()) { + return; + } + + logger.trace("[{}] Started encoding message", debugId); try (ByteString.Output encoded = ByteString.newOutput()) { try (OutputStream os = codec.encode(encoded)) { os.write(data, 0, data.length); } - logger.trace("[{}] Message compressed from {} to {} bytes", id, msgSize, encoded.size()); + logger.trace("[{}] Message compressed from {} to {} bytes", debugId, msgSize, encoded.size()); long bufferSize = msgSize; if (msgSize > encoded.size()) { // if compressed lenght is less than uncompression - update buffer size @@ -210,13 +242,9 @@ private void encode(byte[] data, long msgSize, EnqueuedMessage msg) { msg.setData(encoded.toByteString(), bufferSize); } catch (Throwable ex) { - logger.warn("[{}] Message wasn't sent because of encoding error", id, ex); + logger.warn("[{}] Message wasn't sent because of encoding error", debugId, ex); msg.setError(ex); } readyNotify.run(); } - - boolean hasMore() { - return queue.peek() != null && queue.peek().isReady(); - } } diff --git a/topic/src/test/java/tech/ydb/topic/FailableWriterInterceptor.java b/topic/src/test/java/tech/ydb/topic/FailableWriterInterceptor.java new file mode 100644 index 000000000..9acf78603 --- /dev/null +++ b/topic/src/test/java/tech/ydb/topic/FailableWriterInterceptor.java @@ -0,0 +1,218 @@ +package tech.ydb.topic; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.NavigableMap; +import java.util.TreeMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; + +import tech.ydb.proto.StatusCodesProtos; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromClient; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromServer; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.WriteRequest; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.WriteResponse; + + +/** + * + * @author Aleksandr Gorshenin + */ +public class FailableWriterInterceptor implements Consumer>, ClientInterceptor { + private final AtomicInteger initCounter = new AtomicInteger(); + + private final Map initErrors = new HashMap<>(); + private final TreeMap ackErrors = new TreeMap<>(); + private final TreeMap sendErrors = new TreeMap<>(); + + public void reset() { + initErrors.clear(); + ackErrors.clear(); + sendErrors.clear(); + initCounter.set(0); + } + + @Override + public void accept(ManagedChannelBuilder t) { + t.intercept(this); + } + + public void unavailableOnInit(int number) { + initErrors.put(number, closeStream(Status.UNAVAILABLE)); + } + + public void badRequestOnInit(int number) { + initErrors.put(number, sendError(StatusCodesProtos.StatusIds.StatusCode.BAD_REQUEST)); + } + + public void unavailableOnAckWithSeqNo(long seqNo) { + ackErrors.put(seqNo, closeStream(Status.UNAVAILABLE)); + } + + public void badRequestOnAckWithSeqNo(long seqNo) { + ackErrors.put(seqNo, sendError(StatusCodesProtos.StatusIds.StatusCode.BAD_REQUEST)); + } + + public void unavailableOnSendMsgWithSeqNo(long seqNo) { + sendErrors.put(seqNo, closeStream(Status.UNAVAILABLE)); + } + + public void badSessionOnSendMsgWithSeqNo(long seqNo) { + sendErrors.put(seqNo, sendError(StatusCodesProtos.StatusIds.StatusCode.BAD_SESSION)); + } + + + @Override + public ClientCall interceptCall(MethodDescriptor method, CallOptions callOptions, Channel next) { + return new ProxyCall<>(next.newCall(method, callOptions)); + } + + interface Error { + boolean fail(ClientCall.Listener listener); + } + + private class ProxyCall extends ClientCall { + + private final ClientCall realCall; + private volatile ProxyListener proxyListener; + private volatile boolean isClosed = false; + + ProxyCall(ClientCall delegate) { + this.realCall = delegate; + } + + @Override + public void start(Listener listener, Metadata headers) { + proxyListener = new ProxyListener<>(listener); + realCall.start(proxyListener, headers); + } + + @Override + public void request(int numMessages) { + realCall.request(numMessages); + } + + @Override + public void cancel(String message, Throwable cause) { + realCall.cancel(message, cause); + } + + @Override + public void halfClose() { + realCall.halfClose(); + } + + @Override + @SuppressWarnings("unchecked") + public void sendMessage(W message) { + if (isClosed) { + return; + } + + Error error = null; + if (message instanceof FromClient) { + FromClient msg = (FromClient) message; + if (msg.hasWriteRequest()) { + List list = msg.getWriteRequest().getMessagesList(); + long seqNo = list.get(list.size() - 1).getSeqNo(); + NavigableMap errors = sendErrors.headMap(seqNo, true); + if (errors.lastEntry() != null) { + error = errors.lastEntry().getValue(); + } + errors.clear(); + } + } + + if (error == null) { + realCall.sendMessage(message); + return; + } + + isClosed = error.fail((Listener) proxyListener); + if (isClosed) { + realCall.halfClose(); + } + } + + private class ProxyListener extends Listener { + private final Listener realListener; + + ProxyListener(Listener realListener) { + this.realListener = realListener; + } + + @Override + public void onClose(Status status, Metadata trailers) { + if (!isClosed) { + realListener.onClose(status, trailers); + } + } + + @Override + public void onHeaders(Metadata headers) { + if (!isClosed) { + realListener.onHeaders(headers); + } + } + + @Override + @SuppressWarnings("unchecked") + public void onMessage(R message) { + if (isClosed) { + return; + } + + Error error = null; + if (message instanceof FromServer) { + FromServer msg = (FromServer) message; + if (msg.hasInitResponse()) { + error = initErrors.get(initCounter.incrementAndGet()); + } + if (msg.hasWriteResponse()) { + List acks = msg.getWriteResponse().getAcksList(); + long lastAck = acks.get(acks.size() - 1).getSeqNo(); + NavigableMap errors = ackErrors.headMap(lastAck, true); + if (errors.lastEntry() != null) { + error = errors.lastEntry().getValue(); + } + errors.clear(); + } + } + if (error == null) { + realListener.onMessage(message); + return; + } + + isClosed = error.fail((Listener) realListener); + if (isClosed) { + realCall.halfClose(); + } + } + } + } + + private static Error closeStream(Status grpcStatus) { + return (ClientCall.Listener listener) -> { + listener.onClose(grpcStatus, new Metadata()); + return true; + }; + } + + private static Error sendError(StatusCodesProtos.StatusIds.StatusCode ydbStatus) { + return (ClientCall.Listener listener) -> { + listener.onMessage(FromServer.newBuilder().setStatus(ydbStatus).build()); + return false; + }; + } +} + diff --git a/topic/src/test/java/tech/ydb/topic/TopicWritersIntegrationTest.java b/topic/src/test/java/tech/ydb/topic/TopicWritersIntegrationTest.java new file mode 100644 index 000000000..a0a931699 --- /dev/null +++ b/topic/src/test/java/tech/ydb/topic/TopicWritersIntegrationTest.java @@ -0,0 +1,259 @@ +package tech.ydb.topic; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import tech.ydb.core.StatusCode; +import tech.ydb.core.utils.FutureTools; +import tech.ydb.test.junit4.GrpcTransportRule; +import tech.ydb.topic.description.Consumer; +import tech.ydb.topic.read.DeferredCommitter; +import tech.ydb.topic.read.SyncReader; +import tech.ydb.topic.settings.CreateTopicSettings; +import tech.ydb.topic.settings.ReaderSettings; +import tech.ydb.topic.settings.TopicReadSettings; +import tech.ydb.topic.settings.WriterSettings; +import tech.ydb.topic.write.AsyncWriter; +import tech.ydb.topic.write.Message; +import tech.ydb.topic.write.QueueOverflowException; +import tech.ydb.topic.write.SyncWriter; +import tech.ydb.topic.write.WriteAck; + +/** + * + * @author Aleksandr Gorshenin + */ +public class TopicWritersIntegrationTest { + private final static Logger logger = LoggerFactory.getLogger(TopicWritersIntegrationTest.class); + + private final static FailableWriterInterceptor PROXY = new FailableWriterInterceptor(); + + @ClassRule + public final static GrpcTransportRule ydbTransport = new GrpcTransportRule() + .withGrpcTransportCustomizer(b -> b.addChannelInitializer(PROXY)); + + private final static String TEST_TOPIC = "topic_writers_test"; + + private final static String TEST_PRODUCER1 = "producer"; + private final static String TEST_CONSUMER1 = "consumer"; + + private static TopicClient client; + + @BeforeClass + public static void initClient() { + client = TopicClient.newClient(ydbTransport).build(); + } + + @AfterClass + public static void closeClient() { + client.close(); + } + + @Before + public void initTopic() { + PROXY.reset(); + + logger.info("Create test topic {} ...", TEST_TOPIC); + client.createTopic(TEST_TOPIC, CreateTopicSettings.newBuilder() + .addConsumer(Consumer.newBuilder().setName(TEST_CONSUMER1).build()) + .build()) + .join().expectSuccess("can't create a new topic"); + } + + @After + public void dropTable() { + logger.info("Drop test topic {} ...", TEST_TOPIC); + client.dropTopic(TEST_TOPIC).join(); + } + + private void assertTopicContent(List messages) { + try { + SyncReader reader = client.createSyncReader(ReaderSettings.newBuilder().addTopic( + TopicReadSettings.newBuilder().setPath(TEST_TOPIC).build() + ).setConsumerName(TEST_CONSUMER1).build()); + + reader.initAndWait(); + int idx = 0; + DeferredCommitter committer = DeferredCommitter.newInstance(); + for (byte[] expected: messages) { + tech.ydb.topic.read.Message next = reader.receive(1, TimeUnit.SECONDS); + Assert.assertNotNull("Expected message " + idx, next); + Assert.assertArrayEquals("Unexpected content for message " + idx, expected, next.getData()); + idx++; + + committer.add(next); + } + + committer.commit(); + reader.shutdown(); + } catch (InterruptedException ex) { + throw new AssertionError("Unexpected exception", ex); + } + } + + @Test + public void messageBufferOverflowTest() throws Exception { + WriterSettings settings = WriterSettings.newBuilder() + .setTopicPath(TEST_TOPIC) + .setProducerId(TEST_PRODUCER1) + .setMaxSendBufferMemorySize(1000) + .build(); + + SyncWriter writer = client.createSyncWriter(settings); + writer.initAndWait(); + + byte[] msg1 = new byte[1000]; + byte[] msg2 = new byte[1001]; + Arrays.fill(msg1, (byte) 0x10); + Arrays.fill(msg2, (byte) 0x11); + + writer.send(Message.of(msg1)); + writer.send(Message.of(msg1)); + writer.send(Message.of(msg1)); + writer.send(Message.of(msg2)); // this message is more than buffer limit + writer.send(Message.of(msg1)); + writer.send(Message.of(msg2)); // this message is more than buffer limit + writer.send(Message.of(msg2)); // this message is more than buffer limit + writer.send(Message.of(msg2)); // this message is more than buffer limit + writer.send(Message.of(msg1)); + writer.send(Message.of(msg1)); + + writer.flush(); + writer.shutdown(10, TimeUnit.SECONDS); + + assertTopicContent(Arrays.asList(msg1, msg1, msg1, msg2, msg1, msg2, msg2, msg2, msg1, msg1)); + } + + @Test + public void lazyInitTest() throws Exception { + WriterSettings settings = WriterSettings.newBuilder() + .setTopicPath(TEST_TOPIC) + .setProducerId(TEST_PRODUCER1) + .build(); + + AsyncWriter writer = client.createAsyncWriter(settings); + + CountDownLatch latch = new CountDownLatch(1); + List written = new ArrayList<>(); + CompletableFuture lastMessage = CompletableFuture.supplyAsync(() -> { + ThreadLocalRandom rnd = ThreadLocalRandom.current(); + try { + CompletableFuture ack = FutureTools.failedFuture(new RuntimeException("not started")); + for (int idx = 0; idx < 100; idx++) { + byte[] msg = new byte[1000]; + rnd.nextBytes(msg); + ack = writer.send(Message.of(msg)); + written.add(msg); + } + latch.countDown(); + return ack.join(); + } catch (QueueOverflowException ex) { + latch.countDown(); + throw new RuntimeException(ex); + } + }); + + latch.await(10, TimeUnit.SECONDS); + writer.init(); + + WriteAck ack = lastMessage.get(10, TimeUnit.SECONDS); + Assert.assertEquals(WriteAck.State.WRITTEN, ack.getState()); + + writer.shutdown().join(); + + assertTopicContent(written); + } + + @Test + public void doubleInitTest() throws Exception { + WriterSettings settings = WriterSettings.newBuilder() + .setTopicPath(TEST_TOPIC) + .setProducerId(TEST_PRODUCER1) + .build(); + + AsyncWriter writer = client.createAsyncWriter(settings); + + writer.init(); + writer.init(); + + byte[] msg = "hello".getBytes(); + writer.send(Message.of(msg)).join(); + + writer.shutdown().join(); + + assertTopicContent(Collections.singletonList(msg)); + } + + @Test + public void defaultRetryPolicyWriter() throws Exception { + // errors pattern in order of processing + PROXY.unavailableOnAckWithSeqNo(15); + PROXY.badRequestOnInit(2); + PROXY.badSessionOnSendMsgWithSeqNo(35); + PROXY.unavailableOnInit(4); + PROXY.unavailableOnInit(5); + PROXY.unavailableOnInit(6); + PROXY.badRequestOnAckWithSeqNo(60); + PROXY.unavailableOnAckWithSeqNo(90); + + List expectedErrors = Arrays.asList( + StatusCode.TRANSPORT_UNAVAILABLE, + StatusCode.BAD_REQUEST, + StatusCode.BAD_SESSION, + StatusCode.TRANSPORT_UNAVAILABLE, + StatusCode.TRANSPORT_UNAVAILABLE, + StatusCode.TRANSPORT_UNAVAILABLE, + StatusCode.BAD_REQUEST, + StatusCode.TRANSPORT_UNAVAILABLE + ); + + List realErrors = new ArrayList<>(); + WriterSettings settings = WriterSettings.newBuilder() + .setTopicPath(TEST_TOPIC) + .setProducerId(TEST_PRODUCER1) + .setErrorsHandler((st, th) -> { + if (st != null) { + realErrors.add(st.getCode()); + } + if (th != null) { + realErrors.add(StatusCode.CLIENT_INTERNAL_ERROR); + } + }) + .build(); + + SyncWriter writer = client.createSyncWriter(settings); + writer.initAndWait(); + + List written = new ArrayList<>(); + for (int batch = 0; batch < 10; batch++) { + for (int idx = 0; idx < 10; idx++) { + byte[] msg = new byte[1000]; + Arrays.fill(msg, (byte) (batch * 10 + idx)); + writer.send(Message.of(msg), 1, TimeUnit.MINUTES); + written.add(msg); + } + writer.flush(); + } + + writer.shutdown(10, TimeUnit.SECONDS); + + Assert.assertEquals(expectedErrors, realErrors); + assertTopicContent(written); + } +} diff --git a/topic/src/test/java/tech/ydb/topic/impl/TopicRetryableStreamTest.java b/topic/src/test/java/tech/ydb/topic/impl/TopicRetryableStreamTest.java index 2ef55a3b3..36cbacf72 100644 --- a/topic/src/test/java/tech/ydb/topic/impl/TopicRetryableStreamTest.java +++ b/topic/src/test/java/tech/ydb/topic/impl/TopicRetryableStreamTest.java @@ -119,7 +119,7 @@ public void simpleStartAndCloseTest() { Mockito.verify(h.grpc).start(Mockito.any()); Mockito.verify(h.grpc, Mockito.times(2)).sendNext(EMPTY); // init + sent request - retryable.close(); + Assert.assertTrue(retryable.close()); h.complete(Status.SUCCESS); @@ -140,7 +140,7 @@ public void doubleStartTest() { Mockito.verify(h1.grpc).start(Mockito.any()); Mockito.verify(h2.grpc, Mockito.never()).start(Mockito.any()); // h2 was never started - Mockito.verify(h2.grpc).close(); // h2 was closed immediately + Mockito.verify(h2.grpc, Mockito.never()).close(); // h2 was never closed } @Test @@ -150,8 +150,8 @@ public void doubleCloseTest() { retryable.start(); - retryable.close(); - retryable.close(); + Assert.assertTrue(retryable.close()); + Assert.assertFalse(retryable.close()); Mockito.verify(h1.grpc).start(Mockito.any()); Mockito.verify(h1.grpc).close(); @@ -160,7 +160,7 @@ public void doubleCloseTest() { @Test public void startAfterCloseTest() { TestStream retryable = new TestStream(Arrays.asList(), RetryConfig.noRetries(), mockScheduler()); - retryable.close(); + Assert.assertFalse(retryable.close()); retryable.start(); // nothing } @@ -179,7 +179,7 @@ public void closeBeforeStartIsNoOpTest() { StreamHandle h = new StreamHandle(); TestStream retryable = new TestStream(Arrays.asList(h), RetryConfig.noRetries(), mockScheduler()); - retryable.close(); // no stream yet, should not throw + Assert.assertFalse(retryable.close()); // no stream yet, should not throw Mockito.verify(h.grpc, Mockito.never()).close(); } @@ -241,7 +241,7 @@ public void immediateRetryTest() { retryable.send(EMPTY); h3.complete(s3); - retryable.close(); // no effect + Assert.assertFalse(retryable.close()); // no effect Mockito.verify(h1.grpc, Mockito.times(2)).sendNext(EMPTY); // init req + send Mockito.verify(h1.grpc, Mockito.never()).close(); // stream was closed by error diff --git a/topic/src/test/java/tech/ydb/topic/impl/TopicWritersIntegrationTest.java b/topic/src/test/java/tech/ydb/topic/impl/TopicWritersIntegrationTest.java deleted file mode 100644 index cddf7fb05..000000000 --- a/topic/src/test/java/tech/ydb/topic/impl/TopicWritersIntegrationTest.java +++ /dev/null @@ -1,125 +0,0 @@ -package tech.ydb.topic.impl; - -import java.util.Arrays; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ThreadLocalRandom; -import java.util.concurrent.TimeUnit; - -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.ClassRule; -import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import tech.ydb.core.Status; -import tech.ydb.core.utils.FutureTools; -import tech.ydb.test.junit4.GrpcTransportRule; -import tech.ydb.topic.TopicClient; -import tech.ydb.topic.settings.CreateTopicSettings; -import tech.ydb.topic.settings.WriterSettings; -import tech.ydb.topic.write.AsyncWriter; -import tech.ydb.topic.write.Message; -import tech.ydb.topic.write.QueueOverflowException; -import tech.ydb.topic.write.SyncWriter; -import tech.ydb.topic.write.WriteAck; - -/** - * - * @author Aleksandr Gorshenin - */ -public class TopicWritersIntegrationTest { - private final static Logger logger = LoggerFactory.getLogger(TopicWritersIntegrationTest.class); - - @ClassRule - public final static GrpcTransportRule ydbTransport = new GrpcTransportRule(); - - private final static String TEST_TOPIC = "topic_writers_test"; - - private final static String TEST_PRODUCER1 = "producer"; - - private static TopicClient client; - - @BeforeClass - public static void initTopic() { - logger.info("Create test table {} ...", TEST_TOPIC); - - client = TopicClient.newClient(ydbTransport).build(); - client.createTopic(TEST_TOPIC, CreateTopicSettings.newBuilder().build()) - .join().expectSuccess("can't create a new topic"); - } - - @AfterClass - public static void dropTopic() { - logger.info("Drop test topic {} ...", TEST_TOPIC); - Status dropStatus = client.dropTopic(TEST_TOPIC).join(); - client.close(); - dropStatus.expectSuccess("can't drop test topic"); - } - - @Test - public void messageBufferOverflowTest() throws Exception { - WriterSettings settings = WriterSettings.newBuilder() - .setTopicPath(TEST_TOPIC) - .setProducerId(TEST_PRODUCER1) - .setMaxSendBufferMemorySize(1000) - .build(); - - SyncWriter writer = client.createSyncWriter(settings); - writer.initAndWait(); - - byte[] msg1 = new byte[1000]; - byte[] msg2 = new byte[1001]; - Arrays.fill(msg1, (byte) 0x10); - Arrays.fill(msg2, (byte) 0x11); - - writer.send(Message.of(msg1)); - writer.send(Message.of(msg1)); - writer.send(Message.of(msg1)); - writer.send(Message.of(msg2)); // this message is more than buffer limit - writer.send(Message.of(msg1)); - writer.send(Message.of(msg1)); - writer.send(Message.of(msg1)); - - writer.flush(); - writer.shutdown(10, TimeUnit.SECONDS); - } - - @Test - public void lazyInitTest() throws Exception { - WriterSettings settings = WriterSettings.newBuilder() - .setTopicPath(TEST_TOPIC) - .setProducerId(TEST_PRODUCER1) - .build(); - - AsyncWriter writer = client.createAsyncWriter(settings); - - CountDownLatch latch = new CountDownLatch(1); - CompletableFuture lastMessage = CompletableFuture.supplyAsync(() -> { - ThreadLocalRandom rnd = ThreadLocalRandom.current(); - try { - CompletableFuture ack = FutureTools.failedFuture(new RuntimeException("not started")); - for (int idx = 0; idx < 100; idx++) { - byte[] msg = new byte[1000]; - rnd.nextBytes(msg); - ack = writer.send(Message.of(msg)); - } - latch.countDown(); - return ack.join(); - } catch (QueueOverflowException ex) { - latch.countDown(); - throw new RuntimeException(ex); - } - }); - - latch.await(10, TimeUnit.SECONDS); - writer.init(); - - WriteAck ack = lastMessage.get(10, TimeUnit.SECONDS); - Assert.assertEquals(WriteAck.State.WRITTEN, ack.getState()); - - writer.shutdown().join(); - } -} diff --git a/topic/src/test/java/tech/ydb/topic/write/impl/BufferManagerTest.java b/topic/src/test/java/tech/ydb/topic/write/impl/BufferManagerTest.java index 580624674..af76bfc1d 100644 --- a/topic/src/test/java/tech/ydb/topic/write/impl/BufferManagerTest.java +++ b/topic/src/test/java/tech/ydb/topic/write/impl/BufferManagerTest.java @@ -1,5 +1,7 @@ package tech.ydb.topic.write.impl; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -9,6 +11,8 @@ import org.junit.Test; import org.junit.function.ThrowingRunnable; +import tech.ydb.core.Status; +import tech.ydb.core.StatusCode; import tech.ydb.topic.settings.WriterSettings; import tech.ydb.topic.write.QueueOverflowException; @@ -37,6 +41,12 @@ private static void assertTimeout(String msg, ThrowingRunnable runnable) { Assert.assertEquals(msg, ex.getMessage()); } + private static void assertIllegalState(String msg, ThrowingRunnable runnable) { + IllegalStateException ex = Assert.assertThrows("Must be thrown IllegalStateException", + IllegalStateException.class, runnable); + Assert.assertEquals(msg, ex.getMessage()); + } + private static void assertInterrupted(ThrowingRunnable runnable) throws InterruptedException { // Now try to acquire more bytes in a separate thread — it will block AtomicBoolean interrupted = new AtomicBoolean(false); @@ -53,10 +63,11 @@ private static void assertInterrupted(ThrowingRunnable runnable) throws Interrup } }); t.start(); + Assert.assertTrue(started.await(1, TimeUnit.SECONDS)); while (t.isAlive()) { t.interrupt(); - t.join(100); + t.join(2000); } Assert.assertTrue(interrupted.get()); @@ -68,8 +79,6 @@ public void testGetMaxSize() { Assert.assertEquals(1024, bm.getMaxSize()); } - // --- acquire / release --- - @Test public void testAcquireAndRelease() throws Exception { BufferManager bm = manager(70, 5); @@ -211,4 +220,114 @@ public void testLargeBufferSize() throws Exception { bm.releaseMessage(1); bm.releaseMessage(1); } + + @Test + public void testClosedBuffer() { + BufferManager bm = manager(100, 3); + + bm.close(Status.SUCCESS); + + assertIllegalState("Writer was closed with status Status{code = SUCCESS}", + () -> bm.acquire(1)); + assertIllegalState("Writer was closed with status Status{code = SUCCESS}", + () -> bm.tryAcquire(1)); + assertIllegalState("Writer was closed with status Status{code = SUCCESS}", + () -> bm.tryAcquire(1, 1, TimeUnit.SECONDS)); + } + + @Test + public void testReleaseCountOnBufferClosing() throws InterruptedException, QueueOverflowException { + BufferManager bm = manager(100, 3); + + bm.acquire(10); + bm.acquire(10); + bm.acquire(10); + + CountDownLatch started = new CountDownLatch(2); + Queue problems = new ConcurrentLinkedQueue<>(); + Thread t1 = new Thread(() -> { + try { + started.countDown(); + bm.acquire(10); + } catch (InterruptedException | QueueOverflowException | RuntimeException ex) { + problems.add(ex); + } + }); + Thread t2 = new Thread(() -> { + try { + started.countDown(); + bm.tryAcquire(10, 1, TimeUnit.MINUTES); + } catch (InterruptedException | QueueOverflowException | TimeoutException | RuntimeException ex) { + problems.add(ex); + } + }); + t1.setDaemon(true); + t2.setDaemon(true); + t1.start(); + t2.start(); + + Assert.assertTrue(started.await(1, TimeUnit.SECONDS)); + + bm.close(Status.of(StatusCode.ABORTED)); + + t1.join(2000); + t2.join(2000); + + Assert.assertFalse("Thread t1 must be finished", t1.isAlive()); + Assert.assertFalse("Thread t2 must be finished", t2.isAlive()); + + Assert.assertEquals(2, problems.size()); + for (Exception ex : problems) { + Assert.assertTrue("Unexpected " + ex.getClass(), ex instanceof IllegalStateException); + Assert.assertEquals("Writer was closed with status Status{code = ABORTED(code=400040)}", ex.getMessage()); + } + } + + @Test + public void testReleaseSizeOnBufferClosing() throws InterruptedException, QueueOverflowException { + BufferManager bm = manager(70, 5); + + bm.acquire(20); + bm.acquire(20); + bm.acquire(20); + + CountDownLatch started = new CountDownLatch(2); + Queue problems = new ConcurrentLinkedQueue<>(); + Thread t1 = new Thread(() -> { + try { + started.countDown(); + bm.acquire(70); + } catch (InterruptedException | QueueOverflowException | RuntimeException ex) { + problems.add(ex); + } + }); + Thread t2 = new Thread(() -> { + try { + started.countDown(); + bm.tryAcquire(70, 1, TimeUnit.MINUTES); + } catch (InterruptedException | QueueOverflowException | TimeoutException | RuntimeException ex) { + problems.add(ex); + } + }); + t1.setDaemon(true); + t2.setDaemon(true); + t1.start(); + t2.start(); + + Assert.assertTrue(started.await(1, TimeUnit.SECONDS)); + + bm.close(Status.of(StatusCode.TIMEOUT)); + + t1.join(2000); + t2.join(2000); + + Assert.assertFalse("Thread t1 must be finished", t1.isAlive()); + Assert.assertFalse("Thread t2 must be finished", t2.isAlive()); + + Assert.assertEquals(2, problems.size()); + for (Exception ex : problems) { + Assert.assertTrue("Unexpected " + ex.getClass(), ex instanceof IllegalStateException); + Assert.assertEquals("Writer was closed with status Status{code = TIMEOUT(code=400090)}", ex.getMessage()); + } + } } diff --git a/topic/src/test/java/tech/ydb/topic/write/impl/WriterImplTest.java b/topic/src/test/java/tech/ydb/topic/write/impl/WriterImplTest.java new file mode 100644 index 000000000..f341347c7 --- /dev/null +++ b/topic/src/test/java/tech/ydb/topic/write/impl/WriterImplTest.java @@ -0,0 +1,360 @@ +package tech.ydb.topic.write.impl; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.function.ThrowingRunnable; +import org.mockito.Mockito; + +import tech.ydb.common.retry.RetryConfig; +import tech.ydb.core.Status; +import tech.ydb.core.StatusCode; +import tech.ydb.core.grpc.GrpcReadStream; +import tech.ydb.core.grpc.GrpcReadWriteStream; +import tech.ydb.proto.StatusCodesProtos; +import tech.ydb.proto.topic.YdbTopic; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromClient; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromServer; +import tech.ydb.topic.TopicRpc; +import tech.ydb.topic.description.Codec; +import tech.ydb.topic.description.CodecRegistry; +import tech.ydb.topic.settings.TopicRetryConfig; +import tech.ydb.topic.settings.WriterSettings; +import tech.ydb.topic.write.InitResult; +import tech.ydb.topic.write.Message; +import tech.ydb.topic.write.WriteAck; + +/** + * @author Aleksandr Gorshenin + */ +public class WriterImplTest { + private static final RetryConfig IMMEDIATELY_FOREVER = status -> (number, elapsed) -> 0; + + private static TopicRpc mockRpc(StreamMock first, StreamMock... rest) { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + Mockito.when(rpc.getScheduler()).thenReturn(Mockito.mock(ScheduledExecutorService.class)); + Mockito.when(rpc.writeSession(Mockito.any())).thenReturn(first, rest); + return rpc; + } + + private static WriterImpl createWriter(TopicRpc rpc) { + return createWriter(rpc, TopicRetryConfig.NEVER); + } + + private static WriterImpl createWriter(TopicRpc rpc, RetryConfig retryConfig) { + WriterSettings settings = WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("test-producer") + .setCodec(Codec.RAW) + .setRetryConfig(retryConfig) + .build(); + return new WriterImpl(rpc, settings, Runnable::run, new CodecRegistry()); + } + + private static void assertIllegalState(String msg, ThrowingRunnable runnable) { + IllegalStateException ex = Assert.assertThrows("Must be thrown IllegalStateException", + IllegalStateException.class, runnable); + Assert.assertEquals(msg, ex.getMessage()); + } + + private static void assertRuntimeException(String msg, ThrowingRunnable runnable) { + RuntimeException ex = Assert.assertThrows("Must be thrown RuntimeException", + RuntimeException.class, runnable); + Assert.assertEquals(msg, ex.getMessage()); + } + + private static ThrowingRunnable futureGet(CompletableFuture future) { + return () -> { + try { + future.get(); + } catch (ExecutionException ex) { + throw ex.getCause(); + } + }; + } + + private static final Message MSG1 = Message.of(new byte[] { 0x00, 0x01, 0x02 }); + + @Test + public void doubleInitTest() throws Exception { + StreamMock s = new StreamMock(); + WriterImpl writer = createWriter(mockRpc(s)); + + CompletableFuture m1 = writer.blockingSend(MSG1, null); + CompletableFuture m2 = writer.blockingSend(MSG1, null); + + Assert.assertFalse(m1.isDone()); + Assert.assertFalse(m2.isDone()); + + CompletableFuture initFuture = writer.init(); + + Assert.assertEquals(1, s.messages.size()); // init req + Assert.assertSame(initFuture, writer.init()); + Assert.assertEquals(1, s.messages.size()); // init req + + s.sendInitResponse(123L); + + Assert.assertEquals(2, s.messages.size()); // init req + write request + + Assert.assertFalse(m1.isDone()); + Assert.assertFalse(m2.isDone()); + + s.sendAckResponse(124, 1); + s.sendAckResponse(125, 2); + + Assert.assertEquals(124, m1.join().getSeqNo()); + Assert.assertEquals(125, m2.join().getSeqNo()); + + writer.shutdown(); + + Assert.assertNotNull(s.observer); + Assert.assertTrue(s.isClosed); + Assert.assertFalse(s.isCanceled); + } + + @Test + public void closeBeforeInitTest() throws Exception { + StreamMock s = new StreamMock(); + WriterImpl writer = createWriter(mockRpc(s)); + + CompletableFuture m1 = writer.blockingSend(MSG1, null); + CompletableFuture m2 = writer.blockingSend(MSG1, null); + + Assert.assertFalse(m1.isDone()); + Assert.assertFalse(m2.isDone()); + + CompletableFuture closeFuture = writer.shutdown(); + + Assert.assertTrue(closeFuture.isDone()); + Assert.assertTrue(m1.isDone()); + Assert.assertTrue(m2.isDone()); + + Assert.assertSame(closeFuture, writer.shutdown()); + + assertIllegalState("Writer is already stopped", writer::init); + assertIllegalState("Writer is already stopped", () -> writer.blockingSend(MSG1, null)); + assertIllegalState("Writer is already stopped", () -> writer.blockingSend(MSG1, null, 1, TimeUnit.SECONDS)); + assertIllegalState("Writer is already stopped", () -> writer.nonblockingSend(MSG1, null)); + + Assert.assertTrue(m1.isCompletedExceptionally()); + Assert.assertTrue(m2.isCompletedExceptionally()); + + assertRuntimeException("Message sending was cancelled with status Status{code = SUCCESS, " + + "issues = [Closed by client (S_INFO)]}", futureGet(m1)); + assertRuntimeException("Message sending was cancelled with status Status{code = SUCCESS, " + + "issues = [Closed by client (S_INFO)]}", futureGet(m2)); + + Assert.assertNull(s.observer); + Assert.assertFalse(s.isClosed); + Assert.assertFalse(s.isCanceled); + } + + @Test + public void shutdownCancelsPendingMessages() throws Exception { + StreamMock s = new StreamMock(); + WriterImpl writer = createWriter(mockRpc(s)); + writer.init(); + + s.sendInitResponse(0L); + + CompletableFuture m1 = writer.blockingSend(MSG1, null); + CompletableFuture m2 = writer.blockingSend(MSG1, null); + + writer.shutdown(); + Assert.assertTrue(s.isClosed); + s.close(Status.of(StatusCode.SUCCESS)); + + Assert.assertTrue(m1.isCompletedExceptionally()); + Assert.assertTrue(m2.isCompletedExceptionally()); + + assertRuntimeException("Message had been sent but the writer was stopped with status Status{code = SUCCESS}", + futureGet(m1)); + assertRuntimeException("Message had been sent but the writer was stopped with status Status{code = SUCCESS}", + futureGet(m2)); + } + + @Test + public void streamFailureTest() throws Exception { + StreamMock s = new StreamMock(); + WriterImpl writer = createWriter(mockRpc(s)); + + CompletableFuture initFuture = writer.init(); + CompletableFuture m1 = writer.blockingSend(MSG1, null); + CompletableFuture m2 = writer.blockingSend(MSG1, null); + CompletableFuture flushFuture = writer.flush(); + + Assert.assertFalse(initFuture.isDone()); + Assert.assertFalse(m1.isDone()); + Assert.assertFalse(m2.isDone()); + Assert.assertFalse(flushFuture.isDone()); + + s.close(Status.of(StatusCode.SCHEME_ERROR)); + + assertIllegalState("Writer is already stopped", () -> writer.blockingSend(MSG1, null)); + assertIllegalState("Writer is already stopped", () -> writer.blockingSend(MSG1, null, 1, TimeUnit.SECONDS)); + assertIllegalState("Writer is already stopped", () -> writer.nonblockingSend(MSG1, null)); + + Assert.assertTrue(initFuture.isCompletedExceptionally()); + Assert.assertTrue(m1.isCompletedExceptionally()); + Assert.assertTrue(m2.isCompletedExceptionally()); + Assert.assertTrue(flushFuture.isDone()); + Assert.assertFalse(flushFuture.isCompletedExceptionally()); + + CompletableFuture shutdownFuture = writer.shutdown(); + + Assert.assertTrue(shutdownFuture.isDone()); + Assert.assertFalse(shutdownFuture.isCompletedExceptionally()); + + Assert.assertTrue(m1.isCompletedExceptionally()); + Assert.assertTrue(m2.isCompletedExceptionally()); + + assertRuntimeException("Message sending was cancelled with status Status{code = SCHEME_ERROR(code=400070)}", + futureGet(m1)); + assertRuntimeException("Message sending was cancelled with status Status{code = SCHEME_ERROR(code=400070)}", + futureGet(m2)); + + Assert.assertNotNull(s.observer); + Assert.assertFalse(s.isClosed); // stream was closed itself + Assert.assertFalse(s.isCanceled); + } + + @Test + public void withSeqNoConsistencyTest() throws Exception { + StreamMock s = new StreamMock(); + WriterImpl writer = createWriter(mockRpc(s)); + writer.init(); + s.sendInitResponse(0L); + + // first message without seqNo — establishes isSeqNoProvided = false + Message msg1 = Message.of("msg1".getBytes()); + writer.nonblockingSend(msg1, null); + + // second message WITH seqNo must fail + Message msg2 = Message.newBuilder().setData("msg2".getBytes()).setSeqNo(2L).build(); + assertRuntimeException("SeqNo was provided for a message after it had not been provided for another message. " + + "SeqNo should either be provided for all messages or none of them.", + () -> writer.nonblockingSend(msg2, null)); + } + + @Test + public void withOutSeqNoConsistencyTest() throws Exception { + StreamMock s = new StreamMock(); + WriterImpl writer = createWriter(mockRpc(s)); + writer.init(); + s.sendInitResponse(0L); + + // first message with seqNo — establishes isSeqNoProvided = true + Message msg1 = Message.newBuilder().setData("msg2".getBytes()).setSeqNo(1L).build(); + writer.nonblockingSend(msg1, null); + + // second message WITHOUT seqNo must fail + Message msg2 = Message.of("msg2".getBytes()); + assertRuntimeException("SeqNo was not provided for a message after it had been provided for another message. " + + "SeqNo should either be provided for all messages or none of them.", + () -> writer.nonblockingSend(msg2, null)); + } + + @Test + public void retryResendsPendingMessagesTest() throws Exception { + StreamMock s1 = new StreamMock(); + StreamMock s2 = new StreamMock(); + WriterImpl writer = createWriter(mockRpc(s1, s2), IMMEDIATELY_FOREVER); + + writer.init(); + s1.sendInitResponse(0L); + + CompletableFuture m1 = writer.nonblockingSend(MSG1, null); + CompletableFuture m2 = writer.nonblockingSend(MSG1, null); + + Assert.assertEquals(3, s1.messages.size()); // init req + 2 write requests + Assert.assertFalse(m1.isDone()); + Assert.assertFalse(m2.isDone()); + + // stream 1 fails — first retry delay is 0ms, so start() is called synchronously + s1.close(Status.of(StatusCode.UNAVAILABLE)); + + // stream 2 is now connected; lastSeqNo=0 means message was not yet persisted + s2.sendInitResponse(0L); + + Assert.assertEquals(2, s2.messages.size()); // init req + write request (with two messages) + + s2.sendAckResponse(1L, 42L); + + Assert.assertTrue(m1.isDone()); + Assert.assertEquals(1L, m1.join().getSeqNo()); + Assert.assertEquals(WriteAck.State.WRITTEN, m1.join().getState()); + + Assert.assertFalse(m2.isDone()); + } + + private static class StreamMock implements GrpcReadWriteStream { + private final CompletableFuture future = new CompletableFuture<>(); + private final List messages = new ArrayList<>(); + private GrpcReadStream.Observer observer = null; + private boolean isClosed = false; + private boolean isCanceled = false; + + void sendInitResponse(long lastSeqNo) { + observer.onNext(FromServer.newBuilder() + .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS) + .setInitResponse(YdbTopic.StreamWriteMessage.InitResponse.newBuilder() + .setLastSeqNo(lastSeqNo) + .setSessionId("test-session") + .build()) + .build()); + } + + void sendAckResponse(long seqNo, long offset) { + observer.onNext(FromServer.newBuilder() + .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS) + .setWriteResponse(YdbTopic.StreamWriteMessage.WriteResponse.newBuilder() + .addAcks(YdbTopic.StreamWriteMessage.WriteResponse.WriteAck.newBuilder() + .setSeqNo(seqNo) + .setWritten(YdbTopic.StreamWriteMessage.WriteResponse.WriteAck.Written.newBuilder() + .setOffset(offset) + .build()) + .build()) + .build()) + .build() + ); + } + + + void close(Status status) { + future.complete(status); + } + + @Override + public String authToken() { + return "token"; + } + + @Override + public void sendNext(FromClient message) { + messages.add(message); + } + + @Override + public void close() { + this.isClosed = true; + } + + @Override + public CompletableFuture start(GrpcReadStream.Observer observer) { + this.observer = observer; + return future; + } + + @Override + public void cancel() { + this.isCanceled = true; + } + } + +}