From ea55529dbf87f0610140b281dda6be0c53db1677 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Fri, 17 Apr 2026 21:58:04 +0530 Subject: [PATCH 1/9] feat(spanner): add shared endpoint cooldowns for location-aware rerouting --- .../v1/EndpointOverloadCooldownTracker.java | 152 +++++ .../cloud/spanner/spi/v1/GapicSpannerRpc.java | 6 +- .../cloud/spanner/spi/v1/KeyAwareChannel.java | 80 ++- .../cloud/spanner/spi/v1/KeyRangeCache.java | 453 +++++++-------- ...nAwareSharedBackendReplicaHarnessTest.java | 526 ++++++++++++++++++ .../cloud/spanner/MockSpannerServiceImpl.java | 25 +- .../spanner/SharedBackendReplicaHarness.java | 310 +++++++++++ .../spanner/spi/v1/KeyAwareChannelTest.java | 96 +++- 8 files changed, 1372 insertions(+), 276 deletions(-) create mode 100644 java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointOverloadCooldownTracker.java create mode 100644 java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java create mode 100644 java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SharedBackendReplicaHarness.java diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointOverloadCooldownTracker.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointOverloadCooldownTracker.java new file mode 100644 index 000000000000..0663adf4d9b0 --- /dev/null +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointOverloadCooldownTracker.java @@ -0,0 +1,152 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.spi.v1; + +import com.google.common.annotations.VisibleForTesting; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.LongUnaryOperator; + +/** + * Tracks short-lived endpoint cooldowns after routed {@code RESOURCE_EXHAUSTED} failures. + * + *

This allows later requests to try a different replica instead of immediately routing back to + * the same overloaded endpoint. + */ +final class EndpointOverloadCooldownTracker { + + @VisibleForTesting static final Duration DEFAULT_INITIAL_COOLDOWN = Duration.ofSeconds(10); + @VisibleForTesting static final Duration DEFAULT_MAX_COOLDOWN = Duration.ofMinutes(1); + @VisibleForTesting static final Duration DEFAULT_RESET_AFTER = Duration.ofMinutes(10); + + @VisibleForTesting + static final class CooldownState { + private final int consecutiveFailures; + private final Instant cooldownUntil; + private final Instant lastFailureAt; + + private CooldownState(int consecutiveFailures, Instant cooldownUntil, Instant lastFailureAt) { + this.consecutiveFailures = consecutiveFailures; + this.cooldownUntil = cooldownUntil; + this.lastFailureAt = lastFailureAt; + } + } + + private final ConcurrentHashMap entries = new ConcurrentHashMap<>(); + private final Duration initialCooldown; + private final Duration maxCooldown; + private final Duration resetAfter; + private final Clock clock; + private final LongUnaryOperator randomLong; + + EndpointOverloadCooldownTracker() { + this( + DEFAULT_INITIAL_COOLDOWN, + DEFAULT_MAX_COOLDOWN, + DEFAULT_RESET_AFTER, + Clock.systemUTC(), + bound -> ThreadLocalRandom.current().nextLong(bound)); + } + + @VisibleForTesting + EndpointOverloadCooldownTracker( + Duration initialCooldown, + Duration maxCooldown, + Duration resetAfter, + Clock clock, + LongUnaryOperator randomLong) { + Duration resolvedInitial = + (initialCooldown == null || initialCooldown.isZero() || initialCooldown.isNegative()) + ? DEFAULT_INITIAL_COOLDOWN + : initialCooldown; + Duration resolvedMax = + (maxCooldown == null || maxCooldown.isZero() || maxCooldown.isNegative()) + ? DEFAULT_MAX_COOLDOWN + : maxCooldown; + if (resolvedMax.compareTo(resolvedInitial) < 0) { + resolvedMax = resolvedInitial; + } + this.initialCooldown = resolvedInitial; + this.maxCooldown = resolvedMax; + this.resetAfter = + (resetAfter == null || resetAfter.isZero() || resetAfter.isNegative()) + ? DEFAULT_RESET_AFTER + : resetAfter; + this.clock = clock == null ? Clock.systemUTC() : clock; + this.randomLong = + randomLong == null ? bound -> ThreadLocalRandom.current().nextLong(bound) : randomLong; + } + + boolean isCoolingDown(String address) { + if (address == null || address.isEmpty()) { + return false; + } + Instant now = clock.instant(); + CooldownState state = entries.get(address); + if (state == null) { + return false; + } + if (state.cooldownUntil.isAfter(now)) { + return true; + } + if (Duration.between(state.lastFailureAt, now).compareTo(resetAfter) < 0) { + return false; + } + entries.remove(address, state); + CooldownState current = entries.get(address); + return current != null && current.cooldownUntil.isAfter(now); + } + + void recordFailure(String address) { + if (address == null || address.isEmpty()) { + return; + } + Instant now = clock.instant(); + entries.compute( + address, + (ignored, state) -> { + int consecutiveFailures = 1; + if (state != null + && Duration.between(state.lastFailureAt, now).compareTo(resetAfter) < 0) { + consecutiveFailures = state.consecutiveFailures + 1; + } + Duration cooldown = cooldownForFailures(consecutiveFailures); + return new CooldownState(consecutiveFailures, now.plus(cooldown), now); + }); + } + + private Duration cooldownForFailures(int failures) { + Duration cooldown = initialCooldown; + for (int i = 1; i < failures; i++) { + if (cooldown.compareTo(maxCooldown.dividedBy(2)) > 0) { + cooldown = maxCooldown; + break; + } + cooldown = cooldown.multipliedBy(2); + } + long bound = Math.max(1L, cooldown.toMillis() + 1L); + return Duration.ofMillis(randomLong.applyAsLong(bound)); + } + + @VisibleForTesting + CooldownState getState(String address) { + return entries.get(address); + } +} diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 7c1b6be1c1bd..6cc0a485d056 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -432,7 +432,11 @@ public GapicSpannerRpc(final SpannerOptions options) { this.readRetrySettings = options.getSpannerStubSettings().streamingReadSettings().getRetrySettings(); this.readRetryableCodes = - options.getSpannerStubSettings().streamingReadSettings().getRetryableCodes(); + ImmutableSet.builder() + .addAll( + options.getSpannerStubSettings().streamingReadSettings().getRetryableCodes()) + .add(Code.RESOURCE_EXHAUSTED) + .build(); this.executeQueryRetrySettings = options.getSpannerStubSettings().executeStreamingSqlSettings().getRetrySettings(); this.executeQueryRetryableCodes = diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java index d7b32f72bcd6..90ff41f35973 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java @@ -21,6 +21,7 @@ import com.google.api.core.InternalApi; import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider; import com.google.cloud.spanner.XGoogSpannerRequestId; +import com.google.common.annotations.VisibleForTesting; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import com.google.protobuf.ByteString; @@ -102,11 +103,20 @@ final class KeyAwareChannel extends ManagedChannel { .maximumSize(MAX_TRACKED_EXCLUDED_LOGICAL_REQUESTS) .expireAfterWrite(EXCLUDED_LOGICAL_REQUEST_TTL_MINUTES, TimeUnit.MINUTES) .build(); + private final EndpointOverloadCooldownTracker endpointOverloadCooldowns; private KeyAwareChannel( InstantiatingGrpcChannelProvider channelProvider, @Nullable ChannelEndpointCacheFactory endpointCacheFactory) throws IOException { + this(channelProvider, endpointCacheFactory, new EndpointOverloadCooldownTracker()); + } + + private KeyAwareChannel( + InstantiatingGrpcChannelProvider channelProvider, + @Nullable ChannelEndpointCacheFactory endpointCacheFactory, + EndpointOverloadCooldownTracker endpointOverloadCooldowns) + throws IOException { if (endpointCacheFactory == null) { this.endpointCache = new GrpcChannelEndpointCache(channelProvider); } else { @@ -120,6 +130,7 @@ private KeyAwareChannel( // would interfere with test assertions. this.lifecycleManager = (endpointCacheFactory == null) ? new EndpointLifecycleManager(endpointCache) : null; + this.endpointOverloadCooldowns = endpointOverloadCooldowns; } static KeyAwareChannel create( @@ -129,6 +140,15 @@ static KeyAwareChannel create( return new KeyAwareChannel(channelProvider, endpointCacheFactory); } + @VisibleForTesting + static KeyAwareChannel create( + InstantiatingGrpcChannelProvider channelProvider, + @Nullable ChannelEndpointCacheFactory endpointCacheFactory, + EndpointOverloadCooldownTracker endpointOverloadCooldowns) + throws IOException { + return new KeyAwareChannel(channelProvider, endpointCacheFactory, endpointOverloadCooldowns); + } + private static final class ChannelFinderReference extends SoftReference { final String databaseId; @@ -321,36 +341,56 @@ void clearTransactionAndChannelAffinity(ByteString transactionId, @Nullable Long private void maybeExcludeEndpointOnNextCall( @Nullable ChannelEndpoint endpoint, @Nullable String logicalRequestKey) { - if (endpoint == null || logicalRequestKey == null) { + if (endpoint == null) { return; } String address = endpoint.getAddress(); - if (!defaultEndpointAddress.equals(address)) { - excludedEndpointsForLogicalRequest - .asMap() - .compute( - logicalRequestKey, - (ignored, excludedEndpoints) -> { - Set updated = - excludedEndpoints == null ? ConcurrentHashMap.newKeySet() : excludedEndpoints; - updated.add(address); - return updated; - }); + if (defaultEndpointAddress.equals(address)) { + return; + } + endpointOverloadCooldowns.recordFailure(address); + if (logicalRequestKey == null) { + return; } + excludedEndpointsForLogicalRequest + .asMap() + .compute( + logicalRequestKey, + (ignored, excludedEndpoints) -> { + Set updated = + excludedEndpoints == null ? ConcurrentHashMap.newKeySet() : excludedEndpoints; + updated.add(address); + return updated; + }); } private Predicate consumeExcludedEndpointsForCurrentCall( @Nullable String logicalRequestKey) { - if (logicalRequestKey == null) { - return address -> false; + Predicate requestScopedExcluded = address -> false; + if (logicalRequestKey != null) { + Set excludedEndpoints = + excludedEndpointsForLogicalRequest.asMap().remove(logicalRequestKey); + if (excludedEndpoints != null && !excludedEndpoints.isEmpty()) { + excludedEndpoints = new HashSet<>(excludedEndpoints); + requestScopedExcluded = excludedEndpoints::contains; + } } + Predicate finalRequestScopedExcluded = requestScopedExcluded; + return address -> + finalRequestScopedExcluded.test(address) + || endpointOverloadCooldowns.isCoolingDown(address); + } + + @VisibleForTesting + boolean isCoolingDown(String address) { + return endpointOverloadCooldowns.isCoolingDown(address); + } + + @VisibleForTesting + boolean hasExcludedEndpointForLogicalRequest(String logicalRequestKey, String address) { Set excludedEndpoints = - excludedEndpointsForLogicalRequest.asMap().remove(logicalRequestKey); - if (excludedEndpoints == null || excludedEndpoints.isEmpty()) { - return address -> false; - } - excludedEndpoints = new HashSet<>(excludedEndpoints); - return excludedEndpoints::contains; + excludedEndpointsForLogicalRequest.getIfPresent(logicalRequestKey); + return excludedEndpoints != null && excludedEndpoints.contains(address); } private boolean isReadOnlyTransaction(ByteString transactionId) { diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java index 59955ccb4bd2..41b8798d9611 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java @@ -28,6 +28,7 @@ import com.google.spanner.v1.Tablet; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; @@ -185,11 +186,10 @@ Set getActiveAddresses() { readLock.lock(); try { for (CachedGroup group : groups.values()) { - synchronized (group) { - for (CachedTablet tablet : group.tablets) { - if (!tablet.serverAddress.isEmpty()) { - addresses.add(tablet.serverAddress); - } + GroupSnapshot snapshot = group.snapshot; + for (TabletSnapshot tablet : snapshot.tablets) { + if (!tablet.serverAddress.isEmpty()) { + addresses.add(tablet.serverAddress); } } } @@ -487,34 +487,27 @@ private int compare(ByteString left, ByteString right) { return ByteString.unsignedLexicographicalComparator().compare(left, right); } - /** Represents a single tablet within a group. */ - private class CachedTablet { - long tabletUid = 0; - ByteString incarnation = ByteString.EMPTY; - String serverAddress = ""; - int distance = 0; - boolean skip = false; - Tablet.Role role = Tablet.Role.ROLE_UNSPECIFIED; - String location = ""; - - ChannelEndpoint endpoint = null; - - void update(Tablet tabletIn) { - if (tabletUid > 0 && compare(incarnation, tabletIn.getIncarnation()) > 0) { - return; - } - - tabletUid = tabletIn.getTabletUid(); - incarnation = tabletIn.getIncarnation(); - distance = tabletIn.getDistance(); - skip = tabletIn.getSkip(); - role = tabletIn.getRole(); - location = tabletIn.getLocation(); - - if (!serverAddress.equals(tabletIn.getServerAddress())) { - serverAddress = tabletIn.getServerAddress(); - endpoint = null; - } + private static final GroupSnapshot EMPTY_GROUP_SNAPSHOT = + new GroupSnapshot(ByteString.EMPTY, -1, Collections.emptyList()); + + /** Immutable tablet metadata used by the read path without per-group locking. */ + private static final class TabletSnapshot { + final long tabletUid; + final ByteString incarnation; + final String serverAddress; + final int distance; + final boolean skip; + final Tablet.Role role; + final String location; + + private TabletSnapshot(Tablet tabletIn) { + this.tabletUid = tabletIn.getTabletUid(); + this.incarnation = tabletIn.getIncarnation(); + this.serverAddress = tabletIn.getServerAddress(); + this.distance = tabletIn.getDistance(); + this.skip = tabletIn.getSkip(); + this.role = tabletIn.getRole(); + this.location = tabletIn.getLocation(); } boolean matches(DirectedReadOptions directedReadOptions) { @@ -555,132 +548,6 @@ private boolean matches(DirectedReadOptions.ReplicaSelection selection) { } } - /** - * Evaluates whether this tablet should be skipped for location-aware routing. - * - *

State-aware skip logic: - * - *

- */ - boolean shouldSkip( - RoutingHint.Builder hintBuilder, - Predicate excludedEndpoints, - Set skippedTabletUids) { - // Server-marked skip, no address, or excluded endpoint: always report. - if (skip || serverAddress.isEmpty() || excludedEndpoints.test(serverAddress)) { - addSkippedTablet(hintBuilder, skippedTabletUids); - return true; - } - - // If the cached endpoint's channel has been shut down (e.g. after idle eviction), - // discard the stale reference so we re-lookup from the cache below. - if (endpoint != null && endpoint.getChannel().isShutdown()) { - logger.log( - Level.FINE, - "Tablet {0} at {1}: cached endpoint is shutdown, clearing stale reference", - new Object[] {tabletUid, serverAddress}); - endpoint = null; - } - - // Lookup without creating: location-aware routing should not trigger foreground endpoint - // creation. - if (endpoint == null) { - endpoint = endpointCache.getIfPresent(serverAddress); - } - - // No endpoint exists yet - skip silently, request background recreation so the - // endpoint becomes available for future requests. - if (endpoint == null) { - logger.log( - Level.FINE, - "Tablet {0} at {1}: no endpoint present, skipping silently", - new Object[] {tabletUid, serverAddress}); - maybeAddRecentTransientFailureSkip(hintBuilder, skippedTabletUids); - if (lifecycleManager != null) { - lifecycleManager.requestEndpointRecreation(serverAddress); - } - return true; - } - - // READY - usable for location-aware routing. - if (endpoint.isHealthy()) { - return false; - } - - // TRANSIENT_FAILURE - skip and report so server can refresh client cache. - if (endpoint.isTransientFailure()) { - logger.log( - Level.FINE, - "Tablet {0} at {1}: endpoint in TRANSIENT_FAILURE, adding to skipped_tablets", - new Object[] {tabletUid, serverAddress}); - addSkippedTablet(hintBuilder, skippedTabletUids); - return true; - } - - // IDLE, CONNECTING, SHUTDOWN, or unsupported - skip silently. - logger.log( - Level.FINE, - "Tablet {0} at {1}: endpoint not ready, skipping silently", - new Object[] {tabletUid, serverAddress}); - maybeAddRecentTransientFailureSkip(hintBuilder, skippedTabletUids); - return true; - } - - private void addSkippedTablet(RoutingHint.Builder hintBuilder, Set skippedTabletUids) { - if (!skippedTabletUids.add(tabletUid)) { - return; - } - RoutingHint.SkippedTablet.Builder skipped = hintBuilder.addSkippedTabletUidBuilder(); - skipped.setTabletUid(tabletUid); - skipped.setIncarnation(incarnation); - } - - private void recordKnownTransientFailure( - RoutingHint.Builder hintBuilder, - Predicate excludedEndpoints, - Set skippedTabletUids) { - if (skip || serverAddress.isEmpty() || excludedEndpoints.test(serverAddress)) { - return; - } - - if (endpoint != null && endpoint.getChannel().isShutdown()) { - endpoint = null; - } - - if (endpoint == null) { - endpoint = endpointCache.getIfPresent(serverAddress); - } - - if (endpoint != null && endpoint.isTransientFailure()) { - addSkippedTablet(hintBuilder, skippedTabletUids); - return; - } - - maybeAddRecentTransientFailureSkip(hintBuilder, skippedTabletUids); - } - - private void maybeAddRecentTransientFailureSkip( - RoutingHint.Builder hintBuilder, Set skippedTabletUids) { - if (lifecycleManager != null - && lifecycleManager.wasRecentlyEvictedTransientFailure(serverAddress)) { - addSkippedTablet(hintBuilder, skippedTabletUids); - } - } - - ChannelEndpoint pick(RoutingHint.Builder hintBuilder) { - hintBuilder.setTabletUid(tabletUid); - // Endpoint must already exist and be READY if shouldSkip returned false. - return endpoint; - } - String debugString() { return tabletUid + ":" @@ -698,19 +565,40 @@ String debugString() { } } + private static final class GroupSnapshot { + final ByteString generation; + final int leaderIndex; + final List tablets; + + private GroupSnapshot(ByteString generation, int leaderIndex, List tablets) { + this.generation = generation; + this.leaderIndex = leaderIndex; + this.tablets = Collections.unmodifiableList(new ArrayList<>(tablets)); + } + + boolean hasLeader() { + return leaderIndex >= 0 && leaderIndex < tablets.size(); + } + + TabletSnapshot leader() { + return tablets.get(leaderIndex); + } + } + /** Represents a paxos group with its tablets. */ private class CachedGroup { final long groupUid; - ByteString generation = ByteString.EMPTY; - List tablets = new ArrayList<>(); - int leaderIndex = -1; + volatile GroupSnapshot snapshot = EMPTY_GROUP_SNAPSHOT; int refs = 1; CachedGroup(long groupUid) { this.groupUid = groupUid; } - synchronized void update(Group groupIn) { + void update(Group groupIn) { + GroupSnapshot current = snapshot; + ByteString generation = current.generation; + int leaderIndex = current.leaderIndex; if (compare(groupIn.getGeneration(), generation) > 0) { generation = groupIn.getGeneration(); if (groupIn.getLeaderIndex() >= 0 && groupIn.getLeaderIndex() < groupIn.getTabletsCount()) { @@ -720,37 +608,11 @@ synchronized void update(Group groupIn) { } } - if (tablets.size() == groupIn.getTabletsCount()) { - boolean mismatch = false; - for (int t = 0; t < groupIn.getTabletsCount(); t++) { - if (tablets.get(t).tabletUid != groupIn.getTablets(t).getTabletUid()) { - mismatch = true; - break; - } - } - if (!mismatch) { - for (int t = 0; t < groupIn.getTabletsCount(); t++) { - tablets.get(t).update(groupIn.getTablets(t)); - } - return; - } - } - - Map tabletsByUid = new HashMap<>(tablets.size()); - for (CachedTablet tablet : tablets) { - tabletsByUid.put(tablet.tabletUid, tablet); - } - List newTablets = new ArrayList<>(groupIn.getTabletsCount()); + List tablets = new ArrayList<>(groupIn.getTabletsCount()); for (int t = 0; t < groupIn.getTabletsCount(); t++) { - Tablet tabletIn = groupIn.getTablets(t); - CachedTablet tablet = tabletsByUid.get(tabletIn.getTabletUid()); - if (tablet == null) { - tablet = new CachedTablet(); - } - tablet.update(tabletIn); - newTablets.add(tablet); + tablets.add(new TabletSnapshot(groupIn.getTablets(t))); } - tablets = newTablets; + snapshot = new GroupSnapshot(generation, leaderIndex, tablets); } ChannelEndpoint fillRoutingHint( @@ -758,59 +620,72 @@ ChannelEndpoint fillRoutingHint( DirectedReadOptions directedReadOptions, RoutingHint.Builder hintBuilder, Predicate excludedEndpoints) { + GroupSnapshot snapshot = this.snapshot; Set skippedTabletUids = skippedTabletUids(hintBuilder); boolean hasDirectedReadOptions = directedReadOptions.getReplicasCase() != DirectedReadOptions.ReplicasCase.REPLICAS_NOT_SET; - - // Select a tablet while holding the lock. With state-aware routing, only READY - // endpoints pass shouldSkip(), so the selected tablet always has a cached - // endpoint. No foreground endpoint creation is needed — the lifecycle manager - // creates endpoints in the background. - synchronized (this) { - CachedTablet selected = - selectTabletLocked( - preferLeader, - hasDirectedReadOptions, - hintBuilder, - directedReadOptions, - excludedEndpoints, - skippedTabletUids); - if (selected == null) { - return null; - } - recordKnownTransientFailuresLocked( - selected, directedReadOptions, hintBuilder, excludedEndpoints, skippedTabletUids); - return selected.pick(hintBuilder); - } - } - - private CachedTablet selectTabletLocked( + Map resolvedEndpoints = new HashMap<>(); + + TabletSnapshot selected = + selectTablet( + snapshot, + preferLeader, + hasDirectedReadOptions, + hintBuilder, + directedReadOptions, + excludedEndpoints, + skippedTabletUids, + resolvedEndpoints); + if (selected == null) { + return null; + } + recordKnownTransientFailures( + snapshot, + selected, + directedReadOptions, + hintBuilder, + excludedEndpoints, + skippedTabletUids, + resolvedEndpoints); + hintBuilder.setTabletUid(selected.tabletUid); + return resolveEndpoint(selected, resolvedEndpoints); + } + + private TabletSnapshot selectTablet( + GroupSnapshot snapshot, boolean preferLeader, boolean hasDirectedReadOptions, RoutingHint.Builder hintBuilder, DirectedReadOptions directedReadOptions, Predicate excludedEndpoints, - Set skippedTabletUids) { + Set skippedTabletUids, + Map resolvedEndpoints) { boolean checkedLeader = false; if (preferLeader && !hasDirectedReadOptions - && hasLeader() - && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) { + && snapshot.hasLeader() + && snapshot.leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) { checkedLeader = true; - if (!leader().shouldSkip(hintBuilder, excludedEndpoints, skippedTabletUids)) { - return leader(); + if (!shouldSkip( + snapshot.leader(), + hintBuilder, + excludedEndpoints, + skippedTabletUids, + resolvedEndpoints)) { + return snapshot.leader(); } } - for (int index = 0; index < tablets.size(); index++) { - if (checkedLeader && index == leaderIndex) { + for (int index = 0; index < snapshot.tablets.size(); index++) { + if (checkedLeader && index == snapshot.leaderIndex) { continue; } - CachedTablet tablet = tablets.get(index); + TabletSnapshot tablet = snapshot.tablets.get(index); if (!tablet.matches(directedReadOptions)) { continue; } - if (tablet.shouldSkip(hintBuilder, excludedEndpoints, skippedTabletUids)) { + if (shouldSkip( + tablet, hintBuilder, excludedEndpoints, skippedTabletUids, resolvedEndpoints)) { continue; } return tablet; @@ -818,17 +693,20 @@ && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) { return null; } - private void recordKnownTransientFailuresLocked( - CachedTablet selected, + private void recordKnownTransientFailures( + GroupSnapshot snapshot, + TabletSnapshot selected, DirectedReadOptions directedReadOptions, RoutingHint.Builder hintBuilder, Predicate excludedEndpoints, - Set skippedTabletUids) { - for (CachedTablet tablet : tablets) { + Set skippedTabletUids, + Map resolvedEndpoints) { + for (TabletSnapshot tablet : snapshot.tablets) { if (tablet == selected || !tablet.matches(directedReadOptions)) { continue; } - tablet.recordKnownTransientFailure(hintBuilder, excludedEndpoints, skippedTabletUids); + recordKnownTransientFailure( + tablet, hintBuilder, excludedEndpoints, skippedTabletUids, resolvedEndpoints); } } @@ -840,27 +718,124 @@ private Set skippedTabletUids(RoutingHint.Builder hintBuilder) { return skippedTabletUids; } - boolean hasLeader() { - return leaderIndex >= 0 && leaderIndex < tablets.size(); + private boolean shouldSkip( + TabletSnapshot tablet, + RoutingHint.Builder hintBuilder, + Predicate excludedEndpoints, + Set skippedTabletUids, + Map resolvedEndpoints) { + if (tablet.skip + || tablet.serverAddress.isEmpty() + || excludedEndpoints.test(tablet.serverAddress)) { + addSkippedTablet(tablet, hintBuilder, skippedTabletUids); + return true; + } + + ChannelEndpoint endpoint = resolveEndpoint(tablet, resolvedEndpoints); + if (endpoint == null) { + logger.log( + Level.FINE, + "Tablet {0} at {1}: no endpoint present, skipping silently", + new Object[] {tablet.tabletUid, tablet.serverAddress}); + maybeAddRecentTransientFailureSkip(tablet, hintBuilder, skippedTabletUids); + if (lifecycleManager != null) { + lifecycleManager.requestEndpointRecreation(tablet.serverAddress); + } + return true; + } + if (endpoint.isHealthy()) { + return false; + } + if (endpoint.isTransientFailure()) { + logger.log( + Level.FINE, + "Tablet {0} at {1}: endpoint in TRANSIENT_FAILURE, adding to skipped_tablets", + new Object[] {tablet.tabletUid, tablet.serverAddress}); + addSkippedTablet(tablet, hintBuilder, skippedTabletUids); + return true; + } + + logger.log( + Level.FINE, + "Tablet {0} at {1}: endpoint not ready, skipping silently", + new Object[] {tablet.tabletUid, tablet.serverAddress}); + maybeAddRecentTransientFailureSkip(tablet, hintBuilder, skippedTabletUids); + return true; } - CachedTablet leader() { - return tablets.get(leaderIndex); + private void recordKnownTransientFailure( + TabletSnapshot tablet, + RoutingHint.Builder hintBuilder, + Predicate excludedEndpoints, + Set skippedTabletUids, + Map resolvedEndpoints) { + if (tablet.skip + || tablet.serverAddress.isEmpty() + || excludedEndpoints.test(tablet.serverAddress)) { + return; + } + + ChannelEndpoint endpoint = resolveEndpoint(tablet, resolvedEndpoints); + if (endpoint != null && endpoint.isTransientFailure()) { + addSkippedTablet(tablet, hintBuilder, skippedTabletUids); + return; + } + + maybeAddRecentTransientFailureSkip(tablet, hintBuilder, skippedTabletUids); + } + + private ChannelEndpoint resolveEndpoint( + TabletSnapshot tablet, Map resolvedEndpoints) { + if (tablet.serverAddress.isEmpty()) { + return null; + } + if (resolvedEndpoints.containsKey(tablet.serverAddress)) { + return resolvedEndpoints.get(tablet.serverAddress); + } + ChannelEndpoint endpoint = endpointCache.getIfPresent(tablet.serverAddress); + if (endpoint != null && endpoint.getChannel().isShutdown()) { + logger.log( + Level.FINE, + "Tablet {0} at {1}: cached endpoint is shutdown, clearing stale reference", + new Object[] {tablet.tabletUid, tablet.serverAddress}); + endpoint = null; + } + resolvedEndpoints.put(tablet.serverAddress, endpoint); + return endpoint; + } + + private void maybeAddRecentTransientFailureSkip( + TabletSnapshot tablet, RoutingHint.Builder hintBuilder, Set skippedTabletUids) { + if (lifecycleManager != null + && lifecycleManager.wasRecentlyEvictedTransientFailure(tablet.serverAddress)) { + addSkippedTablet(tablet, hintBuilder, skippedTabletUids); + } + } + + private void addSkippedTablet( + TabletSnapshot tablet, RoutingHint.Builder hintBuilder, Set skippedTabletUids) { + if (!skippedTabletUids.add(tablet.tabletUid)) { + return; + } + RoutingHint.SkippedTablet.Builder skipped = hintBuilder.addSkippedTabletUidBuilder(); + skipped.setTabletUid(tablet.tabletUid); + skipped.setIncarnation(tablet.incarnation); } String debugString() { + GroupSnapshot snapshot = this.snapshot; StringBuilder sb = new StringBuilder(); sb.append(groupUid).append(":["); - for (int i = 0; i < tablets.size(); i++) { - sb.append(tablets.get(i).debugString()); - if (hasLeader() && i == leaderIndex) { + for (int i = 0; i < snapshot.tablets.size(); i++) { + sb.append(snapshot.tablets.get(i).debugString()); + if (snapshot.hasLeader() && i == snapshot.leaderIndex) { sb.append(" (leader)"); } - if (i < tablets.size() - 1) { + if (i < snapshot.tablets.size() - 1) { sb.append(", "); } } - sb.append("]@").append(generation.toStringUtf8()); + sb.append("]@").append(snapshot.generation.toStringUtf8()); sb.append("#").append(refs); return sb.toString(); } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java new file mode 100644 index 000000000000..2196fad56696 --- /dev/null +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java @@ -0,0 +1,526 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.NoCredentials; +import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.spi.v1.KeyRecipeCache; +import com.google.protobuf.AbstractMessage; +import com.google.protobuf.ByteString; +import com.google.protobuf.ListValue; +import com.google.protobuf.TextFormat; +import com.google.protobuf.Value; +import com.google.rpc.RetryInfo; +import com.google.spanner.v1.CacheUpdate; +import com.google.spanner.v1.DirectedReadOptions; +import com.google.spanner.v1.DirectedReadOptions.IncludeReplicas; +import com.google.spanner.v1.DirectedReadOptions.ReplicaSelection; +import com.google.spanner.v1.Group; +import com.google.spanner.v1.Range; +import com.google.spanner.v1.ReadRequest; +import com.google.spanner.v1.RecipeList; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.RoutingHint; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.Tablet; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.protobuf.ProtoUtils; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class LocationAwareSharedBackendReplicaHarnessTest { + + private static final String PROJECT = "fake-project"; + private static final String INSTANCE = "fake-instance"; + private static final String DATABASE = "fake-database"; + private static final String TABLE = "T"; + private static final String REPLICA_LOCATION = "us-east1"; + private static final Statement SEED_QUERY = Statement.of("SELECT 1"); + private static final ByteString RESUME_TOKEN_AFTER_FIRST_ROW = + ByteString.copyFromUtf8("000000001"); + private static final DirectedReadOptions DIRECTED_READ_OPTIONS = + DirectedReadOptions.newBuilder() + .setIncludeReplicas( + IncludeReplicas.newBuilder() + .addReplicaSelections( + ReplicaSelection.newBuilder() + .setLocation(REPLICA_LOCATION) + .setType(ReplicaSelection.Type.READ_ONLY) + .build()) + .build()) + .build(); + + @BeforeClass + public static void enableLocationAwareRouting() { + SpannerOptions.useEnvironment( + new SpannerOptions.SpannerEnvironment() { + @Override + public boolean isEnableLocationApi() { + return true; + } + }); + } + + @AfterClass + public static void restoreEnvironment() { + SpannerOptions.useDefaultEnvironment(); + } + + @Test + public void singleUseReadReroutesOnResourceExhaustedForBypassTraffic() throws Exception { + try (SharedBackendReplicaHarness harness = SharedBackendReplicaHarness.create(2); + Spanner spanner = createSpanner(harness)) { + configureBackend(harness, singleRowReadResultSet("b")); + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); + + seedLocationMetadata(client); + waitForReplicaRoutedRead(client, harness, 0); + harness.clearRequests(); + + harness + .replicas + .get(0) + .putMethodErrors( + SharedBackendReplicaHarness.METHOD_STREAMING_READ, + resourceExhausted("busy-routed-replica")); + + try (ResultSet resultSet = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + assertTrue(resultSet.next()); + } + + assertEquals( + 1, + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 1, + harness + .replicas + .get(1) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 0, + harness + .defaultReplica + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + ReadRequest replicaARequest = + (ReadRequest) + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0); + assertTrue(replicaARequest.getResumeToken().isEmpty()); + assertRetriedOnSameLogicalRequest( + harness + .replicas + .get(0) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0), + harness + .replicas + .get(1) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0)); + } + } + + @Test + public void singleUseReadCooldownSkipsReplicaOnNextRequestForBypassTraffic() throws Exception { + try (SharedBackendReplicaHarness harness = SharedBackendReplicaHarness.create(2); + Spanner spanner = createSpanner(harness)) { + configureBackend(harness, singleRowReadResultSet("b")); + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); + + seedLocationMetadata(client); + waitForReplicaRoutedRead(client, harness, 0); + harness.clearRequests(); + + harness + .replicas + .get(0) + .putMethodErrors( + SharedBackendReplicaHarness.METHOD_STREAMING_READ, + resourceExhaustedWithRetryInfo("busy-routed-replica")); + + try (ResultSet firstRead = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + assertTrue(firstRead.next()); + } + + try (ResultSet secondRead = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + assertTrue(secondRead.next()); + } + + assertEquals( + 1, + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 2, + harness + .replicas + .get(1) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 0, + harness + .defaultReplica + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + List replicaBRequests = + harness.replicas.get(1).getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ); + for (AbstractMessage request : replicaBRequests) { + assertTrue(((ReadRequest) request).getResumeToken().isEmpty()); + } + List replicaBRequestIds = + harness.replicas.get(1).getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ); + assertRetriedOnSameLogicalRequest( + harness + .replicas + .get(0) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0), + replicaBRequestIds.get(0)); + assertNotEquals( + XGoogSpannerRequestId.of(replicaBRequestIds.get(0)).getLogicalRequestKey(), + XGoogSpannerRequestId.of(replicaBRequestIds.get(1)).getLogicalRequestKey()); + } + } + + @Test + public void singleUseReadMidStreamRecvFailureWithoutRetryInfoRetriesForBypassTraffic() + throws Exception { + try (SharedBackendReplicaHarness harness = SharedBackendReplicaHarness.create(2); + Spanner spanner = createSpanner(harness)) { + configureBackend(harness, multiRowReadResultSet("b", "c", "d")); + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); + + seedLocationMetadata(client); + waitForReplicaRoutedRead(client, harness, 0); + harness.clearRequests(); + + harness.backend.setStreamingReadExecutionTime( + SimulatedExecutionTime.ofStreamException(resourceExhausted("busy-routed-replica"), 1L)); + + List rows = new ArrayList<>(); + try (ResultSet resultSet = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + while (resultSet.next()) { + rows.add(resultSet.getString(0)); + } + } + + assertEquals(Arrays.asList("b", "c", "d"), rows); + assertEquals( + 1, + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 1, + harness + .replicas + .get(1) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 0, + harness + .defaultReplica + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + + ReadRequest replicaARequest = + (ReadRequest) + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0); + ReadRequest replicaBRequest = + (ReadRequest) + harness + .replicas + .get(1) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0); + assertTrue(replicaARequest.getResumeToken().isEmpty()); + assertEquals(RESUME_TOKEN_AFTER_FIRST_ROW, replicaBRequest.getResumeToken()); + assertRetriedOnSameLogicalRequest( + harness + .replicas + .get(0) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0), + harness + .replicas + .get(1) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0)); + } + } + + private static Spanner createSpanner(SharedBackendReplicaHarness harness) { + return SpannerOptions.newBuilder() + .usePlainText() + .setExperimentalHost(harness.defaultAddress) + .setProjectId(PROJECT) + .setCredentials(NoCredentials.getInstance()) + .setChannelEndpointCacheFactory(null) + .build() + .getService(); + } + + private static void configureBackend( + SharedBackendReplicaHarness harness, com.google.spanner.v1.ResultSet readResultSet) + throws TextFormat.ParseException { + Statement readStatement = + StatementResult.createReadStatement( + TABLE, KeySet.singleKey(Key.of("b")), Arrays.asList("k")); + harness.backend.putStatementResult(StatementResult.query(readStatement, readResultSet)); + harness.backend.putStatementResult( + StatementResult.query( + SEED_QUERY, + singleRowReadResultSet("seed").toBuilder() + .setCacheUpdate(cacheUpdate(harness)) + .build())); + } + + private static void seedLocationMetadata(DatabaseClient client) { + try (com.google.cloud.spanner.ResultSet resultSet = + client.singleUse().executeQuery(SEED_QUERY)) { + while (resultSet.next()) { + // Consume the cache update on the first query result. + } + } + } + + private static void waitForReplicaRoutedRead( + DatabaseClient client, SharedBackendReplicaHarness harness, int replicaIndex) + throws InterruptedException { + long deadlineNanos = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + while (System.nanoTime() < deadlineNanos) { + try (ResultSet resultSet = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + if (resultSet.next() + && !harness + .replicas + .get(replicaIndex) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .isEmpty()) { + return; + } + } + Thread.sleep(50L); + } + throw new AssertionError("Timed out waiting for location-aware read to route to replica"); + } + + private static CacheUpdate cacheUpdate(SharedBackendReplicaHarness harness) + throws TextFormat.ParseException { + RecipeList recipes = readRecipeList(); + RoutingHint routingHint = exactReadRoutingHint(recipes); + ByteString limitKey = routingHint.getLimitKey(); + if (limitKey.isEmpty()) { + limitKey = routingHint.getKey().concat(ByteString.copyFrom(new byte[] {0})); + } + + return CacheUpdate.newBuilder() + .setDatabaseId(12345L) + .setKeyRecipes(recipes) + .addRange( + Range.newBuilder() + .setStartKey(routingHint.getKey()) + .setLimitKey(limitKey) + .setGroupUid(1L) + .setSplitId(1L) + .setGeneration(com.google.protobuf.ByteString.copyFromUtf8("gen1"))) + .addGroup( + Group.newBuilder() + .setGroupUid(1L) + .setGeneration(com.google.protobuf.ByteString.copyFromUtf8("gen1")) + .setLeaderIndex(0) + .addTablets( + Tablet.newBuilder() + .setTabletUid(11L) + .setServerAddress(harness.replicaAddresses.get(0)) + .setLocation(REPLICA_LOCATION) + .setRole(Tablet.Role.READ_ONLY) + .setDistance(0)) + .addTablets( + Tablet.newBuilder() + .setTabletUid(12L) + .setServerAddress(harness.replicaAddresses.get(1)) + .setLocation(REPLICA_LOCATION) + .setRole(Tablet.Role.READ_ONLY) + .setDistance(0))) + .build(); + } + + private static RecipeList readRecipeList() throws TextFormat.ParseException { + RecipeList.Builder recipes = RecipeList.newBuilder(); + TextFormat.merge( + "schema_generation: \"1\"\n" + + "recipe {\n" + + " table_name: \"" + + TABLE + + "\"\n" + + " part { tag: 1 }\n" + + " part {\n" + + " order: ASCENDING\n" + + " null_order: NULLS_FIRST\n" + + " type { code: STRING }\n" + + " identifier: \"k\"\n" + + " }\n" + + "}\n", + recipes); + return recipes.build(); + } + + private static RoutingHint exactReadRoutingHint(RecipeList recipes) { + KeyRecipeCache recipeCache = new KeyRecipeCache(); + recipeCache.addRecipes(recipes); + ReadRequest.Builder request = + ReadRequest.newBuilder() + .setSession( + String.format( + "projects/%s/instances/%s/databases/%s/sessions/test-session", + PROJECT, INSTANCE, DATABASE)) + .setTable(TABLE) + .addAllColumns(Arrays.asList("k")) + .setDirectedReadOptions(DIRECTED_READ_OPTIONS); + KeySet.singleKey(Key.of("b")).appendToProto(request.getKeySetBuilder()); + recipeCache.computeKeys(request); + return request.getRoutingHint(); + } + + private static io.grpc.StatusRuntimeException resourceExhaustedWithRetryInfo(String description) { + Metadata trailers = new Metadata(); + trailers.put( + ProtoUtils.keyForProto(RetryInfo.getDefaultInstance()), + RetryInfo.newBuilder() + .setRetryDelay( + com.google.protobuf.Duration.newBuilder() + .setNanos((int) TimeUnit.MILLISECONDS.toNanos(1L)) + .build()) + .build()); + return Status.RESOURCE_EXHAUSTED.withDescription(description).asRuntimeException(trailers); + } + + private static StatusRuntimeException resourceExhausted(String description) { + return Status.RESOURCE_EXHAUSTED.withDescription(description).asRuntimeException(); + } + + private static void assertRetriedOnSameLogicalRequest( + String firstRequestId, String secondRequestId) { + XGoogSpannerRequestId first = XGoogSpannerRequestId.of(firstRequestId); + XGoogSpannerRequestId second = XGoogSpannerRequestId.of(secondRequestId); + assertEquals(first.getLogicalRequestKey(), second.getLogicalRequestKey()); + assertEquals(first.getAttempt() + 1, second.getAttempt()); + } + + private static com.google.spanner.v1.ResultSet singleRowReadResultSet(String value) { + return readResultSet(Arrays.asList(value)); + } + + private static com.google.spanner.v1.ResultSet multiRowReadResultSet(String... values) { + return readResultSet(Arrays.asList(values)); + } + + private static com.google.spanner.v1.ResultSet readResultSet(List values) { + com.google.spanner.v1.ResultSet.Builder builder = + com.google.spanner.v1.ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + StructType.Field.newBuilder() + .setName("k") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build())); + for (String value : values) { + builder.addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue(value).build()) + .build()); + } + return builder.build(); + } +} diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java index 6f40052d0aed..10fa82dd6d26 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java @@ -202,9 +202,13 @@ private static class PartialResultSetsIterator implements Iterator responseObserver, SimulatedExecutionTime executionTime, boolean isMultiplexedSession) @@ -1783,7 +1801,8 @@ private void returnPartialResultSet( new PartialResultSetsIterator( resultSet, isMultiplexedSession && isReadWriteTransaction(transactionId), - transactionId); + transactionId, + resumeToken); long index = 0L; while (iterator.hasNext()) { SimulatedExecutionTime.checkStreamException( diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SharedBackendReplicaHarness.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SharedBackendReplicaHarness.java new file mode 100644 index 000000000000..891ae0f7d19e --- /dev/null +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SharedBackendReplicaHarness.java @@ -0,0 +1,310 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.protobuf.AbstractMessage; +import com.google.protobuf.Empty; +import com.google.spanner.v1.BatchCreateSessionsRequest; +import com.google.spanner.v1.BatchCreateSessionsResponse; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.CommitResponse; +import com.google.spanner.v1.CreateSessionRequest; +import com.google.spanner.v1.DeleteSessionRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.GetSessionRequest; +import com.google.spanner.v1.PartialResultSet; +import com.google.spanner.v1.ReadRequest; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.RollbackRequest; +import com.google.spanner.v1.Session; +import com.google.spanner.v1.SpannerGrpc; +import com.google.spanner.v1.Transaction; +import io.grpc.Metadata; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; +import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; +import io.grpc.stub.StreamObserver; +import java.io.Closeable; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Shared-backend replica harness for end-to-end location-aware routing tests. */ +final class SharedBackendReplicaHarness implements Closeable { + + static final String METHOD_BATCH_CREATE_SESSIONS = "BatchCreateSessions"; + static final String METHOD_BEGIN_TRANSACTION = "BeginTransaction"; + static final String METHOD_COMMIT = "Commit"; + static final String METHOD_CREATE_SESSION = "CreateSession"; + static final String METHOD_DELETE_SESSION = "DeleteSession"; + static final String METHOD_EXECUTE_SQL = "ExecuteSql"; + static final String METHOD_EXECUTE_STREAMING_SQL = "ExecuteStreamingSql"; + static final String METHOD_GET_SESSION = "GetSession"; + static final String METHOD_READ = "Read"; + static final String METHOD_ROLLBACK = "Rollback"; + static final String METHOD_STREAMING_READ = "StreamingRead"; + + static final class HookedReplicaSpannerService extends SpannerGrpc.SpannerImplBase { + private final MockSpannerServiceImpl backend; + private final Map> methodErrors = new HashMap<>(); + private final Map> requests = new HashMap<>(); + private final Map> requestIds = new HashMap<>(); + + private HookedReplicaSpannerService(MockSpannerServiceImpl backend) { + this.backend = backend; + } + + synchronized void putMethodErrors(String method, Throwable... errors) { + ArrayDeque queue = new ArrayDeque<>(); + for (Throwable error : errors) { + queue.addLast(error); + } + methodErrors.put(method, queue); + } + + synchronized List getRequests(String method) { + return new ArrayList<>(requests.getOrDefault(method, new ArrayList<>())); + } + + synchronized List getRequestIds(String method) { + return new ArrayList<>(requestIds.getOrDefault(method, new ArrayList<>())); + } + + synchronized void clearRequests() { + requests.clear(); + requestIds.clear(); + } + + private synchronized void recordRequest(String method, AbstractMessage request) { + requests.computeIfAbsent(method, ignored -> new ArrayList<>()).add(request); + } + + private synchronized void recordRequestId(String method, String requestId) { + requestIds.computeIfAbsent(method, ignored -> new ArrayList<>()).add(requestId); + } + + private synchronized Throwable nextError(String method) { + ArrayDeque queue = methodErrors.get(method); + if (queue == null || queue.isEmpty()) { + return null; + } + return queue.removeFirst(); + } + + private boolean maybeFail(String method, StreamObserver responseObserver) { + Throwable error = nextError(method); + if (error == null) { + return false; + } + responseObserver.onError(error); + return true; + } + + @Override + public void batchCreateSessions( + BatchCreateSessionsRequest request, + StreamObserver responseObserver) { + recordRequest(METHOD_BATCH_CREATE_SESSIONS, request); + if (!maybeFail(METHOD_BATCH_CREATE_SESSIONS, responseObserver)) { + backend.batchCreateSessions(request, responseObserver); + } + } + + @Override + public void beginTransaction( + BeginTransactionRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_BEGIN_TRANSACTION, request); + if (!maybeFail(METHOD_BEGIN_TRANSACTION, responseObserver)) { + backend.beginTransaction(request, responseObserver); + } + } + + @Override + public void commit(CommitRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_COMMIT, request); + if (!maybeFail(METHOD_COMMIT, responseObserver)) { + backend.commit(request, responseObserver); + } + } + + @Override + public void createSession( + CreateSessionRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_CREATE_SESSION, request); + if (!maybeFail(METHOD_CREATE_SESSION, responseObserver)) { + backend.createSession(request, responseObserver); + } + } + + @Override + public void deleteSession( + DeleteSessionRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_DELETE_SESSION, request); + if (!maybeFail(METHOD_DELETE_SESSION, responseObserver)) { + backend.deleteSession(request, responseObserver); + } + } + + @Override + public void executeSql(ExecuteSqlRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_EXECUTE_SQL, request); + if (!maybeFail(METHOD_EXECUTE_SQL, responseObserver)) { + backend.executeSql(request, responseObserver); + } + } + + @Override + public void executeStreamingSql( + ExecuteSqlRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_EXECUTE_STREAMING_SQL, request); + if (!maybeFail(METHOD_EXECUTE_STREAMING_SQL, responseObserver)) { + backend.executeStreamingSql(request, responseObserver); + } + } + + @Override + public void getSession(GetSessionRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_GET_SESSION, request); + if (!maybeFail(METHOD_GET_SESSION, responseObserver)) { + backend.getSession(request, responseObserver); + } + } + + @Override + public void read(ReadRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_READ, request); + if (!maybeFail(METHOD_READ, responseObserver)) { + backend.read(request, responseObserver); + } + } + + @Override + public void rollback(RollbackRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_ROLLBACK, request); + if (!maybeFail(METHOD_ROLLBACK, responseObserver)) { + backend.rollback(request, responseObserver); + } + } + + @Override + public void streamingRead( + ReadRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_STREAMING_READ, request); + if (!maybeFail(METHOD_STREAMING_READ, responseObserver)) { + backend.streamingRead(request, responseObserver); + } + } + } + + private final List servers; + final MockSpannerServiceImpl backend; + final HookedReplicaSpannerService defaultReplica; + final String defaultAddress; + final List replicas; + final List replicaAddresses; + + private SharedBackendReplicaHarness( + MockSpannerServiceImpl backend, + HookedReplicaSpannerService defaultReplica, + String defaultAddress, + List replicas, + List replicaAddresses, + List servers) { + this.backend = backend; + this.defaultReplica = defaultReplica; + this.defaultAddress = defaultAddress; + this.replicas = replicas; + this.replicaAddresses = replicaAddresses; + this.servers = servers; + } + + static SharedBackendReplicaHarness create(int replicaCount) throws IOException { + MockSpannerServiceImpl backend = new MockSpannerServiceImpl(); + backend.setAbortProbability(0.0D); + List servers = new ArrayList<>(); + HookedReplicaSpannerService defaultReplica = new HookedReplicaSpannerService(backend); + List replicas = new ArrayList<>(); + List replicaAddresses = new ArrayList<>(); + String defaultAddress = startServer(servers, defaultReplica); + for (int i = 0; i < replicaCount; i++) { + HookedReplicaSpannerService replica = new HookedReplicaSpannerService(backend); + replicas.add(replica); + replicaAddresses.add(startServer(servers, replica)); + } + return new SharedBackendReplicaHarness( + backend, defaultReplica, defaultAddress, replicas, replicaAddresses, servers); + } + + private static String startServer(List servers, HookedReplicaSpannerService service) + throws IOException { + InetSocketAddress address = new InetSocketAddress("localhost", 0); + ServerInterceptor interceptor = + new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + service.recordRequestId( + call.getMethodDescriptor().getBareMethodName(), + headers.get(XGoogSpannerRequestId.REQUEST_ID_HEADER_KEY)); + return next.startCall(call, headers); + } + }; + Server server = + NettyServerBuilder.forAddress(address) + .addService(ServerInterceptors.intercept(service, interceptor)) + .build() + .start(); + servers.add(server); + return "localhost:" + server.getPort(); + } + + void clearRequests() { + defaultReplica.clearRequests(); + for (HookedReplicaSpannerService replica : replicas) { + replica.clearRequests(); + } + } + + @Override + public void close() throws IOException { + IOException failure = null; + for (Server server : servers) { + server.shutdown(); + } + for (Server server : servers) { + try { + server.awaitTermination(5L, java.util.concurrent.TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + if (failure == null) { + failure = new IOException("Interrupted while stopping replica harness", e); + } + } + } + if (failure != null) { + throw failure; + } + } +} diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java index 1ad3888b4f9d..1c0a277ca4f4 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java @@ -56,6 +56,10 @@ import io.grpc.MethodDescriptor; import io.grpc.Status; import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneOffset; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -458,9 +462,11 @@ public void singleUseCommitUsesSameMutationSelectionHeuristicAsBeginTransaction( @Test public void resourceExhaustedRoutedEndpointIsAvoidedOnRetry() throws Exception { - TestHarness harness = createHarness(); + TestHarness harness = createHarness(createDeterministicCooldownTracker()); seedCache(harness, createLeaderAndReplicaCacheUpdate()); - CallOptions retryCallOptions = retryCallOptions(1L); + XGoogSpannerRequestId requestId = retryRequestId(1L); + CallOptions retryCallOptions = retryCallOptions(requestId); + String logicalRequestKey = requestId.getLogicalRequestKey(); ExecuteSqlRequest request = ExecuteSqlRequest.newBuilder() @@ -481,6 +487,12 @@ public void resourceExhaustedRoutedEndpointIsAvoidedOnRetry() throws Exception { harness.endpointCache.latestCallForAddress("server-a:1234"); firstDelegate.emitOnClose(Status.RESOURCE_EXHAUSTED, new Metadata()); + assertThat(harness.channel.isCoolingDown("server-a:1234")).isTrue(); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + logicalRequestKey, "server-a:1234")) + .isTrue(); + ClientCall secondCall = harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), retryCallOptions); secondCall.start(new CapturingListener(), new Metadata()); @@ -488,6 +500,11 @@ public void resourceExhaustedRoutedEndpointIsAvoidedOnRetry() throws Exception { assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1); assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(1); + assertThat(harness.channel.isCoolingDown("server-a:1234")).isTrue(); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + logicalRequestKey, "server-a:1234")) + .isFalse(); } @Test @@ -590,11 +607,16 @@ public void resourceExhaustedRoutedEndpointFallsBackToDefaultWhenNoReplicaExists } @Test - public void resourceExhaustedSkipDoesNotAffectDifferentLogicalRequest() throws Exception { - TestHarness harness = createHarness(); + public void resourceExhaustedCooldownAffectsDifferentLogicalRequestButExclusionDoesNot() + throws Exception { + TestHarness harness = createHarness(createDeterministicCooldownTracker()); seedCache(harness, createLeaderAndReplicaCacheUpdate()); - CallOptions firstLogicalRequest = retryCallOptions(4L); - CallOptions secondLogicalRequest = retryCallOptions(5L); + XGoogSpannerRequestId firstRequestId = retryRequestId(4L); + XGoogSpannerRequestId secondRequestId = retryRequestId(5L); + CallOptions firstLogicalRequest = retryCallOptions(firstRequestId); + CallOptions secondLogicalRequest = retryCallOptions(secondRequestId); + String firstLogicalRequestKey = firstRequestId.getLogicalRequestKey(); + String secondLogicalRequestKey = secondRequestId.getLogicalRequestKey(); ExecuteSqlRequest request = ExecuteSqlRequest.newBuilder() @@ -613,21 +635,47 @@ public void resourceExhaustedSkipDoesNotAffectDifferentLogicalRequest() throws E harness.endpointCache.latestCallForAddress("server-a:1234"); firstDelegate.emitOnClose(Status.RESOURCE_EXHAUSTED, new Metadata()); + assertThat(harness.channel.isCoolingDown("server-a:1234")).isTrue(); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + firstLogicalRequestKey, "server-a:1234")) + .isTrue(); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + secondLogicalRequestKey, "server-a:1234")) + .isFalse(); + ClientCall unrelatedCall = harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), secondLogicalRequest); unrelatedCall.start(new CapturingListener(), new Metadata()); unrelatedCall.sendMessage(request); - assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(2); - assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(0); + assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1); + assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(1); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + firstLogicalRequestKey, "server-a:1234")) + .isTrue(); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + secondLogicalRequestKey, "server-a:1234")) + .isFalse(); ClientCall retriedFirstCall = harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), firstLogicalRequest); retriedFirstCall.start(new CapturingListener(), new Metadata()); retriedFirstCall.sendMessage(request); - assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(2); - assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(1); + assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1); + assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(2); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + firstLogicalRequestKey, "server-a:1234")) + .isFalse(); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + secondLogicalRequestKey, "server-a:1234")) + .isFalse(); } @Test @@ -1235,13 +1283,28 @@ private static RecipeList parseRecipeList(String text) throws TextFormat.ParseEx } private static TestHarness createHarness() throws IOException { + return createHarness(new EndpointOverloadCooldownTracker()); + } + + private static TestHarness createHarness(EndpointOverloadCooldownTracker tracker) + throws IOException { FakeEndpointCache endpointCache = new FakeEndpointCache(DEFAULT_ADDRESS); InstantiatingGrpcChannelProvider provider = InstantiatingGrpcChannelProvider.newBuilder().setEndpoint("localhost:9999").build(); - KeyAwareChannel channel = KeyAwareChannel.create(provider, baseProvider -> endpointCache); + KeyAwareChannel channel = + KeyAwareChannel.create(provider, baseProvider -> endpointCache, tracker); return new TestHarness(channel, endpointCache, endpointCache.defaultManagedChannel()); } + private static EndpointOverloadCooldownTracker createDeterministicCooldownTracker() { + return new EndpointOverloadCooldownTracker( + Duration.ofMinutes(1), + Duration.ofMinutes(1), + Duration.ofMinutes(10), + Clock.fixed(Instant.ofEpochSecond(100), ZoneOffset.UTC), + bound -> bound - 1L); + } + private static final class TestHarness { private final KeyAwareChannel channel; private final FakeEndpointCache endpointCache; @@ -1483,9 +1546,16 @@ private static ByteString bytes(String value) { return ByteString.copyFromUtf8(value); } + private static XGoogSpannerRequestId retryRequestId(long nthRequest) { + return XGoogSpannerRequestId.of(1L, 0L, nthRequest, 0L); + } + private static CallOptions retryCallOptions(long nthRequest) { + return retryCallOptions(retryRequestId(nthRequest)); + } + + private static CallOptions retryCallOptions(XGoogSpannerRequestId requestId) { return CallOptions.DEFAULT.withOption( - XGoogSpannerRequestId.REQUEST_ID_CALL_OPTIONS_KEY, - XGoogSpannerRequestId.of(1L, 0L, nthRequest, 0L)); + XGoogSpannerRequestId.REQUEST_ID_CALL_OPTIONS_KEY, requestId); } } From 2598291cc467602a9542249866f5e7fa545e8c4a Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Sat, 18 Apr 2026 03:48:54 +0530 Subject: [PATCH 2/9] retry unavailable errors on different replica --- .../cloud/spanner/spi/v1/KeyAwareChannel.java | 13 +- ...nAwareSharedBackendReplicaHarnessTest.java | 154 ++++++++++++++++++ .../cloud/spanner/MockSpannerServiceImpl.java | 4 +- 3 files changed, 166 insertions(+), 5 deletions(-) diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java index 90ff41f35973..0d48d07924c6 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java @@ -95,9 +95,9 @@ final class KeyAwareChannel extends ManagedChannel { // Bounded to prevent unbounded growth if application code does not close read-only transactions. private final Cache readOnlyTxPreferLeader = CacheBuilder.newBuilder().maximumSize(MAX_TRACKED_READ_ONLY_TRANSACTIONS).build(); - // If a routed endpoint returns RESOURCE_EXHAUSTED, the next retry attempt of that same logical - // request should avoid that endpoint once so other requests are unaffected. Bound and age out - // entries in case a caller gives up and never issues a retry. + // If a routed endpoint returns RESOURCE_EXHAUSTED or UNAVAILABLE, the next retry attempt of + // that same logical request should avoid that endpoint once so other requests are unaffected. + // Bound and age out entries in case a caller gives up and never issues a retry. private final Cache> excludedEndpointsForLogicalRequest = CacheBuilder.newBuilder() .maximumSize(MAX_TRACKED_EXCLUDED_LOGICAL_REQUESTS) @@ -364,6 +364,11 @@ private void maybeExcludeEndpointOnNextCall( }); } + private static boolean shouldExcludeEndpointOnRetry(io.grpc.Status.Code statusCode) { + return statusCode == io.grpc.Status.Code.RESOURCE_EXHAUSTED + || statusCode == io.grpc.Status.Code.UNAVAILABLE; + } + private Predicate consumeExcludedEndpointsForCurrentCall( @Nullable String logicalRequestKey) { Predicate requestScopedExcluded = address -> false; @@ -898,7 +903,7 @@ public void onMessage(ResponseT message) { @Override public void onClose(io.grpc.Status status, Metadata trailers) { - if (status.getCode() == io.grpc.Status.Code.RESOURCE_EXHAUSTED) { + if (shouldExcludeEndpointOnRetry(status.getCode())) { call.parentChannel.maybeExcludeEndpointOnNextCall( call.selectedEndpoint, call.logicalRequestKey); } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java index 2196fad56696..9b6c2de65397 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java @@ -248,6 +248,156 @@ public void singleUseReadCooldownSkipsReplicaOnNextRequestForBypassTraffic() thr } } + @Test + public void singleUseReadReroutesOnUnavailableForBypassTraffic() throws Exception { + try (SharedBackendReplicaHarness harness = SharedBackendReplicaHarness.create(2); + Spanner spanner = createSpanner(harness)) { + configureBackend(harness, singleRowReadResultSet("b")); + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); + + seedLocationMetadata(client); + waitForReplicaRoutedRead(client, harness, 0); + harness.clearRequests(); + + harness + .replicas + .get(0) + .putMethodErrors( + SharedBackendReplicaHarness.METHOD_STREAMING_READ, unavailable("isolated-replica")); + + try (ResultSet resultSet = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + assertTrue(resultSet.next()); + } + + assertEquals( + 1, + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 1, + harness + .replicas + .get(1) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 0, + harness + .defaultReplica + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + ReadRequest replicaARequest = + (ReadRequest) + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0); + assertTrue(replicaARequest.getResumeToken().isEmpty()); + assertRetriedOnSameLogicalRequest( + harness + .replicas + .get(0) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0), + harness + .replicas + .get(1) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0)); + } + } + + @Test + public void singleUseReadCooldownSkipsUnavailableReplicaOnNextRequestForBypassTraffic() + throws Exception { + try (SharedBackendReplicaHarness harness = SharedBackendReplicaHarness.create(2); + Spanner spanner = createSpanner(harness)) { + configureBackend(harness, singleRowReadResultSet("b")); + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); + + seedLocationMetadata(client); + waitForReplicaRoutedRead(client, harness, 0); + harness.clearRequests(); + + harness + .replicas + .get(0) + .putMethodErrors( + SharedBackendReplicaHarness.METHOD_STREAMING_READ, unavailable("isolated-replica")); + + try (ResultSet firstRead = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + assertTrue(firstRead.next()); + } + + try (ResultSet secondRead = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + assertTrue(secondRead.next()); + } + + assertEquals( + 1, + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 2, + harness + .replicas + .get(1) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 0, + harness + .defaultReplica + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + List replicaBRequests = + harness.replicas.get(1).getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ); + for (AbstractMessage request : replicaBRequests) { + assertTrue(((ReadRequest) request).getResumeToken().isEmpty()); + } + List replicaBRequestIds = + harness.replicas.get(1).getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ); + assertRetriedOnSameLogicalRequest( + harness + .replicas + .get(0) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0), + replicaBRequestIds.get(0)); + assertNotEquals( + XGoogSpannerRequestId.of(replicaBRequestIds.get(0)).getLogicalRequestKey(), + XGoogSpannerRequestId.of(replicaBRequestIds.get(1)).getLogicalRequestKey()); + } + } + @Test public void singleUseReadMidStreamRecvFailureWithoutRetryInfoRetriesForBypassTraffic() throws Exception { @@ -486,6 +636,10 @@ private static StatusRuntimeException resourceExhausted(String description) { return Status.RESOURCE_EXHAUSTED.withDescription(description).asRuntimeException(); } + private static StatusRuntimeException unavailable(String description) { + return Status.UNAVAILABLE.withDescription(description).asRuntimeException(); + } + private static void assertRetriedOnSameLogicalRequest( String firstRequestId, String secondRequestId) { XGoogSpannerRequestId first = XGoogSpannerRequestId.of(firstRequestId); diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java index 10fa82dd6d26..3ea19ad2422a 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java @@ -208,7 +208,9 @@ private PartialResultSetsIterator( ByteString resumeToken) { this.resultSet = resultSet; this.currentRow = parseResumeToken(resumeToken); - this.hasNext = currentRow < resultSet.getRowsCount(); + this.hasNext = + currentRow < resultSet.getRowsCount() + || (currentRow == 0 && resultSet.getRowsCount() == 0); this.setPrecommitToken = setPrecommitToken; this.transactionId = transactionId; } From 2463cbf56986d4b8c5853e5e5ccc38334004fc09 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Sat, 18 Apr 2026 04:03:28 +0530 Subject: [PATCH 3/9] address comments --- .../cloud/spanner/spi/v1/GapicSpannerRpc.java | 13 +++-- .../spanner/spi/v1/GapicSpannerRpcTest.java | 48 +++++++++++++++++++ 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 6cc0a485d056..808ade76fe39 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -431,12 +431,15 @@ public GapicSpannerRpc(final SpannerOptions options) { && isEnableDirectAccess; this.readRetrySettings = options.getSpannerStubSettings().streamingReadSettings().getRetrySettings(); + Set streamingReadRetryableCodes = + options.getSpannerStubSettings().streamingReadSettings().getRetryableCodes(); this.readRetryableCodes = - ImmutableSet.builder() - .addAll( - options.getSpannerStubSettings().streamingReadSettings().getRetryableCodes()) - .add(Code.RESOURCE_EXHAUSTED) - .build(); + enableLocationApi + ? ImmutableSet.builder() + .addAll(streamingReadRetryableCodes) + .add(Code.RESOURCE_EXHAUSTED) + .build() + : streamingReadRetryableCodes; this.executeQueryRetrySettings = options.getSpannerStubSettings().executeStreamingSqlSettings().getRetrySettings(); this.executeQueryRetryableCodes = diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java index 165557608ac3..3a53c85b4a09 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpcTest.java @@ -34,6 +34,7 @@ import com.google.api.gax.rpc.ApiCallContext; import com.google.api.gax.rpc.ApiClientHeaderProvider; import com.google.api.gax.rpc.HeaderProvider; +import com.google.api.gax.rpc.StatusCode.Code; import com.google.api.gax.rpc.TransportChannelProvider; import com.google.auth.Credentials; import com.google.auth.oauth2.AccessToken; @@ -1059,6 +1060,53 @@ public boolean isEnableLocationApi() { } } + @Test + public void testReadRetryableCodesIncludeResourceExhaustedWhenLocationApiEnabled() { + try { + SpannerOptions.useEnvironment( + new SpannerOptions.SpannerEnvironment() { + @Override + public boolean isEnableLocationApi() { + return true; + } + }); + GapicSpannerRpc rpc = new GapicSpannerRpc(createSpannerOptions(), true); + try { + assertThat(rpc.getReadRetryableCodes()).contains(Code.RESOURCE_EXHAUSTED); + } finally { + rpc.shutdown(); + } + } finally { + SpannerOptions.useDefaultEnvironment(); + } + } + + @Test + public void testReadRetryableCodesDoNotAddResourceExhaustedWhenLocationApiDisabled() { + try { + SpannerOptions.useEnvironment( + new SpannerOptions.SpannerEnvironment() { + @Override + public boolean isEnableLocationApi() { + return false; + } + }); + GapicSpannerRpc rpc = new GapicSpannerRpc(createSpannerOptions(), true); + try { + assertThat(rpc.getReadRetryableCodes()) + .isEqualTo( + createSpannerOptions() + .getSpannerStubSettings() + .streamingReadSettings() + .getRetryableCodes()); + } finally { + rpc.shutdown(); + } + } finally { + SpannerOptions.useDefaultEnvironment(); + } + } + @Test public void testGrpcGcpExtensionPreservesChannelConfigurator() throws Exception { InstantiatingGrpcChannelProvider.Builder channelProviderBuilder = From 7c4bb2e07996c4adb44e09094ff991fb997e3d2b Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Mon, 20 Apr 2026 02:12:41 +0530 Subject: [PATCH 4/9] add EWMA support for stale_reads --- .../spi/v1/EndpointLatencyRegistry.java | 171 ++++++ .../spi/v1/EndpointLifecycleManager.java | 47 +- .../spanner/spi/v1/EwmaLatencyTracker.java | 53 +- .../cloud/spanner/spi/v1/GapicSpannerRpc.java | 52 +- .../spi/v1/GrpcChannelEndpointCache.java | 22 +- .../GrpcGcpEndpointChannelConfigurator.java | 54 ++ .../spanner/spi/v1/HeaderInterceptor.java | 29 +- .../cloud/spanner/spi/v1/KeyAwareChannel.java | 68 ++- .../v1/KeyAwareTransportChannelProvider.java | 34 +- .../cloud/spanner/spi/v1/KeyRangeCache.java | 552 +++++++++++++++++- .../spi/v1/RequestIdTargetTracker.java | 87 +++ .../spi/v1/EndpointLifecycleManagerTest.java | 67 +++ .../spi/v1/EwmaLatencyTrackerTest.java | 34 +- .../spi/v1/GrpcChannelEndpointCacheTest.java | 83 ++- .../spanner/spi/v1/KeyAwareChannelTest.java | 123 +++- .../spanner/spi/v1/KeyRangeCacheTest.java | 340 ++++++++++- .../v1/ReplicaSelectionMockServerTest.java | 207 +++++++ 17 files changed, 1922 insertions(+), 101 deletions(-) create mode 100644 java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistry.java create mode 100644 java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcGcpEndpointChannelConfigurator.java create mode 100644 java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/RequestIdTargetTracker.java diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistry.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistry.java new file mode 100644 index 000000000000..be3eff53b75c --- /dev/null +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistry.java @@ -0,0 +1,171 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.spi.v1; + +import com.google.common.annotations.VisibleForTesting; +import java.time.Duration; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +/** Shared process-local latency scores for routed Spanner endpoints. */ +final class EndpointLatencyRegistry { + + static final Duration DEFAULT_ERROR_PENALTY = Duration.ofSeconds(10); + static final Duration DEFAULT_RTT = Duration.ofMillis(10); + static final double DEFAULT_PENALTY_VALUE = 1_000_000.0; + + private static final ConcurrentHashMap TRACKERS = + new ConcurrentHashMap<>(); + private static final ConcurrentHashMap INFLIGHT_REQUESTS = + new ConcurrentHashMap<>(); + + private EndpointLatencyRegistry() {} + + static boolean hasScore(long operationUid, String endpointLabelOrAddress) { + TrackerKey trackerKey = trackerKey(operationUid, endpointLabelOrAddress); + return trackerKey != null && TRACKERS.containsKey(trackerKey); + } + + static double getSelectionCost(long operationUid, String endpointLabelOrAddress) { + TrackerKey trackerKey = trackerKey(operationUid, endpointLabelOrAddress); + if (trackerKey == null) { + return Double.MAX_VALUE; + } + double activeRequests = getInflight(endpointLabelOrAddress); + LatencyTracker tracker = TRACKERS.get(trackerKey); + if (tracker != null) { + return tracker.getScore() * (activeRequests + 1.0); + } + if (activeRequests > 0.0) { + return DEFAULT_PENALTY_VALUE + activeRequests; + } + return defaultRttMicros() * (activeRequests + 1.0); + } + + static void recordLatency(long operationUid, String endpointLabelOrAddress, Duration latency) { + TrackerKey trackerKey = trackerKey(operationUid, endpointLabelOrAddress); + if (trackerKey == null || latency == null) { + return; + } + TRACKERS.computeIfAbsent(trackerKey, ignored -> new EwmaLatencyTracker()).update(latency); + } + + static void recordError(long operationUid, String endpointLabelOrAddress) { + recordError(operationUid, endpointLabelOrAddress, DEFAULT_ERROR_PENALTY); + } + + static void recordError(long operationUid, String endpointLabelOrAddress, Duration penalty) { + TrackerKey trackerKey = trackerKey(operationUid, endpointLabelOrAddress); + if (trackerKey == null || penalty == null) { + return; + } + TRACKERS.computeIfAbsent(trackerKey, ignored -> new EwmaLatencyTracker()).recordError(penalty); + } + + static void beginRequest(String endpointLabelOrAddress) { + String address = normalizeAddress(endpointLabelOrAddress); + if (address == null) { + return; + } + INFLIGHT_REQUESTS.computeIfAbsent(address, ignored -> new AtomicInteger()).incrementAndGet(); + } + + static void finishRequest(String endpointLabelOrAddress) { + String address = normalizeAddress(endpointLabelOrAddress); + if (address == null) { + return; + } + AtomicInteger counter = INFLIGHT_REQUESTS.get(address); + if (counter == null) { + return; + } + int updated = counter.decrementAndGet(); + if (updated <= 0) { + INFLIGHT_REQUESTS.remove(address, counter); + } + } + + static int getInflight(String endpointLabelOrAddress) { + String address = normalizeAddress(endpointLabelOrAddress); + if (address == null) { + return 0; + } + AtomicInteger counter = INFLIGHT_REQUESTS.get(address); + return counter == null ? 0 : Math.max(0, counter.get()); + } + + @VisibleForTesting + static void clear() { + TRACKERS.clear(); + INFLIGHT_REQUESTS.clear(); + } + + @VisibleForTesting + static String normalizeAddress(String endpointLabelOrAddress) { + if (endpointLabelOrAddress == null || endpointLabelOrAddress.isEmpty()) { + return null; + } + return endpointLabelOrAddress; + } + + @VisibleForTesting + static TrackerKey trackerKey(long operationUid, String endpointLabelOrAddress) { + String address = normalizeAddress(endpointLabelOrAddress); + if (operationUid <= 0 || address == null) { + return null; + } + return new TrackerKey(operationUid, address); + } + + private static long defaultRttMicros() { + return DEFAULT_RTT.toNanos() / 1_000L; + } + + @VisibleForTesting + static final class TrackerKey { + private final long operationUid; + private final String address; + + private TrackerKey(long operationUid, String address) { + this.operationUid = operationUid; + this.address = address; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof TrackerKey)) { + return false; + } + TrackerKey that = (TrackerKey) other; + return operationUid == that.operationUid && Objects.equals(address, that.address); + } + + @Override + public int hashCode() { + return Objects.hash(operationUid, address); + } + + @Override + public String toString() { + return operationUid + "@" + address; + } + } +} diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManager.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManager.java index ae78f07b14a3..af8bc64ba118 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManager.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManager.java @@ -24,6 +24,7 @@ import java.time.Duration; import java.time.Instant; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -70,8 +71,9 @@ class EndpointLifecycleManager { private static final long EVICTION_CHECK_INTERVAL_SECONDS = 300; /** - * Maximum consecutive TRANSIENT_FAILURE probes before evicting an endpoint. Gives the channel - * time to recover from transient network issues before we tear it down and recreate. + * Maximum observed TRANSIENT_FAILURE probes before evicting an endpoint. The counter resets only + * after the channel reaches READY, so CONNECTING/IDLE oscillation does not hide a persistently + * unhealthy endpoint. */ private static final int MAX_TRANSIENT_FAILURE_COUNT = 3; @@ -104,6 +106,7 @@ static final class EndpointState { private final ChannelEndpointCache endpointCache; private final Map endpoints = new ConcurrentHashMap<>(); + private final Set evictedAddresses = ConcurrentHashMap.newKeySet(); private final Set transientFailureEvictedAddresses = ConcurrentHashMap.newKeySet(); private final Map finderGenerations = new ConcurrentHashMap<>(); private final Map pendingActiveAddressUpdates = @@ -215,6 +218,7 @@ private boolean ensureEndpointExists(String address) { address, addr -> { logger.log(Level.FINE, "Creating endpoint state for address: {0}", addr); + evictedAddresses.remove(addr); created[0] = true; return new EndpointState(addr, clock.instant()); }); @@ -493,7 +497,8 @@ private void stopProbing(String address) { *

All exceptions are caught to prevent {@link ScheduledExecutorService} from cancelling future * runs of this task. */ - private void probe(String address) { + @VisibleForTesting + void probe(String address) { try { if (isShutdown.get()) { return; @@ -530,25 +535,24 @@ private void probe(String address) { logger.log( Level.FINE, "Probe for {0}: channel IDLE, requesting connection (warmup)", address); channel.getState(true); - state.consecutiveTransientFailures = 0; break; case CONNECTING: - state.consecutiveTransientFailures = 0; break; case TRANSIENT_FAILURE: state.consecutiveTransientFailures++; logger.log( Level.FINE, - "Probe for {0}: channel in TRANSIENT_FAILURE ({1}/{2})", + "Probe for {0}: channel in TRANSIENT_FAILURE ({1}/{2} observed failures since last" + + " READY)", new Object[] { address, state.consecutiveTransientFailures, MAX_TRANSIENT_FAILURE_COUNT }); if (state.consecutiveTransientFailures >= MAX_TRANSIENT_FAILURE_COUNT) { logger.log( Level.FINE, - "Evicting endpoint {0}: {1} consecutive TRANSIENT_FAILURE probes", + "Evicting endpoint {0}: {1} TRANSIENT_FAILURE probes without reaching READY", new Object[] {address, state.consecutiveTransientFailures}); evictEndpoint(address, EvictionReason.TRANSIENT_FAILURE); } @@ -608,6 +612,7 @@ private void evictEndpoint(String address, EvictionReason reason) { stopProbing(address); endpoints.remove(address); + evictedAddresses.add(address); if (reason == EvictionReason.TRANSIENT_FAILURE) { markTransientFailureEvicted(address); } else { @@ -636,6 +641,7 @@ void requestEndpointRecreation(String address) { logger.log(Level.FINE, "Recreating previously evicted endpoint for address: {0}", address); EndpointState state = new EndpointState(address, clock.instant()); + evictedAddresses.remove(address); if (endpoints.putIfAbsent(address, state) == null) { // Schedule after putIfAbsent returns so the entry is visible to the scheduler thread. scheduler.submit(() -> createAndStartProbing(address)); @@ -663,6 +669,32 @@ int managedEndpointCount() { return endpoints.size(); } + Map snapshotEndpointStateCounts() { + Map counts = new HashMap<>(); + snapshotEndpointStates().values().forEach(state -> counts.merge(state, 1L, Long::sum)); + return counts; + } + + Map snapshotEndpointStates() { + Map states = new HashMap<>(); + for (String address : endpoints.keySet()) { + ChannelEndpoint endpoint = endpointCache.getIfPresent(address); + String stateName = "unknown"; + if (endpoint != null) { + ConnectivityState state = endpoint.getChannel().getState(false); + stateName = + state == ConnectivityState.TRANSIENT_FAILURE + ? "transient_failure" + : state.name().toLowerCase(); + } + states.put(address, stateName); + } + for (String address : evictedAddresses) { + states.putIfAbsent(address, "evicted"); + } + return states; + } + /** Shuts down the lifecycle manager and all probing. */ void shutdown() { if (!isShutdown.compareAndSet(false, true)) { @@ -684,6 +716,7 @@ void shutdown() { } } endpoints.clear(); + evictedAddresses.clear(); transientFailureEvictedAddresses.clear(); pendingActiveAddressUpdates.clear(); queuedFinderKeys.clear(); diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java index 0cb2331660f9..64101277e298 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java @@ -18,25 +18,33 @@ import com.google.api.core.BetaApi; import com.google.api.core.InternalApi; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import java.time.Duration; import java.util.concurrent.TimeUnit; +import java.util.function.LongSupplier; +import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; /** * Implementation of {@link LatencyTracker} using Exponentially Weighted Moving Average (EWMA). * - *

Formula: $S_{i+1} = \alpha * new\_latency + (1 - \alpha) * S_i$ + *

By default, this tracker uses a time-decayed EWMA: + * $S_{i+1} = \alpha(\Delta t) * new\_latency + (1 - \alpha(\Delta t)) * S_i$, where $\alpha(\Delta + * t) = 1 - e^{-\Delta t / \tau}$. * - *

This class is thread-safe. + *

A fixed-alpha constructor is retained for focused tests. */ @InternalApi @BetaApi public class EwmaLatencyTracker implements LatencyTracker { public static final double DEFAULT_ALPHA = 0.05; + public static final Duration DEFAULT_DECAY_TIME = Duration.ofSeconds(10); - private final double alpha; + @Nullable private final Double fixedAlpha; + private final long tauNanos; + private final LongSupplier nanoTimeSupplier; private final Object lock = new Object(); @GuardedBy("lock") @@ -45,9 +53,12 @@ public class EwmaLatencyTracker implements LatencyTracker { @GuardedBy("lock") private boolean initialized = false; - /** Creates a new tracker with the default alpha value of 0.05. */ + @GuardedBy("lock") + private long lastUpdatedAtNanos; + + /** Creates a new tracker with Envoy-style time-based decay and a 10-second decay window. */ public EwmaLatencyTracker() { - this(DEFAULT_ALPHA); + this(DEFAULT_DECAY_TIME, System::nanoTime); } /** @@ -56,8 +67,25 @@ public EwmaLatencyTracker() { * @param alpha the smoothing factor, must be in the range (0, 1] */ public EwmaLatencyTracker(double alpha) { + this(alpha, System::nanoTime); + } + + @VisibleForTesting + EwmaLatencyTracker(Duration decayTime, LongSupplier nanoTimeSupplier) { + Preconditions.checkArgument( + decayTime != null && !decayTime.isZero() && !decayTime.isNegative(), + "decayTime must be > 0"); + this.fixedAlpha = null; + this.tauNanos = decayTime.toNanos(); + this.nanoTimeSupplier = nanoTimeSupplier; + } + + @VisibleForTesting + EwmaLatencyTracker(double alpha, LongSupplier nanoTimeSupplier) { Preconditions.checkArgument(alpha > 0.0 && alpha <= 1.0, "alpha must be in (0, 1]"); - this.alpha = alpha; + this.fixedAlpha = alpha; + this.tauNanos = 0L; + this.nanoTimeSupplier = nanoTimeSupplier; } @Override @@ -77,12 +105,16 @@ public void update(Duration latency) { // Use Long.MAX_VALUE to give it the lowest possible priority. latencyMicros = Long.MAX_VALUE; } + long nowNanos = nanoTimeSupplier.getAsLong(); synchronized (lock) { if (!initialized) { score = latencyMicros; initialized = true; + lastUpdatedAtNanos = nowNanos; } else { + double alpha = fixedAlpha != null ? fixedAlpha : calculateTimeBasedAlpha(nowNanos); score = alpha * latencyMicros + (1 - alpha) * score; + lastUpdatedAtNanos = nowNanos; } } } @@ -92,4 +124,13 @@ public void recordError(Duration penalty) { // Treat the error as a sample with high latency (penalty) update(penalty); } + + private double calculateTimeBasedAlpha(long nowNanos) { + long deltaNanos = nowNanos - lastUpdatedAtNanos; + if (deltaNanos <= 0L) { + return 1.0; + } + double alpha = 1.0 - Math.exp(-(double) deltaNanos / (double) tauNanos); + return Math.min(1.0, Math.max(0.0, alpha)); + } } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 808ade76fe39..0cbab272d213 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -369,7 +369,11 @@ public GapicSpannerRpc(final SpannerOptions options) { GrpcTransportOptions.setUpCredentialsProvider(options); InstantiatingGrpcChannelProvider.Builder defaultChannelProviderBuilder = - createChannelProviderBuilder(options, headerProviderWithUserAgent, isEnableDirectAccess); + createBaseChannelProviderBuilder( + options, headerProviderWithUserAgent, isEnableDirectAccess); + GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator = + createGrpcGcpEndpointChannelConfigurator(defaultChannelProviderBuilder, options); + maybeEnableGrpcGcpExtension(defaultChannelProviderBuilder, options); if (options.getChannelProvider() == null && isEnableDirectAccess @@ -391,7 +395,8 @@ public GapicSpannerRpc(final SpannerOptions options) { enableLocationApi && baseChannelProvider instanceof InstantiatingGrpcChannelProvider ? new KeyAwareTransportChannelProvider( (InstantiatingGrpcChannelProvider) baseChannelProvider, - options.getChannelEndpointCacheFactory()) + options.getChannelEndpointCacheFactory(), + endpointChannelConfigurator) : baseChannelProvider; spannerWatchdog = @@ -732,6 +737,17 @@ private InstantiatingGrpcChannelProvider.Builder createChannelProviderBuilder( final SpannerOptions options, final HeaderProvider headerProviderWithUserAgent, boolean isEnableDirectAccess) { + InstantiatingGrpcChannelProvider.Builder defaultChannelProviderBuilder = + createBaseChannelProviderBuilder( + options, headerProviderWithUserAgent, isEnableDirectAccess); + maybeEnableGrpcGcpExtension(defaultChannelProviderBuilder, options); + return defaultChannelProviderBuilder; + } + + private InstantiatingGrpcChannelProvider.Builder createBaseChannelProviderBuilder( + final SpannerOptions options, + final HeaderProvider headerProviderWithUserAgent, + boolean isEnableDirectAccess) { InstantiatingGrpcChannelProvider.Builder defaultChannelProviderBuilder = InstantiatingGrpcChannelProvider.newBuilder() .setChannelConfigurator(options.getChannelConfigurator()) @@ -777,8 +793,6 @@ private InstantiatingGrpcChannelProvider.Builder createChannelProviderBuilder( defaultChannelProviderBuilder.setExecutor(executor); } } - // If it is enabled in options uses the channel pool provided by the gRPC-GCP extension. - maybeEnableGrpcGcpExtension(defaultChannelProviderBuilder, options); return defaultChannelProviderBuilder; } @@ -834,6 +848,36 @@ static GcpChannelPoolOptions getGrpcGcpChannelPoolOptions(SpannerOptions options .build(); } + @VisibleForTesting + static GcpChannelPoolOptions getGrpcGcpEndpointChannelPoolOptions(SpannerOptions options) { + GcpChannelPoolOptions channelPoolOptions = options.getGcpChannelPoolOptions(); + return GcpChannelPoolOptions.newBuilder() + .setMaxSize(1) + .setMinSize(1) + .setInitSize(1) + .disableDynamicScaling() + .setAffinityKeyLifetime(channelPoolOptions.getAffinityKeyLifetime()) + .setCleanupInterval(channelPoolOptions.getCleanupInterval()) + .build(); + } + + @Nullable + private static GrpcGcpEndpointChannelConfigurator createGrpcGcpEndpointChannelConfigurator( + InstantiatingGrpcChannelProvider.Builder channelProviderBuilder, SpannerOptions options) { + if (!options.isGrpcGcpExtensionEnabled()) { + return null; + } + + GcpManagedChannelOptions endpointGrpcGcpOptions = + GcpManagedChannelOptions.newBuilder(grpcGcpOptionsWithMetricsAndDcp(options)) + .withChannelPoolOptions(getGrpcGcpEndpointChannelPoolOptions(options)) + .build(); + return new GrpcGcpEndpointChannelConfigurator( + channelProviderBuilder.getChannelConfigurator(), + parseGrpcGcpApiConfig(), + endpointGrpcGcpOptions); + } + @SuppressWarnings("rawtypes") private static void maybeEnableGrpcGcpExtension( InstantiatingGrpcChannelProvider.Builder defaultChannelProviderBuilder, diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java index 98e7f83b094f..54ee67439e59 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java @@ -17,6 +17,7 @@ package com.google.cloud.spanner.spi.v1; import com.google.api.core.InternalApi; +import com.google.api.gax.grpc.ChannelPoolSettings; import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider; import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder; import com.google.cloud.spanner.ErrorCode; @@ -54,6 +55,7 @@ class GrpcChannelEndpointCache implements ChannelEndpointCache { private final Map servers = new ConcurrentHashMap<>(); private final GrpcChannelEndpoint defaultEndpoint; private final String defaultAuthority; + @Nullable private final GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator; private final AtomicBoolean isShutdown = new AtomicBoolean(false); /** @@ -65,7 +67,15 @@ class GrpcChannelEndpointCache implements ChannelEndpointCache { */ public GrpcChannelEndpointCache(InstantiatingGrpcChannelProvider channelProvider) throws IOException { + this(channelProvider, null); + } + + public GrpcChannelEndpointCache( + InstantiatingGrpcChannelProvider channelProvider, + @Nullable GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator) + throws IOException { this.baseProvider = channelProvider; + this.endpointChannelConfigurator = endpointChannelConfigurator; String defaultEndpoint = channelProvider.getEndpoint(); this.defaultEndpoint = new GrpcChannelEndpoint(defaultEndpoint, channelProvider); this.defaultAuthority = this.defaultEndpoint.getChannel().authority(); @@ -110,19 +120,25 @@ public ChannelEndpoint getIfPresent(String address) { return servers.get(address); } - private InstantiatingGrpcChannelProvider createProviderWithAuthorityOverride(String address) { + @VisibleForTesting + InstantiatingGrpcChannelProvider createProviderWithAuthorityOverride(String address) { InstantiatingGrpcChannelProvider endpointProvider = (InstantiatingGrpcChannelProvider) baseProvider.withEndpoint(address); if (Objects.equals(defaultAuthority, address)) { return endpointProvider; } Builder builder = endpointProvider.toBuilder(); + builder.setChannelPoolSettings(ChannelPoolSettings.staticallySized(1)); + builder.setKeepAliveWithoutCalls(Boolean.TRUE); final com.google.api.core.ApiFunction - baseConfigurator = builder.getChannelConfigurator(); + baseConfigurator = + endpointChannelConfigurator == null ? builder.getChannelConfigurator() : null; builder.setChannelConfigurator( channelBuilder -> { ManagedChannelBuilder effectiveBuilder = channelBuilder; - if (baseConfigurator != null) { + if (endpointChannelConfigurator != null) { + effectiveBuilder = endpointChannelConfigurator.configure(effectiveBuilder); + } else if (baseConfigurator != null) { effectiveBuilder = baseConfigurator.apply(effectiveBuilder); } return effectiveBuilder.overrideAuthority(defaultAuthority); diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcGcpEndpointChannelConfigurator.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcGcpEndpointChannelConfigurator.java new file mode 100644 index 000000000000..67b9cb3f0140 --- /dev/null +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcGcpEndpointChannelConfigurator.java @@ -0,0 +1,54 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.spi.v1; + +import com.google.api.core.ApiFunction; +import com.google.cloud.grpc.GcpManagedChannelBuilder; +import com.google.cloud.grpc.GcpManagedChannelOptions; +import io.grpc.ManagedChannelBuilder; +import javax.annotation.Nullable; + +/** + * Rebuilds the grpc-gcp wrapper for routed endpoint channels while preserving the base channel + * configuration. + */ +final class GrpcGcpEndpointChannelConfigurator { + @Nullable + private final ApiFunction baseConfigurator; + + private final String apiConfigJson; + private final GcpManagedChannelOptions grpcGcpOptions; + + GrpcGcpEndpointChannelConfigurator( + @Nullable ApiFunction baseConfigurator, + String apiConfigJson, + GcpManagedChannelOptions grpcGcpOptions) { + this.baseConfigurator = baseConfigurator; + this.apiConfigJson = apiConfigJson; + this.grpcGcpOptions = grpcGcpOptions; + } + + ManagedChannelBuilder configure(ManagedChannelBuilder channelBuilder) { + ManagedChannelBuilder effectiveBuilder = channelBuilder; + if (baseConfigurator != null) { + effectiveBuilder = baseConfigurator.apply(effectiveBuilder); + } + return GcpManagedChannelBuilder.forDelegateBuilder(effectiveBuilder) + .withApiConfigJsonString(apiConfigJson) + .withOptions(grpcGcpOptions); + } +} diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/HeaderInterceptor.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/HeaderInterceptor.java index 861e839a0366..e01e4e97529a 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/HeaderInterceptor.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/HeaderInterceptor.java @@ -42,9 +42,11 @@ import io.opentelemetry.api.common.Attributes; import io.opentelemetry.api.common.AttributesBuilder; import io.opentelemetry.api.trace.Span; +import java.time.Duration; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.logging.Level; import java.util.logging.Logger; import java.util.regex.Matcher; @@ -104,6 +106,8 @@ public ClientCall interceptCall( public void start(Listener responseListener, Metadata headers) { try { Span span = Span.current(); + long startedAtNanos = System.nanoTime(); + AtomicBoolean firstResponseRecorded = new AtomicBoolean(false); DatabaseName databaseName = extractDatabaseName(headers); String key = extractKey(databaseName, method.getFullMethodName()); String requestId = extractRequestId(headers); @@ -115,6 +119,7 @@ public void start(Listener responseListener, Metadata headers) { new SimpleForwardingClientCallListener(responseListener) { @Override public void onHeaders(Metadata metadata) { + recordFirstResponseLatency(requestId, startedAtNanos, firstResponseRecorded); String serverTiming = metadata.get(SERVER_TIMING_HEADER_KEY); try { // Get gfe and afe Latency value @@ -137,17 +142,22 @@ public void onClose(Status status, Metadata trailers) { recordCustomMetrics(tagContext, attributes, isDirectPathUsed); Map builtInMetricsAttributes = new HashMap<>(); try { - builtInMetricsAttributes = getBuiltInMetricAttributes(key, databaseName); + builtInMetricsAttributes = + new HashMap<>(getBuiltInMetricAttributes(key, databaseName)); } catch (ExecutionException e) { LOGGER.log( LEVEL, "Unable to get built-in metric attributes {}", e.getMessage()); } + if (status.isOk()) { + recordFirstResponseLatency(requestId, startedAtNanos, firstResponseRecorded); + } recordBuiltInMetrics( compositeTracer, builtInMetricsAttributes, requestId, isDirectPathUsed, isAfeEnabled); + RequestIdTargetTracker.remove(requestId); super.onClose(status, trailers); } }, @@ -208,6 +218,23 @@ private void recordBuiltInMetrics( } } + private void recordFirstResponseLatency( + String requestId, long startedAtNanos, AtomicBoolean firstResponseRecorded) { + if (!firstResponseRecorded.compareAndSet(false, true)) { + return; + } + RequestIdTargetTracker.RoutingTarget routingTarget = RequestIdTargetTracker.get(requestId); + if (routingTarget == null + || routingTarget.operationUid <= 0 + || routingTarget.targetEndpoint == null + || routingTarget.targetEndpoint.isEmpty()) { + return; + } + long latencyNanos = Math.max(0L, System.nanoTime() - startedAtNanos); + EndpointLatencyRegistry.recordLatency( + routingTarget.operationUid, routingTarget.targetEndpoint, Duration.ofNanos(latencyNanos)); + } + private Map parseServerTimingHeader(String serverTiming) { Map serverTimingMetrics = new HashMap<>(); if (serverTiming != null) { diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java index 0d48d07924c6..127c75c9adb3 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java @@ -107,18 +107,25 @@ final class KeyAwareChannel extends ManagedChannel { private KeyAwareChannel( InstantiatingGrpcChannelProvider channelProvider, - @Nullable ChannelEndpointCacheFactory endpointCacheFactory) + @Nullable ChannelEndpointCacheFactory endpointCacheFactory, + @Nullable GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator) throws IOException { - this(channelProvider, endpointCacheFactory, new EndpointOverloadCooldownTracker()); + this( + channelProvider, + endpointCacheFactory, + endpointChannelConfigurator, + new EndpointOverloadCooldownTracker()); } private KeyAwareChannel( InstantiatingGrpcChannelProvider channelProvider, @Nullable ChannelEndpointCacheFactory endpointCacheFactory, + @Nullable GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator, EndpointOverloadCooldownTracker endpointOverloadCooldowns) throws IOException { if (endpointCacheFactory == null) { - this.endpointCache = new GrpcChannelEndpointCache(channelProvider); + this.endpointCache = + new GrpcChannelEndpointCache(channelProvider, endpointChannelConfigurator); } else { this.endpointCache = endpointCacheFactory.create(channelProvider); } @@ -137,7 +144,15 @@ static KeyAwareChannel create( InstantiatingGrpcChannelProvider channelProvider, @Nullable ChannelEndpointCacheFactory endpointCacheFactory) throws IOException { - return new KeyAwareChannel(channelProvider, endpointCacheFactory); + return new KeyAwareChannel(channelProvider, endpointCacheFactory, null); + } + + static KeyAwareChannel create( + InstantiatingGrpcChannelProvider channelProvider, + @Nullable ChannelEndpointCacheFactory endpointCacheFactory, + @Nullable GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator) + throws IOException { + return new KeyAwareChannel(channelProvider, endpointCacheFactory, endpointChannelConfigurator); } @VisibleForTesting @@ -146,7 +161,8 @@ static KeyAwareChannel create( @Nullable ChannelEndpointCacheFactory endpointCacheFactory, EndpointOverloadCooldownTracker endpointOverloadCooldowns) throws IOException { - return new KeyAwareChannel(channelProvider, endpointCacheFactory, endpointOverloadCooldowns); + return new KeyAwareChannel( + channelProvider, endpointCacheFactory, null, endpointOverloadCooldowns); } private static final class ChannelFinderReference extends SoftReference { @@ -364,6 +380,18 @@ private void maybeExcludeEndpointOnNextCall( }); } + private void maybeRecordErrorPenalty( + @Nullable ChannelEndpoint endpoint, io.grpc.Status.Code statusCode, long operationUid) { + if (!shouldExcludeEndpointOnRetry(statusCode) || endpoint == null || operationUid <= 0L) { + return; + } + String address = endpoint.getAddress(); + if (defaultEndpointAddress.equals(address)) { + return; + } + EndpointLatencyRegistry.recordError(operationUid, address); + } + private static boolean shouldExcludeEndpointOnRetry(io.grpc.Status.Code statusCode) { return statusCode == io.grpc.Status.Code.RESOURCE_EXHAUSTED || statusCode == io.grpc.Status.Code.UNAVAILABLE; @@ -497,6 +525,8 @@ static final class KeyAwareClientCall private ChannelFinder channelFinder; @Nullable private Predicate excludedEndpoints; @Nullable private ChannelEndpoint selectedEndpoint; + @Nullable private String selectedTargetEndpoint; + private long selectedOperationUid; @Nullable private ByteString transactionIdToClear; private boolean allowDefaultAffinity; private long pendingRequests; @@ -563,6 +593,7 @@ public void sendMessage(RequestT message) { Predicate excludedEndpoints = excludedEndpoints(); ChannelEndpoint endpoint = null; ChannelFinder finder = null; + long operationUid = 0L; if (message instanceof ReadRequest) { ReadRequest.Builder reqBuilder = ((ReadRequest) message).toBuilder(); @@ -570,6 +601,7 @@ public void sendMessage(RequestT message) { RoutingDecision routing = routeFromRequest(reqBuilder); finder = routing.finder; endpoint = routing.endpoint; + operationUid = routing.operationUid; message = (RequestT) reqBuilder.build(); } else if (message instanceof ExecuteSqlRequest) { ExecuteSqlRequest.Builder reqBuilder = ((ExecuteSqlRequest) message).toBuilder(); @@ -577,6 +609,7 @@ public void sendMessage(RequestT message) { RoutingDecision routing = routeFromRequest(reqBuilder); finder = routing.finder; endpoint = routing.endpoint; + operationUid = routing.operationUid; message = (RequestT) reqBuilder.build(); } else if (message instanceof BeginTransactionRequest) { BeginTransactionRequest.Builder reqBuilder = @@ -638,7 +671,15 @@ public void sendMessage(RequestT message) { throw new IllegalStateException("No default endpoint available for key-aware call"); } selectedEndpoint = endpoint; + selectedTargetEndpoint = endpoint.getAddress(); + selectedOperationUid = operationUid; this.channelFinder = finder; + EndpointLatencyRegistry.beginRequest(selectedTargetEndpoint); + XGoogSpannerRequestId requestId = callOptions.getOption(REQUEST_ID_CALL_OPTIONS_KEY); + if (requestId != null) { + RequestIdTargetTracker.record( + requestId.getHeaderValue(), selectedTargetEndpoint, operationUid); + } // Record real traffic for idle eviction tracking. parentChannel.onRequestRouted(endpoint); @@ -815,7 +856,7 @@ private RoutingDecision routeFromRequest(ReadRequest.Builder reqBuilder) { : finder.findServer(reqBuilder, excludedEndpoints); endpoint = routed; } - return new RoutingDecision(finder, endpoint); + return new RoutingDecision(finder, endpoint, operationUid(reqBuilder.getRoutingHint())); } private RoutingDecision routeFromRequest(ExecuteSqlRequest.Builder reqBuilder) { @@ -838,20 +879,27 @@ private RoutingDecision routeFromRequest(ExecuteSqlRequest.Builder reqBuilder) { : finder.findServer(reqBuilder, excludedEndpoints); endpoint = routed; } - return new RoutingDecision(finder, endpoint); + return new RoutingDecision(finder, endpoint, operationUid(reqBuilder.getRoutingHint())); } } private static final class RoutingDecision { @Nullable private final ChannelFinder finder; @Nullable private final ChannelEndpoint endpoint; + private final long operationUid; - private RoutingDecision(@Nullable ChannelFinder finder, @Nullable ChannelEndpoint endpoint) { + private RoutingDecision( + @Nullable ChannelFinder finder, @Nullable ChannelEndpoint endpoint, long operationUid) { this.finder = finder; this.endpoint = endpoint; + this.operationUid = operationUid; } } + private static long operationUid(com.google.spanner.v1.RoutingHint routingHint) { + return routingHint == null ? 0L : routingHint.getOperationUid(); + } + static final class KeyAwareClientCallListener extends SimpleForwardingClientCallListener { private final KeyAwareClientCall call; @@ -904,9 +952,13 @@ public void onMessage(ResponseT message) { @Override public void onClose(io.grpc.Status status, Metadata trailers) { if (shouldExcludeEndpointOnRetry(status.getCode())) { + call.parentChannel.maybeRecordErrorPenalty( + call.selectedEndpoint, status.getCode(), call.selectedOperationUid); call.parentChannel.maybeExcludeEndpointOnNextCall( call.selectedEndpoint, call.logicalRequestKey); } + EndpointLatencyRegistry.finishRequest(call.selectedTargetEndpoint); + RequestIdTargetTracker.remove(call.logicalRequestKey); call.maybeClearAffinity(); super.onClose(status, trailers); } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareTransportChannelProvider.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareTransportChannelProvider.java index 438717c3c98f..a772af4c3567 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareTransportChannelProvider.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareTransportChannelProvider.java @@ -29,25 +29,22 @@ final class KeyAwareTransportChannelProvider implements TransportChannelProvider { private final InstantiatingGrpcChannelProvider baseProvider; @Nullable private final ChannelEndpointCacheFactory endpointCacheFactory; - - KeyAwareTransportChannelProvider( - InstantiatingGrpcChannelProvider.Builder builder, - @Nullable ChannelEndpointCacheFactory endpointCacheFactory) { - this.baseProvider = builder.build(); - this.endpointCacheFactory = endpointCacheFactory; - } + @Nullable private final GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator; KeyAwareTransportChannelProvider( InstantiatingGrpcChannelProvider baseProvider, - @Nullable ChannelEndpointCacheFactory endpointCacheFactory) { + @Nullable ChannelEndpointCacheFactory endpointCacheFactory, + @Nullable GrpcGcpEndpointChannelConfigurator endpointChannelConfigurator) { this.baseProvider = baseProvider; this.endpointCacheFactory = endpointCacheFactory; + this.endpointChannelConfigurator = endpointChannelConfigurator; } @Override public GrpcTransportChannel getTransportChannel() throws IOException { return GrpcTransportChannel.newBuilder() - .setManagedChannel(KeyAwareChannel.create(baseProvider, endpointCacheFactory)) + .setManagedChannel( + KeyAwareChannel.create(baseProvider, endpointCacheFactory, endpointChannelConfigurator)) .build(); } @@ -85,41 +82,48 @@ public boolean shouldAutoClose() { public TransportChannelProvider withEndpoint(String endpoint) { return new KeyAwareTransportChannelProvider( (InstantiatingGrpcChannelProvider) baseProvider.withEndpoint(endpoint), - endpointCacheFactory); + endpointCacheFactory, + endpointChannelConfigurator); } @Override public TransportChannelProvider withCredentials(Credentials credentials) { return new KeyAwareTransportChannelProvider( (InstantiatingGrpcChannelProvider) baseProvider.withCredentials(credentials), - endpointCacheFactory); + endpointCacheFactory, + endpointChannelConfigurator); } @Override public TransportChannelProvider withHeaders(Map headers) { return new KeyAwareTransportChannelProvider( - (InstantiatingGrpcChannelProvider) baseProvider.withHeaders(headers), endpointCacheFactory); + (InstantiatingGrpcChannelProvider) baseProvider.withHeaders(headers), + endpointCacheFactory, + endpointChannelConfigurator); } @Override public TransportChannelProvider withPoolSize(int poolSize) { return new KeyAwareTransportChannelProvider( (InstantiatingGrpcChannelProvider) baseProvider.withPoolSize(poolSize), - endpointCacheFactory); + endpointCacheFactory, + endpointChannelConfigurator); } @Override public TransportChannelProvider withExecutor(ScheduledExecutorService executor) { return new KeyAwareTransportChannelProvider( (InstantiatingGrpcChannelProvider) baseProvider.withExecutor(executor), - endpointCacheFactory); + endpointCacheFactory, + endpointChannelConfigurator); } @Override public TransportChannelProvider withExecutor(Executor executor) { return new KeyAwareTransportChannelProvider( (InstantiatingGrpcChannelProvider) baseProvider.withExecutor(executor), - endpointCacheFactory); + endpointCacheFactory, + endpointChannelConfigurator); } @Override diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java index 41b8798d9611..027a9c176572 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java @@ -26,6 +26,7 @@ import com.google.spanner.v1.Range; import com.google.spanner.v1.RoutingHint; import com.google.spanner.v1.Tablet; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -65,6 +66,115 @@ public enum RangeMode { PICK_RANDOM } + enum RouteFailureReason { + NONE, + MISSING_ROUTING_KEY, + CACHE_MISS, + ALL_EXCLUDED_OR_COOLDOWN, + NO_READY_REPLICA, + NO_MATCHING_REPLICA, + NO_ROUTABLE_REPLICA + } + + static final class RouteLookupResult { + @javax.annotation.Nullable final ChannelEndpoint endpoint; + @javax.annotation.Nullable final String targetEndpointLabel; + final List skippedTabletDetails; + final RouteFailureReason failureReason; + @javax.annotation.Nullable final SelectionDetail selectionDetail; + + private RouteLookupResult( + @javax.annotation.Nullable ChannelEndpoint endpoint, + @javax.annotation.Nullable String targetEndpointLabel, + List skippedTabletDetails, + RouteFailureReason failureReason, + @javax.annotation.Nullable SelectionDetail selectionDetail) { + this.endpoint = endpoint; + this.targetEndpointLabel = targetEndpointLabel; + this.skippedTabletDetails = skippedTabletDetails; + this.failureReason = failureReason; + this.selectionDetail = selectionDetail; + } + + static RouteLookupResult routed( + ChannelEndpoint endpoint, + String targetEndpointLabel, + List skippedTabletDetails, + @javax.annotation.Nullable SelectionDetail selectionDetail) { + return new RouteLookupResult( + endpoint, + targetEndpointLabel, + Collections.unmodifiableList(new ArrayList<>(skippedTabletDetails)), + RouteFailureReason.NONE, + selectionDetail); + } + + static RouteLookupResult failed( + RouteFailureReason failureReason, List skippedTabletDetails) { + return new RouteLookupResult( + null, + null, + Collections.unmodifiableList(new ArrayList<>(skippedTabletDetails)), + failureReason, + null); + } + } + + static final class SkippedTabletDetail { + @javax.annotation.Nullable final String targetEndpointLabel; + final String reason; + + private SkippedTabletDetail( + @javax.annotation.Nullable String targetEndpointLabel, String reason) { + this.targetEndpointLabel = targetEndpointLabel; + this.reason = reason; + } + } + + static final class SelectionDetail { + final String selectionReason; + final long operationUid; + final int eligibleCandidateCount; + final int scoredCandidateCount; + final double selectedScore; + final double bestEligibleScore; + final String alternativesSummary; + + private SelectionDetail( + String selectionReason, + long operationUid, + int eligibleCandidateCount, + int scoredCandidateCount, + double selectedScore, + double bestEligibleScore, + String alternativesSummary) { + this.selectionReason = selectionReason; + this.operationUid = operationUid; + this.eligibleCandidateCount = eligibleCandidateCount; + this.scoredCandidateCount = scoredCandidateCount; + this.selectedScore = selectedScore; + this.bestEligibleScore = bestEligibleScore; + this.alternativesSummary = alternativesSummary; + } + + double scoreGap() { + if (!Double.isFinite(selectedScore) + || !Double.isFinite(bestEligibleScore) + || selectedScore == Double.MAX_VALUE + || bestEligibleScore == Double.MAX_VALUE) { + return Double.NaN; + } + return selectedScore - bestEligibleScore; + } + } + + static String formatTargetEndpointLabel(String address, boolean isLeader) { + if (address == null || address.isEmpty() || !isLeader) { + return address; + } + return address + "-LEADER"; + } + private final ChannelEndpointCache endpointCache; @javax.annotation.Nullable private final EndpointLifecycleManager lifecycleManager; private final NavigableMap ranges = @@ -74,6 +184,7 @@ public enum RangeMode { private final Lock readLock = cacheLock.readLock(); private final Lock writeLock = cacheLock.writeLock(); private final AtomicLong accessCounter = new AtomicLong(); + private final ReplicaSelector replicaSelector = new PowerOfTwoReplicaSelector(); private volatile boolean deterministicRandom = false; private volatile int minCacheEntriesForRandomPick = DEFAULT_MIN_ENTRIES_FOR_RANDOM_PICK; @@ -99,6 +210,16 @@ void setMinCacheEntriesForRandomPick(int value) { minCacheEntriesForRandomPick = value; } + @VisibleForTesting + void recordReplicaLatency(long operationUid, String address, Duration latency) { + EndpointLatencyRegistry.recordLatency(operationUid, address, latency); + } + + @VisibleForTesting + void recordReplicaError(long operationUid, String address) { + EndpointLatencyRegistry.recordError(operationUid, address); + } + /** Applies cache updates. Tablets are processed inside group updates. */ public void addRanges(CacheUpdate cacheUpdate) { List touchedGroups = new ArrayList<>(); @@ -154,9 +275,21 @@ public ChannelEndpoint fillRoutingHint( DirectedReadOptions directedReadOptions, RoutingHint.Builder hintBuilder, Predicate excludedEndpoints) { + return lookupRoutingHint( + preferLeader, rangeMode, directedReadOptions, hintBuilder, excludedEndpoints) + .endpoint; + } + + RouteLookupResult lookupRoutingHint( + boolean preferLeader, + RangeMode rangeMode, + DirectedReadOptions directedReadOptions, + RoutingHint.Builder hintBuilder, + Predicate excludedEndpoints) { + List skippedTabletDetails = new ArrayList<>(); ByteString key = hintBuilder.getKey(); if (key.isEmpty()) { - return null; + return RouteLookupResult.failed(RouteFailureReason.MISSING_ROUTING_KEY, skippedTabletDetails); } CachedRange targetRange; @@ -168,7 +301,7 @@ public ChannelEndpoint fillRoutingHint( } if (targetRange == null || targetRange.group == null) { - return null; + return RouteLookupResult.failed(RouteFailureReason.CACHE_MISS, skippedTabletDetails); } hintBuilder.setGroupUid(targetRange.group.groupUid); @@ -176,8 +309,8 @@ public ChannelEndpoint fillRoutingHint( hintBuilder.setKey(targetRange.startKey); hintBuilder.setLimitKey(targetRange.limitKey); - return targetRange.group.fillRoutingHint( - preferLeader, directedReadOptions, hintBuilder, excludedEndpoints); + return targetRange.group.lookupRoutingHint( + preferLeader, directedReadOptions, hintBuilder, excludedEndpoints, skippedTabletDetails); } /** Returns all server addresses currently referenced by cached tablets. */ @@ -615,17 +748,19 @@ void update(Group groupIn) { snapshot = new GroupSnapshot(generation, leaderIndex, tablets); } - ChannelEndpoint fillRoutingHint( + RouteLookupResult lookupRoutingHint( boolean preferLeader, DirectedReadOptions directedReadOptions, RoutingHint.Builder hintBuilder, - Predicate excludedEndpoints) { + Predicate excludedEndpoints, + List skippedTabletDetails) { GroupSnapshot snapshot = this.snapshot; Set skippedTabletUids = skippedTabletUids(hintBuilder); boolean hasDirectedReadOptions = directedReadOptions.getReplicasCase() != DirectedReadOptions.ReplicasCase.REPLICAS_NOT_SET; Map resolvedEndpoints = new HashMap<>(); + SelectionStats selectionStats = new SelectionStats(); TabletSnapshot selected = selectTablet( @@ -636,9 +771,34 @@ ChannelEndpoint fillRoutingHint( directedReadOptions, excludedEndpoints, skippedTabletUids, - resolvedEndpoints); + skippedTabletDetails, + resolvedEndpoints, + selectionStats); if (selected == null) { - return null; + RouteFailureReason failureReason = selectionStats.toFailureReason(); + if (failureReason == RouteFailureReason.ALL_EXCLUDED_OR_COOLDOWN) { + selected = + selectRandomExcludedOrCoolingDownTablet( + snapshot, directedReadOptions, hintBuilder, resolvedEndpoints); + if (selected != null) { + recordKnownTransientFailures( + snapshot, + selected, + directedReadOptions, + hintBuilder, + excludedEndpoints, + skippedTabletUids, + skippedTabletDetails, + resolvedEndpoints); + hintBuilder.setTabletUid(selected.tabletUid); + return RouteLookupResult.routed( + resolveEndpoint(selected, resolvedEndpoints), + endpointLabel(snapshot, selected), + skippedTabletDetails, + selectionStats.selectionDetail); + } + } + return RouteLookupResult.failed(failureReason, skippedTabletDetails); } recordKnownTransientFailures( snapshot, @@ -647,9 +807,14 @@ ChannelEndpoint fillRoutingHint( hintBuilder, excludedEndpoints, skippedTabletUids, + skippedTabletDetails, resolvedEndpoints); hintBuilder.setTabletUid(selected.tabletUid); - return resolveEndpoint(selected, resolvedEndpoints); + return RouteLookupResult.routed( + resolveEndpoint(selected, resolvedEndpoints), + endpointLabel(snapshot, selected), + skippedTabletDetails, + selectionStats.selectionDetail); } private TabletSnapshot selectTablet( @@ -660,19 +825,37 @@ private TabletSnapshot selectTablet( DirectedReadOptions directedReadOptions, Predicate excludedEndpoints, Set skippedTabletUids, - Map resolvedEndpoints) { + List skippedTabletDetails, + Map resolvedEndpoints, + SelectionStats selectionStats) { + if (!preferLeader) { + return selectLatencyAwareTablet( + snapshot, + directedReadOptions, + hintBuilder, + excludedEndpoints, + skippedTabletUids, + skippedTabletDetails, + resolvedEndpoints, + selectionStats); + } + boolean checkedLeader = false; if (preferLeader && !hasDirectedReadOptions && snapshot.hasLeader() && snapshot.leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) { checkedLeader = true; + selectionStats.matchingReplicas++; if (!shouldSkip( + snapshot, snapshot.leader(), hintBuilder, excludedEndpoints, skippedTabletUids, - resolvedEndpoints)) { + skippedTabletDetails, + resolvedEndpoints, + selectionStats)) { return snapshot.leader(); } } @@ -684,8 +867,16 @@ private TabletSnapshot selectTablet( if (!tablet.matches(directedReadOptions)) { continue; } + selectionStats.matchingReplicas++; if (shouldSkip( - tablet, hintBuilder, excludedEndpoints, skippedTabletUids, resolvedEndpoints)) { + snapshot, + tablet, + hintBuilder, + excludedEndpoints, + skippedTabletUids, + skippedTabletDetails, + resolvedEndpoints, + selectionStats)) { continue; } return tablet; @@ -693,6 +884,150 @@ private TabletSnapshot selectTablet( return null; } + private TabletSnapshot selectLatencyAwareTablet( + GroupSnapshot snapshot, + DirectedReadOptions directedReadOptions, + RoutingHint.Builder hintBuilder, + Predicate excludedEndpoints, + Set skippedTabletUids, + List skippedTabletDetails, + Map resolvedEndpoints, + SelectionStats selectionStats) { + long operationUid = hintBuilder.getOperationUid(); + List eligibleTablets = new ArrayList<>(); + List eligibleEndpoints = new ArrayList<>(); + int scoredCandidates = 0; + + for (TabletSnapshot tablet : snapshot.tablets) { + if (!tablet.matches(directedReadOptions)) { + continue; + } + selectionStats.matchingReplicas++; + if (shouldSkip( + snapshot, + tablet, + hintBuilder, + excludedEndpoints, + skippedTabletUids, + skippedTabletDetails, + resolvedEndpoints, + selectionStats)) { + continue; + } + + ChannelEndpoint endpoint = resolveEndpoint(tablet, resolvedEndpoints); + if (endpoint == null) { + continue; + } + eligibleTablets.add(tablet); + eligibleEndpoints.add(endpoint); + if (EndpointLatencyRegistry.hasScore(operationUid, tablet.serverAddress)) { + scoredCandidates++; + } + } + + if (eligibleTablets.isEmpty()) { + return null; + } + if (eligibleTablets.size() == 1) { + TabletSnapshot selected = eligibleTablets.get(0); + selectionStats.selectionDetail = + buildSelectionDetail( + snapshot, + eligibleTablets, + operationUid, + "single_candidate", + selected, + scoredCandidates); + return selected; + } + if (deterministicRandom) { + TabletSnapshot selected = + eligibleTablets.stream() + .min( + Comparator.comparingDouble( + tablet -> + EndpointLatencyRegistry.getSelectionCost( + operationUid, tablet.serverAddress))) + .orElse(eligibleTablets.get(0)); + selectionStats.selectionDetail = + buildSelectionDetail( + snapshot, + eligibleTablets, + operationUid, + "latency_score", + selected, + scoredCandidates); + return selected; + } + + ChannelEndpoint selectedEndpoint = + replicaSelector.select( + eligibleEndpoints, + endpoint -> + EndpointLatencyRegistry.getSelectionCost(operationUid, endpoint.getAddress())); + if (selectedEndpoint == null) { + TabletSnapshot selected = eligibleTablets.get(0); + selectionStats.selectionDetail = + buildSelectionDetail( + snapshot, + eligibleTablets, + operationUid, + "latency_score", + selected, + scoredCandidates); + return selected; + } + for (int i = 0; i < eligibleTablets.size(); i++) { + if (eligibleEndpoints.get(i) == selectedEndpoint) { + TabletSnapshot selected = eligibleTablets.get(i); + selectionStats.selectionDetail = + buildSelectionDetail( + snapshot, + eligibleTablets, + operationUid, + "latency_score", + selected, + scoredCandidates); + return selected; + } + } + TabletSnapshot selected = eligibleTablets.get(0); + selectionStats.selectionDetail = + buildSelectionDetail( + snapshot, eligibleTablets, operationUid, "latency_score", selected, scoredCandidates); + return selected; + } + + @javax.annotation.Nullable + private TabletSnapshot selectRandomExcludedOrCoolingDownTablet( + GroupSnapshot snapshot, + DirectedReadOptions directedReadOptions, + RoutingHint.Builder hintBuilder, + Map resolvedEndpoints) { + List candidates = new ArrayList<>(); + for (TabletSnapshot tablet : snapshot.tablets) { + if (!tablet.matches(directedReadOptions) || tablet.skip || tablet.serverAddress.isEmpty()) { + continue; + } + ChannelEndpoint endpoint = resolveEndpoint(tablet, resolvedEndpoints); + if (endpoint == null || !endpoint.isHealthy()) { + continue; + } + candidates.add(tablet); + } + if (candidates.isEmpty()) { + return null; + } + int index = + uniformRandom( + candidates.size(), + hintBuilder.getKey(), + hintBuilder.getLimitKey(), + snapshot.generation); + return candidates.get(index); + } + private void recordKnownTransientFailures( GroupSnapshot snapshot, TabletSnapshot selected, @@ -700,13 +1035,20 @@ private void recordKnownTransientFailures( RoutingHint.Builder hintBuilder, Predicate excludedEndpoints, Set skippedTabletUids, + List skippedTabletDetails, Map resolvedEndpoints) { for (TabletSnapshot tablet : snapshot.tablets) { if (tablet == selected || !tablet.matches(directedReadOptions)) { continue; } recordKnownTransientFailure( - tablet, hintBuilder, excludedEndpoints, skippedTabletUids, resolvedEndpoints); + snapshot, + tablet, + hintBuilder, + excludedEndpoints, + skippedTabletUids, + skippedTabletDetails, + resolvedEndpoints); } } @@ -719,25 +1061,46 @@ private Set skippedTabletUids(RoutingHint.Builder hintBuilder) { } private boolean shouldSkip( + GroupSnapshot snapshot, TabletSnapshot tablet, RoutingHint.Builder hintBuilder, Predicate excludedEndpoints, Set skippedTabletUids, - Map resolvedEndpoints) { - if (tablet.skip - || tablet.serverAddress.isEmpty() - || excludedEndpoints.test(tablet.serverAddress)) { - addSkippedTablet(tablet, hintBuilder, skippedTabletUids); + List skippedTabletDetails, + Map resolvedEndpoints, + SelectionStats selectionStats) { + String targetEndpointLabel = endpointLabel(snapshot, tablet); + if (tablet.skip) { + selectionStats.tabletMarkedSkipCount++; + addSkippedTablet( + tablet, + hintBuilder, + skippedTabletUids, + skippedTabletDetails, + targetEndpointLabel, + "tablet_marked_skip"); + return true; + } + if (tablet.serverAddress.isEmpty()) { + selectionStats.missingAddressCount++; + addSkippedTablet( + tablet, hintBuilder, skippedTabletUids, skippedTabletDetails, null, "missing_address"); + return true; + } + if (excludedEndpoints.test(tablet.serverAddress)) { + selectionStats.excludedCount++; return true; } ChannelEndpoint endpoint = resolveEndpoint(tablet, resolvedEndpoints); if (endpoint == null) { + selectionStats.missingEndpointCount++; logger.log( Level.FINE, "Tablet {0} at {1}: no endpoint present, skipping silently", new Object[] {tablet.tabletUid, tablet.serverAddress}); - maybeAddRecentTransientFailureSkip(tablet, hintBuilder, skippedTabletUids); + maybeAddRecentTransientFailureSkip( + tablet, targetEndpointLabel, hintBuilder, skippedTabletUids, skippedTabletDetails); if (lifecycleManager != null) { lifecycleManager.requestEndpointRecreation(tablet.serverAddress); } @@ -747,27 +1110,132 @@ private boolean shouldSkip( return false; } if (endpoint.isTransientFailure()) { + selectionStats.transientFailureCount++; logger.log( Level.FINE, "Tablet {0} at {1}: endpoint in TRANSIENT_FAILURE, adding to skipped_tablets", new Object[] {tablet.tabletUid, tablet.serverAddress}); - addSkippedTablet(tablet, hintBuilder, skippedTabletUids); + addSkippedTablet( + tablet, + hintBuilder, + skippedTabletUids, + skippedTabletDetails, + targetEndpointLabel, + "transient_failure"); return true; } + selectionStats.notReadyCount++; logger.log( Level.FINE, "Tablet {0} at {1}: endpoint not ready, skipping silently", new Object[] {tablet.tabletUid, tablet.serverAddress}); - maybeAddRecentTransientFailureSkip(tablet, hintBuilder, skippedTabletUids); + maybeAddRecentTransientFailureSkip( + tablet, targetEndpointLabel, hintBuilder, skippedTabletUids, skippedTabletDetails); return true; } + private final class SelectionStats { + private int matchingReplicas; + private int excludedCount; + private int transientFailureCount; + private int notReadyCount; + private int missingEndpointCount; + private int tabletMarkedSkipCount; + private int missingAddressCount; + @javax.annotation.Nullable private SelectionDetail selectionDetail; + + private RouteFailureReason toFailureReason() { + if (matchingReplicas == 0) { + return RouteFailureReason.NO_MATCHING_REPLICA; + } + if (excludedCount == matchingReplicas) { + return RouteFailureReason.ALL_EXCLUDED_OR_COOLDOWN; + } + if (transientFailureCount > 0 || notReadyCount > 0 || missingEndpointCount > 0) { + return RouteFailureReason.NO_READY_REPLICA; + } + if (tabletMarkedSkipCount > 0 || missingAddressCount > 0) { + return RouteFailureReason.NO_ROUTABLE_REPLICA; + } + return RouteFailureReason.NO_ROUTABLE_REPLICA; + } + } + + private SelectionDetail buildSelectionDetail( + GroupSnapshot snapshot, + List eligibleTablets, + long operationUid, + String selectionReason, + TabletSnapshot selected, + int scoredCandidates) { + double bestScore = Double.MAX_VALUE; + for (TabletSnapshot tablet : eligibleTablets) { + bestScore = + Math.min( + bestScore, + EndpointLatencyRegistry.getSelectionCost(operationUid, tablet.serverAddress)); + } + + StringBuilder alternatives = new StringBuilder(); + int appended = 0; + double selectedScore = + EndpointLatencyRegistry.getSelectionCost(operationUid, selected.serverAddress); + for (TabletSnapshot tablet : eligibleTablets) { + if (tablet == selected || appended >= 4) { + continue; + } + if (alternatives.length() > 0) { + alternatives.append(", "); + } + double candidateScore = + EndpointLatencyRegistry.getSelectionCost(operationUid, tablet.serverAddress); + alternatives + .append(endpointLabel(snapshot, tablet)) + .append("=") + .append(formatScore(candidateScore)) + .append(":") + .append(alternativeReason(candidateScore, selectedScore)); + appended++; + } + + return new SelectionDetail( + selectionReason, + operationUid, + eligibleTablets.size(), + scoredCandidates, + selectedScore, + bestScore, + alternatives.toString()); + } + + private String alternativeReason(double candidateScore, double selectedScore) { + if (!Double.isFinite(candidateScore) || candidateScore == Double.MAX_VALUE) { + return "unscored"; + } + if (candidateScore < selectedScore) { + return "not_sampled_by_policy"; + } + if (candidateScore > selectedScore) { + return "higher_score"; + } + return "tie_not_selected"; + } + + private String formatScore(double score) { + if (!Double.isFinite(score) || score == Double.MAX_VALUE) { + return "unknown"; + } + return Long.toString(Math.round(score)); + } + private void recordKnownTransientFailure( + GroupSnapshot snapshot, TabletSnapshot tablet, RoutingHint.Builder hintBuilder, Predicate excludedEndpoints, Set skippedTabletUids, + List skippedTabletDetails, Map resolvedEndpoints) { if (tablet.skip || tablet.serverAddress.isEmpty() @@ -777,11 +1245,27 @@ private void recordKnownTransientFailure( ChannelEndpoint endpoint = resolveEndpoint(tablet, resolvedEndpoints); if (endpoint != null && endpoint.isTransientFailure()) { - addSkippedTablet(tablet, hintBuilder, skippedTabletUids); + addSkippedTablet( + tablet, + hintBuilder, + skippedTabletUids, + skippedTabletDetails, + endpointLabel(snapshot, tablet), + "known_transient_failure"); return; } - maybeAddRecentTransientFailureSkip(tablet, hintBuilder, skippedTabletUids); + maybeAddRecentTransientFailureSkip( + tablet, + endpointLabel(snapshot, tablet), + hintBuilder, + skippedTabletUids, + skippedTabletDetails); + } + + private String endpointLabel(GroupSnapshot snapshot, TabletSnapshot tablet) { + return formatTargetEndpointLabel( + tablet.serverAddress, snapshot.hasLeader() && snapshot.leader() == tablet); } private ChannelEndpoint resolveEndpoint( @@ -805,21 +1289,37 @@ private ChannelEndpoint resolveEndpoint( } private void maybeAddRecentTransientFailureSkip( - TabletSnapshot tablet, RoutingHint.Builder hintBuilder, Set skippedTabletUids) { + TabletSnapshot tablet, + @javax.annotation.Nullable String targetEndpointLabel, + RoutingHint.Builder hintBuilder, + Set skippedTabletUids, + List skippedTabletDetails) { if (lifecycleManager != null && lifecycleManager.wasRecentlyEvictedTransientFailure(tablet.serverAddress)) { - addSkippedTablet(tablet, hintBuilder, skippedTabletUids); + addSkippedTablet( + tablet, + hintBuilder, + skippedTabletUids, + skippedTabletDetails, + targetEndpointLabel, + "recent_transient_failure_eviction"); } } private void addSkippedTablet( - TabletSnapshot tablet, RoutingHint.Builder hintBuilder, Set skippedTabletUids) { + TabletSnapshot tablet, + RoutingHint.Builder hintBuilder, + Set skippedTabletUids, + List skippedTabletDetails, + @javax.annotation.Nullable String targetEndpointLabel, + String reason) { if (!skippedTabletUids.add(tablet.tabletUid)) { return; } RoutingHint.SkippedTablet.Builder skipped = hintBuilder.addSkippedTabletUidBuilder(); skipped.setTabletUid(tablet.tabletUid); skipped.setIncarnation(tablet.incarnation); + skippedTabletDetails.add(new SkippedTabletDetail(targetEndpointLabel, reason)); } String debugString() { diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/RequestIdTargetTracker.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/RequestIdTargetTracker.java new file mode 100644 index 000000000000..f64a3bd23728 --- /dev/null +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/RequestIdTargetTracker.java @@ -0,0 +1,87 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.spi.v1; + +import com.google.cloud.spanner.XGoogSpannerRequestId; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; + +final class RequestIdTargetTracker { + + private static final Cache TARGETS = + CacheBuilder.newBuilder() + .maximumSize(100_000_000L) + .expireAfterWrite(10, TimeUnit.MINUTES) + .build(); + + private RequestIdTargetTracker() {} + + static void record(String requestId, String targetEndpoint, long operationUid) { + String trackingKey = normalizeRequestKey(requestId); + if (trackingKey == null || targetEndpoint == null || targetEndpoint.isEmpty()) { + return; + } + TARGETS.put(trackingKey, new RoutingTarget(targetEndpoint, operationUid)); + } + + @Nullable + static RoutingTarget get(String requestId) { + String trackingKey = normalizeRequestKey(requestId); + if (trackingKey == null) { + return null; + } + return TARGETS.getIfPresent(trackingKey); + } + + static void remove(String requestId) { + String trackingKey = normalizeRequestKey(requestId); + if (trackingKey == null) { + return; + } + TARGETS.invalidate(trackingKey); + } + + @VisibleForTesting + static void clear() { + TARGETS.invalidateAll(); + } + + @VisibleForTesting + static String normalizeRequestKey(String requestId) { + if (requestId == null || requestId.isEmpty()) { + return null; + } + try { + return XGoogSpannerRequestId.of(requestId).getLogicalRequestKey(); + } catch (IllegalStateException e) { + return requestId; + } + } + + static final class RoutingTarget { + final String targetEndpoint; + final long operationUid; + + private RoutingTarget(String targetEndpoint, long operationUid) { + this.targetEndpoint = targetEndpoint; + this.operationUid = operationUid; + } + } +} diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManagerTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManagerTest.java index 552cfd9bd2c8..341974f17fd6 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManagerTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLifecycleManagerTest.java @@ -267,6 +267,73 @@ public void transientFailureEvictionTrackedUntilEndpointReadyAgain() throws Exce () -> !manager.wasRecentlyEvictedTransientFailure("server1")); } + @Test + public void transientFailureOscillationWithConnectingStillEvictsEndpoint() throws Exception { + KeyRangeCacheTest.FakeEndpointCache cache = new KeyRangeCacheTest.FakeEndpointCache(); + manager = + new EndpointLifecycleManager( + cache, /* probeIntervalSeconds= */ 60, Duration.ofMinutes(30), Clock.systemUTC()); + + registerAddresses(manager, "server1"); + awaitCondition( + "endpoint should be created in background", () -> cache.getIfPresent("server1") != null); + + cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.TRANSIENT_FAILURE); + manager.probe("server1"); + assertEquals(1, manager.getEndpointState("server1").consecutiveTransientFailures); + + cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.CONNECTING); + manager.probe("server1"); + assertEquals(1, manager.getEndpointState("server1").consecutiveTransientFailures); + + cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.TRANSIENT_FAILURE); + manager.probe("server1"); + assertEquals(2, manager.getEndpointState("server1").consecutiveTransientFailures); + + cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.CONNECTING); + manager.probe("server1"); + assertEquals(2, manager.getEndpointState("server1").consecutiveTransientFailures); + + cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.TRANSIENT_FAILURE); + manager.probe("server1"); + + assertFalse(manager.isManaged("server1")); + assertTrue(manager.wasRecentlyEvictedTransientFailure("server1")); + assertNull(cache.getIfPresent("server1")); + } + + @Test + public void readyResetsTransientFailureCounterAfterRecovery() throws Exception { + KeyRangeCacheTest.FakeEndpointCache cache = new KeyRangeCacheTest.FakeEndpointCache(); + manager = + new EndpointLifecycleManager( + cache, /* probeIntervalSeconds= */ 60, Duration.ofMinutes(30), Clock.systemUTC()); + + registerAddresses(manager, "server1"); + awaitCondition( + "endpoint should be created in background", () -> cache.getIfPresent("server1") != null); + + cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.TRANSIENT_FAILURE); + manager.probe("server1"); + cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.CONNECTING); + manager.probe("server1"); + cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.TRANSIENT_FAILURE); + manager.probe("server1"); + assertEquals(2, manager.getEndpointState("server1").consecutiveTransientFailures); + + cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.READY); + manager.probe("server1"); + EndpointLifecycleManager.EndpointState state = manager.getEndpointState("server1"); + assertNotNull(state); + assertEquals(0, state.consecutiveTransientFailures); + assertNotNull(state.lastReadyAt); + + cache.setState("server1", KeyRangeCacheTest.EndpointHealthState.TRANSIENT_FAILURE); + manager.probe("server1"); + assertEquals(1, manager.getEndpointState("server1").consecutiveTransientFailures); + assertTrue(manager.isManaged("server1")); + } + @Test public void transientFailureEvictionMarkerRemovedWhenAddressNoLongerActive() throws Exception { KeyRangeCacheTest.FakeEndpointCache cache = new KeyRangeCacheTest.FakeEndpointCache(); diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTrackerTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTrackerTest.java index 306628b9bdab..84fa32b4fb7e 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTrackerTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTrackerTest.java @@ -20,6 +20,7 @@ import static org.junit.Assert.assertThrows; import java.time.Duration; +import java.util.function.LongSupplier; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -29,7 +30,7 @@ public class EwmaLatencyTrackerTest { @Test public void testInitialization() { - EwmaLatencyTracker tracker = new EwmaLatencyTracker(); + EwmaLatencyTracker tracker = new EwmaLatencyTracker(Duration.ofSeconds(10), new FakeClock()); tracker.update(Duration.ofNanos(100 * 1000)); assertEquals(100.0, tracker.getScore(), 0.001); } @@ -42,7 +43,7 @@ public void testUninitializedScore() { @Test public void testOverflowScore() { - EwmaLatencyTracker tracker = new EwmaLatencyTracker(); + EwmaLatencyTracker tracker = new EwmaLatencyTracker(Duration.ofSeconds(10), new FakeClock()); tracker.update(Duration.ofSeconds(Long.MAX_VALUE)); assertEquals((double) Long.MAX_VALUE, tracker.getScore(), 0.001); } @@ -50,7 +51,7 @@ public void testOverflowScore() { @Test public void testEwmaCalculation() { double alpha = 0.5; - EwmaLatencyTracker tracker = new EwmaLatencyTracker(alpha); + EwmaLatencyTracker tracker = new EwmaLatencyTracker(alpha, new FakeClock()); tracker.update(Duration.ofNanos(100 * 1000)); // Initial score = 100 assertEquals(100.0, tracker.getScore(), 0.001); @@ -63,19 +64,21 @@ public void testEwmaCalculation() { } @Test - public void testDefaultAlpha() { - EwmaLatencyTracker tracker = new EwmaLatencyTracker(); + public void testDefaultDecayUsesTimeBasedAlpha() { + FakeClock clock = new FakeClock(); + EwmaLatencyTracker tracker = new EwmaLatencyTracker(Duration.ofSeconds(10), clock); tracker.update(Duration.ofNanos(100 * 1000)); + clock.advance(Duration.ofSeconds(10)); tracker.update(Duration.ofNanos(200 * 1000)); - double expected = - EwmaLatencyTracker.DEFAULT_ALPHA * 200 + (1 - EwmaLatencyTracker.DEFAULT_ALPHA) * 100; + double alpha = 1.0 - Math.exp(-1.0); + double expected = alpha * 200 + (1.0 - alpha) * 100; assertEquals(expected, tracker.getScore(), 0.001); } @Test public void testRecordError() { - EwmaLatencyTracker tracker = new EwmaLatencyTracker(0.5); + EwmaLatencyTracker tracker = new EwmaLatencyTracker(0.5, new FakeClock()); tracker.update(Duration.ofNanos(100 * 1000)); tracker.recordError(Duration.ofNanos(10000 * 1000)); // Score = 0.5 * 10000 + 0.5 * 100 = 5050 @@ -91,11 +94,24 @@ public void testInvalidAlpha() { @Test public void testAlphaOne() { - EwmaLatencyTracker tracker = new EwmaLatencyTracker(1.0); + EwmaLatencyTracker tracker = new EwmaLatencyTracker(1.0, new FakeClock()); tracker.update(Duration.ofNanos(100 * 1000)); assertEquals(100.0, tracker.getScore(), 0.001); tracker.update(Duration.ofNanos(200 * 1000)); assertEquals(200.0, tracker.getScore(), 0.001); } + + private static final class FakeClock implements LongSupplier { + private long currentNanos; + + @Override + public long getAsLong() { + return currentNanos; + } + + void advance(Duration duration) { + currentNanos += duration.toNanos(); + } + } } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCacheTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCacheTest.java index 74afec18bfc3..c0eea3f88c4f 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCacheTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCacheTest.java @@ -25,6 +25,9 @@ import org.junit.Test; public class GrpcChannelEndpointCacheTest { + private static final String DEFAULT_ENDPOINT = "default.invalid:1234"; + private static final String ROUTED_ENDPOINT_A = "replica-a.invalid:1111"; + private static final String ROUTED_ENDPOINT_B = "replica-b.invalid:2222"; private static InstantiatingGrpcChannelProvider createProvider(String endpoint) { return InstantiatingGrpcChannelProvider.newBuilder() @@ -35,7 +38,7 @@ private static InstantiatingGrpcChannelProvider createProvider(String endpoint) @Test public void defaultChannelIsCached() throws Exception { - GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider("localhost:1234")); + GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider(DEFAULT_ENDPOINT)); try { ChannelEndpoint defaultChannel = cache.defaultChannel(); ChannelEndpoint server = cache.get(defaultChannel.getAddress()); @@ -47,11 +50,11 @@ public void defaultChannelIsCached() throws Exception { @Test public void getCachesPerAddress() throws Exception { - GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider("localhost:1234")); + GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider(DEFAULT_ENDPOINT)); try { - ChannelEndpoint first = cache.get("localhost:1111"); - ChannelEndpoint second = cache.get("localhost:1111"); - ChannelEndpoint third = cache.get("localhost:2222"); + ChannelEndpoint first = cache.get(ROUTED_ENDPOINT_A); + ChannelEndpoint second = cache.get(ROUTED_ENDPOINT_A); + ChannelEndpoint third = cache.get(ROUTED_ENDPOINT_B); assertThat(second).isSameInstanceAs(first); assertThat(third).isNotSameInstanceAs(first); @@ -62,11 +65,57 @@ public void getCachesPerAddress() throws Exception { @Test public void routedChannelsReuseDefaultAuthority() throws Exception { - GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider("localhost:1234")); + GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider(DEFAULT_ENDPOINT)); try { - ChannelEndpoint routed = cache.get("localhost:1111"); + ChannelEndpoint routed = cache.get(ROUTED_ENDPOINT_A); - assertThat(routed.getChannel().authority()).isEqualTo("localhost:1234"); + assertThat(routed.getChannel().authority()).isEqualTo(DEFAULT_ENDPOINT); + } finally { + cache.shutdown(); + } + } + + @Test + public void routedChannelsUseSingleUnderlyingChannel() throws Exception { + InstantiatingGrpcChannelProvider provider = + InstantiatingGrpcChannelProvider.newBuilder() + .setEndpoint(DEFAULT_ENDPOINT) + .setPoolSize(4) + .setChannelConfigurator(ManagedChannelBuilder::usePlaintext) + .build(); + GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(provider); + try { + InstantiatingGrpcChannelProvider routedProvider = + cache.createProviderWithAuthorityOverride(ROUTED_ENDPOINT_A); + + assertThat(provider.toBuilder().getPoolSize()).isEqualTo(4); + assertThat(routedProvider.getChannelPoolSettings().getInitialChannelCount()).isEqualTo(1); + assertThat(routedProvider.getChannelPoolSettings().getMinChannelCount()).isEqualTo(1); + assertThat(routedProvider.getChannelPoolSettings().getMaxChannelCount()).isEqualTo(1); + } finally { + cache.shutdown(); + } + } + + @Test + public void routedChannelsEnableKeepAliveWithoutCallsOnlyForEndpointProvider() throws Exception { + InstantiatingGrpcChannelProvider provider = + InstantiatingGrpcChannelProvider.newBuilder() + .setEndpoint(DEFAULT_ENDPOINT) + .setPoolSize(4) + .setKeepAliveTimeDuration(java.time.Duration.ofSeconds(120)) + .setKeepAliveWithoutCalls(Boolean.FALSE) + .setChannelConfigurator(ManagedChannelBuilder::usePlaintext) + .build(); + GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(provider); + try { + InstantiatingGrpcChannelProvider routedProvider = + cache.createProviderWithAuthorityOverride(ROUTED_ENDPOINT_A); + + assertThat(provider.getKeepAliveWithoutCalls()).isFalse(); + assertThat(routedProvider.getKeepAliveWithoutCalls()).isTrue(); + assertThat(routedProvider.getKeepAliveTimeDuration()) + .isEqualTo(provider.getKeepAliveTimeDuration()); } finally { cache.shutdown(); } @@ -74,11 +123,11 @@ public void routedChannelsReuseDefaultAuthority() throws Exception { @Test public void evictRemovesNonDefaultServer() throws Exception { - GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider("localhost:1234")); + GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider(DEFAULT_ENDPOINT)); try { - ChannelEndpoint first = cache.get("localhost:1111"); - cache.evict("localhost:1111"); - ChannelEndpoint second = cache.get("localhost:1111"); + ChannelEndpoint first = cache.get(ROUTED_ENDPOINT_A); + cache.evict(ROUTED_ENDPOINT_A); + ChannelEndpoint second = cache.get(ROUTED_ENDPOINT_A); assertThat(second).isNotSameInstanceAs(first); } finally { @@ -88,7 +137,7 @@ public void evictRemovesNonDefaultServer() throws Exception { @Test public void evictIgnoresDefaultChannel() throws Exception { - GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider("localhost:1234")); + GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider(DEFAULT_ENDPOINT)); try { ChannelEndpoint defaultChannel = cache.defaultChannel(); cache.evict(defaultChannel.getAddress()); @@ -102,18 +151,18 @@ public void evictIgnoresDefaultChannel() throws Exception { @Test public void shutdownPreventsNewServers() throws Exception { - GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider("localhost:1234")); + GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider(DEFAULT_ENDPOINT)); cache.shutdown(); - assertThrows(SpannerException.class, () -> cache.get("localhost:1111")); + assertThrows(SpannerException.class, () -> cache.get(ROUTED_ENDPOINT_A)); assertThat(cache.defaultChannel().getChannel().isShutdown()).isTrue(); } @Test public void healthReflectsChannelShutdown() throws Exception { - GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider("localhost:1234")); + GrpcChannelEndpointCache cache = new GrpcChannelEndpointCache(createProvider(DEFAULT_ENDPOINT)); try { - ChannelEndpoint server = cache.get("localhost:1111"); + ChannelEndpoint server = cache.get(ROUTED_ENDPOINT_A); // Newly created channel is not READY (likely IDLE), so isHealthy is false for location aware. // isHealthy now requires READY state for location aware routing. assertThat(server.isHealthy()).isFalse(); diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java index 1c0a277ca4f4..ba795dd7074a 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java @@ -507,6 +507,13 @@ public void resourceExhaustedRoutedEndpointIsAvoidedOnRetry() throws Exception { .isFalse(); } + @Test + public void resourceExhaustedOrUnavailableRoutedEndpointRecordsErrorPenalty() throws Exception { + assertRoutedEndpointErrorPenaltyRecorded(Status.RESOURCE_EXHAUSTED, 101L); + EndpointLatencyRegistry.clear(); + assertRoutedEndpointErrorPenaltyRecorded(Status.UNAVAILABLE, 102L); + } + @Test public void resourceExhaustedAffinityEndpointIsAvoidedForSubsequentTransactionRequest() throws Exception { @@ -572,7 +579,7 @@ public void resourceExhaustedAffinityEndpointIsAvoidedForSubsequentTransactionRe } @Test - public void resourceExhaustedRoutedEndpointFallsBackToDefaultWhenNoReplicaExists() + public void resourceExhaustedRoutedEndpointRetriesSameReplicaWhenSingleReplicaIsExcluded() throws Exception { TestHarness harness = createHarness(); CallOptions retryCallOptions = retryCallOptions(3L); @@ -602,8 +609,84 @@ public void resourceExhaustedRoutedEndpointFallsBackToDefaultWhenNoReplicaExists secondCall.start(new CapturingListener(), new Metadata()); secondCall.sendMessage(request); - assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1); - assertThat(harness.defaultManagedChannel.callCount()).isEqualTo(2); + assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(2); + assertThat(harness.defaultManagedChannel.callCount()).isEqualTo(1); + } + + @Test + public void resourceExhaustedRoutedEndpointRetriesRandomExcludedReplicaWhenAllReplicasExcluded() + throws Exception { + boolean selectedServerAOnThirdAttempt = false; + boolean selectedServerBOnThirdAttempt = false; + + for (int attempt = 0; + attempt < 20 && (!selectedServerAOnThirdAttempt || !selectedServerBOnThirdAttempt); + attempt++) { + TestHarness harness = createHarness(); + seedCache(harness, createLeaderAndReplicaCacheUpdate()); + CallOptions retryCallOptions = retryCallOptions(100L + attempt); + ExecuteSqlRequest request = + ExecuteSqlRequest.newBuilder() + .setSession(SESSION) + .setRoutingHint(RoutingHint.newBuilder().setKey(bytes("a")).build()) + .build(); + + ClientCall firstCall = + harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), retryCallOptions); + firstCall.start(new CapturingListener(), new Metadata()); + firstCall.sendMessage(request); + + assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1); + assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(0); + + @SuppressWarnings("unchecked") + RecordingClientCall firstDelegate = + (RecordingClientCall) + harness.endpointCache.latestCallForAddress("server-a:1234"); + firstDelegate.emitOnClose(Status.RESOURCE_EXHAUSTED, new Metadata()); + + ClientCall secondCall = + harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), retryCallOptions); + secondCall.start(new CapturingListener(), new Metadata()); + secondCall.sendMessage(request); + + assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1); + assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(1); + + @SuppressWarnings("unchecked") + RecordingClientCall secondDelegate = + (RecordingClientCall) + harness.endpointCache.latestCallForAddress("server-b:1234"); + secondDelegate.emitOnClose(Status.RESOURCE_EXHAUSTED, new Metadata()); + + int serverACallCountBeforeThirdAttempt = + harness.endpointCache.callCountForAddress("server-a:1234"); + int serverBCallCountBeforeThirdAttempt = + harness.endpointCache.callCountForAddress("server-b:1234"); + + ClientCall thirdCall = + harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), retryCallOptions); + thirdCall.start(new CapturingListener(), new Metadata()); + thirdCall.sendMessage(request); + + int serverACallCountAfterThirdAttempt = + harness.endpointCache.callCountForAddress("server-a:1234"); + int serverBCallCountAfterThirdAttempt = + harness.endpointCache.callCountForAddress("server-b:1234"); + + assertThat( + (serverACallCountAfterThirdAttempt - serverACallCountBeforeThirdAttempt) + + (serverBCallCountAfterThirdAttempt - serverBCallCountBeforeThirdAttempt)) + .isEqualTo(1); + + selectedServerAOnThirdAttempt |= + serverACallCountAfterThirdAttempt == serverACallCountBeforeThirdAttempt + 1; + selectedServerBOnThirdAttempt |= + serverBCallCountAfterThirdAttempt == serverBCallCountBeforeThirdAttempt + 1; + } + + assertThat(selectedServerAOnThirdAttempt).isTrue(); + assertThat(selectedServerBOnThirdAttempt).isTrue(); } @Test @@ -1558,4 +1641,38 @@ private static CallOptions retryCallOptions(XGoogSpannerRequestId requestId) { return CallOptions.DEFAULT.withOption( XGoogSpannerRequestId.REQUEST_ID_CALL_OPTIONS_KEY, requestId); } + + private static void assertRoutedEndpointErrorPenaltyRecorded(Status status, long operationUid) + throws Exception { + EndpointLatencyRegistry.clear(); + TestHarness harness = createHarness(); + seedCache(harness, createLeaderAndReplicaCacheUpdate()); + + ExecuteSqlRequest request = + ExecuteSqlRequest.newBuilder() + .setSession(SESSION) + .setRoutingHint( + RoutingHint.newBuilder().setKey(bytes("b")).setOperationUid(operationUid).build()) + .build(); + + ClientCall call = + harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), retryCallOptions(operationUid)); + call.start(new CapturingListener(), new Metadata()); + call.sendMessage(request); + + assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1); + + @SuppressWarnings("unchecked") + RecordingClientCall delegate = + (RecordingClientCall) + harness.endpointCache.latestCallForAddress("server-a:1234"); + long routedOperationUid = delegate.lastMessage.getRoutingHint().getOperationUid(); + assertThat(routedOperationUid).isGreaterThan(0L); + delegate.emitOnClose(status, new Metadata()); + + assertThat(EndpointLatencyRegistry.hasScore(routedOperationUid, "server-a:1234")).isTrue(); + assertThat(EndpointLatencyRegistry.getSelectionCost(routedOperationUid, "server-a:1234")) + .isEqualTo((double) EndpointLatencyRegistry.DEFAULT_ERROR_PENALTY.toNanos() / 1_000D); + assertThat(EndpointLatencyRegistry.hasScore(routedOperationUid, "server-b:1234")).isFalse(); + } } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java index b19123daa704..5dd2fa01a04b 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java @@ -33,16 +33,24 @@ import io.grpc.ConnectivityState; import io.grpc.ManagedChannel; import io.grpc.MethodDescriptor; +import java.time.Duration; import java.util.HashMap; import java.util.Map; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; +import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class KeyRangeCacheTest { + private static final long TEST_OPERATION_UID = 101L; + + @After + public void tearDown() { + EndpointLatencyRegistry.clear(); + } @Test public void skipsTransientFailureTabletWithSkippedTablet() { @@ -130,8 +138,67 @@ public void skipsExplicitlyExcludedTablet() { assertNotNull(server); assertEquals("server2", server.getAddress()); - assertEquals(1, hint.getSkippedTabletUidCount()); - assertEquals(1L, hint.getSkippedTabletUid(0).getTabletUid()); + assertEquals(0, hint.getSkippedTabletUidCount()); + } + + @Test + public void lookupRoutingHintReportsCacheMiss() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + KeyRangeCache cache = new KeyRangeCache(endpointCache); + + RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a")); + KeyRangeCache.RouteLookupResult result = + cache.lookupRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + hint, + address -> false); + + assertNull(result.endpoint); + assertEquals(KeyRangeCache.RouteFailureReason.CACHE_MISS, result.failureReason); + } + + @Test + public void lookupRoutingHintReusesReplicaWhenAllCandidatesAreExcludedOrCoolingDown() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + KeyRangeCache cache = new KeyRangeCache(endpointCache); + cache.addRanges(singleReplicaUpdate("server1")); + endpointCache.get("server1"); + + RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a")); + KeyRangeCache.RouteLookupResult result = + cache.lookupRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + hint, + "server1"::equals); + + assertNotNull(result.endpoint); + assertEquals("server1", result.endpoint.getAddress()); + assertEquals(KeyRangeCache.RouteFailureReason.NONE, result.failureReason); + } + + @Test + public void lookupRoutingHintReportsNoReadyReplica() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + KeyRangeCache cache = new KeyRangeCache(endpointCache); + cache.addRanges(singleReplicaUpdate("server1")); + endpointCache.get("server1"); + endpointCache.setState("server1", EndpointHealthState.IDLE); + + RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a")); + KeyRangeCache.RouteLookupResult result = + cache.lookupRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + hint, + address -> false); + + assertNull(result.endpoint); + assertEquals(KeyRangeCache.RouteFailureReason.NO_READY_REPLICA, result.failureReason); } @Test @@ -351,6 +418,29 @@ public void connectingEndpointCausesDefaultHostFallbackWithoutSkippedTablet() { assertEquals(0, hint.getSkippedTabletUidCount()); } + @Test + public void excludedEndpointDoesNotAddSkippedTablet() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + KeyRangeCache cache = new KeyRangeCache(endpointCache); + cache.addRanges(singleReplicaUpdate("server1")); + + endpointCache.get("server1"); + endpointCache.setState("server1", EndpointHealthState.READY); + + RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a")); + ChannelEndpoint server = + cache.fillRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + hint, + "server1"::equals); + + assertNotNull(server); + assertEquals("server1", server.getAddress()); + assertEquals(0, hint.getSkippedTabletUidCount()); + } + @Test public void transientFailureEndpointCausesSkippedTabletPlusDefaultHostFallback() { FakeEndpointCache endpointCache = new FakeEndpointCache(); @@ -481,6 +571,7 @@ public void transientFailureReplicaSkippedAndReadyReplicaSelected() { public void laterTransientFailureReplicaReportedWhenEarlierReplicaSelected() { FakeEndpointCache endpointCache = new FakeEndpointCache(); KeyRangeCache cache = new KeyRangeCache(endpointCache); + cache.useDeterministicRandom(); cache.addRanges(threeReplicaUpdate()); endpointCache.get("server1"); @@ -512,6 +603,7 @@ public void laterRecentlyEvictedTransientFailureReplicaReportedWhenEarlierReplic new RecentTransientFailureLifecycleManager(endpointCache); try { KeyRangeCache cache = new KeyRangeCache(endpointCache, lifecycleManager); + cache.useDeterministicRandom(); cache.addRanges(threeReplicaUpdate()); endpointCache.get("server1"); @@ -535,6 +627,250 @@ public void laterRecentlyEvictedTransientFailureReplicaReportedWhenEarlierReplic } } + @Test + public void preferLeaderFalseUsesLowestLatencyReplicaWhenScoresAvailable() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + KeyRangeCache cache = new KeyRangeCache(endpointCache); + cache.useDeterministicRandom(); + cache.addRanges(threeReplicaUpdate()); + + endpointCache.get("server1"); + endpointCache.get("server2"); + endpointCache.get("server3"); + + cache.recordReplicaLatency(TEST_OPERATION_UID, "server1", Duration.ofNanos(300_000L)); + cache.recordReplicaLatency(TEST_OPERATION_UID, "server2", Duration.ofNanos(100_000L)); + cache.recordReplicaLatency(TEST_OPERATION_UID, "server3", Duration.ofNanos(200_000L)); + + RoutingHint.Builder hint = + RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID); + ChannelEndpoint server = + cache.fillRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + hint); + + assertNotNull(server); + assertEquals("server2", server.getAddress()); + } + + @Test + public void preferLeaderTrueIgnoresLatencyScoresForLeaderSelection() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + KeyRangeCache cache = new KeyRangeCache(endpointCache); + cache.useDeterministicRandom(); + cache.addRanges(threeReplicaUpdate()); + + endpointCache.get("server1"); + endpointCache.get("server2"); + endpointCache.get("server3"); + + cache.recordReplicaLatency(TEST_OPERATION_UID, "server1", Duration.ofNanos(300_000L)); + cache.recordReplicaLatency(TEST_OPERATION_UID, "server2", Duration.ofNanos(100_000L)); + cache.recordReplicaLatency(TEST_OPERATION_UID, "server3", Duration.ofNanos(200_000L)); + + RoutingHint.Builder hint = + RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID); + ChannelEndpoint server = + cache.fillRoutingHint( + true, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + hint); + + assertNotNull(server); + assertEquals("server1", server.getAddress()); + } + + @Test + public void preferLeaderFalseSkipsBestScoredReplicaWhenItIsNotReady() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + KeyRangeCache cache = new KeyRangeCache(endpointCache); + cache.useDeterministicRandom(); + cache.addRanges(threeReplicaUpdate()); + + endpointCache.get("server1"); + endpointCache.get("server2"); + endpointCache.get("server3"); + endpointCache.setState("server2", EndpointHealthState.IDLE); + + cache.recordReplicaLatency(TEST_OPERATION_UID, "server1", Duration.ofNanos(300_000L)); + cache.recordReplicaLatency(TEST_OPERATION_UID, "server2", Duration.ofNanos(100_000L)); + cache.recordReplicaLatency(TEST_OPERATION_UID, "server3", Duration.ofNanos(200_000L)); + + RoutingHint.Builder hint = + RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID); + ChannelEndpoint server = + cache.fillRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + hint); + + assertNotNull(server); + assertEquals("server3", server.getAddress()); + } + + @Test + public void preferLeaderFalseUsesOperationUidScopedScores() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + KeyRangeCache cache = new KeyRangeCache(endpointCache); + cache.useDeterministicRandom(); + cache.addRanges(threeReplicaUpdate()); + + endpointCache.get("server1"); + endpointCache.get("server2"); + endpointCache.get("server3"); + + cache.recordReplicaLatency(201L, "server1", Duration.ofNanos(100_000L)); + cache.recordReplicaLatency(201L, "server2", Duration.ofNanos(300_000L)); + cache.recordReplicaLatency(201L, "server3", Duration.ofNanos(200_000L)); + cache.recordReplicaLatency(202L, "server1", Duration.ofNanos(300_000L)); + cache.recordReplicaLatency(202L, "server2", Duration.ofNanos(100_000L)); + cache.recordReplicaLatency(202L, "server3", Duration.ofNanos(200_000L)); + + ChannelEndpoint firstOperationServer = + cache.fillRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(201L)); + ChannelEndpoint secondOperationServer = + cache.fillRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(202L)); + + assertNotNull(firstOperationServer); + assertNotNull(secondOperationServer); + assertEquals("server1", firstOperationServer.getAddress()); + assertEquals("server2", secondOperationServer.getAddress()); + } + + @Test + public void preferLeaderFalseUsesInflightCostForColdReplicaSelection() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + KeyRangeCache cache = new KeyRangeCache(endpointCache); + cache.useDeterministicRandom(); + cache.addRanges(threeReplicaUpdate()); + + endpointCache.get("server1"); + endpointCache.get("server2"); + endpointCache.get("server3"); + + EndpointLatencyRegistry.beginRequest("server1"); + + ChannelEndpoint server = + cache.fillRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID)); + + assertNotNull(server); + assertEquals("server2", server.getAddress()); + } + + @Test + public void coldReplicaSelectionEmitsFiniteDefaultCost() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + KeyRangeCache cache = new KeyRangeCache(endpointCache); + cache.useDeterministicRandom(); + cache.addRanges(threeReplicaUpdate()); + + endpointCache.get("server1"); + endpointCache.get("server2"); + endpointCache.get("server3"); + + KeyRangeCache.RouteLookupResult result = + cache.lookupRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID), + address -> false); + + assertNotNull(result.endpoint); + assertNotNull(result.selectionDetail); + assertEquals("latency_score", result.selectionDetail.selectionReason); + assertEquals(10_000.0D, result.selectionDetail.selectedScore, 0.0D); + assertTrue(Double.isFinite(result.selectionDetail.scoreGap())); + assertEquals(0.0D, result.selectionDetail.scoreGap(), 0.0D); + } + + @Test + public void preferLeaderFalseInflightCostCanOutweighLowerLatency() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + KeyRangeCache cache = new KeyRangeCache(endpointCache); + cache.useDeterministicRandom(); + cache.addRanges(threeReplicaUpdate()); + + endpointCache.get("server1"); + endpointCache.get("server2"); + endpointCache.get("server3"); + + cache.recordReplicaLatency(TEST_OPERATION_UID, "server1", Duration.ofNanos(100_000L)); + cache.recordReplicaLatency(TEST_OPERATION_UID, "server2", Duration.ofNanos(300_000L)); + EndpointLatencyRegistry.beginRequest("server1"); + EndpointLatencyRegistry.beginRequest("server1"); + EndpointLatencyRegistry.beginRequest("server1"); + + ChannelEndpoint server = + cache.fillRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID)); + + assertNotNull(server); + assertEquals("server2", server.getAddress()); + } + + @Test + public void preferLeaderFalseErrorPenaltySteersSelectionAwayFromPenalizedReplica() { + FakeEndpointCache baselineEndpointCache = new FakeEndpointCache(); + KeyRangeCache baselineCache = new KeyRangeCache(baselineEndpointCache); + baselineCache.useDeterministicRandom(); + baselineCache.addRanges(threeReplicaUpdate()); + + baselineEndpointCache.get("server1"); + baselineEndpointCache.get("server2"); + baselineEndpointCache.get("server3"); + + ChannelEndpoint baselineServer = + baselineCache.fillRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID)); + + assertNotNull(baselineServer); + + EndpointLatencyRegistry.clear(); + + FakeEndpointCache penalizedEndpointCache = new FakeEndpointCache(); + KeyRangeCache penalizedCache = new KeyRangeCache(penalizedEndpointCache); + penalizedCache.useDeterministicRandom(); + penalizedCache.addRanges(threeReplicaUpdate()); + + penalizedEndpointCache.get("server1"); + penalizedEndpointCache.get("server2"); + penalizedEndpointCache.get("server3"); + penalizedCache.recordReplicaError(TEST_OPERATION_UID, baselineServer.getAddress()); + + ChannelEndpoint penalizedSelection = + penalizedCache.fillRoutingHint( + false, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + RoutingHint.newBuilder().setKey(bytes("a")).setOperationUid(TEST_OPERATION_UID)); + + assertNotNull(penalizedSelection); + assertTrue(!baselineServer.getAddress().equals(penalizedSelection.getAddress())); + } + // --- Eviction and recreation tests --- @Test diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java index 7ac5faf2e16e..229c28d6c281 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner.spi.v1; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @@ -23,11 +24,13 @@ import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.DatabaseId; import com.google.cloud.spanner.MockSpannerServiceImpl; +import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; import com.google.cloud.spanner.Options; import com.google.cloud.spanner.Spanner; import com.google.cloud.spanner.SpannerOptions; import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.TimestampBound; import com.google.common.base.Stopwatch; import com.google.protobuf.ByteString; import com.google.protobuf.ListValue; @@ -149,6 +152,8 @@ public void onCompleted() { @After public void tearDown() throws InterruptedException { + EndpointLatencyRegistry.clear(); + RequestIdTargetTracker.clear(); for (ServerInstance si : servers) { si.server.shutdown(); } @@ -330,4 +335,206 @@ public void testEndToEndWithSpannerOptions() throws Exception { server1ReceivedSuccessfulRead); } } + + @Test + public void testStaleSingleUseReadBootstrapsScoresAndConvergesToLowerLatencyReplica() + throws Exception { + SpannerOptions options = + SpannerOptions.newBuilder() + .usePlainText() + .setExperimentalHost("localhost:" + servers.get(0).port) + .setProjectId("fake-project") + .setChannelEndpointCacheFactory(null) + .build(); + + RecipeList.Builder recipeListBuilder = RecipeList.newBuilder(); + try { + TextFormat.merge( + "recipe {\n" + + " table_name: \"Table\"\n" + + " part { tag: 1 }\n" + + " part {\n" + + " order: ASCENDING\n" + + " null_order: NULLS_FIRST\n" + + " type { code: STRING }\n" + + " }\n" + + "}\n", + recipeListBuilder); + } catch (TextFormat.ParseException e) { + throw new RuntimeException(e); + } + + CacheUpdate cacheUpdate = + CacheUpdate.newBuilder() + .setDatabaseId(12345L) + .setKeyRecipes(recipeListBuilder.build()) + .addGroup( + Group.newBuilder() + .setGroupUid(1L) + .addTablets( + Tablet.newBuilder() + .setTabletUid(101L) + .setServerAddress("localhost:" + servers.get(0).port) + .setRole(Tablet.Role.READ_ONLY) + .setDistance(0) + .build()) + .addTablets( + Tablet.newBuilder() + .setTabletUid(102L) + .setServerAddress("localhost:" + servers.get(1).port) + .setRole(Tablet.Role.READ_ONLY) + .setDistance(0) + .build()) + .build()) + .addRange( + Range.newBuilder() + .setStartKey(ByteString.EMPTY) + .setLimitKey(ByteString.copyFromUtf8("\u00FF")) + .setGroupUid(1L) + .setSplitId(1L) + .setGeneration(ByteString.copyFromUtf8("gen1")) + .build()) + .build(); + + ResultSet resultSetWithUpdate = + SELECT1_RESULTSET.toBuilder().setCacheUpdate(cacheUpdate).build(); + + servers + .get(0) + .mockSpanner + .putStatementResult(StatementResult.query(Statement.of("SELECT 1"), resultSetWithUpdate)); + + com.google.cloud.spanner.Statement readStatement = + StatementResult.createReadStatement( + "Table", + com.google.cloud.spanner.KeySet.singleKey(com.google.cloud.spanner.Key.of()), + Arrays.asList("Column")); + + servers + .get(0) + .mockSpanner + .putStatementResult(StatementResult.query(readStatement, SELECT1_RESULTSET)); + servers + .get(1) + .mockSpanner + .putStatementResult(StatementResult.query(readStatement, SELECT1_RESULTSET)); + servers + .get(0) + .mockSpanner + .setStreamingReadExecutionTime(SimulatedExecutionTime.ofMinimumAndRandomTime(40, 0)); + servers + .get(1) + .mockSpanner + .setStreamingReadExecutionTime(SimulatedExecutionTime.ofMinimumAndRandomTime(0, 0)); + + try (Spanner spanner = options.getService()) { + DatabaseClient client = + spanner.getDatabaseClient( + DatabaseId.of("fake-project", "fake-instance", "fake-database")); + + try (com.google.cloud.spanner.ResultSet rs = + client.singleUse().executeQuery(Statement.of("SELECT 1"))) { + while (rs.next()) { + /* consume */ + } + } + + clearServerRequests(); + boolean sampledServer0 = false; + boolean sampledServer1 = false; + Stopwatch watch = Stopwatch.createStarted(); + int attempt = 0; + long operationUid = 0L; + + while (watch.elapsed(TimeUnit.SECONDS) < 10 && (!sampledServer0 || !sampledServer1)) { + attempt++; + String key = "bootstrap-key-" + attempt; + try (com.google.cloud.spanner.ResultSet rs = + client + .singleUse(TimestampBound.ofExactStaleness(15L, TimeUnit.SECONDS)) + .read( + "Table", + com.google.cloud.spanner.KeySet.singleKey(com.google.cloud.spanner.Key.of(key)), + Arrays.asList("Column"))) { + while (rs.next()) { + /* consume */ + } + } + + long currentOperationUid = findReadOperationUid(key); + assertTrue("Expected stale read to carry operation_uid", currentOperationUid > 0L); + if (operationUid == 0L) { + operationUid = currentOperationUid; + } else { + assertEquals( + "Expected stale reads to reuse the same operation_uid", + operationUid, + currentOperationUid); + } + sampledServer0 = hasReadRequestForKey(servers.get(0).mockSpanner, key) || sampledServer0; + sampledServer1 = hasReadRequestForKey(servers.get(1).mockSpanner, key) || sampledServer1; + } + + assertTrue("Expected bootstrap exploration to sample server0", sampledServer0); + assertTrue("Expected bootstrap exploration to sample server1", sampledServer1); + assertTrue("Expected stale reads to reuse the same operation_uid", operationUid > 0L); + + clearServerRequests(); + boolean routedToLowerLatencyReplica = false; + int convergenceAttempt = 0; + while (watch.elapsed(TimeUnit.SECONDS) < 10 && !routedToLowerLatencyReplica) { + convergenceAttempt++; + String key = "convergence-key-" + convergenceAttempt; + try (com.google.cloud.spanner.ResultSet rs = + client + .singleUse(TimestampBound.ofExactStaleness(15L, TimeUnit.SECONDS)) + .read( + "Table", + com.google.cloud.spanner.KeySet.singleKey(com.google.cloud.spanner.Key.of(key)), + Arrays.asList("Column"))) { + while (rs.next()) { + /* consume */ + } + } + + routedToLowerLatencyReplica = + hasReadRequestForKey(servers.get(1).mockSpanner, key) + && !hasReadRequestForKey(servers.get(0).mockSpanner, key); + } + + assertTrue( + "Expected latency-aware routing to converge to the faster replica", + routedToLowerLatencyReplica); + } + } + + private void clearServerRequests() { + for (ServerInstance server : servers) { + server.mockSpanner.clearRequests(); + } + } + + private long findReadOperationUid(String key) { + for (ServerInstance server : servers) { + for (ReadRequest request : server.mockSpanner.getRequestsOfType(ReadRequest.class)) { + if (request.getKeySet().getKeysCount() == 0 + || request.getKeySet().getKeys(0).getValuesCount() == 0) { + continue; + } + if (key.equals(request.getKeySet().getKeys(0).getValues(0).getStringValue())) { + return request.getRoutingHint().getOperationUid(); + } + } + } + return 0L; + } + + private boolean hasReadRequestForKey(MockSpannerServiceImpl mockSpanner, String key) { + return mockSpanner.getRequestsOfType(ReadRequest.class).stream() + .anyMatch( + request -> + request.getKeySet().getKeysCount() > 0 + && request.getKeySet().getKeys(0).getValuesCount() > 0 + && key.equals(request.getKeySet().getKeys(0).getValues(0).getStringValue())); + } } From 1b00d27d3353fd053da3f4ed37bbba8b62d39110 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Mon, 20 Apr 2026 02:36:56 +0530 Subject: [PATCH 5/9] support for strong reads --- .../spanner/spi/v1/EwmaLatencyTracker.java | 5 +- .../cloud/spanner/spi/v1/KeyRangeCache.java | 87 +++-- .../spanner/spi/v1/KeyRangeCacheTest.java | 29 +- .../v1/ReplicaSelectionMockServerTest.java | 313 ++++++++++++++++++ 4 files changed, 409 insertions(+), 25 deletions(-) diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java index 64101277e298..57f8547e7e04 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EwmaLatencyTracker.java @@ -29,9 +29,8 @@ /** * Implementation of {@link LatencyTracker} using Exponentially Weighted Moving Average (EWMA). * - *

By default, this tracker uses a time-decayed EWMA: - * $S_{i+1} = \alpha(\Delta t) * new\_latency + (1 - \alpha(\Delta t)) * S_i$, where $\alpha(\Delta - * t) = 1 - e^{-\Delta t / \tau}$. + *

By default, this tracker uses a time-decayed EWMA: $S_{i+1} = \alpha(\Delta t) * new\_latency + * + (1 - \alpha(\Delta t)) * S_i$, where $\alpha(\Delta t) = 1 - e^{-\Delta t / \tau}$. * *

A fixed-alpha constructor is retained for focused tests. */ diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java index 027a9c176572..d02233b5a96b 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java @@ -44,6 +44,7 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.function.Predicate; +import java.util.function.ToDoubleFunction; import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.IntStream; @@ -57,6 +58,7 @@ public final class KeyRangeCache { private static final int MAX_LOCAL_REPLICA_DISTANCE = 5; private static final int DEFAULT_MIN_ENTRIES_FOR_RANDOM_PICK = 1000; + private static final double LOCAL_LEADER_SELECTION_COST_MULTIPLIER = 0.5D; /** Determines how to handle ranges that span multiple splits. */ public enum RangeMode { @@ -828,8 +830,10 @@ private TabletSnapshot selectTablet( List skippedTabletDetails, Map resolvedEndpoints, SelectionStats selectionStats) { - if (!preferLeader) { - return selectLatencyAwareTablet( + if (!preferLeader || hintBuilder.getOperationUid() > 0L) { + TabletSnapshot preferredLeader = + preferLeader ? localLeaderForScoreBias(snapshot, hasDirectedReadOptions) : null; + return selectScoreAwareTablet( snapshot, directedReadOptions, hintBuilder, @@ -837,7 +841,9 @@ private TabletSnapshot selectTablet( skippedTabletUids, skippedTabletDetails, resolvedEndpoints, - selectionStats); + selectionStats, + preferredLeader, + preferLeader ? "leader_preferred_latency_score" : "latency_score"); } boolean checkedLeader = false; @@ -884,7 +890,18 @@ private TabletSnapshot selectTablet( return null; } - private TabletSnapshot selectLatencyAwareTablet( + @javax.annotation.Nullable + private TabletSnapshot localLeaderForScoreBias( + GroupSnapshot snapshot, boolean hasDirectedReadOptions) { + if (!hasDirectedReadOptions + && snapshot.hasLeader() + && snapshot.leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) { + return snapshot.leader(); + } + return null; + } + + private TabletSnapshot selectScoreAwareTablet( GroupSnapshot snapshot, DirectedReadOptions directedReadOptions, RoutingHint.Builder hintBuilder, @@ -892,10 +909,13 @@ private TabletSnapshot selectLatencyAwareTablet( Set skippedTabletUids, List skippedTabletDetails, Map resolvedEndpoints, - SelectionStats selectionStats) { + SelectionStats selectionStats, + @javax.annotation.Nullable TabletSnapshot preferredLeader, + String selectionReason) { long operationUid = hintBuilder.getOperationUid(); List eligibleTablets = new ArrayList<>(); List eligibleEndpoints = new ArrayList<>(); + Map endpointByAddress = new HashMap<>(); int scoredCandidates = 0; for (TabletSnapshot tablet : snapshot.tablets) { @@ -921,6 +941,7 @@ private TabletSnapshot selectLatencyAwareTablet( } eligibleTablets.add(tablet); eligibleEndpoints.add(endpoint); + endpointByAddress.put(endpoint.getAddress(), tablet); if (EndpointLatencyRegistry.hasScore(operationUid, tablet.serverAddress)) { scoredCandidates++; } @@ -936,6 +957,7 @@ private TabletSnapshot selectLatencyAwareTablet( snapshot, eligibleTablets, operationUid, + selectionCostLookup(operationUid, preferredLeader), "single_candidate", selected, scoredCandidates); @@ -946,16 +968,15 @@ private TabletSnapshot selectLatencyAwareTablet( eligibleTablets.stream() .min( Comparator.comparingDouble( - tablet -> - EndpointLatencyRegistry.getSelectionCost( - operationUid, tablet.serverAddress))) + tablet -> selectionCost(operationUid, tablet, preferredLeader))) .orElse(eligibleTablets.get(0)); selectionStats.selectionDetail = buildSelectionDetail( snapshot, eligibleTablets, operationUid, - "latency_score", + selectionCostLookup(operationUid, preferredLeader), + selectionReason, selected, scoredCandidates); return selected; @@ -965,7 +986,8 @@ private TabletSnapshot selectLatencyAwareTablet( replicaSelector.select( eligibleEndpoints, endpoint -> - EndpointLatencyRegistry.getSelectionCost(operationUid, endpoint.getAddress())); + selectionCost( + operationUid, endpointByAddress.get(endpoint.getAddress()), preferredLeader)); if (selectedEndpoint == null) { TabletSnapshot selected = eligibleTablets.get(0); selectionStats.selectionDetail = @@ -973,7 +995,8 @@ private TabletSnapshot selectLatencyAwareTablet( snapshot, eligibleTablets, operationUid, - "latency_score", + selectionCostLookup(operationUid, preferredLeader), + selectionReason, selected, scoredCandidates); return selected; @@ -986,7 +1009,8 @@ private TabletSnapshot selectLatencyAwareTablet( snapshot, eligibleTablets, operationUid, - "latency_score", + selectionCostLookup(operationUid, preferredLeader), + selectionReason, selected, scoredCandidates); return selected; @@ -995,10 +1019,35 @@ private TabletSnapshot selectLatencyAwareTablet( TabletSnapshot selected = eligibleTablets.get(0); selectionStats.selectionDetail = buildSelectionDetail( - snapshot, eligibleTablets, operationUid, "latency_score", selected, scoredCandidates); + snapshot, + eligibleTablets, + operationUid, + selectionCostLookup(operationUid, preferredLeader), + selectionReason, + selected, + scoredCandidates); return selected; } + private ToDoubleFunction selectionCostLookup( + long operationUid, @javax.annotation.Nullable TabletSnapshot preferredLeader) { + return tablet -> selectionCost(operationUid, tablet, preferredLeader); + } + + private double selectionCost( + long operationUid, + @javax.annotation.Nullable TabletSnapshot tablet, + @javax.annotation.Nullable TabletSnapshot preferredLeader) { + if (tablet == null) { + return Double.MAX_VALUE; + } + double cost = EndpointLatencyRegistry.getSelectionCost(operationUid, tablet.serverAddress); + if (preferredLeader != null && tablet == preferredLeader) { + return cost * LOCAL_LEADER_SELECTION_COST_MULTIPLIER; + } + return cost; + } + @javax.annotation.Nullable private TabletSnapshot selectRandomExcludedOrCoolingDownTablet( GroupSnapshot snapshot, @@ -1166,21 +1215,18 @@ private SelectionDetail buildSelectionDetail( GroupSnapshot snapshot, List eligibleTablets, long operationUid, + ToDoubleFunction selectionCostLookup, String selectionReason, TabletSnapshot selected, int scoredCandidates) { double bestScore = Double.MAX_VALUE; for (TabletSnapshot tablet : eligibleTablets) { - bestScore = - Math.min( - bestScore, - EndpointLatencyRegistry.getSelectionCost(operationUid, tablet.serverAddress)); + bestScore = Math.min(bestScore, selectionCostLookup.applyAsDouble(tablet)); } StringBuilder alternatives = new StringBuilder(); int appended = 0; - double selectedScore = - EndpointLatencyRegistry.getSelectionCost(operationUid, selected.serverAddress); + double selectedScore = selectionCostLookup.applyAsDouble(selected); for (TabletSnapshot tablet : eligibleTablets) { if (tablet == selected || appended >= 4) { continue; @@ -1188,8 +1234,7 @@ private SelectionDetail buildSelectionDetail( if (alternatives.length() > 0) { alternatives.append(", "); } - double candidateScore = - EndpointLatencyRegistry.getSelectionCost(operationUid, tablet.serverAddress); + double candidateScore = selectionCostLookup.applyAsDouble(tablet); alternatives .append(endpointLabel(snapshot, tablet)) .append("=") diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java index 5dd2fa01a04b..a2bddc79fca2 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyRangeCacheTest.java @@ -656,7 +656,7 @@ public void preferLeaderFalseUsesLowestLatencyReplicaWhenScoresAvailable() { } @Test - public void preferLeaderTrueIgnoresLatencyScoresForLeaderSelection() { + public void preferLeaderTrueUsesLatencyScoresWhenOperationUidAvailable() { FakeEndpointCache endpointCache = new FakeEndpointCache(); KeyRangeCache cache = new KeyRangeCache(endpointCache); cache.useDeterministicRandom(); @@ -679,6 +679,33 @@ public void preferLeaderTrueIgnoresLatencyScoresForLeaderSelection() { DirectedReadOptions.getDefaultInstance(), hint); + assertNotNull(server); + assertEquals("server2", server.getAddress()); + } + + @Test + public void preferLeaderTrueWithoutOperationUidKeepsLeaderSelection() { + FakeEndpointCache endpointCache = new FakeEndpointCache(); + KeyRangeCache cache = new KeyRangeCache(endpointCache); + cache.useDeterministicRandom(); + cache.addRanges(threeReplicaUpdate()); + + endpointCache.get("server1"); + endpointCache.get("server2"); + endpointCache.get("server3"); + + cache.recordReplicaLatency(TEST_OPERATION_UID, "server1", Duration.ofNanos(300_000L)); + cache.recordReplicaLatency(TEST_OPERATION_UID, "server2", Duration.ofNanos(100_000L)); + cache.recordReplicaLatency(TEST_OPERATION_UID, "server3", Duration.ofNanos(200_000L)); + + RoutingHint.Builder hint = RoutingHint.newBuilder().setKey(bytes("a")); + ChannelEndpoint server = + cache.fillRoutingHint( + true, + KeyRangeCache.RangeMode.COVERING_SPLIT, + DirectedReadOptions.getDefaultInstance(), + hint); + assertNotNull(server); assertEquals("server1", server.getAddress()); } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java index 229c28d6c281..dc444eb687f6 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java @@ -40,6 +40,7 @@ import com.google.spanner.v1.DirectedReadOptions; import com.google.spanner.v1.DirectedReadOptions.IncludeReplicas; import com.google.spanner.v1.DirectedReadOptions.ReplicaSelection; +import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.Group; import com.google.spanner.v1.Range; import com.google.spanner.v1.ReadRequest; @@ -88,6 +89,8 @@ public class ReplicaSelectionMockServerTest { .build()) .setMetadata(SELECT1_METADATA) .build(); + private static final String QUERY_SQL = "SELECT * FROM Table WHERE Column = @p1"; + private static final String QUERY_PARAM = "p1"; private List servers; private final int numServers = 2; @@ -508,12 +511,300 @@ public void testStaleSingleUseReadBootstrapsScoresAndConvergesToLowerLatencyRepl } } + @Test + public void testStrongSingleUseReadConvergesToLowerLatencyReplica() throws Exception { + SpannerOptions options = + SpannerOptions.newBuilder() + .usePlainText() + .setExperimentalHost("localhost:" + servers.get(0).port) + .setProjectId("fake-project") + .setChannelEndpointCacheFactory(null) + .build(); + + ResultSet resultSetWithUpdate = + SELECT1_RESULTSET.toBuilder() + .setCacheUpdate(createReplicaCacheUpdate(readRecipeList())) + .build(); + + servers + .get(0) + .mockSpanner + .putStatementResult(StatementResult.query(Statement.of("SELECT 1"), resultSetWithUpdate)); + + com.google.cloud.spanner.Statement readStatement = + StatementResult.createReadStatement( + "Table", + com.google.cloud.spanner.KeySet.singleKey(com.google.cloud.spanner.Key.of()), + Arrays.asList("Column")); + + servers + .get(0) + .mockSpanner + .putStatementResult(StatementResult.query(readStatement, SELECT1_RESULTSET)); + servers + .get(1) + .mockSpanner + .putStatementResult(StatementResult.query(readStatement, SELECT1_RESULTSET)); + servers + .get(0) + .mockSpanner + .setStreamingReadExecutionTime(SimulatedExecutionTime.ofMinimumAndRandomTime(40, 0)); + servers + .get(1) + .mockSpanner + .setStreamingReadExecutionTime(SimulatedExecutionTime.ofMinimumAndRandomTime(0, 0)); + + try (Spanner spanner = options.getService()) { + DatabaseClient client = + spanner.getDatabaseClient( + DatabaseId.of("fake-project", "fake-instance", "fake-database")); + + try (com.google.cloud.spanner.ResultSet rs = + client.singleUse().executeQuery(Statement.of("SELECT 1"))) { + while (rs.next()) { + /* consume */ + } + } + + clearServerRequests(); + long operationUid = 0L; + + for (int attempt = 1; attempt <= 3; attempt++) { + String key = "strong-read-bootstrap-" + attempt; + try (com.google.cloud.spanner.ResultSet rs = + client + .singleUse() + .read( + "Table", + com.google.cloud.spanner.KeySet.singleKey(com.google.cloud.spanner.Key.of(key)), + Arrays.asList("Column"))) { + while (rs.next()) { + /* consume */ + } + } + + long currentOperationUid = findReadOperationUid(key); + assertTrue("Expected strong read to carry operation_uid", currentOperationUid > 0L); + if (operationUid == 0L) { + operationUid = currentOperationUid; + } else { + assertEquals( + "Expected strong reads to reuse the same operation_uid", + operationUid, + currentOperationUid); + } + } + + assertTrue("Expected strong reads to reuse the same operation_uid", operationUid > 0L); + + clearServerRequests(); + Stopwatch watch = Stopwatch.createStarted(); + boolean routedToLowerLatencyReplica = false; + int convergenceAttempt = 0; + while (watch.elapsed(TimeUnit.SECONDS) < 10 && !routedToLowerLatencyReplica) { + convergenceAttempt++; + String key = "strong-read-convergence-" + convergenceAttempt; + try (com.google.cloud.spanner.ResultSet rs = + client + .singleUse() + .read( + "Table", + com.google.cloud.spanner.KeySet.singleKey(com.google.cloud.spanner.Key.of(key)), + Arrays.asList("Column"))) { + while (rs.next()) { + /* consume */ + } + } + + routedToLowerLatencyReplica = + hasReadRequestForKey(servers.get(1).mockSpanner, key) + && !hasReadRequestForKey(servers.get(0).mockSpanner, key); + } + + assertTrue( + "Expected strong read routing to converge to the faster replica", + routedToLowerLatencyReplica); + } + } + + @Test + public void testStrongSingleUseQueryConvergesToLowerLatencyReplica() throws Exception { + SpannerOptions options = + SpannerOptions.newBuilder() + .usePlainText() + .setExperimentalHost("localhost:" + servers.get(0).port) + .setProjectId("fake-project") + .setChannelEndpointCacheFactory(null) + .build(); + + servers + .get(0) + .mockSpanner + .setExecuteStreamingSqlExecutionTime(SimulatedExecutionTime.ofMinimumAndRandomTime(40, 0)); + servers + .get(1) + .mockSpanner + .setExecuteStreamingSqlExecutionTime(SimulatedExecutionTime.ofMinimumAndRandomTime(0, 0)); + + try (Spanner spanner = options.getService()) { + DatabaseClient client = + spanner.getDatabaseClient( + DatabaseId.of("fake-project", "fake-instance", "fake-database")); + assertStrongQueryConvergesToLowerLatencyReplica( + statement -> { + try (com.google.cloud.spanner.ResultSet rs = + client.singleUse().executeQuery(statement)) { + while (rs.next()) { + /* consume */ + } + } + }); + } + } + + @FunctionalInterface + private interface QueryExecutor { + void execute(Statement statement) throws Exception; + } + + private void assertStrongQueryConvergesToLowerLatencyReplica(QueryExecutor queryExecutor) + throws Exception { + String seedKey = "query-seed"; + installQueryResultOnAllServers(seedKey, SELECT1_RESULTSET); + + queryExecutor.execute(queryStatement(seedKey)); + long operationUid = findQueryOperationUid(seedKey); + assertTrue("Expected strong query to carry operation_uid", operationUid > 0L); + + installQueryResultOnAllServers( + seedKey, + SELECT1_RESULTSET.toBuilder() + .setCacheUpdate(createReplicaCacheUpdate(queryRecipeList(operationUid))) + .build()); + queryExecutor.execute(queryStatement(seedKey)); + clearServerRequests(); + + for (int attempt = 1; attempt <= 3; attempt++) { + String key = "strong-query-bootstrap-" + attempt; + installQueryResultOnAllServers(key, SELECT1_RESULTSET); + queryExecutor.execute(queryStatement(key)); + + long currentOperationUid = findQueryOperationUid(key); + assertEquals( + "Expected strong queries to reuse the same operation_uid", + operationUid, + currentOperationUid); + } + + clearServerRequests(); + Stopwatch watch = Stopwatch.createStarted(); + boolean routedToLowerLatencyReplica = false; + int convergenceAttempt = 0; + while (watch.elapsed(TimeUnit.SECONDS) < 10 && !routedToLowerLatencyReplica) { + convergenceAttempt++; + String key = "strong-query-convergence-" + convergenceAttempt; + installQueryResultOnAllServers(key, SELECT1_RESULTSET); + queryExecutor.execute(queryStatement(key)); + + routedToLowerLatencyReplica = + hasQueryRequestForKey(servers.get(1).mockSpanner, key) + && !hasQueryRequestForKey(servers.get(0).mockSpanner, key); + } + + assertTrue( + "Expected strong query routing to converge to the faster replica", + routedToLowerLatencyReplica); + } + private void clearServerRequests() { for (ServerInstance server : servers) { server.mockSpanner.clearRequests(); } } + private CacheUpdate createReplicaCacheUpdate(RecipeList keyRecipes) { + return CacheUpdate.newBuilder() + .setDatabaseId(12345L) + .setKeyRecipes(keyRecipes) + .addGroup( + Group.newBuilder() + .setGroupUid(1L) + .setLeaderIndex(0) + .addTablets( + Tablet.newBuilder() + .setTabletUid(101L) + .setServerAddress("localhost:" + servers.get(0).port) + .setRole(Tablet.Role.READ_ONLY) + .setDistance(0) + .build()) + .addTablets( + Tablet.newBuilder() + .setTabletUid(102L) + .setServerAddress("localhost:" + servers.get(1).port) + .setRole(Tablet.Role.READ_ONLY) + .setDistance(0) + .build()) + .build()) + .addRange( + Range.newBuilder() + .setStartKey(ByteString.EMPTY) + .setLimitKey(ByteString.copyFromUtf8("\u00FF")) + .setGroupUid(1L) + .setSplitId(1L) + .setGeneration(ByteString.copyFromUtf8("gen1")) + .build()) + .build(); + } + + private RecipeList readRecipeList() throws TextFormat.ParseException { + RecipeList.Builder recipeListBuilder = RecipeList.newBuilder(); + TextFormat.merge( + "recipe {\n" + + " table_name: \"Table\"\n" + + " part { tag: 1 }\n" + + " part {\n" + + " order: ASCENDING\n" + + " null_order: NULLS_FIRST\n" + + " type { code: STRING }\n" + + " identifier: \"k\"\n" + + " }\n" + + "}\n", + recipeListBuilder); + return recipeListBuilder.build(); + } + + private RecipeList queryRecipeList(long operationUid) throws TextFormat.ParseException { + RecipeList.Builder recipeListBuilder = RecipeList.newBuilder(); + TextFormat.merge( + "recipe {\n" + + " operation_uid: " + + operationUid + + "\n" + + " part { tag: 1 }\n" + + " part {\n" + + " order: ASCENDING\n" + + " null_order: NULLS_FIRST\n" + + " type { code: STRING }\n" + + " identifier: \"" + + QUERY_PARAM + + "\"\n" + + " }\n" + + "}\n", + recipeListBuilder); + return recipeListBuilder.build(); + } + + private Statement queryStatement(String key) { + return Statement.newBuilder(QUERY_SQL).bind(QUERY_PARAM).to(key).build(); + } + + private void installQueryResultOnAllServers(String key, ResultSet resultSet) { + Statement statement = queryStatement(key); + for (ServerInstance server : servers) { + server.mockSpanner.putStatementResult(StatementResult.query(statement, resultSet)); + } + } + private long findReadOperationUid(String key) { for (ServerInstance server : servers) { for (ReadRequest request : server.mockSpanner.getRequestsOfType(ReadRequest.class)) { @@ -529,6 +820,19 @@ private long findReadOperationUid(String key) { return 0L; } + private long findQueryOperationUid(String key) { + for (ServerInstance server : servers) { + for (ExecuteSqlRequest request : + server.mockSpanner.getRequestsOfType(ExecuteSqlRequest.class)) { + if (request.getParams().getFieldsMap().containsKey(QUERY_PARAM) + && key.equals(request.getParams().getFieldsOrThrow(QUERY_PARAM).getStringValue())) { + return request.getRoutingHint().getOperationUid(); + } + } + } + return 0L; + } + private boolean hasReadRequestForKey(MockSpannerServiceImpl mockSpanner, String key) { return mockSpanner.getRequestsOfType(ReadRequest.class).stream() .anyMatch( @@ -537,4 +841,13 @@ private boolean hasReadRequestForKey(MockSpannerServiceImpl mockSpanner, String && request.getKeySet().getKeys(0).getValuesCount() > 0 && key.equals(request.getKeySet().getKeys(0).getValues(0).getStringValue())); } + + private boolean hasQueryRequestForKey(MockSpannerServiceImpl mockSpanner, String key) { + return mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() + .anyMatch( + request -> + request.getParams().getFieldsMap().containsKey(QUERY_PARAM) + && key.equals( + request.getParams().getFieldsOrThrow(QUERY_PARAM).getStringValue())); + } } From da002771158e3d325b5c33ddd6b0016354ef06e7 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Mon, 20 Apr 2026 02:51:16 +0530 Subject: [PATCH 6/9] fix tests --- ...nAwareSharedBackendReplicaHarnessTest.java | 101 +++++++++++------- .../spanner/spi/v1/KeyAwareChannelTest.java | 7 ++ 2 files changed, 67 insertions(+), 41 deletions(-) diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java index 9b6c2de65397..41722e90bc73 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java @@ -105,12 +105,13 @@ public void singleUseReadReroutesOnResourceExhaustedForBypassTraffic() throws Ex DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); seedLocationMetadata(client); - waitForReplicaRoutedRead(client, harness, 0); + int firstReplicaIndex = waitForReplicaRoutedRead(client, harness); + int secondReplicaIndex = 1 - firstReplicaIndex; harness.clearRequests(); harness .replicas - .get(0) + .get(firstReplicaIndex) .putMethodErrors( SharedBackendReplicaHarness.METHOD_STREAMING_READ, resourceExhausted("busy-routed-replica")); @@ -130,14 +131,14 @@ public void singleUseReadReroutesOnResourceExhaustedForBypassTraffic() throws Ex 1, harness .replicas - .get(0) + .get(firstReplicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .size()); assertEquals( 1, harness .replicas - .get(1) + .get(secondReplicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .size()); assertEquals( @@ -150,19 +151,19 @@ public void singleUseReadReroutesOnResourceExhaustedForBypassTraffic() throws Ex (ReadRequest) harness .replicas - .get(0) + .get(firstReplicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .get(0); assertTrue(replicaARequest.getResumeToken().isEmpty()); assertRetriedOnSameLogicalRequest( harness .replicas - .get(0) + .get(firstReplicaIndex) .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .get(0), harness .replicas - .get(1) + .get(secondReplicaIndex) .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .get(0)); } @@ -176,12 +177,13 @@ public void singleUseReadCooldownSkipsReplicaOnNextRequestForBypassTraffic() thr DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); seedLocationMetadata(client); - waitForReplicaRoutedRead(client, harness, 0); + int firstReplicaIndex = waitForReplicaRoutedRead(client, harness); + int secondReplicaIndex = 1 - firstReplicaIndex; harness.clearRequests(); harness .replicas - .get(0) + .get(firstReplicaIndex) .putMethodErrors( SharedBackendReplicaHarness.METHOD_STREAMING_READ, resourceExhaustedWithRetryInfo("busy-routed-replica")); @@ -212,14 +214,14 @@ public void singleUseReadCooldownSkipsReplicaOnNextRequestForBypassTraffic() thr 1, harness .replicas - .get(0) + .get(firstReplicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .size()); assertEquals( 2, harness .replicas - .get(1) + .get(secondReplicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .size()); assertEquals( @@ -229,16 +231,22 @@ public void singleUseReadCooldownSkipsReplicaOnNextRequestForBypassTraffic() thr .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .size()); List replicaBRequests = - harness.replicas.get(1).getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ); + harness + .replicas + .get(secondReplicaIndex) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ); for (AbstractMessage request : replicaBRequests) { assertTrue(((ReadRequest) request).getResumeToken().isEmpty()); } List replicaBRequestIds = - harness.replicas.get(1).getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ); + harness + .replicas + .get(secondReplicaIndex) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ); assertRetriedOnSameLogicalRequest( harness .replicas - .get(0) + .get(firstReplicaIndex) .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .get(0), replicaBRequestIds.get(0)); @@ -256,12 +264,13 @@ public void singleUseReadReroutesOnUnavailableForBypassTraffic() throws Exceptio DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); seedLocationMetadata(client); - waitForReplicaRoutedRead(client, harness, 0); + int firstReplicaIndex = waitForReplicaRoutedRead(client, harness); + int secondReplicaIndex = 1 - firstReplicaIndex; harness.clearRequests(); harness .replicas - .get(0) + .get(firstReplicaIndex) .putMethodErrors( SharedBackendReplicaHarness.METHOD_STREAMING_READ, unavailable("isolated-replica")); @@ -280,14 +289,14 @@ public void singleUseReadReroutesOnUnavailableForBypassTraffic() throws Exceptio 1, harness .replicas - .get(0) + .get(firstReplicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .size()); assertEquals( 1, harness .replicas - .get(1) + .get(secondReplicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .size()); assertEquals( @@ -300,19 +309,19 @@ public void singleUseReadReroutesOnUnavailableForBypassTraffic() throws Exceptio (ReadRequest) harness .replicas - .get(0) + .get(firstReplicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .get(0); assertTrue(replicaARequest.getResumeToken().isEmpty()); assertRetriedOnSameLogicalRequest( harness .replicas - .get(0) + .get(firstReplicaIndex) .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .get(0), harness .replicas - .get(1) + .get(secondReplicaIndex) .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .get(0)); } @@ -327,12 +336,13 @@ public void singleUseReadCooldownSkipsUnavailableReplicaOnNextRequestForBypassTr DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); seedLocationMetadata(client); - waitForReplicaRoutedRead(client, harness, 0); + int firstReplicaIndex = waitForReplicaRoutedRead(client, harness); + int secondReplicaIndex = 1 - firstReplicaIndex; harness.clearRequests(); harness .replicas - .get(0) + .get(firstReplicaIndex) .putMethodErrors( SharedBackendReplicaHarness.METHOD_STREAMING_READ, unavailable("isolated-replica")); @@ -362,14 +372,14 @@ public void singleUseReadCooldownSkipsUnavailableReplicaOnNextRequestForBypassTr 1, harness .replicas - .get(0) + .get(firstReplicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .size()); assertEquals( 2, harness .replicas - .get(1) + .get(secondReplicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .size()); assertEquals( @@ -379,16 +389,22 @@ public void singleUseReadCooldownSkipsUnavailableReplicaOnNextRequestForBypassTr .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .size()); List replicaBRequests = - harness.replicas.get(1).getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ); + harness + .replicas + .get(secondReplicaIndex) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ); for (AbstractMessage request : replicaBRequests) { assertTrue(((ReadRequest) request).getResumeToken().isEmpty()); } List replicaBRequestIds = - harness.replicas.get(1).getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ); + harness + .replicas + .get(secondReplicaIndex) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ); assertRetriedOnSameLogicalRequest( harness .replicas - .get(0) + .get(firstReplicaIndex) .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .get(0), replicaBRequestIds.get(0)); @@ -407,7 +423,8 @@ public void singleUseReadMidStreamRecvFailureWithoutRetryInfoRetriesForBypassTra DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); seedLocationMetadata(client); - waitForReplicaRoutedRead(client, harness, 0); + int firstReplicaIndex = waitForReplicaRoutedRead(client, harness); + int secondReplicaIndex = 1 - firstReplicaIndex; harness.clearRequests(); harness.backend.setStreamingReadExecutionTime( @@ -432,14 +449,14 @@ public void singleUseReadMidStreamRecvFailureWithoutRetryInfoRetriesForBypassTra 1, harness .replicas - .get(0) + .get(firstReplicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .size()); assertEquals( 1, harness .replicas - .get(1) + .get(secondReplicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .size()); assertEquals( @@ -453,14 +470,14 @@ public void singleUseReadMidStreamRecvFailureWithoutRetryInfoRetriesForBypassTra (ReadRequest) harness .replicas - .get(0) + .get(firstReplicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .get(0); ReadRequest replicaBRequest = (ReadRequest) harness .replicas - .get(1) + .get(secondReplicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .get(0); assertTrue(replicaARequest.getResumeToken().isEmpty()); @@ -468,12 +485,12 @@ public void singleUseReadMidStreamRecvFailureWithoutRetryInfoRetriesForBypassTra assertRetriedOnSameLogicalRequest( harness .replicas - .get(0) + .get(firstReplicaIndex) .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .get(0), harness .replicas - .get(1) + .get(secondReplicaIndex) .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .get(0)); } @@ -514,9 +531,8 @@ private static void seedLocationMetadata(DatabaseClient client) { } } - private static void waitForReplicaRoutedRead( - DatabaseClient client, SharedBackendReplicaHarness harness, int replicaIndex) - throws InterruptedException { + private static int waitForReplicaRoutedRead( + DatabaseClient client, SharedBackendReplicaHarness harness) throws InterruptedException { long deadlineNanos = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); while (System.nanoTime() < deadlineNanos) { try (ResultSet resultSet = @@ -527,13 +543,16 @@ private static void waitForReplicaRoutedRead( KeySet.singleKey(Key.of("b")), Arrays.asList("k"), Options.directedRead(DIRECTED_READ_OPTIONS))) { - if (resultSet.next() - && !harness + if (resultSet.next()) { + for (int replicaIndex = 0; replicaIndex < harness.replicas.size(); replicaIndex++) { + if (!harness .replicas .get(replicaIndex) .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) .isEmpty()) { - return; + return replicaIndex; + } + } } } Thread.sleep(50L); diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java index ba795dd7074a..a993c7be550f 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java @@ -66,6 +66,7 @@ import java.util.Map; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; +import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -76,6 +77,12 @@ public class KeyAwareChannelTest { private static final String SESSION = "projects/p/instances/i/databases/d/sessions/test-session-id"; + @After + public void clearSharedRoutingState() { + EndpointLatencyRegistry.clear(); + RequestIdTargetTracker.clear(); + } + @Test public void cancelBeforeStartPreservesTrailersAndSkipsDelegateCreation() throws Exception { TestHarness harness = createHarness(); From b74b4537bd3619c78f4a89e001efe2250a9b9f93 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Mon, 20 Apr 2026 10:25:18 +0530 Subject: [PATCH 7/9] add more test for R/W shape --- .../cloud/spanner/spi/v1/ChannelFinder.java | 8 +- .../spi/v1/EndpointLatencyRegistry.java | 108 +++++++++++++----- .../spanner/spi/v1/HeaderInterceptor.java | 5 +- .../cloud/spanner/spi/v1/KeyAwareChannel.java | 47 ++++++-- .../cloud/spanner/spi/v1/KeyRangeCache.java | 21 +++- .../spi/v1/RequestIdTargetTracker.java | 13 ++- ...nAwareSharedBackendReplicaHarnessTest.java | 102 +++++++++++++++++ .../spi/v1/EndpointLatencyRegistryTest.java | 104 +++++++++++++++++ .../spanner/spi/v1/KeyAwareChannelTest.java | 13 ++- .../v1/ReplicaSelectionMockServerTest.java | 12 +- 10 files changed, 371 insertions(+), 62 deletions(-) create mode 100644 java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistryTest.java diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelFinder.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelFinder.java index 6e77ebd2692d..f4fdb41fe3c5 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelFinder.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelFinder.java @@ -82,7 +82,8 @@ public ChannelFinder( ChannelEndpointCache endpointCache, @Nullable EndpointLifecycleManager lifecycleManager, @Nullable String finderKey) { - this.rangeCache = new KeyRangeCache(Objects.requireNonNull(endpointCache), lifecycleManager); + this.rangeCache = + new KeyRangeCache(Objects.requireNonNull(endpointCache), lifecycleManager, finderKey); this.lifecycleManager = lifecycleManager; this.finderKey = finderKey; } @@ -91,6 +92,11 @@ void useDeterministicRandom() { rangeCache.useDeterministicRandom(); } + @Nullable + String finderKey() { + return finderKey; + } + private static ExecutorService createCacheUpdatePool() { ThreadPoolExecutor executor = new ThreadPoolExecutor( diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistry.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistry.java index be3eff53b75c..52ad76f3274c 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistry.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistry.java @@ -17,37 +17,51 @@ package com.google.cloud.spanner.spi.v1; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Ticker; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; import java.time.Duration; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; /** Shared process-local latency scores for routed Spanner endpoints. */ final class EndpointLatencyRegistry { + private static final String GLOBAL_SCOPE = "__global__"; static final Duration DEFAULT_ERROR_PENALTY = Duration.ofSeconds(10); static final Duration DEFAULT_RTT = Duration.ofMillis(10); static final double DEFAULT_PENALTY_VALUE = 1_000_000.0; + @VisibleForTesting static final Duration TRACKER_EXPIRE_AFTER_ACCESS = Duration.ofMinutes(10); + @VisibleForTesting static final long MAX_TRACKERS = 100_000L; - private static final ConcurrentHashMap TRACKERS = - new ConcurrentHashMap<>(); + private static volatile Cache TRACKERS = + newTrackerCache(Ticker.systemTicker()); private static final ConcurrentHashMap INFLIGHT_REQUESTS = new ConcurrentHashMap<>(); private EndpointLatencyRegistry() {} - static boolean hasScore(long operationUid, String endpointLabelOrAddress) { - TrackerKey trackerKey = trackerKey(operationUid, endpointLabelOrAddress); - return trackerKey != null && TRACKERS.containsKey(trackerKey); + static boolean hasScore( + @javax.annotation.Nullable String databaseScope, + long operationUid, + String endpointLabelOrAddress) { + TrackerKey trackerKey = trackerKey(databaseScope, operationUid, endpointLabelOrAddress); + return trackerKey != null && TRACKERS.getIfPresent(trackerKey) != null; } - static double getSelectionCost(long operationUid, String endpointLabelOrAddress) { - TrackerKey trackerKey = trackerKey(operationUid, endpointLabelOrAddress); + static double getSelectionCost( + @javax.annotation.Nullable String databaseScope, + long operationUid, + String endpointLabelOrAddress) { + TrackerKey trackerKey = trackerKey(databaseScope, operationUid, endpointLabelOrAddress); if (trackerKey == null) { return Double.MAX_VALUE; } double activeRequests = getInflight(endpointLabelOrAddress); - LatencyTracker tracker = TRACKERS.get(trackerKey); + LatencyTracker tracker = TRACKERS.getIfPresent(trackerKey); if (tracker != null) { return tracker.getScore() * (activeRequests + 1.0); } @@ -57,24 +71,35 @@ static double getSelectionCost(long operationUid, String endpointLabelOrAddress) return defaultRttMicros() * (activeRequests + 1.0); } - static void recordLatency(long operationUid, String endpointLabelOrAddress, Duration latency) { - TrackerKey trackerKey = trackerKey(operationUid, endpointLabelOrAddress); + static void recordLatency( + @javax.annotation.Nullable String databaseScope, + long operationUid, + String endpointLabelOrAddress, + Duration latency) { + TrackerKey trackerKey = trackerKey(databaseScope, operationUid, endpointLabelOrAddress); if (trackerKey == null || latency == null) { return; } - TRACKERS.computeIfAbsent(trackerKey, ignored -> new EwmaLatencyTracker()).update(latency); + getOrCreateTracker(trackerKey).update(latency); } - static void recordError(long operationUid, String endpointLabelOrAddress) { - recordError(operationUid, endpointLabelOrAddress, DEFAULT_ERROR_PENALTY); + static void recordError( + @javax.annotation.Nullable String databaseScope, + long operationUid, + String endpointLabelOrAddress) { + recordError(databaseScope, operationUid, endpointLabelOrAddress, DEFAULT_ERROR_PENALTY); } - static void recordError(long operationUid, String endpointLabelOrAddress, Duration penalty) { - TrackerKey trackerKey = trackerKey(operationUid, endpointLabelOrAddress); + static void recordError( + @javax.annotation.Nullable String databaseScope, + long operationUid, + String endpointLabelOrAddress, + Duration penalty) { + TrackerKey trackerKey = trackerKey(databaseScope, operationUid, endpointLabelOrAddress); if (trackerKey == null || penalty == null) { return; } - TRACKERS.computeIfAbsent(trackerKey, ignored -> new EwmaLatencyTracker()).recordError(penalty); + getOrCreateTracker(trackerKey).recordError(penalty); } static void beginRequest(String endpointLabelOrAddress) { @@ -94,10 +119,7 @@ static void finishRequest(String endpointLabelOrAddress) { if (counter == null) { return; } - int updated = counter.decrementAndGet(); - if (updated <= 0) { - INFLIGHT_REQUESTS.remove(address, counter); - } + counter.updateAndGet(current -> current > 0 ? current - 1 : 0); } static int getInflight(String endpointLabelOrAddress) { @@ -111,10 +133,15 @@ static int getInflight(String endpointLabelOrAddress) { @VisibleForTesting static void clear() { - TRACKERS.clear(); + TRACKERS.invalidateAll(); INFLIGHT_REQUESTS.clear(); } + @VisibleForTesting + static void useTrackerTicker(Ticker ticker) { + TRACKERS = newTrackerCache(ticker); + } + @VisibleForTesting static String normalizeAddress(String endpointLabelOrAddress) { if (endpointLabelOrAddress == null || endpointLabelOrAddress.isEmpty()) { @@ -124,24 +151,49 @@ static String normalizeAddress(String endpointLabelOrAddress) { } @VisibleForTesting - static TrackerKey trackerKey(long operationUid, String endpointLabelOrAddress) { + static TrackerKey trackerKey( + @javax.annotation.Nullable String databaseScope, + long operationUid, + String endpointLabelOrAddress) { String address = normalizeAddress(endpointLabelOrAddress); if (operationUid <= 0 || address == null) { return null; } - return new TrackerKey(operationUid, address); + return new TrackerKey(normalizeScope(databaseScope), operationUid, address); } private static long defaultRttMicros() { return DEFAULT_RTT.toNanos() / 1_000L; } + private static String normalizeScope(@javax.annotation.Nullable String databaseScope) { + return (databaseScope == null || databaseScope.isEmpty()) ? GLOBAL_SCOPE : databaseScope; + } + + private static LatencyTracker getOrCreateTracker(TrackerKey trackerKey) { + try { + return TRACKERS.get(trackerKey, EwmaLatencyTracker::new); + } catch (ExecutionException e) { + throw new IllegalStateException("Failed to create latency tracker", e); + } + } + + private static Cache newTrackerCache(Ticker ticker) { + return CacheBuilder.newBuilder() + .maximumSize(MAX_TRACKERS) + .expireAfterAccess(TRACKER_EXPIRE_AFTER_ACCESS.toNanos(), TimeUnit.NANOSECONDS) + .ticker(ticker) + .build(); + } + @VisibleForTesting static final class TrackerKey { + private final String databaseScope; private final long operationUid; private final String address; - private TrackerKey(long operationUid, String address) { + private TrackerKey(String databaseScope, long operationUid, String address) { + this.databaseScope = databaseScope; this.operationUid = operationUid; this.address = address; } @@ -155,17 +207,19 @@ public boolean equals(Object other) { return false; } TrackerKey that = (TrackerKey) other; - return operationUid == that.operationUid && Objects.equals(address, that.address); + return operationUid == that.operationUid + && Objects.equals(databaseScope, that.databaseScope) + && Objects.equals(address, that.address); } @Override public int hashCode() { - return Objects.hash(operationUid, address); + return Objects.hash(databaseScope, operationUid, address); } @Override public String toString() { - return operationUid + "@" + address; + return databaseScope + ":" + operationUid + "@" + address; } } } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/HeaderInterceptor.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/HeaderInterceptor.java index e01e4e97529a..445ea436947a 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/HeaderInterceptor.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/HeaderInterceptor.java @@ -232,7 +232,10 @@ private void recordFirstResponseLatency( } long latencyNanos = Math.max(0L, System.nanoTime() - startedAtNanos); EndpointLatencyRegistry.recordLatency( - routingTarget.operationUid, routingTarget.targetEndpoint, Duration.ofNanos(latencyNanos)); + routingTarget.databaseScope, + routingTarget.operationUid, + routingTarget.targetEndpoint, + Duration.ofNanos(latencyNanos)); } private Map parseServerTimingHeader(String serverTiming) { diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java index 127c75c9adb3..3d64187fd056 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java @@ -207,10 +207,7 @@ private ChannelFinder getOrCreateChannelFinder(String databaseId) { ref = channelFinders.get(databaseId); finder = (ref != null) ? ref.get() : null; if (finder == null) { - finder = - lifecycleManager != null - ? new ChannelFinder(endpointCache, lifecycleManager, databaseId) - : new ChannelFinder(endpointCache); + finder = new ChannelFinder(endpointCache, lifecycleManager, databaseId); channelFinders.put( databaseId, new ChannelFinderReference(databaseId, finder, channelFinderReferenceQueue)); @@ -381,7 +378,10 @@ private void maybeExcludeEndpointOnNextCall( } private void maybeRecordErrorPenalty( - @Nullable ChannelEndpoint endpoint, io.grpc.Status.Code statusCode, long operationUid) { + @Nullable String databaseScope, + @Nullable ChannelEndpoint endpoint, + io.grpc.Status.Code statusCode, + long operationUid) { if (!shouldExcludeEndpointOnRetry(statusCode) || endpoint == null || operationUid <= 0L) { return; } @@ -389,7 +389,7 @@ private void maybeRecordErrorPenalty( if (defaultEndpointAddress.equals(address)) { return; } - EndpointLatencyRegistry.recordError(operationUid, address); + EndpointLatencyRegistry.recordError(databaseScope, operationUid, address); } private static boolean shouldExcludeEndpointOnRetry(io.grpc.Status.Code statusCode) { @@ -526,6 +526,7 @@ static final class KeyAwareClientCall @Nullable private Predicate excludedEndpoints; @Nullable private ChannelEndpoint selectedEndpoint; @Nullable private String selectedTargetEndpoint; + @Nullable private String selectedDatabaseScope; private long selectedOperationUid; @Nullable private ByteString transactionIdToClear; private boolean allowDefaultAffinity; @@ -593,6 +594,7 @@ public void sendMessage(RequestT message) { Predicate excludedEndpoints = excludedEndpoints(); ChannelEndpoint endpoint = null; ChannelFinder finder = null; + String databaseScope = null; long operationUid = 0L; if (message instanceof ReadRequest) { @@ -601,6 +603,7 @@ public void sendMessage(RequestT message) { RoutingDecision routing = routeFromRequest(reqBuilder); finder = routing.finder; endpoint = routing.endpoint; + databaseScope = routing.databaseScope; operationUid = routing.operationUid; message = (RequestT) reqBuilder.build(); } else if (message instanceof ExecuteSqlRequest) { @@ -609,6 +612,7 @@ public void sendMessage(RequestT message) { RoutingDecision routing = routeFromRequest(reqBuilder); finder = routing.finder; endpoint = routing.endpoint; + databaseScope = routing.databaseScope; operationUid = routing.operationUid; message = (RequestT) reqBuilder.build(); } else if (message instanceof BeginTransactionRequest) { @@ -617,6 +621,7 @@ public void sendMessage(RequestT message) { String databaseId = parentChannel.extractDatabaseIdFromSession(reqBuilder.getSession()); if (databaseId != null) { finder = parentChannel.getOrCreateChannelFinder(databaseId); + databaseScope = databaseId; } if (finder != null && reqBuilder.hasMutationKey()) { endpoint = finder.findServer(reqBuilder, excludedEndpoints); @@ -633,6 +638,7 @@ public void sendMessage(RequestT message) { String databaseId = parentChannel.extractDatabaseIdFromSession(request.getSession()); if (databaseId != null) { finder = parentChannel.getOrCreateChannelFinder(databaseId); + databaseScope = databaseId; } CommitRequest.Builder reqBuilder = null; if (finder != null && request.getMutationsCount() > 0) { @@ -672,13 +678,17 @@ public void sendMessage(RequestT message) { } selectedEndpoint = endpoint; selectedTargetEndpoint = endpoint.getAddress(); + selectedDatabaseScope = databaseScope != null ? databaseScope : routingScope(finder); selectedOperationUid = operationUid; this.channelFinder = finder; EndpointLatencyRegistry.beginRequest(selectedTargetEndpoint); XGoogSpannerRequestId requestId = callOptions.getOption(REQUEST_ID_CALL_OPTIONS_KEY); if (requestId != null) { RequestIdTargetTracker.record( - requestId.getHeaderValue(), selectedTargetEndpoint, operationUid); + requestId.getHeaderValue(), + selectedDatabaseScope, + selectedTargetEndpoint, + operationUid); } // Record real traffic for idle eviction tracking. @@ -856,7 +866,8 @@ private RoutingDecision routeFromRequest(ReadRequest.Builder reqBuilder) { : finder.findServer(reqBuilder, excludedEndpoints); endpoint = routed; } - return new RoutingDecision(finder, endpoint, operationUid(reqBuilder.getRoutingHint())); + return new RoutingDecision( + finder, endpoint, databaseId, operationUid(reqBuilder.getRoutingHint())); } private RoutingDecision routeFromRequest(ExecuteSqlRequest.Builder reqBuilder) { @@ -879,23 +890,34 @@ private RoutingDecision routeFromRequest(ExecuteSqlRequest.Builder reqBuilder) { : finder.findServer(reqBuilder, excludedEndpoints); endpoint = routed; } - return new RoutingDecision(finder, endpoint, operationUid(reqBuilder.getRoutingHint())); + return new RoutingDecision( + finder, endpoint, databaseId, operationUid(reqBuilder.getRoutingHint())); } } private static final class RoutingDecision { @Nullable private final ChannelFinder finder; @Nullable private final ChannelEndpoint endpoint; + @Nullable private final String databaseScope; private final long operationUid; private RoutingDecision( - @Nullable ChannelFinder finder, @Nullable ChannelEndpoint endpoint, long operationUid) { + @Nullable ChannelFinder finder, + @Nullable ChannelEndpoint endpoint, + @Nullable String databaseScope, + long operationUid) { this.finder = finder; this.endpoint = endpoint; + this.databaseScope = databaseScope; this.operationUid = operationUid; } } + @Nullable + private static String routingScope(@Nullable ChannelFinder finder) { + return finder == null ? null : finder.finderKey(); + } + private static long operationUid(com.google.spanner.v1.RoutingHint routingHint) { return routingHint == null ? 0L : routingHint.getOperationUid(); } @@ -953,7 +975,10 @@ public void onMessage(ResponseT message) { public void onClose(io.grpc.Status status, Metadata trailers) { if (shouldExcludeEndpointOnRetry(status.getCode())) { call.parentChannel.maybeRecordErrorPenalty( - call.selectedEndpoint, status.getCode(), call.selectedOperationUid); + call.selectedDatabaseScope, + call.selectedEndpoint, + status.getCode(), + call.selectedOperationUid); call.parentChannel.maybeExcludeEndpointOnNextCall( call.selectedEndpoint, call.logicalRequestKey); } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java index d02233b5a96b..52295c503b6e 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java @@ -179,6 +179,7 @@ static String formatTargetEndpointLabel(String address, boolean isLeader) { private final ChannelEndpointCache endpointCache; @javax.annotation.Nullable private final EndpointLifecycleManager lifecycleManager; + @javax.annotation.Nullable private final String databaseScope; private final NavigableMap ranges = new TreeMap<>(ByteString.unsignedLexicographicalComparator()); private final Map groups = new HashMap<>(); @@ -192,14 +193,22 @@ static String formatTargetEndpointLabel(String address, boolean isLeader) { private volatile int minCacheEntriesForRandomPick = DEFAULT_MIN_ENTRIES_FOR_RANDOM_PICK; public KeyRangeCache(ChannelEndpointCache endpointCache) { - this(endpointCache, null); + this(endpointCache, null, null); } public KeyRangeCache( ChannelEndpointCache endpointCache, @javax.annotation.Nullable EndpointLifecycleManager lifecycleManager) { + this(endpointCache, lifecycleManager, null); + } + + KeyRangeCache( + ChannelEndpointCache endpointCache, + @javax.annotation.Nullable EndpointLifecycleManager lifecycleManager, + @javax.annotation.Nullable String databaseScope) { this.endpointCache = Objects.requireNonNull(endpointCache); this.lifecycleManager = lifecycleManager; + this.databaseScope = databaseScope; } @VisibleForTesting @@ -214,12 +223,12 @@ void setMinCacheEntriesForRandomPick(int value) { @VisibleForTesting void recordReplicaLatency(long operationUid, String address, Duration latency) { - EndpointLatencyRegistry.recordLatency(operationUid, address, latency); + EndpointLatencyRegistry.recordLatency(databaseScope, operationUid, address, latency); } @VisibleForTesting void recordReplicaError(long operationUid, String address) { - EndpointLatencyRegistry.recordError(operationUid, address); + EndpointLatencyRegistry.recordError(databaseScope, operationUid, address); } /** Applies cache updates. Tablets are processed inside group updates. */ @@ -942,7 +951,7 @@ private TabletSnapshot selectScoreAwareTablet( eligibleTablets.add(tablet); eligibleEndpoints.add(endpoint); endpointByAddress.put(endpoint.getAddress(), tablet); - if (EndpointLatencyRegistry.hasScore(operationUid, tablet.serverAddress)) { + if (EndpointLatencyRegistry.hasScore(databaseScope, operationUid, tablet.serverAddress)) { scoredCandidates++; } } @@ -1041,7 +1050,9 @@ private double selectionCost( if (tablet == null) { return Double.MAX_VALUE; } - double cost = EndpointLatencyRegistry.getSelectionCost(operationUid, tablet.serverAddress); + double cost = + EndpointLatencyRegistry.getSelectionCost( + databaseScope, operationUid, tablet.serverAddress); if (preferredLeader != null && tablet == preferredLeader) { return cost * LOCAL_LEADER_SELECTION_COST_MULTIPLIER; } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/RequestIdTargetTracker.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/RequestIdTargetTracker.java index f64a3bd23728..4c9eef0de1fb 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/RequestIdTargetTracker.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/RequestIdTargetTracker.java @@ -24,21 +24,23 @@ import javax.annotation.Nullable; final class RequestIdTargetTracker { + @VisibleForTesting static final long MAX_TRACKED_TARGETS = 1_000_000L; private static final Cache TARGETS = CacheBuilder.newBuilder() - .maximumSize(100_000_000L) + .maximumSize(MAX_TRACKED_TARGETS) .expireAfterWrite(10, TimeUnit.MINUTES) .build(); private RequestIdTargetTracker() {} - static void record(String requestId, String targetEndpoint, long operationUid) { + static void record( + String requestId, @Nullable String databaseScope, String targetEndpoint, long operationUid) { String trackingKey = normalizeRequestKey(requestId); if (trackingKey == null || targetEndpoint == null || targetEndpoint.isEmpty()) { return; } - TARGETS.put(trackingKey, new RoutingTarget(targetEndpoint, operationUid)); + TARGETS.put(trackingKey, new RoutingTarget(databaseScope, targetEndpoint, operationUid)); } @Nullable @@ -76,10 +78,13 @@ static String normalizeRequestKey(String requestId) { } static final class RoutingTarget { + @Nullable final String databaseScope; final String targetEndpoint; final long operationUid; - private RoutingTarget(String targetEndpoint, long operationUid) { + private RoutingTarget( + @Nullable String databaseScope, String targetEndpoint, long operationUid) { + this.databaseScope = databaseScope; this.targetEndpoint = targetEndpoint; this.operationUid = operationUid; } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java index 41722e90bc73..ed2cfc7f192a 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import com.google.cloud.NoCredentials; import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; @@ -52,6 +53,7 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -496,10 +498,101 @@ public void singleUseReadMidStreamRecvFailureWithoutRetryInfoRetriesForBypassTra } } + @Test + public void readWriteTransactionAbortedCommitUsesReadAffinityReplicaForBypassTraffic() + throws Exception { + try (SharedBackendReplicaHarness harness = SharedBackendReplicaHarness.create(2); + Spanner spanner = createSpanner(harness)) { + configureBackend(harness, singleRowReadResultSet("b")); + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); + + seedLocationMetadata(client); + harness.clearRequests(); + AtomicInteger attempts = new AtomicInteger(); + AtomicInteger firstReplicaIndex = new AtomicInteger(-1); + + client + .readWriteTransaction() + .run( + transaction -> { + int attempt = attempts.incrementAndGet(); + try (ResultSet resultSet = + transaction.read(TABLE, KeySet.singleKey(Key.of("b")), Arrays.asList("k"))) { + assertTrue(resultSet.next()); + } + + if (attempt == 1) { + int routedReplicaIndex = + findReplicaWithRequest( + harness, SharedBackendReplicaHarness.METHOD_STREAMING_READ); + if (routedReplicaIndex < 0) { + fail("Expected read-write transaction read to route to a bypass replica"); + } + firstReplicaIndex.set(routedReplicaIndex); + harness + .replicas + .get(routedReplicaIndex) + .putMethodErrors( + SharedBackendReplicaHarness.METHOD_COMMIT, + Status.ABORTED + .withDescription("commit aborted on routed replica") + .asRuntimeException()); + } + + transaction.buffer( + Mutation.newInsertOrUpdateBuilder("NoRecipeTable") + .set("id") + .to("row-1") + .build()); + return null; + }); + + assertEquals(2, attempts.get()); + assertTrue(firstReplicaIndex.get() >= 0); + int secondReplicaIndex = 1 - firstReplicaIndex.get(); + assertEquals( + 2, + harness + .replicas + .get(firstReplicaIndex.get()) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 0, + harness + .replicas + .get(secondReplicaIndex) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 2, + harness + .replicas + .get(firstReplicaIndex.get()) + .getRequests(SharedBackendReplicaHarness.METHOD_COMMIT) + .size()); + assertEquals( + 0, + harness + .replicas + .get(secondReplicaIndex) + .getRequests(SharedBackendReplicaHarness.METHOD_COMMIT) + .size()); + assertEquals( + 0, harness.defaultReplica.getRequests(SharedBackendReplicaHarness.METHOD_COMMIT).size()); + } + } + private static Spanner createSpanner(SharedBackendReplicaHarness harness) { return SpannerOptions.newBuilder() .usePlainText() .setExperimentalHost(harness.defaultAddress) + .setSessionPoolOption( + SessionPoolOptions.newBuilder() + .setExperimentalHost() + .setUseMultiplexedSession(true) + .setUseMultiplexedSessionForRW(true) + .build()) .setProjectId(PROJECT) .setCredentials(NoCredentials.getInstance()) .setChannelEndpointCacheFactory(null) @@ -560,6 +653,15 @@ private static int waitForReplicaRoutedRead( throw new AssertionError("Timed out waiting for location-aware read to route to replica"); } + private static int findReplicaWithRequest(SharedBackendReplicaHarness harness, String method) { + for (int replicaIndex = 0; replicaIndex < harness.replicas.size(); replicaIndex++) { + if (!harness.replicas.get(replicaIndex).getRequests(method).isEmpty()) { + return replicaIndex; + } + } + return -1; + } + private static CacheUpdate cacheUpdate(SharedBackendReplicaHarness harness) throws TextFormat.ParseException { RecipeList recipes = readRecipeList(); diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistryTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistryTest.java new file mode 100644 index 000000000000..45b071347ce5 --- /dev/null +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/EndpointLatencyRegistryTest.java @@ -0,0 +1,104 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.spi.v1; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.base.Ticker; +import com.google.common.testing.FakeTicker; +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class EndpointLatencyRegistryTest { + private static final String DATABASE_SCOPE = "projects/p/instances/i/databases/d"; + + @After + public void tearDown() { + EndpointLatencyRegistry.useTrackerTicker(Ticker.systemTicker()); + EndpointLatencyRegistry.clear(); + } + + @Test + public void trackersExpireAfterAccessWindow() { + FakeTicker ticker = new FakeTicker(); + EndpointLatencyRegistry.useTrackerTicker(ticker); + + EndpointLatencyRegistry.recordLatency( + DATABASE_SCOPE, 101L, "server-a:1234", Duration.ofMillis(5)); + + assertThat(EndpointLatencyRegistry.hasScore(DATABASE_SCOPE, 101L, "server-a:1234")).isTrue(); + + ticker.advance( + EndpointLatencyRegistry.TRACKER_EXPIRE_AFTER_ACCESS.toNanos() + 1L, TimeUnit.NANOSECONDS); + + assertThat(EndpointLatencyRegistry.hasScore(DATABASE_SCOPE, 101L, "server-a:1234")).isFalse(); + } + + @Test + public void accessKeepsTrackerAliveWithinExpiryWindow() { + FakeTicker ticker = new FakeTicker(); + EndpointLatencyRegistry.useTrackerTicker(ticker); + + EndpointLatencyRegistry.recordLatency( + DATABASE_SCOPE, 202L, "server-b:1234", Duration.ofMillis(7)); + + ticker.advance( + EndpointLatencyRegistry.TRACKER_EXPIRE_AFTER_ACCESS.toNanos() / 2L, TimeUnit.NANOSECONDS); + assertThat(EndpointLatencyRegistry.getSelectionCost(DATABASE_SCOPE, 202L, "server-b:1234")) + .isGreaterThan(0.0); + + ticker.advance( + EndpointLatencyRegistry.TRACKER_EXPIRE_AFTER_ACCESS.toNanos() / 2L, TimeUnit.NANOSECONDS); + + assertThat(EndpointLatencyRegistry.hasScore(DATABASE_SCOPE, 202L, "server-b:1234")).isTrue(); + } + + @Test + public void trackersAreIsolatedByDatabaseScope() { + EndpointLatencyRegistry.recordLatency( + "projects/p1/instances/i1/databases/d1", 303L, "server-a:1234", Duration.ofMillis(9)); + + assertThat( + EndpointLatencyRegistry.hasScore( + "projects/p1/instances/i1/databases/d1", 303L, "server-a:1234")) + .isTrue(); + assertThat( + EndpointLatencyRegistry.hasScore( + "projects/p2/instances/i2/databases/d2", 303L, "server-a:1234")) + .isFalse(); + } + + @Test + public void inflightCountDoesNotGoNegativeAndCanBeReusedAfterZero() { + EndpointLatencyRegistry.beginRequest("server-c:1234"); + assertThat(EndpointLatencyRegistry.getInflight("server-c:1234")).isEqualTo(1); + + EndpointLatencyRegistry.finishRequest("server-c:1234"); + assertThat(EndpointLatencyRegistry.getInflight("server-c:1234")).isEqualTo(0); + + EndpointLatencyRegistry.finishRequest("server-c:1234"); + assertThat(EndpointLatencyRegistry.getInflight("server-c:1234")).isEqualTo(0); + + EndpointLatencyRegistry.beginRequest("server-c:1234"); + assertThat(EndpointLatencyRegistry.getInflight("server-c:1234")).isEqualTo(1); + } +} diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java index a993c7be550f..975967278cde 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java @@ -629,6 +629,8 @@ public void resourceExhaustedRoutedEndpointRetriesRandomExcludedReplicaWhenAllRe for (int attempt = 0; attempt < 20 && (!selectedServerAOnThirdAttempt || !selectedServerBOnThirdAttempt); attempt++) { + EndpointLatencyRegistry.clear(); + RequestIdTargetTracker.clear(); TestHarness harness = createHarness(); seedCache(harness, createLeaderAndReplicaCacheUpdate()); CallOptions retryCallOptions = retryCallOptions(100L + attempt); @@ -1677,9 +1679,14 @@ private static void assertRoutedEndpointErrorPenaltyRecorded(Status status, long assertThat(routedOperationUid).isGreaterThan(0L); delegate.emitOnClose(status, new Metadata()); - assertThat(EndpointLatencyRegistry.hasScore(routedOperationUid, "server-a:1234")).isTrue(); - assertThat(EndpointLatencyRegistry.getSelectionCost(routedOperationUid, "server-a:1234")) + String databaseScope = "projects/p/instances/i/databases/d"; + assertThat(EndpointLatencyRegistry.hasScore(databaseScope, routedOperationUid, "server-a:1234")) + .isTrue(); + assertThat( + EndpointLatencyRegistry.getSelectionCost( + databaseScope, routedOperationUid, "server-a:1234")) .isEqualTo((double) EndpointLatencyRegistry.DEFAULT_ERROR_PENALTY.toNanos() / 1_000D); - assertThat(EndpointLatencyRegistry.hasScore(routedOperationUid, "server-b:1234")).isFalse(); + assertThat(EndpointLatencyRegistry.hasScore(databaseScope, routedOperationUid, "server-b:1234")) + .isFalse(); } } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java index dc444eb687f6..05b773ad4f60 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/ReplicaSelectionMockServerTest.java @@ -443,14 +443,9 @@ public void testStaleSingleUseReadBootstrapsScoresAndConvergesToLowerLatencyRepl } clearServerRequests(); - boolean sampledServer0 = false; - boolean sampledServer1 = false; - Stopwatch watch = Stopwatch.createStarted(); - int attempt = 0; long operationUid = 0L; - while (watch.elapsed(TimeUnit.SECONDS) < 10 && (!sampledServer0 || !sampledServer1)) { - attempt++; + for (int attempt = 1; attempt <= 3; attempt++) { String key = "bootstrap-key-" + attempt; try (com.google.cloud.spanner.ResultSet rs = client @@ -474,15 +469,12 @@ public void testStaleSingleUseReadBootstrapsScoresAndConvergesToLowerLatencyRepl operationUid, currentOperationUid); } - sampledServer0 = hasReadRequestForKey(servers.get(0).mockSpanner, key) || sampledServer0; - sampledServer1 = hasReadRequestForKey(servers.get(1).mockSpanner, key) || sampledServer1; } - assertTrue("Expected bootstrap exploration to sample server0", sampledServer0); - assertTrue("Expected bootstrap exploration to sample server1", sampledServer1); assertTrue("Expected stale reads to reuse the same operation_uid", operationUid > 0L); clearServerRequests(); + Stopwatch watch = Stopwatch.createStarted(); boolean routedToLowerLatencyReplica = false; int convergenceAttempt = 0; while (watch.elapsed(TimeUnit.SECONDS) < 10 && !routedToLowerLatencyReplica) { From e94496f08024f207e1a359183cd91692fc2716bb Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Mon, 20 Apr 2026 10:35:43 +0530 Subject: [PATCH 8/9] more test assertions --- ...nAwareSharedBackendReplicaHarnessTest.java | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java index ed2cfc7f192a..2adee5be5fde 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java @@ -503,7 +503,7 @@ public void readWriteTransactionAbortedCommitUsesReadAffinityReplicaForBypassTra throws Exception { try (SharedBackendReplicaHarness harness = SharedBackendReplicaHarness.create(2); Spanner spanner = createSpanner(harness)) { - configureBackend(harness, singleRowReadResultSet("b")); + configureBackend(harness, singleRowReadResultSet("b"), /* leaderReplicaIndex= */ 1); DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); seedLocationMetadata(client); @@ -538,17 +538,11 @@ public void readWriteTransactionAbortedCommitUsesReadAffinityReplicaForBypassTra .withDescription("commit aborted on routed replica") .asRuntimeException()); } - - transaction.buffer( - Mutation.newInsertOrUpdateBuilder("NoRecipeTable") - .set("id") - .to("row-1") - .build()); return null; }); assertEquals(2, attempts.get()); - assertTrue(firstReplicaIndex.get() >= 0); + assertEquals(1, firstReplicaIndex.get()); int secondReplicaIndex = 1 - firstReplicaIndex.get(); assertEquals( 2, @@ -603,6 +597,14 @@ private static Spanner createSpanner(SharedBackendReplicaHarness harness) { private static void configureBackend( SharedBackendReplicaHarness harness, com.google.spanner.v1.ResultSet readResultSet) throws TextFormat.ParseException { + configureBackend(harness, readResultSet, /* leaderReplicaIndex= */ 0); + } + + private static void configureBackend( + SharedBackendReplicaHarness harness, + com.google.spanner.v1.ResultSet readResultSet, + int leaderReplicaIndex) + throws TextFormat.ParseException { Statement readStatement = StatementResult.createReadStatement( TABLE, KeySet.singleKey(Key.of("b")), Arrays.asList("k")); @@ -611,7 +613,7 @@ private static void configureBackend( StatementResult.query( SEED_QUERY, singleRowReadResultSet("seed").toBuilder() - .setCacheUpdate(cacheUpdate(harness)) + .setCacheUpdate(cacheUpdate(harness, leaderReplicaIndex)) .build())); } @@ -664,6 +666,12 @@ private static int findReplicaWithRequest(SharedBackendReplicaHarness harness, S private static CacheUpdate cacheUpdate(SharedBackendReplicaHarness harness) throws TextFormat.ParseException { + return cacheUpdate(harness, /* leaderReplicaIndex= */ 0); + } + + private static CacheUpdate cacheUpdate( + SharedBackendReplicaHarness harness, int leaderReplicaIndex) + throws TextFormat.ParseException { RecipeList recipes = readRecipeList(); RoutingHint routingHint = exactReadRoutingHint(recipes); ByteString limitKey = routingHint.getLimitKey(); @@ -685,7 +693,7 @@ private static CacheUpdate cacheUpdate(SharedBackendReplicaHarness harness) Group.newBuilder() .setGroupUid(1L) .setGeneration(com.google.protobuf.ByteString.copyFromUtf8("gen1")) - .setLeaderIndex(0) + .setLeaderIndex(leaderReplicaIndex) .addTablets( Tablet.newBuilder() .setTabletUid(11L) From 85d55c6b9aa6a74307593daafeae8a4745578a99 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Mon, 20 Apr 2026 10:48:09 +0530 Subject: [PATCH 9/9] fix keep alives --- .../spi/v1/GrpcChannelEndpointCache.java | 6 +++++ ...nAwareSharedBackendReplicaHarnessTest.java | 25 +++++++++++++++++++ .../spi/v1/GrpcChannelEndpointCacheTest.java | 10 ++++++-- 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java index 54ee67439e59..f50a20304a98 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCache.java @@ -27,6 +27,7 @@ import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import java.io.IOException; +import java.time.Duration; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; @@ -51,6 +52,9 @@ class GrpcChannelEndpointCache implements ChannelEndpointCache { /** Timeout for graceful channel shutdown. */ private static final long SHUTDOWN_TIMEOUT_SECONDS = 5; + @VisibleForTesting static final Duration ROUTED_KEEPALIVE_TIME = Duration.ofSeconds(2); + @VisibleForTesting static final Duration ROUTED_KEEPALIVE_TIMEOUT = Duration.ofSeconds(20); + private final InstantiatingGrpcChannelProvider baseProvider; private final Map servers = new ConcurrentHashMap<>(); private final GrpcChannelEndpoint defaultEndpoint; @@ -129,6 +133,8 @@ InstantiatingGrpcChannelProvider createProviderWithAuthorityOverride(String addr } Builder builder = endpointProvider.toBuilder(); builder.setChannelPoolSettings(ChannelPoolSettings.staticallySized(1)); + builder.setKeepAliveTimeDuration(ROUTED_KEEPALIVE_TIME); + builder.setKeepAliveTimeoutDuration(ROUTED_KEEPALIVE_TIMEOUT); builder.setKeepAliveWithoutCalls(Boolean.TRUE); final com.google.api.core.ApiFunction baseConfigurator = diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java index 2adee5be5fde..817dafaae74d 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java @@ -507,6 +507,7 @@ public void readWriteTransactionAbortedCommitUsesReadAffinityReplicaForBypassTra DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); seedLocationMetadata(client); + waitForReplicaRoutedStrongRead(client, harness, /* expectedReplicaIndex= */ 1); harness.clearRequests(); AtomicInteger attempts = new AtomicInteger(); AtomicInteger firstReplicaIndex = new AtomicInteger(-1); @@ -655,6 +656,30 @@ private static int waitForReplicaRoutedRead( throw new AssertionError("Timed out waiting for location-aware read to route to replica"); } + private static void waitForReplicaRoutedStrongRead( + DatabaseClient client, SharedBackendReplicaHarness harness, int expectedReplicaIndex) + throws InterruptedException { + long deadlineNanos = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + while (System.nanoTime() < deadlineNanos) { + harness.clearRequests(); + try (ResultSet resultSet = + client.singleUse().read(TABLE, KeySet.singleKey(Key.of("b")), Arrays.asList("k"))) { + if (resultSet.next()) { + if (!harness + .replicas + .get(expectedReplicaIndex) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .isEmpty()) { + return; + } + } + } + Thread.sleep(50L); + } + throw new AssertionError( + "Timed out waiting for strong read to route to replica " + expectedReplicaIndex); + } + private static int findReplicaWithRequest(SharedBackendReplicaHarness harness, String method) { for (int replicaIndex = 0; replicaIndex < harness.replicas.size(); replicaIndex++) { if (!harness.replicas.get(replicaIndex).getRequests(method).isEmpty()) { diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCacheTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCacheTest.java index c0eea3f88c4f..cca418a53f99 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCacheTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/GrpcChannelEndpointCacheTest.java @@ -98,12 +98,13 @@ public void routedChannelsUseSingleUnderlyingChannel() throws Exception { } @Test - public void routedChannelsEnableKeepAliveWithoutCallsOnlyForEndpointProvider() throws Exception { + public void routedChannelsOverrideKeepAliveSettingsOnlyForEndpointProvider() throws Exception { InstantiatingGrpcChannelProvider provider = InstantiatingGrpcChannelProvider.newBuilder() .setEndpoint(DEFAULT_ENDPOINT) .setPoolSize(4) .setKeepAliveTimeDuration(java.time.Duration.ofSeconds(120)) + .setKeepAliveTimeoutDuration(java.time.Duration.ofSeconds(60)) .setKeepAliveWithoutCalls(Boolean.FALSE) .setChannelConfigurator(ManagedChannelBuilder::usePlaintext) .build(); @@ -113,9 +114,14 @@ public void routedChannelsEnableKeepAliveWithoutCallsOnlyForEndpointProvider() t cache.createProviderWithAuthorityOverride(ROUTED_ENDPOINT_A); assertThat(provider.getKeepAliveWithoutCalls()).isFalse(); + assertThat(provider.getKeepAliveTimeDuration()).isEqualTo(java.time.Duration.ofSeconds(120)); + assertThat(provider.getKeepAliveTimeoutDuration()) + .isEqualTo(java.time.Duration.ofSeconds(60)); assertThat(routedProvider.getKeepAliveWithoutCalls()).isTrue(); assertThat(routedProvider.getKeepAliveTimeDuration()) - .isEqualTo(provider.getKeepAliveTimeDuration()); + .isEqualTo(GrpcChannelEndpointCache.ROUTED_KEEPALIVE_TIME); + assertThat(routedProvider.getKeepAliveTimeoutDuration()) + .isEqualTo(GrpcChannelEndpointCache.ROUTED_KEEPALIVE_TIMEOUT); } finally { cache.shutdown(); }