diff --git a/topic/src/main/java/tech/ydb/topic/TopicRpc.java b/topic/src/main/java/tech/ydb/topic/TopicRpc.java index 6adeeef21..f4bc48a37 100644 --- a/topic/src/main/java/tech/ydb/topic/TopicRpc.java +++ b/topic/src/main/java/tech/ydb/topic/TopicRpc.java @@ -7,7 +7,17 @@ import tech.ydb.core.Status; import tech.ydb.core.grpc.GrpcReadWriteStream; import tech.ydb.core.grpc.GrpcRequestSettings; -import tech.ydb.proto.topic.YdbTopic; +import tech.ydb.proto.topic.YdbTopic.AlterTopicRequest; +import tech.ydb.proto.topic.YdbTopic.CommitOffsetRequest; +import tech.ydb.proto.topic.YdbTopic.CreateTopicRequest; +import tech.ydb.proto.topic.YdbTopic.DescribeConsumerRequest; +import tech.ydb.proto.topic.YdbTopic.DescribeConsumerResult; +import tech.ydb.proto.topic.YdbTopic.DescribeTopicRequest; +import tech.ydb.proto.topic.YdbTopic.DescribeTopicResult; +import tech.ydb.proto.topic.YdbTopic.DropTopicRequest; +import tech.ydb.proto.topic.YdbTopic.StreamReadMessage; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage; +import tech.ydb.proto.topic.YdbTopic.UpdateOffsetsInTransactionRequest; /** @@ -21,7 +31,7 @@ public interface TopicRpc { * @param settings rpc call settings * @return completable future with status of operation */ - CompletableFuture createTopic(YdbTopic.CreateTopicRequest request, GrpcRequestSettings settings); + CompletableFuture createTopic(CreateTopicRequest request, GrpcRequestSettings settings); /** * Alter topic. @@ -29,7 +39,7 @@ public interface TopicRpc { * @param settings rpc call settings * @return completable future with status of operation */ - CompletableFuture alterTopic(YdbTopic.AlterTopicRequest request, GrpcRequestSettings settings); + CompletableFuture alterTopic(AlterTopicRequest request, GrpcRequestSettings settings); /** * Drop topic. @@ -37,7 +47,7 @@ public interface TopicRpc { * @param settings rpc call settings * @return completable future with status of operation */ - CompletableFuture dropTopic(YdbTopic.DropTopicRequest request, GrpcRequestSettings settings); + CompletableFuture dropTopic(DropTopicRequest request, GrpcRequestSettings settings); /** * Describe topic. @@ -45,8 +55,8 @@ public interface TopicRpc { * @param settings rpc call settings * @return completable future with result of operation */ - CompletableFuture> describeTopic(YdbTopic.DescribeTopicRequest request, - GrpcRequestSettings settings); + CompletableFuture> describeTopic(DescribeTopicRequest request, + GrpcRequestSettings settings); /** * Describe consumer. @@ -54,8 +64,8 @@ CompletableFuture> describeTopic(YdbTopic.D * @param settings rpc call settings * @return completable future with result of operation */ - CompletableFuture> describeConsumer( - YdbTopic.DescribeConsumerRequest request, GrpcRequestSettings settings + CompletableFuture> describeConsumer( + DescribeConsumerRequest request, GrpcRequestSettings settings ); /** @@ -64,7 +74,7 @@ CompletableFuture> describeConsumer( * @param settings rpc call settings * @return completable future with result of operation */ - CompletableFuture commitOffset(YdbTopic.CommitOffsetRequest request, GrpcRequestSettings settings); + CompletableFuture commitOffset(CommitOffsetRequest request, GrpcRequestSettings settings); /** * Updates offsets in transaction. @@ -72,16 +82,17 @@ CompletableFuture> describeConsumer( * @param settings rpc call settings * @return completable future with result of operation */ - CompletableFuture updateOffsetsInTransaction(YdbTopic.UpdateOffsetsInTransactionRequest request, + CompletableFuture updateOffsetsInTransaction(UpdateOffsetsInTransactionRequest request, GrpcRequestSettings settings); - GrpcReadWriteStream writeSession( - String traceId - ); + GrpcReadWriteStream writeSession(String traceId); - GrpcReadWriteStream readSession( - String traceId - ); + default GrpcReadWriteStream writeSession( + GrpcRequestSettings settings) { + return writeSession(settings.getTraceId()); + } + + GrpcReadWriteStream readSession(String traceId); ScheduledExecutorService getScheduler(); } diff --git a/topic/src/main/java/tech/ydb/topic/impl/GrpcTopicRpc.java b/topic/src/main/java/tech/ydb/topic/impl/GrpcTopicRpc.java index d7be2799f..7aa1f518f 100644 --- a/topic/src/main/java/tech/ydb/topic/impl/GrpcTopicRpc.java +++ b/topic/src/main/java/tech/ydb/topic/impl/GrpcTopicRpc.java @@ -13,6 +13,22 @@ import tech.ydb.core.grpc.GrpcTransport; import tech.ydb.core.operation.OperationBinder; import tech.ydb.proto.topic.YdbTopic; +import tech.ydb.proto.topic.YdbTopic.AlterTopicRequest; +import tech.ydb.proto.topic.YdbTopic.AlterTopicResponse; +import tech.ydb.proto.topic.YdbTopic.CommitOffsetRequest; +import tech.ydb.proto.topic.YdbTopic.CreateTopicRequest; +import tech.ydb.proto.topic.YdbTopic.DescribeConsumerRequest; +import tech.ydb.proto.topic.YdbTopic.DescribeConsumerResponse; +import tech.ydb.proto.topic.YdbTopic.DescribeConsumerResult; +import tech.ydb.proto.topic.YdbTopic.DescribeTopicRequest; +import tech.ydb.proto.topic.YdbTopic.DescribeTopicResponse; +import tech.ydb.proto.topic.YdbTopic.DescribeTopicResult; +import tech.ydb.proto.topic.YdbTopic.DropTopicRequest; +import tech.ydb.proto.topic.YdbTopic.DropTopicResponse; +import tech.ydb.proto.topic.YdbTopic.StreamReadMessage; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage; +import tech.ydb.proto.topic.YdbTopic.UpdateOffsetsInTransactionRequest; +import tech.ydb.proto.topic.YdbTopic.UpdateOffsetsInTransactionResponse; import tech.ydb.proto.topic.v1.TopicServiceGrpc; import tech.ydb.topic.TopicRpc; @@ -33,77 +49,78 @@ public static GrpcTopicRpc useTransport(@WillNotClose GrpcTransport transport) { } @Override - public CompletableFuture createTopic(YdbTopic.CreateTopicRequest request, GrpcRequestSettings settings) { + public CompletableFuture createTopic(CreateTopicRequest request, GrpcRequestSettings settings) { return transport .unaryCall(TopicServiceGrpc.getCreateTopicMethod(), settings, request) .thenApply(OperationBinder.bindSync(YdbTopic.CreateTopicResponse::getOperation)); } @Override - public CompletableFuture alterTopic(YdbTopic.AlterTopicRequest request, GrpcRequestSettings settings) { + public CompletableFuture alterTopic(AlterTopicRequest request, GrpcRequestSettings settings) { return transport .unaryCall(TopicServiceGrpc.getAlterTopicMethod(), settings, request) - .thenApply(OperationBinder.bindSync(YdbTopic.AlterTopicResponse::getOperation)); + .thenApply(OperationBinder.bindSync(AlterTopicResponse::getOperation)); } @Override - public CompletableFuture> describeTopic(YdbTopic.DescribeTopicRequest request, + public CompletableFuture> describeTopic(DescribeTopicRequest request, GrpcRequestSettings settings) { return transport .unaryCall(TopicServiceGrpc.getDescribeTopicMethod(), settings, request) - .thenApply(OperationBinder.bindSync( - YdbTopic.DescribeTopicResponse::getOperation, YdbTopic.DescribeTopicResult.class) - ); + .thenApply(OperationBinder.bindSync(DescribeTopicResponse::getOperation, DescribeTopicResult.class)); } @Override - public CompletableFuture> describeConsumer( - YdbTopic.DescribeConsumerRequest request, GrpcRequestSettings settings - ) { + public CompletableFuture> describeConsumer(DescribeConsumerRequest request, + GrpcRequestSettings settings) { return transport .unaryCall(TopicServiceGrpc.getDescribeConsumerMethod(), settings, request) - .thenApply(OperationBinder.bindSync( - YdbTopic.DescribeConsumerResponse::getOperation, YdbTopic.DescribeConsumerResult.class) - ); + .thenApply(OperationBinder.bindSync(DescribeConsumerResponse::getOperation, + DescribeConsumerResult.class)); } @Override - public CompletableFuture dropTopic(YdbTopic.DropTopicRequest request, GrpcRequestSettings settings) { + public CompletableFuture dropTopic(DropTopicRequest request, GrpcRequestSettings settings) { return transport .unaryCall(TopicServiceGrpc.getDropTopicMethod(), settings, request) - .thenApply(OperationBinder.bindSync(YdbTopic.DropTopicResponse::getOperation)); + .thenApply(OperationBinder.bindSync(DropTopicResponse::getOperation)); } @Override - public CompletableFuture commitOffset(YdbTopic.CommitOffsetRequest request, GrpcRequestSettings settings) { + public CompletableFuture commitOffset(CommitOffsetRequest request, GrpcRequestSettings settings) { return transport .unaryCall(TopicServiceGrpc.getCommitOffsetMethod(), settings, request) .thenApply(OperationBinder.bindSync(YdbTopic.CommitOffsetResponse::getOperation)); } @Override - public CompletableFuture updateOffsetsInTransaction(YdbTopic.UpdateOffsetsInTransactionRequest request, + public CompletableFuture updateOffsetsInTransaction(UpdateOffsetsInTransactionRequest request, GrpcRequestSettings settings) { return transport .unaryCall(TopicServiceGrpc.getUpdateOffsetsInTransactionMethod(), settings, request) - .thenApply(OperationBinder.bindSync(YdbTopic.UpdateOffsetsInTransactionResponse::getOperation)); + .thenApply(OperationBinder.bindSync(UpdateOffsetsInTransactionResponse::getOperation)); } @Override - public GrpcReadWriteStream - writeSession(String streamId) { + public GrpcReadWriteStream writeSession(String id) { GrpcRequestSettings settings = GrpcRequestSettings.newBuilder() - .withTraceId(streamId) + .withTraceId(id) .disableDeadline() .build(); + return writeSession(settings); + } + + @Override + public GrpcReadWriteStream writeSession( + GrpcRequestSettings settings) { return transport.readWriteStreamCall(TopicServiceGrpc.getStreamWriteMethod(), settings); } + @Override - public GrpcReadWriteStream - readSession(String streamId) { + public GrpcReadWriteStream readSession(String id) { GrpcRequestSettings settings = GrpcRequestSettings.newBuilder() - .withTraceId(streamId) + .withTraceId(id) .disableDeadline() .build(); return transport.readWriteStreamCall(TopicServiceGrpc.getStreamReadMethod(), settings); diff --git a/topic/src/main/java/tech/ydb/topic/impl/TopicStream.java b/topic/src/main/java/tech/ydb/topic/impl/TopicStream.java index d2421fc6b..d435693ad 100644 --- a/topic/src/main/java/tech/ydb/topic/impl/TopicStream.java +++ b/topic/src/main/java/tech/ydb/topic/impl/TopicStream.java @@ -1,79 +1,15 @@ package tech.ydb.topic.impl; -import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; import com.google.protobuf.Message; -import org.slf4j.Logger; import tech.ydb.core.Status; -import tech.ydb.core.StatusCode; -import tech.ydb.core.grpc.GrpcReadWriteStream; -public abstract class TopicStream { - private final Logger logger; - private final String debugId; - private final GrpcReadWriteStream stream; - private final CompletableFuture streamStatus = new CompletableFuture<>(); - private volatile String token; +public interface TopicStream { + CompletableFuture start(W initReq, Consumer messageHandler); + void send(W request); - public TopicStream(Logger logger, String debugId, GrpcReadWriteStream stream) { - this.logger = logger; - this.debugId = debugId; - this.stream = stream; - this.token = stream.authToken(); - } - - protected abstract W updateTokenMessage(String token); - protected abstract Status parseMessageStatus(R message); - - public CompletableFuture start(W initReq, Consumer messageHandler) { - this.logger.debug("[{}] is about to start", debugId); - this.stream.start((R msg) -> { - Status messageStatus = parseMessageStatus(msg); - if (messageStatus.isSuccess()) { - messageHandler.accept(msg); - } else { - logger.warn("[{}] stopped by getting status {}", debugId, messageStatus); - if (streamStatus.complete(messageStatus)) { - stream.close(); - } - } - }).whenComplete((st, th) -> { - Status status = st != null ? st : Status.of(StatusCode.CLIENT_INTERNAL_ERROR, th); - logger.debug("[{}] finished with status {}", debugId, status); - streamStatus.complete(status); - }); - - if (!streamStatus.isDone()) { - stream.sendNext(initReq); - } - - return streamStatus; - } - - public void close() { - logger.debug("[{}] closed by app", debugId); - if (!streamStatus.isDone()) { - stream.close(); - } - } - - public void send(W req) { - if (streamStatus.isDone()) { - logger.warn("[{}] is already closed. Next message with type {} was NOT sent", debugId, - req.getDescriptorForType().getName()); - return; - } - - String currentToken = stream.authToken(); - if (!Objects.equals(token, currentToken)) { - token = currentToken; - logger.info("{} sends new token", this); - stream.sendNext(updateTokenMessage(token)); - } - - stream.sendNext(req); - } + void close(); } diff --git a/topic/src/main/java/tech/ydb/topic/impl/TopicStreamBase.java b/topic/src/main/java/tech/ydb/topic/impl/TopicStreamBase.java new file mode 100644 index 000000000..b2179c6ad --- /dev/null +++ b/topic/src/main/java/tech/ydb/topic/impl/TopicStreamBase.java @@ -0,0 +1,82 @@ +package tech.ydb.topic.impl; + +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +import com.google.protobuf.Message; +import org.slf4j.Logger; + +import tech.ydb.core.Status; +import tech.ydb.core.StatusCode; +import tech.ydb.core.grpc.GrpcReadWriteStream; + +public abstract class TopicStreamBase implements TopicStream { + private final Logger logger; + private final String debugId; + private final GrpcReadWriteStream stream; + private final CompletableFuture streamStatus = new CompletableFuture<>(); + private volatile String token; + + public TopicStreamBase(Logger logger, String debugId, GrpcReadWriteStream stream) { + this.logger = logger; + this.debugId = debugId; + this.stream = stream; + this.token = stream.authToken(); + } + + protected abstract W updateTokenMessage(String token); + protected abstract Status parseMessageStatus(R message); + + @Override + public CompletableFuture start(W initReq, Consumer messageHandler) { + this.logger.debug("[{}] is about to start", debugId); + this.stream.start((R msg) -> { + Status messageStatus = parseMessageStatus(msg); + if (messageStatus.isSuccess()) { + messageHandler.accept(msg); + } else { + logger.warn("[{}] stopped by getting status {}", debugId, messageStatus); + if (streamStatus.complete(messageStatus)) { + stream.close(); + } + } + }).whenComplete((st, th) -> { + Status status = st != null ? st : Status.of(StatusCode.CLIENT_INTERNAL_ERROR, th); + logger.debug("[{}] finished with status {}", debugId, status); + streamStatus.complete(status); + }); + + if (!streamStatus.isDone()) { + stream.sendNext(initReq); + } + + return streamStatus; + } + + @Override + public void close() { + logger.debug("[{}] closed by app", debugId); + if (!streamStatus.isDone()) { + stream.close(); + } + } + + @Override + public void send(W req) { + if (streamStatus.isDone()) { + logger.warn("[{}] is already closed. Next message with type {} was NOT sent", debugId, + req.getDescriptorForType().getName()); + return; + } + + String currentToken = stream.authToken(); + if (!Objects.equals(token, currentToken)) { + token = currentToken; + logger.info("{} sends new token", this); + stream.sendNext(updateTokenMessage(token)); + } + + stream.sendNext(req); + } +} diff --git a/topic/src/main/java/tech/ydb/topic/impl/TopicStreamFail.java b/topic/src/main/java/tech/ydb/topic/impl/TopicStreamFail.java new file mode 100644 index 000000000..20ef759a4 --- /dev/null +++ b/topic/src/main/java/tech/ydb/topic/impl/TopicStreamFail.java @@ -0,0 +1,38 @@ +package tech.ydb.topic.impl; + +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +import com.google.protobuf.Message; +import org.slf4j.Logger; + +import tech.ydb.core.Status; + +public class TopicStreamFail implements TopicStream { + private final Logger logger; + private final String debugId; + + private final Status status; + + public TopicStreamFail(Logger logger, String debugId, Status status) { + this.logger = logger; + this.debugId = debugId; + this.status = status; + } + + @Override + public CompletableFuture start(W initReq, Consumer messageHandler) { + return CompletableFuture.completedFuture(status); + } + + @Override + public void send(W req) { + logger.warn("[{}] is failed stream with status {}. Next message with type {} was NOT sent", debugId, status, + req.getDescriptorForType().getName()); + } + + @Override + public void close() { + logger.warn("[{}] is failed stream with status {}. It doesn't need to close", debugId, status); + } +} diff --git a/topic/src/main/java/tech/ydb/topic/settings/WriterSettings.java b/topic/src/main/java/tech/ydb/topic/settings/WriterSettings.java index b301a427a..02b59135d 100644 --- a/topic/src/main/java/tech/ydb/topic/settings/WriterSettings.java +++ b/topic/src/main/java/tech/ydb/topic/settings/WriterSettings.java @@ -4,6 +4,7 @@ import tech.ydb.common.retry.RetryConfig; import tech.ydb.core.Status; +import tech.ydb.topic.TopicClient; import tech.ydb.topic.description.Codec; /** @@ -18,6 +19,8 @@ public class WriterSettings { private final String producerId; private final String messageGroupId; private final Long partitionId; + private final boolean useDirectWrite; + private final int codec; private final long maxSendBufferMemorySize; private final int maxSendBufferMessagesCount; @@ -30,6 +33,7 @@ private WriterSettings(Builder builder) { this.producerId = builder.producerId; this.messageGroupId = builder.messageGroupId; this.partitionId = builder.partitionId; + this.useDirectWrite = builder.useDirectWrite; this.codec = builder.codec; this.maxSendBufferMemorySize = builder.maxSendBufferMemorySize; this.maxSendBufferMessagesCount = builder.maxSendBufferMessagesCount; @@ -57,6 +61,10 @@ public String getMessageGroupId() { return messageGroupId; } + public boolean isDirectWrite() { + return useDirectWrite; + } + public BiConsumer getErrorsHandler() { return errorsHandler; } @@ -90,6 +98,7 @@ public static class Builder { private String producerId = null; private String messageGroupId = null; private Long partitionId = null; + private boolean useDirectWrite = false; private int codec = Codec.GZIP; private long maxSendBufferMemorySize = MAX_MEMORY_USAGE_BYTES_DEFAULT; private int maxSendBufferMessagesCount = MAX_IN_FLIGHT_COUNT_DEFAULT; @@ -153,6 +162,39 @@ public Builder setPartitionId(long partitionId) { return this; } + /** + * Enable or disable direct write mode, where the writer connects to the specific YDB node that owns the target + * partition rather than routing through a proxy. + *

+ * When enabled, the writer resolves the target node before opening the write stream: + *

    + *
  • If {@link #setPartitionId} is set, the node is resolved via + * {@link TopicClient#describeTopic(java.lang.String) describeTopic}. + *
  • If {@link #setProducerId} is set (and no explicit partition), the partition is resolved first via a + * probe write stream, then the node is resolved via + * {@link TopicClient#describeTopic(java.lang.String) describeTopic}. + *
+ * Direct write reduces write latency by eliminating the proxy hop at the cost of an extra RPC on + * (re)connection. + *

+ * Warning: direct write requires a direct network link from the client to the YDB nodes. If the client + * can only reach a YDB proxy (e.g. in cloud environments where nodes are not publicly accessible), enabling + * this mode will cause connection failures. Use it only when the client has direct network access to all YDB + * nodes in the cluster. + *

+ * Direct write requires either {@link #setPartitionId} or {@link #setProducerId} to be set; otherwise + * {@link TopicClient#createSyncWriter(tech.ydb.topic.settings.WriterSettings) createSyncWriter} and + * {@link TopicClient#createAsyncWriter(tech.ydb.topic.settings.WriterSettings) createAsyncWriter} will throw + * {@link IllegalArgumentException}. + * + * @param enabled {@code true} to enable direct write + * @return this builder + */ + public Builder setDirectWrite(boolean enabled) { + this.useDirectWrite = enabled; + return this; + } + /** * Set codec to use for data compression prior to write * @param codec compression codec diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/WriteSession.java b/topic/src/main/java/tech/ydb/topic/write/impl/WriteSession.java index ade535a09..2456b23f4 100644 --- a/topic/src/main/java/tech/ydb/topic/write/impl/WriteSession.java +++ b/topic/src/main/java/tech/ydb/topic/write/impl/WriteSession.java @@ -13,6 +13,7 @@ import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromServer; import tech.ydb.topic.TopicRpc; import tech.ydb.topic.impl.TopicRetryableStream; +import tech.ydb.topic.impl.TopicStream; import tech.ydb.topic.settings.WriterSettings; import tech.ydb.topic.write.WriteAck; @@ -20,6 +21,8 @@ * @author Nikolay Perfilov */ public final class WriteSession extends TopicRetryableStream { + public interface Stream extends TopicStream { } + private static final Logger logger = LoggerFactory.getLogger(WriteSession.class); public interface Listener { @@ -32,20 +35,20 @@ public interface Listener { } private final Listener listener; - private final StreamFactory streamFactory; + private final WriteStreamFactory streamFactory; private final MessageSender sender; private final BiConsumer errorsHandler; public WriteSession(String debugId, TopicRpc rpc, WriterSettings settings, Listener controller) { super(logger, debugId, settings.getRetryConfig(), rpc.getScheduler()); this.listener = controller; - this.streamFactory = new StreamFactory(rpc, settings); + this.streamFactory = WriteStreamFactory.of(rpc, settings); this.sender = new MessageSender(debugId, settings.getCodec(), this::send); this.errorsHandler = settings.getErrorsHandler(); } @Override - protected WriteStream createNewStream(String id) { + protected Stream createNewStream(String id) { return streamFactory.createNewStream(id); } @@ -68,7 +71,7 @@ private void onInitResponse(YdbTopic.StreamWriteMessage.InitResponse response) { String sessionId = response.getSessionId(); resetRetries(); logger.info("[{}] Session with id {} (partition {}) initialized for topic \"{}\", lastSeqNo {}", - debugId, sessionId, response.getPartitionId(), streamFactory.topicPath, lastSeqNo); + debugId, sessionId, response.getPartitionId(), streamFactory.getTopicPath(), lastSeqNo); listener.onStart(lastSeqNo, sessionId); } @@ -141,45 +144,4 @@ public void onNext(YdbTopic.StreamWriteMessage.FromServer message) { logger.warn("[{}] got unknown type message", debugId); } } - - private class StreamFactory { - private final String topicPath; - private final TopicRpc rpc; - private final YdbTopic.StreamWriteMessage.InitRequest initRequest; - - StreamFactory(TopicRpc rpc, WriterSettings settings) { - this.rpc = rpc; - this.topicPath = settings.getTopicPath(); - - YdbTopic.StreamWriteMessage.InitRequest.Builder req = YdbTopic.StreamWriteMessage.InitRequest - .newBuilder() - .setPath(topicPath); - String producerId = settings.getProducerId(); - if (producerId != null) { - req.setProducerId(producerId); - } - String messageGroupId = settings.getMessageGroupId(); - Long partitionId = settings.getPartitionId(); - if (messageGroupId != null) { - if (partitionId != null) { - throw new IllegalArgumentException("Both MessageGroupId and PartitionId are set in WriterSettings"); - } - req.setMessageGroupId(messageGroupId); - } else if (partitionId != null) { - req.setPartitionId(partitionId); - } - - this.initRequest = req.build(); - } - - public WriteStream createNewStream(String id) { - return new WriteStream(id, rpc); - } - - public YdbTopic.StreamWriteMessage.FromClient initRequest() { - return YdbTopic.StreamWriteMessage.FromClient.newBuilder() - .setInitRequest(initRequest) - .build(); - } - } } diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/WriteStream.java b/topic/src/main/java/tech/ydb/topic/write/impl/WriteStream.java index 1f431e3a4..10ccad609 100644 --- a/topic/src/main/java/tech/ydb/topic/write/impl/WriteStream.java +++ b/topic/src/main/java/tech/ydb/topic/write/impl/WriteStream.java @@ -6,21 +6,22 @@ import tech.ydb.core.Issue; import tech.ydb.core.Status; import tech.ydb.core.StatusCode; +import tech.ydb.core.grpc.GrpcReadWriteStream; import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromClient; import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromServer; import tech.ydb.proto.topic.YdbTopic.UpdateTokenRequest; -import tech.ydb.topic.TopicRpc; -import tech.ydb.topic.impl.TopicStream; +import tech.ydb.topic.impl.TopicStreamBase; +import tech.ydb.topic.impl.TopicStreamFail; /** * * @author Aleksandr Gorshenin */ -public class WriteStream extends TopicStream { +public class WriteStream extends TopicStreamBase implements WriteSession.Stream { private static final Logger logger = LoggerFactory.getLogger(WriteStream.class); - public WriteStream(String id, TopicRpc rpc) { - super(logger, id, rpc.writeSession(id)); + public WriteStream(String id, GrpcReadWriteStream stream) { + super(logger, id, stream); } @Override @@ -34,4 +35,10 @@ protected FromClient updateTokenMessage(String token) { protected Status parseMessageStatus(FromServer message) { return Status.of(StatusCode.fromProto(message.getStatus()), Issue.fromPb(message.getIssuesList())); } + + public static class Fail extends TopicStreamFail implements WriteSession.Stream { + public Fail(String id, Status status) { + super(logger, id, status); + } + } } diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/WriteStreamFactory.java b/topic/src/main/java/tech/ydb/topic/write/impl/WriteStreamFactory.java new file mode 100644 index 000000000..a1920a740 --- /dev/null +++ b/topic/src/main/java/tech/ydb/topic/write/impl/WriteStreamFactory.java @@ -0,0 +1,229 @@ +package tech.ydb.topic.write.impl; + +import java.time.Duration; +import java.util.concurrent.CompletableFuture; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import tech.ydb.core.Issue; +import tech.ydb.core.Result; +import tech.ydb.core.Status; +import tech.ydb.core.StatusCode; +import tech.ydb.core.grpc.GrpcReadWriteStream; +import tech.ydb.core.grpc.GrpcRequestSettings; +import tech.ydb.proto.StatusCodesProtos; +import tech.ydb.proto.topic.YdbTopic; +import tech.ydb.proto.topic.YdbTopic.DescribeTopicRequest; +import tech.ydb.proto.topic.YdbTopic.DescribeTopicResult; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromClient; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromServer; +import tech.ydb.topic.TopicRpc; +import tech.ydb.topic.settings.WriterSettings; + +/** + * + * @author Aleksandr Gorshenin + */ +public class WriteStreamFactory { + private static final Logger logger = LoggerFactory.getLogger(WriteStreamFactory.class); + + private final String topicPath; + private final StreamWriteMessage.InitRequest initRequest; + protected final TopicRpc rpc; + + private WriteStreamFactory(TopicRpc rpc, WriterSettings settings) { + this.rpc = rpc; + this.topicPath = settings.getTopicPath(); + + String producerId = settings.getProducerId(); + String messageGroupId = settings.getMessageGroupId(); + Long partitionId = settings.getPartitionId(); + + StreamWriteMessage.InitRequest.Builder req = StreamWriteMessage.InitRequest.newBuilder() + .setPath(topicPath); + + if (producerId != null) { + req.setProducerId(producerId); + } + if (messageGroupId != null) { + if (partitionId != null) { + throw new IllegalArgumentException("Both MessageGroupId and PartitionId are set in WriterSettings"); + } + req.setMessageGroupId(messageGroupId); + } else if (partitionId != null) { + req.setPartitionId(partitionId); + } + + this.initRequest = req.build(); + } + + public String getTopicPath() { + return topicPath; + } + + public WriteSession.Stream createNewStream(String id) { + return new WriteStream(id, rpc.writeSession(id)); + } + + public YdbTopic.StreamWriteMessage.FromClient initRequest() { + return YdbTopic.StreamWriteMessage.FromClient.newBuilder() + .setInitRequest(initRequest) + .build(); + } + + protected Result lookupNodeId(String id, long partitionId) { + logger.info("[{}] describe topic {} to look up node for partition {}", id, topicPath, partitionId); + Result describeTopic = rpc.describeTopic( + DescribeTopicRequest.newBuilder().setIncludeLocation(true).setPath(topicPath).build(), + GrpcRequestSettings.newBuilder().withDeadline(Duration.ofMinutes(1)).build() + ).join(); + + if (!describeTopic.isSuccess()) { + logger.warn("[{}] describe topic {} failed with status {}", id, topicPath, describeTopic.getStatus()); + return Result.fail(describeTopic.getStatus()); + } + + // lookup for nodeID + for (DescribeTopicResult.PartitionInfo partition : describeTopic.getValue().getPartitionsList()) { + if (partition.getPartitionId() == partitionId) { + return Result.success(partition.getPartitionLocation().getNodeId()); + } + } + + logger.warn("[{}] topic {} doesn't have partition {}, direct writing failed", id, topicPath, partitionId); + Issue issue = Issue.of("Cannot find partition " + partitionId, Issue.Severity.ERROR); + return Result.fail(Status.of(StatusCode.BAD_REQUEST, issue)); + } + + protected Result lookupPartitionId(String id, String producerId) { + CompletableFuture> partitionId = new CompletableFuture<>(); + + // create one-shot stream to detect partitionID for this producer + logger.info("[{}] create probe stream for topic {} with producer {}", id, topicPath, producerId); + GrpcRequestSettings settings = GrpcRequestSettings.newBuilder() + .withTraceId(id + "-probe") + .withDeadline(Duration.ofMinutes(1)) + .build(); + GrpcReadWriteStream stream = rpc.writeSession(settings); + + CompletableFuture streamFuture = stream.start(resp -> { + if (resp.getStatus() != StatusCodesProtos.StatusIds.StatusCode.SUCCESS) { + Status status = Status.of(StatusCode.fromProto(resp.getStatus()), Issue.fromPb(resp.getIssuesList())); + logger.warn("[{}] probe stream to topic {} with producer {} got error {}", id, topicPath, + producerId, status); + partitionId.complete(Result.fail(status)); + return; + } + + if (resp.hasInitResponse()) { + long pid = resp.getInitResponse().getPartitionId(); + logger.info("[{}] probe stream to topic {} with producer {} has partition {}", id, topicPath, + producerId, pid); + partitionId.complete(Result.success(pid)); + return; + } + + logger.warn("[{}] probe stream to topic {} with producer {} got unexpected message {}", id, topicPath, + producerId, resp.getClass().getName()); + + Issue issue = Issue.of("Unexpected message from stream with producer " + producerId, Issue.Severity.ERROR); + partitionId.complete(Result.fail(Status.of(StatusCode.BAD_REQUEST, issue))); + }); + + if (streamFuture.isDone()) { + logger.warn("[{}] probe stream to topic {} with producer {} failed with status {}", id, topicPath, + producerId, streamFuture.join()); + return Result.fail(streamFuture.join()); + } + + try { + streamFuture.whenComplete((st, th) -> { + Status status = st != null ? st : Status.of(StatusCode.CLIENT_INTERNAL_ERROR, th); + if (!partitionId.isDone()) { + logger.warn("[{}] probe stream to topic {} with producer {} failed with status {}", id, topicPath, + producerId, status); + partitionId.complete(Result.fail(status)); + } + }); + stream.sendNext(initRequest()); + return partitionId.join(); + } finally { + if (!streamFuture.isDone()) { + stream.close(); + } + } + } + + public static WriteStreamFactory of(TopicRpc rpc, WriterSettings settings) { + if (!settings.isDirectWrite()) { + return new WriteStreamFactory(rpc, settings); + } + + if (settings.getPartitionId() != null) { + return new DirectWriteByPartitionId(rpc, settings, settings.getPartitionId()); + } + + if (settings.getProducerId() != null) { + return new DirectWriteByProducerId(rpc, settings, settings.getProducerId()); + } + + throw new IllegalArgumentException("Direct writing requires PartitionId or ProducerId in WriterSettings"); + } + + private static class DirectWriteByPartitionId extends WriteStreamFactory { + private final long partitionId; + + private DirectWriteByPartitionId(TopicRpc rpc, WriterSettings settings, long partitionId) { + super(rpc, settings); + this.partitionId = partitionId; + } + + @Override + public WriteSession.Stream createNewStream(String id) { + Result nodeId = lookupNodeId(id, partitionId); + if (!nodeId.isSuccess()) { + return new WriteStream.Fail(id, nodeId.getStatus()); + } + + GrpcRequestSettings settings = GrpcRequestSettings.newBuilder() + .withTraceId(id) + .disableDeadline() + .withDirectMode(true) + .withPreferredNodeID(nodeId.getValue()) + .build(); + return new WriteStream(id, rpc.writeSession(settings)); + } + } + + private static class DirectWriteByProducerId extends WriteStreamFactory { + private final String producerId; + + private DirectWriteByProducerId(TopicRpc rpc, WriterSettings settings, String producerId) { + super(rpc, settings); + this.producerId = producerId; + } + + @Override + public WriteSession.Stream createNewStream(String id) { + Result partitionId = lookupPartitionId(id, producerId); + if (!partitionId.isSuccess()) { + return new WriteStream.Fail(id, partitionId.getStatus()); + } + + Result nodeId = lookupNodeId(id, partitionId.getValue()); + if (!nodeId.isSuccess()) { + return new WriteStream.Fail(id, nodeId.getStatus()); + } + + GrpcRequestSettings settings = GrpcRequestSettings.newBuilder() + .withTraceId(id) + .disableDeadline() + .withDirectMode(true) + .withPreferredNodeID(nodeId.getValue()) + .build(); + return new WriteStream(id, rpc.writeSession(settings)); + } + } +} diff --git a/topic/src/test/java/tech/ydb/topic/impl/TopicRetryableStreamTest.java b/topic/src/test/java/tech/ydb/topic/impl/TopicRetryableStreamTest.java index 8e514b59d..3ef82d174 100644 --- a/topic/src/test/java/tech/ydb/topic/impl/TopicRetryableStreamTest.java +++ b/topic/src/test/java/tech/ydb/topic/impl/TopicRetryableStreamTest.java @@ -38,9 +38,9 @@ private static class StreamHandle { private final GrpcReadWriteStream grpc = Mockito.mock(GrpcReadWriteStream.class); private final CompletableFuture grpcFuture = new CompletableFuture<>(); - private final TopicStream stream; + private final TopicStreamBase stream; - StreamHandle(TopicStream mocked) { + StreamHandle(TopicStreamBase mocked) { this.stream = mocked; Mockito.when(mocked.start(Mockito.any(), Mockito.any())).thenReturn(grpcFuture); } @@ -49,7 +49,7 @@ private static class StreamHandle { Mockito.when(grpc.authToken()).thenReturn("token"); Mockito.when(grpc.start(Mockito.any())).thenReturn(grpcFuture); - stream = new TopicStream(logger, "inner", grpc) { + stream = new TopicStreamBase(logger, "inner", grpc) { @Override protected Empty updateTokenMessage(String token) { return EMPTY; @@ -205,7 +205,7 @@ public void noRetriesErrorStatusTest() { @Test public void noRetriesExceptionStatusTest() { @SuppressWarnings("unchecked") - StreamHandle h = new StreamHandle(Mockito.mock(TopicStream.class)); + StreamHandle h = new StreamHandle(Mockito.mock(TopicStreamBase.class)); TestStream retryable = new TestStream(Arrays.asList(h), RetryConfig.noRetries(), mockScheduler()); retryable.start(); diff --git a/topic/src/test/java/tech/ydb/topic/impl/TopicStreamTest.java b/topic/src/test/java/tech/ydb/topic/impl/TopicStreamTest.java index 41ec23b18..642195fc2 100644 --- a/topic/src/test/java/tech/ydb/topic/impl/TopicStreamTest.java +++ b/topic/src/test/java/tech/ydb/topic/impl/TopicStreamTest.java @@ -24,7 +24,7 @@ public class TopicStreamTest { private interface MockedStream extends GrpcReadWriteStream { } - private static class TestStream extends TopicStream { + private static class TestStream extends TopicStreamBase { TestStream(MockedStream mock) { super(logger, "test", mock); } diff --git a/topic/src/test/java/tech/ydb/topic/write/impl/WriteStreamFactoryTest.java b/topic/src/test/java/tech/ydb/topic/write/impl/WriteStreamFactoryTest.java new file mode 100644 index 000000000..ce48d8183 --- /dev/null +++ b/topic/src/test/java/tech/ydb/topic/write/impl/WriteStreamFactoryTest.java @@ -0,0 +1,404 @@ +package tech.ydb.topic.write.impl; + +import java.util.Arrays; +import java.util.concurrent.CompletableFuture; + +import org.junit.Assert; +import org.junit.Test; +import org.mockito.Mockito; + +import tech.ydb.core.Issue; +import tech.ydb.core.Result; +import tech.ydb.core.Status; +import tech.ydb.core.StatusCode; +import tech.ydb.core.grpc.GrpcReadStream; +import tech.ydb.core.grpc.GrpcReadWriteStream; +import tech.ydb.core.grpc.GrpcRequestSettings; +import tech.ydb.proto.StatusCodesProtos; +import tech.ydb.proto.topic.YdbTopic; +import tech.ydb.proto.topic.YdbTopic.DescribeTopicResult; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromClient; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromServer; +import tech.ydb.topic.TopicRpc; +import tech.ydb.topic.settings.WriterSettings; + +/** + * @author Aleksandr Gorshenin + */ +public class WriteStreamFactoryTest { + + @SuppressWarnings("unchecked") + private static GrpcReadWriteStream mockGrpcStream() { + GrpcReadWriteStream grpc = Mockito.mock(GrpcReadWriteStream.class); + Mockito.when(grpc.authToken()).thenReturn(""); + return grpc; + } + + private static void mockStreamError(GrpcReadWriteStream mock, Status error) { + Mockito.when(mock.start(Mockito.any())).thenReturn(CompletableFuture.completedFuture(error)); + } + + private static void mockStreamResponse(GrpcReadWriteStream mock, FromServer response) { + CompletableFuture result = new CompletableFuture<>(); + + Mockito.when(mock.start(Mockito.any())).thenAnswer(iom -> { + GrpcReadStream.Observer obs = iom.getArgument(0); + obs.onNext(response); + return result; + }).thenReturn(result); + + Mockito.doAnswer((iom) -> { + result.complete(Status.SUCCESS); + return null; + }).when(mock).close(); + } + + private static DescribeTopicResult.PartitionInfo partition(long partitionId, int nodeId) { + return DescribeTopicResult.PartitionInfo.newBuilder() + .setPartitionId(partitionId) + .setPartitionLocation(YdbTopic.PartitionLocation.newBuilder() + .setNodeId(nodeId) + .build()) + .build(); + } + + private static void mockDescribeResult(TopicRpc rpc, DescribeTopicResult.PartitionInfo... partitions) { + Mockito.when(rpc.describeTopic(Mockito.any(), Mockito.any())) + .thenReturn(CompletableFuture.completedFuture(Result.success( + DescribeTopicResult.newBuilder().addAllPartitions(Arrays.asList(partitions)).build()) + )); + } + + private static void mockDescribeResult(TopicRpc rpc, Status status) { + Mockito.when(rpc.describeTopic(Mockito.any(), Mockito.any())) + .thenReturn(CompletableFuture.completedFuture(Result.fail(status))); + } + + @Test + public void regularWriteTest() { + GrpcReadWriteStream grpc = mockGrpcStream(); + TopicRpc rpc = Mockito.mock(TopicRpc.class); + Mockito.when(rpc.writeSession(Mockito.eq("s1"))).thenReturn(grpc); + + WriterSettings settings = WriterSettings.newBuilder() + .setTopicPath("/local/topic") + .build(); + + WriteStreamFactory factory = WriteStreamFactory.of(rpc, settings); + Assert.assertEquals("/local/topic", factory.getTopicPath()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream); + Mockito.verify(rpc).writeSession("s1"); + } + + @Test + public void writeWithoutDeduplicationTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .build()); + + YdbTopic.StreamWriteMessage.InitRequest req = factory.initRequest() + .getInitRequest(); + Assert.assertEquals("/test/topic", req.getPath()); + Assert.assertEquals("", req.getProducerId()); + Assert.assertFalse(req.hasMessageGroupId()); + Assert.assertFalse(req.hasPartitionId()); + } + + @Test + public void writeWithProducerIdTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("producer") + .build()); + + YdbTopic.StreamWriteMessage.InitRequest req = factory.initRequest() + .getInitRequest(); + Assert.assertEquals("/test/topic", req.getPath()); + Assert.assertEquals("producer", req.getProducerId()); + Assert.assertFalse(req.hasMessageGroupId()); + Assert.assertFalse(req.hasPartitionId()); + } + + @Test + public void writeWithProducerIdAndMessageGroupIdTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("producer") + .setMessageGroupId("producer") + .build()); + + YdbTopic.StreamWriteMessage.InitRequest req = factory.initRequest() + .getInitRequest(); + Assert.assertEquals("/test/topic", req.getPath()); + Assert.assertEquals("producer", req.getProducerId()); + Assert.assertEquals("producer", req.getMessageGroupId()); + Assert.assertFalse(req.hasPartitionId()); + } + + @Test + public void writeWithPartitionIdTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setPartitionId(5L) + .build()); + + YdbTopic.StreamWriteMessage.InitRequest req = factory.initRequest().getInitRequest(); + Assert.assertEquals(5L, req.getPartitionId()); + Assert.assertFalse(req.hasMessageGroupId()); + } + + @Test + public void messageGroupAndPartitionErrorTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + WriterSettings settings = WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setMessageGroupId("group-1") + .setPartitionId(5L) + .build(); + Exception ex = Assert.assertThrows(IllegalArgumentException.class, () -> WriteStreamFactory.of(rpc, settings)); + Assert.assertEquals("Both MessageGroupId and PartitionId are set in WriterSettings", ex.getMessage()); + } + + @Test + public void invalidDirectWriteTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + WriterSettings settings = WriterSettings.newBuilder() + .setTopicPath("/local/topic") + .setDirectWrite(true) // requires producerId or partitionId + .build(); + + Exception ex = Assert.assertThrows(IllegalArgumentException.class, () -> WriteStreamFactory.of(rpc, settings)); + Assert.assertEquals("Direct writing requires PartitionId or ProducerId in WriterSettings", ex.getMessage()); + } + + @Test + public void directWriteByPartitionIdTest() { + GrpcReadWriteStream grpc = mockGrpcStream(); + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + mockDescribeResult(rpc, partition(1L, 10), partition(2L, 42), partition(3L, 23)); + Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(grpc); + + WriterSettings settings = WriterSettings.newBuilder() + .setTopicPath("/local/topic") + .setPartitionId(2L) + .setDirectWrite(true) + .build(); + + // just verify it doesn't throw and returns a factory for the correct topic + WriteStreamFactory factory = WriteStreamFactory.of(rpc, settings); + Assert.assertEquals("/local/topic", factory.getTopicPath()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream); + Mockito.verify(rpc).writeSession(Mockito.any(GrpcRequestSettings.class)); + } + + @Test + public void directWriteByPartitionIdTestDescribeFailTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + mockDescribeResult(rpc, Status.of(StatusCode.UNAVAILABLE)); + + WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setPartitionId(3L) + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + + Mockito.verify(rpc, Mockito.never()).writeSession(Mockito.any(GrpcRequestSettings.class)); + + Assert.assertTrue(stream instanceof WriteStream.Fail); + CompletableFuture res = stream.start(null, null); + Assert.assertTrue(res.isDone()); + Assert.assertEquals(Status.of(StatusCode.UNAVAILABLE), res.join()); + + stream.close(); // no effect + } + + @Test + public void directWriteByPartitionIdTestPartitionNotFoundTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + // result has partition 5, but we're looking for partition 3 + mockDescribeResult(rpc, partition(4L, 99), partition(5L, 100)); + + WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setPartitionId(3L) + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + + Mockito.verify(rpc, Mockito.never()).writeSession(Mockito.any(GrpcRequestSettings.class)); + + Assert.assertTrue(stream instanceof WriteStream.Fail); + CompletableFuture res = stream.start(null, null); + Assert.assertTrue(res.isDone()); + Status expected = Status.of(StatusCode.BAD_REQUEST, Issue.of("Cannot find partition 3", Issue.Severity.ERROR)); + Assert.assertEquals(expected, res.join()); + + stream.close(); // no effect + } + + @Test + public void directWriteByProducerIdTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + GrpcReadWriteStream probeGrpc = mockGrpcStream(); + GrpcReadWriteStream actualGrpc = mockGrpcStream(); + + FromServer initResponse = FromServer.newBuilder() + .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS) + .setInitResponse(YdbTopic.StreamWriteMessage.InitResponse.newBuilder() + .setLastSeqNo(0) + .setPartitionId(7L) + .setSessionId("session") + .build()) + .build(); + + mockStreamResponse(probeGrpc, initResponse); + mockDescribeResult(rpc, partition(7L, 55)); + + Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))) + .thenReturn(probeGrpc).thenReturn(actualGrpc); + + WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("producer-1") + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream); + Mockito.verify(rpc, Mockito.times(2)).writeSession(Mockito.any(GrpcRequestSettings.class)); + } + + @Test + public void directWriteByProducerIdProbeFailTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + GrpcReadWriteStream probeGrpc = mockGrpcStream(); + + mockStreamError(probeGrpc, Status.of(StatusCode.UNAUTHORIZED)); + + Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(probeGrpc); + + WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("producer-1") + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream.Fail); + Mockito.verify(rpc).writeSession(Mockito.any(GrpcRequestSettings.class)); + + CompletableFuture res = stream.start(null, null); + Assert.assertTrue(res.isDone()); + Assert.assertEquals(Status.of(StatusCode.UNAUTHORIZED), res.join()); + stream.close(); // no effect + } + + @Test + public void directWriteByProducerIdProbeWrongResponseTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + GrpcReadWriteStream probeGrpc = mockGrpcStream(); + + FromServer initResponse = FromServer.newBuilder() + .setStatus(StatusCodesProtos.StatusIds.StatusCode.INTERNAL_ERROR) + .build(); + mockStreamResponse(probeGrpc, initResponse); + + Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(probeGrpc); + + WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("producer-1") + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream.Fail); + Mockito.verify(rpc).writeSession(Mockito.any(GrpcRequestSettings.class)); + + CompletableFuture res = stream.start(null, null); + Assert.assertTrue(res.isDone()); + Assert.assertEquals(Status.of(StatusCode.INTERNAL_ERROR), res.join()); + stream.close(); // no effect + } + + @Test + public void directWriteByProducerIdProbeUnexpectedResponseTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + GrpcReadWriteStream probeGrpc = mockGrpcStream(); + + FromServer initResponse = FromServer.newBuilder() + .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS) + .setUpdateTokenResponse(YdbTopic.UpdateTokenResponse.newBuilder().build()) + .build(); + mockStreamResponse(probeGrpc, initResponse); + + Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(probeGrpc); + + WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("producer-1") + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream.Fail); + Mockito.verify(rpc).writeSession(Mockito.any(GrpcRequestSettings.class)); + + CompletableFuture res = stream.start(null, null); + Assert.assertTrue(res.isDone()); + Issue issue = Issue.of("Unexpected message from stream with producer producer-1", Issue.Severity.ERROR); + Assert.assertEquals(Status.of(StatusCode.BAD_REQUEST, issue), res.join()); + stream.close(); // no effect + } + + @Test + public void directWriteByProducerIdPartitionNotFoundTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + GrpcReadWriteStream probeGrpc = mockGrpcStream(); +// GrpcReadWriteStream actualGrpc = mockGrpcStream(); + + FromServer initResponse = FromServer.newBuilder() + .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS) + .setInitResponse(YdbTopic.StreamWriteMessage.InitResponse.newBuilder() + .setLastSeqNo(0) + .setPartitionId(5L) + .setSessionId("session") + .build()) + .build(); + + mockStreamResponse(probeGrpc, initResponse); + mockDescribeResult(rpc, partition(1L, 55), partition(2L, 55)); + + Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(probeGrpc); + + WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("producer-1") + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream.Fail); + CompletableFuture res = stream.start(null, null); + Assert.assertTrue(res.isDone()); + Status expected = Status.of(StatusCode.BAD_REQUEST, Issue.of("Cannot find partition 5", Issue.Severity.ERROR)); + Assert.assertEquals(expected, res.join()); + } +} diff --git a/topic/src/test/java/tech/ydb/topic/write/impl/WriterImplTest.java b/topic/src/test/java/tech/ydb/topic/write/impl/WriterImplTest.java index f341347c7..1b5d72306 100644 --- a/topic/src/test/java/tech/ydb/topic/write/impl/WriterImplTest.java +++ b/topic/src/test/java/tech/ydb/topic/write/impl/WriterImplTest.java @@ -39,7 +39,7 @@ public class WriterImplTest { private static TopicRpc mockRpc(StreamMock first, StreamMock... rest) { TopicRpc rpc = Mockito.mock(TopicRpc.class); Mockito.when(rpc.getScheduler()).thenReturn(Mockito.mock(ScheduledExecutorService.class)); - Mockito.when(rpc.writeSession(Mockito.any())).thenReturn(first, rest); + Mockito.when(rpc.writeSession(Mockito.any(String.class))).thenReturn(first, rest); return rpc; }