diff --git a/pinot-common/src/main/java/org/apache/pinot/common/datatable/StatMap.java b/pinot-common/src/main/java/org/apache/pinot/common/datatable/StatMap.java index 9754a0a4dffe..1fd299e99393 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/datatable/StatMap.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/datatable/StatMap.java @@ -280,7 +280,7 @@ public ObjectNode asJson() { if (key.includeDefaultInJson()) { node.put(key.getStatName(), false); } - } else { + } else if (key.includeInJson()) { node.put(key.getStatName(), (boolean) value); } break; @@ -289,7 +289,7 @@ public ObjectNode asJson() { if (key.includeDefaultInJson()) { node.put(key.getStatName(), 0); } - } else { + } else if (key.includeInJson()) { node.put(key.getStatName(), (int) value); } break; @@ -298,7 +298,7 @@ public ObjectNode asJson() { if (key.includeDefaultInJson()) { node.put(key.getStatName(), 0L); } - } else { + } else if (key.includeInJson()) { node.put(key.getStatName(), (long) value); } break; @@ -307,7 +307,7 @@ public ObjectNode asJson() { if (key.includeDefaultInJson()) { node.put(key.getStatName(), ""); } - } else { + } else if (key.includeInJson()) { node.put(key.getStatName(), (String) value); } break; @@ -502,6 +502,10 @@ default boolean includeDefaultInJson() { return false; } + default boolean includeInJson() { + return true; + } + static int minPositive(int value1, int value2) { if (value1 == 0 && value2 >= 0) { return value2; diff --git a/pinot-common/src/main/java/org/apache/pinot/common/metrics/BrokerGauge.java b/pinot-common/src/main/java/org/apache/pinot/common/metrics/BrokerGauge.java index 4eba8fe001c7..282cd5e962ea 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/metrics/BrokerGauge.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/metrics/BrokerGauge.java @@ -75,6 +75,8 @@ public enum BrokerGauge implements AbstractMetrics.Gauge { * Per-server adaptive routing stats exported as metrics (MSE / multi-stage engine). */ ADAPTIVE_SERVER_MSE_NUM_IN_FLIGHT_REQUESTS("adaptiveServerMseNumInFlightRequests", false), + ADAPTIVE_SERVER_MSE_LATENCY_EMA("adaptiveServerMseLatencyEma", false), + ADAPTIVE_SERVER_MSE_HYBRID_SCORE("adaptiveServerMseHybridScore", false), /** * The queue size of ServerRoutingStatsManager main executor service. diff --git a/pinot-core/src/main/java/org/apache/pinot/core/transport/server/routing/stats/ServerRoutingStatsManager.java b/pinot-core/src/main/java/org/apache/pinot/core/transport/server/routing/stats/ServerRoutingStatsManager.java index fe85d62f761a..585f39dced8d 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/transport/server/routing/stats/ServerRoutingStatsManager.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/transport/server/routing/stats/ServerRoutingStatsManager.java @@ -29,7 +29,6 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import javax.annotation.Nullable; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.metrics.BrokerGauge; @@ -452,19 +451,17 @@ private void exportStatsAsMetrics() { BrokerGauge.ADAPTIVE_SERVER_NUM_IN_FLIGHT_REQUESTS, BrokerGauge.ADAPTIVE_SERVER_LATENCY_EMA, BrokerGauge.ADAPTIVE_SERVER_HYBRID_SCORE); - // TODO: Export MSE latency stats once we support it exportStatsForMap(_mseServerQueryStatsMap, "server.", BrokerGauge.ADAPTIVE_SERVER_MSE_NUM_IN_FLIGHT_REQUESTS, - null, - null); + BrokerGauge.ADAPTIVE_SERVER_MSE_LATENCY_EMA, + BrokerGauge.ADAPTIVE_SERVER_MSE_HYBRID_SCORE); } catch (Exception e) { LOGGER.error("Exception caught while exporting routing stats as metrics.", e); } } private void exportStatsForMap(ConcurrentHashMap statsMap, String tagPrefix, - BrokerGauge numInFlightGauge, @Nullable BrokerGauge latencyEmaGauge, - @Nullable BrokerGauge hybridScoreGauge) { + BrokerGauge numInFlightGauge, BrokerGauge latencyEmaGauge, BrokerGauge hybridScoreGauge) { for (Map.Entry entry : statsMap.entrySet()) { String serverInstanceId = entry.getKey(); ServerRoutingStatsEntry stats = entry.getValue(); @@ -484,12 +481,8 @@ private void exportStatsForMap(ConcurrentHashMap constructDispatchablePlanFragmentM if (dispatchablePlanMetadata.getTimeBoundaryInfo() != null) { dispatchablePlanFragment.setTimeBoundaryInfo(dispatchablePlanMetadata.getTimeBoundaryInfo()); } + + PlanFragment planFrag = dispatchablePlanFragment.getPlanFragment(); + PlanNode root = planFrag != null ? planFrag.getFragmentRoot() : null; + FragmentType fragmentType = FragmentType.classify(root, + !dispatchablePlanMetadata.getScannedTables().isEmpty(), _dispatchablePlanMetadataMap); + if (fragmentType != null) { + dispatchablePlanFragment.setFragmentType(fragmentType); + } } return dispatchablePlanFragmentMap; } + private Map createDispatchablePlanFragmentMap(PlanFragment planFragmentRoot) { HashMap result = Maps.newHashMapWithExpectedSize(_dispatchablePlanMetadataMap.size()); diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanFragment.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanFragment.java index 4e3798f5a1fe..bfe1941be344 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanFragment.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanFragment.java @@ -49,6 +49,9 @@ public class DispatchablePlanFragment { // used for passing custom properties to build StageMetadata on the server. private final Map _customProperties; + // Broker-only classification — not shipped to servers. + private FragmentType _fragmentType; + public DispatchablePlanFragment(PlanFragment planFragment) { this(planFragment, new ArrayList<>(), new HashMap<>(), new HashMap<>()); } @@ -132,4 +135,12 @@ public void setServerInstanceToWorkerIdMap(Map getServerInstances() { return _serverInstanceToWorkerIdMap.keySet(); } + + public FragmentType getFragmentType() { + return _fragmentType; + } + + public void setFragmentType(FragmentType fragmentType) { + _fragmentType = fragmentType; + } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/FragmentType.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/FragmentType.java new file mode 100644 index 000000000000..d13395af009b --- /dev/null +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/FragmentType.java @@ -0,0 +1,74 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.planner.physical; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.apache.calcite.rel.RelDistribution; +import org.apache.pinot.query.planner.plannode.MailboxReceiveNode; +import org.apache.pinot.query.planner.plannode.MailboxSendNode; +import org.apache.pinot.query.planner.plannode.PlanNode; + + +public enum FragmentType { + // Scans a fact table; non-SINGLETON sender + LEAF, + // Dim-table leaf via SINGLETON exchange, with a leaf upstream sender + SINGLETON_LEAF, + // Non-scanning stage (join, agg, sort); timings reflect upstream cascade delays + INTERMEDIATE; + + public static FragmentType classify(PlanNode root, boolean hasScannedTable, + Map metadataMap) { + if (!(root instanceof MailboxSendNode)) { + return null; + } + if (!hasScannedTable) { + return INTERMEDIATE; + } + List singletonSenderStageIds = getSingletonReceiveSenderStageIds(root); + if (singletonSenderStageIds.isEmpty()) { + return LEAF; + } + boolean allSendersAreLeaves = singletonSenderStageIds.stream().allMatch(senderStageId -> { + DispatchablePlanMetadata senderMeta = metadataMap.get(senderStageId); + return senderMeta != null && !senderMeta.getScannedTables().isEmpty(); + }); + return allSendersAreLeaves ? SINGLETON_LEAF : INTERMEDIATE; + } + + static List getSingletonReceiveSenderStageIds(PlanNode node) { + List result = new ArrayList<>(); + collectSingletonReceiveSenderStageIds(node, result); + return result; + } + + private static void collectSingletonReceiveSenderStageIds(PlanNode node, List result) { + if (node instanceof MailboxReceiveNode) { + MailboxReceiveNode receiveNode = (MailboxReceiveNode) node; + if (receiveNode.getDistributionType() == RelDistribution.Type.SINGLETON) { + result.add(receiveNode.getSenderStageId()); + } + } + for (PlanNode input : node.getInputs()) { + collectSingletonReceiveSenderStageIds(input, result); + } + } +} diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/FragmentTypeTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/FragmentTypeTest.java new file mode 100644 index 000000000000..21d1c1b7282a --- /dev/null +++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/physical/FragmentTypeTest.java @@ -0,0 +1,177 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.planner.physical; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.calcite.rel.RelDistribution; +import org.apache.pinot.calcite.rel.logical.PinotRelExchangeType; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.query.planner.plannode.FilterNode; +import org.apache.pinot.query.planner.plannode.MailboxReceiveNode; +import org.apache.pinot.query.planner.plannode.MailboxSendNode; +import org.apache.pinot.query.planner.plannode.PlanNode; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; + + +public class FragmentTypeTest { + + private static final DataSchema SCHEMA = new DataSchema( + new String[]{"col1"}, new ColumnDataType[]{ColumnDataType.INT}); + + @Test + public void testNonMailboxSendRootReturnsNull() { + PlanNode filterNode = new FilterNode(0, SCHEMA, null, List.of(), null); + assertNull(FragmentType.classify(filterNode, true, Map.of())); + } + + @Test + public void testNullRootReturnsNull() { + assertNull(FragmentType.classify(null, true, Map.of())); + } + + @Test + public void testNoScannedTableReturnsIntermediate() { + MailboxSendNode sendNode = new MailboxSendNode(0, SCHEMA, List.of(), + 1, PinotRelExchangeType.STREAMING, RelDistribution.Type.HASH_DISTRIBUTED, + List.of(0), false, null, false, "murmur"); + assertEquals(FragmentType.classify(sendNode, false, Map.of()), FragmentType.INTERMEDIATE); + } + + @Test + public void testLeafWithNoSingletonReceive() { + // A send node with no receive children (pure leaf scan) + MailboxSendNode sendNode = new MailboxSendNode(1, SCHEMA, List.of(), + 0, PinotRelExchangeType.STREAMING, RelDistribution.Type.HASH_DISTRIBUTED, + List.of(0), false, null, false, "murmur"); + assertEquals(FragmentType.classify(sendNode, true, Map.of()), FragmentType.LEAF); + } + + @Test + public void testLeafWithNonSingletonReceive() { + // A receive node with HASH distribution (not SINGLETON) should not affect LEAF classification + MailboxReceiveNode receiveNode = new MailboxReceiveNode(1, SCHEMA, 2, + PinotRelExchangeType.STREAMING, RelDistribution.Type.HASH_DISTRIBUTED, + List.of(0), null, false, false, null); + MailboxSendNode sendNode = new MailboxSendNode(1, SCHEMA, List.of(receiveNode), + 0, PinotRelExchangeType.STREAMING, RelDistribution.Type.HASH_DISTRIBUTED, + List.of(0), false, null, false, "murmur"); + assertEquals(FragmentType.classify(sendNode, true, Map.of()), FragmentType.LEAF); + } + + @Test + public void testSingletonLeafWhenSenderIsLeaf() { + // Stage 1 has a SINGLETON receive from stage 2, and stage 2 has scanned tables + MailboxReceiveNode receiveNode = new MailboxReceiveNode(1, SCHEMA, 2, + PinotRelExchangeType.STREAMING, RelDistribution.Type.SINGLETON, + null, null, false, false, null); + MailboxSendNode sendNode = new MailboxSendNode(1, SCHEMA, List.of(receiveNode), + 0, PinotRelExchangeType.STREAMING, RelDistribution.Type.HASH_DISTRIBUTED, + List.of(0), false, null, false, "murmur"); + + Map metadataMap = new HashMap<>(); + DispatchablePlanMetadata senderMeta = new DispatchablePlanMetadata(); + senderMeta.addScannedTable("dimTable"); + metadataMap.put(2, senderMeta); + + assertEquals(FragmentType.classify(sendNode, true, metadataMap), FragmentType.SINGLETON_LEAF); + } + + @Test + public void testIntermediateWhenSingletonSenderHasNoScannedTables() { + // Stage 1 has a SINGLETON receive from stage 2, but stage 2 has no scanned tables + MailboxReceiveNode receiveNode = new MailboxReceiveNode(1, SCHEMA, 2, + PinotRelExchangeType.STREAMING, RelDistribution.Type.SINGLETON, + null, null, false, false, null); + MailboxSendNode sendNode = new MailboxSendNode(1, SCHEMA, List.of(receiveNode), + 0, PinotRelExchangeType.STREAMING, RelDistribution.Type.HASH_DISTRIBUTED, + List.of(0), false, null, false, "murmur"); + + Map metadataMap = new HashMap<>(); + DispatchablePlanMetadata senderMeta = new DispatchablePlanMetadata(); + metadataMap.put(2, senderMeta); + + assertEquals(FragmentType.classify(sendNode, true, metadataMap), FragmentType.INTERMEDIATE); + } + + @Test + public void testIntermediateWhenSingletonSenderMetadataMissing() { + // Stage 1 has a SINGLETON receive from stage 2, but stage 2 metadata is null + MailboxReceiveNode receiveNode = new MailboxReceiveNode(1, SCHEMA, 2, + PinotRelExchangeType.STREAMING, RelDistribution.Type.SINGLETON, + null, null, false, false, null); + MailboxSendNode sendNode = new MailboxSendNode(1, SCHEMA, List.of(receiveNode), + 0, PinotRelExchangeType.STREAMING, RelDistribution.Type.HASH_DISTRIBUTED, + List.of(0), false, null, false, "murmur"); + + assertEquals(FragmentType.classify(sendNode, true, Map.of()), FragmentType.INTERMEDIATE); + } + + @Test + public void testMultipleSingletonReceivesAllLeaves() { + // Two SINGLETON receives, both senders have scanned tables + MailboxReceiveNode receive1 = new MailboxReceiveNode(1, SCHEMA, 2, + PinotRelExchangeType.STREAMING, RelDistribution.Type.SINGLETON, + null, null, false, false, null); + MailboxReceiveNode receive2 = new MailboxReceiveNode(1, SCHEMA, 3, + PinotRelExchangeType.STREAMING, RelDistribution.Type.SINGLETON, + null, null, false, false, null); + MailboxSendNode sendNode = new MailboxSendNode(1, SCHEMA, List.of(receive1, receive2), + 0, PinotRelExchangeType.STREAMING, RelDistribution.Type.HASH_DISTRIBUTED, + List.of(0), false, null, false, "murmur"); + + Map metadataMap = new HashMap<>(); + DispatchablePlanMetadata meta2 = new DispatchablePlanMetadata(); + meta2.addScannedTable("dim1"); + metadataMap.put(2, meta2); + DispatchablePlanMetadata meta3 = new DispatchablePlanMetadata(); + meta3.addScannedTable("dim2"); + metadataMap.put(3, meta3); + + assertEquals(FragmentType.classify(sendNode, true, metadataMap), FragmentType.SINGLETON_LEAF); + } + + @Test + public void testMultipleSingletonReceivesOnlyOneIsLeaf() { + // Two SINGLETON receives; one sender has scanned tables, the other does not + MailboxReceiveNode receive1 = new MailboxReceiveNode(1, SCHEMA, 2, + PinotRelExchangeType.STREAMING, RelDistribution.Type.SINGLETON, + null, null, false, false, null); + MailboxReceiveNode receive2 = new MailboxReceiveNode(1, SCHEMA, 3, + PinotRelExchangeType.STREAMING, RelDistribution.Type.SINGLETON, + null, null, false, false, null); + MailboxSendNode sendNode = new MailboxSendNode(1, SCHEMA, List.of(receive1, receive2), + 0, PinotRelExchangeType.STREAMING, RelDistribution.Type.HASH_DISTRIBUTED, + List.of(0), false, null, false, "murmur"); + + Map metadataMap = new HashMap<>(); + DispatchablePlanMetadata meta2 = new DispatchablePlanMetadata(); + meta2.addScannedTable("dim1"); + metadataMap.put(2, meta2); + DispatchablePlanMetadata meta3 = new DispatchablePlanMetadata(); + metadataMap.put(3, meta3); + + assertEquals(FragmentType.classify(sendNode, true, metadataMap), FragmentType.INTERMEDIATE); + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/BaseMailboxReceiveOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/BaseMailboxReceiveOperator.java index 635fd3ace2aa..00de16c67c75 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/BaseMailboxReceiveOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/BaseMailboxReceiveOperator.java @@ -20,7 +20,9 @@ import com.google.common.base.Preconditions; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import javax.annotation.Nullable; import org.apache.calcite.rel.RelDistribution; import org.apache.pinot.common.datatable.StatMap; @@ -28,13 +30,17 @@ import org.apache.pinot.query.mailbox.ReceivingMailbox; import org.apache.pinot.query.planner.physical.MailboxIdUtils; import org.apache.pinot.query.planner.plannode.MailboxReceiveNode; +import org.apache.pinot.query.routing.MailboxInfo; import org.apache.pinot.query.routing.MailboxInfos; import org.apache.pinot.query.runtime.blocks.MseBlock; import org.apache.pinot.query.runtime.operator.utils.AsyncStream; import org.apache.pinot.query.runtime.operator.utils.BlockingMultiStreamConsumer; import org.apache.pinot.query.runtime.plan.MultiStageQueryStats; import org.apache.pinot.query.runtime.plan.OpChainExecutionContext; +import org.apache.pinot.query.service.dispatch.AdaptiveRoutingUpstreamTimings; import org.apache.pinot.spi.query.QueryThreadContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** @@ -47,8 +53,11 @@ * When exchangeType is non-Singleton, we pull from each instance in round-robin way to get matched mailbox content. */ public abstract class BaseMailboxReceiveOperator extends MultiStageOperator { + private static final Logger LOGGER = LoggerFactory.getLogger(BaseMailboxReceiveOperator.class); + protected final MailboxService _mailboxService; protected final RelDistribution.Type _distributionType; + protected final int _senderStageId; protected final List _mailboxIds; protected final BlockingMultiStreamConsumer.OfMseBlock _multiConsumer; protected final List> _receivingStats; @@ -63,11 +72,11 @@ public BaseMailboxReceiveOperator(OpChainExecutionContext context, MailboxReceiv _distributionType = distributionType; long requestId = context.getRequestId(); - int senderStageId = node.getSenderStageId(); - MailboxInfos mailboxInfos = context.getWorkerMetadata().getMailboxInfosMap().get(senderStageId); + _senderStageId = node.getSenderStageId(); + MailboxInfos mailboxInfos = context.getWorkerMetadata().getMailboxInfosMap().get(_senderStageId); if (mailboxInfos != null) { _mailboxIds = - MailboxIdUtils.toMailboxIds(requestId, senderStageId, mailboxInfos.getMailboxInfos(), context.getStageId(), + MailboxIdUtils.toMailboxIds(requestId, _senderStageId, mailboxInfos.getMailboxInfos(), context.getStageId(), context.getWorkerId()); int numMailboxes = _mailboxIds.size(); List asyncStreams = new ArrayList<>(numMailboxes); @@ -79,12 +88,18 @@ public BaseMailboxReceiveOperator(OpChainExecutionContext context, MailboxReceiv asyncStreams.add(asyncStream); _receivingStats.add(asyncStream._mailbox.getStatMap()); } - _multiConsumer = new BlockingMultiStreamConsumer.OfMseBlock(context, asyncStreams, senderStageId); + boolean collectTiming = "true".equals( + context.getOpChainMetadata().get(AdaptiveRoutingUpstreamTimings.COLLECT_UPSTREAM_TIMING_KEY)); + Map streamIdToSenderKey = collectTiming + ? buildStreamIdToSenderKey(mailboxInfos, asyncStreams) + : Map.of(); + _multiConsumer = new BlockingMultiStreamConsumer.OfMseBlock( + context, asyncStreams, _senderStageId, streamIdToSenderKey); } else { // TODO: Revisit if we should throw exception here. _mailboxIds = List.of(); _receivingStats = List.of(); - _multiConsumer = new BlockingMultiStreamConsumer.OfMseBlock(context, List.of(), senderStageId); + _multiConsumer = new BlockingMultiStreamConsumer.OfMseBlock(context, List.of(), _senderStageId); } _statMap.merge(StatKey.FAN_IN, _mailboxIds.size()); } @@ -124,13 +139,32 @@ public MultiStageQueryStats calculateUpstreamStats() { @Override public StatMap copyStatMaps() { - return new StatMap<>(_statMap); + StatMap copy = new StatMap<>(_statMap); + // On the cancel/timeout path (where onEos() never ran), include ALL known senders: + // completed senders get their actual measured latency, pending senders get the current + // wall-clock elapsed time so the adaptive selector can identify slow servers. + mergeSenderTimingsInto(copy, _multiConsumer.getSenderElapsedMsIncludingPending()); + return copy; } protected void onEos() { for (StatMap receivingStats : _receivingStats) { addReceivingStats(receivingStats); } + mergeSenderTimingsInto(_statMap, _multiConsumer.getSenderElapsedMs()); + LOGGER.debug("==[UPSTREAM_TIMING]== stage {} onEos: merged sender timings", _senderStageId); + } + + /** + * Encodes per-sender elapsed-time data and merges it into {@code target}. + */ + private void mergeSenderTimingsInto(StatMap target, Map senderElapsedMs) { + if (!senderElapsedMs.isEmpty()) { + String encoded = AdaptiveRoutingUpstreamTimings.encode(senderElapsedMs); + if (encoded != null) { + target.merge(StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, encoded); + } + } } @Override @@ -150,6 +184,27 @@ public void registerExecution(long time, int numRows, long memoryUsedBytes, long _statMap.merge(StatKey.GC_TIME_MS, gcTimeMs); } + /** + * Builds a map from stream ID to sender key (hostname|mailboxPort) for per-sender elapsed-time tracking. + * Iterates {@link MailboxInfos} and {@code asyncStreams} in parallel (same order as + * {@link MailboxIdUtils#toMailboxIds}), mapping each stream's ID to its server's sender key. + */ + private static Map buildStreamIdToSenderKey(MailboxInfos mailboxInfos, + List> asyncStreams) { + Map streamIdToSenderKey = new HashMap<>(asyncStreams.size()); + int i = 0; + for (MailboxInfo info : mailboxInfos.getMailboxInfos()) { + String key = AdaptiveRoutingUpstreamTimings.senderKey(info.getHostname(), info.getPort()); + for (int ignored : info.getWorkerIds()) { + streamIdToSenderKey.put(asyncStreams.get(i).getId(), key); + i++; + } + } + Preconditions.checkState(i == asyncStreams.size(), + "MailboxInfos worker count (%s) does not match asyncStreams size (%s)", i, asyncStreams.size()); + return streamIdToSenderKey; + } + private void addReceivingStats(StatMap from) { _statMap.merge(StatKey.RAW_MESSAGES, from.getInt(ReceivingMailbox.StatKey.DESERIALIZED_MESSAGES)); _statMap.merge(StatKey.DESERIALIZED_BYTES, from.getLong(ReceivingMailbox.StatKey.DESERIALIZED_BYTES)); @@ -266,7 +321,35 @@ public int merge(int value1, int value2) { /** * Time spent on GC while this operator or its children in the same stage were running. */ - GC_TIME_MS(StatMap.Type.LONG); + GC_TIME_MS(StatMap.Type.LONG), + /** + * Per-upstream-sender wall-clock elapsed time (consumer construction -> EOS arrival), + * encoded as a semicolon-separated list of {@code "hostname|mailboxPort=elapsedMs"} pairs. + * + *

Populated by every {@link BaseMailboxReceiveOperator}. Elapsed time is anchored to consumer + * construction (after pipeline breakers), avoiding cross-host clock skew. When multiple workers on + * the same stage contribute stats for the same sender, the merge function takes the maximum elapsed + * time so the slowest observation is preserved for the same sending server. + * + *

On the broker side, {@code QueryDispatcher#extractMaxTimingsPerInstance} reads this stat and calls + * {@link org.apache.pinot.core.transport.server.routing.stats.ServerRoutingStatsManager + * #recordStatsUponResponseArrival} + * for servers that the broker did not track directly (e.g. S2 in a join query routed through S1). + * Senders with no observed timing are omitted from the encoding. + */ + UPSTREAM_SERVER_RESPONSE_TIMES_MS(StatMap.Type.STRING) { + @Override + public String merge(@Nullable String value1, @Nullable String value2) { + return AdaptiveRoutingUpstreamTimings.mergeEncodings(value1, value2); + } + + @Override + public boolean includeInJson() { + // Excluded from stage stats in the query response because the encoded timing string can be large + // (one entry per upstream sender). It is only needed broker-side for adaptive routing. + return false; + } + }; private final StatMap.Type _type; diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/BlockingMultiStreamConsumer.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/BlockingMultiStreamConsumer.java index 741aa1ee2dd1..a87bcb86f24b 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/BlockingMultiStreamConsumer.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/BlockingMultiStreamConsumer.java @@ -18,11 +18,16 @@ */ package org.apache.pinot.query.runtime.operator.utils; +import com.google.common.annotations.VisibleForTesting; import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; +import java.util.function.LongSupplier; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.pinot.query.mailbox.ReceivingMailbox; @@ -92,10 +97,12 @@ public BlockingMultiStreamConsumer(Object id, long deadlineMs, List completedStream); /** * This method is called whenever a timeout is reached while reading an element. @@ -277,7 +284,7 @@ private E readDroppingSuccessEos() { LOGGER.debug("==[RECEIVE]== EOS received : " + _id + " in mailbox: " + removed.getId() + " (mailboxes alive: " + ids + ")"); } - onMailboxSuccess(block); + onMailboxSuccess(block, removed); block = readBlockOrNull(); } @@ -345,13 +352,35 @@ public static class OfMseBlock extends BlockingMultiStreamConsumer max elapsed time (ms) seen for that sender + private final Map _senderElapsedMs = new ConcurrentHashMap<>(); + // Maps stream ID -> sender key (hostname|mailboxPort) + private final Map _streamIdToSenderKey; public OfMseBlock(OpChainExecutionContext context, List> asyncProducers, int senderStageId) { + this(context, asyncProducers, senderStageId, Map.of()); + } + + public OfMseBlock(OpChainExecutionContext context, + List> asyncProducers, int senderStageId, + Map streamIdToSenderKey) { + this(context, asyncProducers, senderStageId, streamIdToSenderKey, System::currentTimeMillis); + } + + @VisibleForTesting + OfMseBlock(OpChainExecutionContext context, + List> asyncProducers, int senderStageId, + Map streamIdToSenderKey, LongSupplier clock) { super(context.getId(), context.getPassiveDeadlineMs(), asyncProducers); _stageId = context.getStageId(); _stats = MultiStageQueryStats.emptyStats(context.getStageId()); _senderStageId = senderStageId; + _streamIdToSenderKey = streamIdToSenderKey; + _clock = clock; + _startTimeMs = clock.getAsLong(); } @Override @@ -365,8 +394,46 @@ protected boolean isSuccess(ReceivingMailbox.MseBlockWithStats element) { } @Override - protected void onMailboxSuccess(ReceivingMailbox.MseBlockWithStats element) { + protected void onMailboxSuccess(ReceivingMailbox.MseBlockWithStats element, + AsyncStream completedStream) { mergeStats(element); + if (_streamIdToSenderKey.isEmpty()) { + return; + } + String senderKey = _streamIdToSenderKey.get(completedStream.getId()); + if (senderKey == null) { + LOGGER.warn("==[UPSTREAM_TIMING]== stage {} stream {} has no senderKey mapping, skipping timing", + _senderStageId, completedStream.getId()); + return; + } + long elapsedMs = _clock.getAsLong() - _startTimeMs; + _senderElapsedMs.merge(senderKey, elapsedMs, Math::max); + LOGGER.debug("==[UPSTREAM_TIMING]== stage {} sender {} EOS at {}ms since receiver start", + _senderStageId, senderKey, elapsedMs); + } + + /** + * Returns the per-sender timing map (senderKey -> max elapsedMs since consumer construction). + * May be empty if no timing data was collected. + */ + public Map getSenderElapsedMs() { + return Collections.unmodifiableMap(_senderElapsedMs); + } + + /** + * Returns per-sender timing with pending (non-completing) senders injected at the current + * wall-clock elapsed time. Completed senders retain their actual measured latency. + */ + public Map getSenderElapsedMsIncludingPending() { + if (_streamIdToSenderKey.isEmpty()) { + return Collections.unmodifiableMap(_senderElapsedMs); + } + long elapsedMs = _clock.getAsLong() - _startTimeMs; + Map result = new HashMap<>(_senderElapsedMs); + for (String senderKey : _streamIdToSenderKey.values()) { + result.putIfAbsent(senderKey, elapsedMs); + } + return Collections.unmodifiableMap(result); } @Override diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AdaptiveRoutingStageClassification.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AdaptiveRoutingStageClassification.java new file mode 100644 index 000000000000..ae566498cb7f --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AdaptiveRoutingStageClassification.java @@ -0,0 +1,138 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.service.dispatch; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import org.apache.pinot.query.planner.PlanFragment; +import org.apache.pinot.query.planner.physical.DispatchablePlanFragment; +import org.apache.pinot.query.planner.physical.DispatchableSubPlan; +import org.apache.pinot.query.planner.physical.FragmentType; +import org.apache.pinot.query.planner.plannode.MailboxSendNode; +import org.apache.pinot.query.planner.plannode.PlanNode; +import org.apache.pinot.query.routing.QueryServerInstance; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Classifies a query plan's stages for upstream timing extraction. + * + *

Two kinds of stages are consulted: + *

    + *
  • Pure leaf receivers: stages that directly receive from a non-SINGLETON leaf.
  • + *
  • SINGLETON leaf stages receiving from another leaf: a leaf stage that both scans a + * dim table and receives all upstream data via a SINGLETON exchange, but ONLY when its upstream + * sender stage is also a leaf.
  • + *
+ * + *

Only pure leaf servers (non-SINGLETON) are eligible for EMA updates. Intermediate and + * SINGLETON leaf servers are excluded because their timings reflect cascade delays from upstream + * rather than their own scan performance. + */ +final class AdaptiveRoutingStageClassification { + private static final Logger LOGGER = LoggerFactory.getLogger(AdaptiveRoutingStageClassification.class); + + /** Maps "hostname|mailboxPort" -> instanceId for all known servers. */ + final Map _senderKeyToInstanceId; + /** Stage IDs whose UPSTREAM_SERVER_RESPONSE_TIMES_MS we trust. */ + final Set _trustedStageIds; + /** Instance IDs of pure leaf servers eligible for EMA updates in the fallback path. */ + final Set _trackedServers; + + private AdaptiveRoutingStageClassification(Map senderKeyToInstanceId, Set trustedStageIds, + Set trackedServers) { + _senderKeyToInstanceId = senderKeyToInstanceId; + _trustedStageIds = trustedStageIds; + _trackedServers = trackedServers; + } + + /** + * Derives trusted stages and tracked servers from pre-computed roles set at planning time. + * + *

For pure leaves: their receivers are trusted (unless also receiving from a non-leaf). + * For SINGLETON leaf trusted: the stage itself is trusted. + * Contamination filter: any stage that receives from a non-leaf is excluded from trusted. + */ + static AdaptiveRoutingStageClassification classify(DispatchableSubPlan plan) { + Map senderKeyToInstanceId = new HashMap<>(); + Set trustedStageIds = new HashSet<>(); + Set trackedServers = new HashSet<>(); + Set stagesReceivingFromNonLeaf = new HashSet<>(); + + for (DispatchablePlanFragment fragment : plan.getQueryStagesWithoutRoot()) { + Set fragmentServers = fragment.getServerInstanceToWorkerIdMap().keySet(); + for (QueryServerInstance server : fragmentServers) { + String key = AdaptiveRoutingUpstreamTimings.senderKey(server.getHostname(), server.getQueryMailboxPort()); + senderKeyToInstanceId.putIfAbsent(key, server.getInstanceId()); + } + + FragmentType role = fragment.getFragmentType(); + PlanFragment planFragment = fragment.getPlanFragment(); + PlanNode fragmentRoot = planFragment != null ? planFragment.getFragmentRoot() : null; + MailboxSendNode sendNode = fragmentRoot instanceof MailboxSendNode ? (MailboxSendNode) fragmentRoot : null; + + if (role == FragmentType.LEAF) { + // Leaf: receivers are trusted, servers are tracked. + if (sendNode != null) { + for (int receiverStageId : sendNode.getReceiverStageIds()) { + trustedStageIds.add(receiverStageId); + } + } + for (QueryServerInstance server : fragmentServers) { + trackedServers.add(server.getInstanceId()); + } + } else if (role == FragmentType.SINGLETON_LEAF) { + // SINGLETON leaf with leaf sender: this stage's own stats are trusted. + trustedStageIds.add(planFragment.getFragmentId()); + // Its receivers are contaminated (SINGLETON cascade). + if (sendNode != null) { + for (int receiverStageId : sendNode.getReceiverStageIds()) { + stagesReceivingFromNonLeaf.add(receiverStageId); + } + } + } else if (role == FragmentType.INTERMEDIATE && sendNode != null) { + // Intermediate stage: its receivers are contaminated by cascade delays. + for (int receiverStageId : sendNode.getReceiverStageIds()) { + stagesReceivingFromNonLeaf.add(receiverStageId); + } + } else if (role == null && sendNode != null) { + LOGGER.debug("Stage {} has null FragmentType; treating as INTERMEDIATE", + planFragment != null ? planFragment.getFragmentId() : "unknown"); + for (int receiverStageId : sendNode.getReceiverStageIds()) { + stagesReceivingFromNonLeaf.add(receiverStageId); + } + } + } + + // Exclude stages contaminated by non-leaf/SINGLETON senders. + trustedStageIds.removeAll(stagesReceivingFromNonLeaf); + + // Stage 0 (the broker reducer) is always trusted: its sender timings are measured directly + // by the multi-consumer, not propagated through intermediate stages. + trustedStageIds.add(0); + + LOGGER.debug("==[UPSTREAM_TIMING]== classifyStages: trustedStageIds={} trackedServers={} " + + "senderKeyToInstanceId.size={}", trustedStageIds, trackedServers.size(), senderKeyToInstanceId.size()); + + return new AdaptiveRoutingStageClassification(senderKeyToInstanceId, trustedStageIds, trackedServers); + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AdaptiveRoutingUpstreamTimings.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AdaptiveRoutingUpstreamTimings.java new file mode 100644 index 000000000000..5e5863385cae --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AdaptiveRoutingUpstreamTimings.java @@ -0,0 +1,125 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.service.dispatch; + +import java.util.HashMap; +import java.util.Map; +import javax.annotation.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Holds per-upstream-sender timing data for adaptive routing, encoded for transport via StatMap STRING keys. + * + *

Format: {@code "key1=elapsedMs1;key2=elapsedMs2"} + *

    + *
  • Entries are separated by {@code ';'}
  • + *
  • Key and value are separated by {@code '='}
  • + *
  • Keys are of the form {@code "hostname|mailboxPort"} (pipe cannot appear in DNS names or port numbers)
  • + *
  • elapsedMs values are non-negative longs
  • + *
+ * + *

This encoding is used in + * {@link org.apache.pinot.query.runtime.operator.BaseMailboxReceiveOperator.StatKey#UPSTREAM_SERVER_RESPONSE_TIMES_MS}. + */ +public class AdaptiveRoutingUpstreamTimings { + + private static final Logger LOGGER = LoggerFactory.getLogger(AdaptiveRoutingUpstreamTimings.class); + + public static final String COLLECT_UPSTREAM_TIMING_KEY = "collectUpstreamTiming"; + + static final char ENTRY_SEPARATOR = ';'; + static final char KV_SEPARATOR = '='; + + private AdaptiveRoutingUpstreamTimings() { + } + + /** + * Returns the key used to identify a sender in the timing map, given its hostname and mailbox port. + *

The {@code '|'} separator is chosen because it cannot appear in DNS hostnames or port numbers. + */ + public static String senderKey(String hostname, int mailboxPort) { + return hostname + "|" + mailboxPort; + } + + /** + * Encode a map of senderKey -> elapsedMs into a string. + * + * @return encoded string, or {@code null} if the map is empty (null = absent in StatMap) + */ + @Nullable + public static String encode(Map timings) { + if (timings.isEmpty()) { + return null; + } + StringBuilder sb = new StringBuilder(); + for (Map.Entry entry : timings.entrySet()) { + if (sb.length() > 0) { + sb.append(ENTRY_SEPARATOR); + } + sb.append(entry.getKey()).append(KV_SEPARATOR).append(entry.getValue()); + } + return sb.toString(); + } + + /** + * Decode a string (possibly null) into a mutable map of senderKey -> elapsedMs. + */ + public static Map decode(@Nullable String encoded) { + Map result = new HashMap<>(); + if (encoded == null || encoded.isEmpty()) { + return result; + } + int start = 0; + int len = encoded.length(); + while (start < len) { + int end = encoded.indexOf(ENTRY_SEPARATOR, start); + if (end < 0) { + end = len; + } + int eq = encoded.indexOf(KV_SEPARATOR, start); + if (eq > start && eq < end) { + String key = encoded.substring(start, eq); + try { + long value = Long.parseLong(encoded.substring(eq + 1, end)); + result.put(key, value); + } catch (NumberFormatException e) { + LOGGER.warn("Skipping malformed timing entry '{}': {}", encoded.substring(start, end), e.getMessage()); + } + } + start = end + 1; + } + return result; + } + + /** + * Merge two encoded timing strings, taking the max elapsedMs per senderKey. + * Either or both arguments may be null (treated as empty). + */ + @Nullable + public static String mergeEncodings(@Nullable String enc1, @Nullable String enc2) { + Map merged = decode(enc1); + Map incoming = decode(enc2); + for (Map.Entry entry : incoming.entrySet()) { + merged.merge(entry.getKey(), entry.getValue(), Math::max); + } + return encode(merged); + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java index 388607808536..7643e8dde84b 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java @@ -47,6 +47,7 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.LongSupplier; import javax.annotation.Nullable; import org.apache.calcite.runtime.PairList; import org.apache.commons.lang3.tuple.Pair; @@ -104,6 +105,22 @@ * {@code QueryDispatcher} dispatch a query to different workers. */ public class QueryDispatcher { + + static final class CancelOutcome { + static final CancelOutcome NONE = new CancelOutcome(null, Set.of()); + + @Nullable final MultiStageQueryStats _stats; + final Set _respondingServerIds; + + CancelOutcome(@Nullable MultiStageQueryStats stats, Set respondingServerIds) { + _stats = stats; + _respondingServerIds = respondingServerIds; + } + + boolean wasAttempted() { + return !_respondingServerIds.isEmpty(); + } + } private static final Logger LOGGER = LoggerFactory.getLogger(QueryDispatcher.class); private static final String PINOT_BROKER_QUERY_DISPATCHER_FORMAT = "multistage-query-dispatch-%d"; @@ -119,11 +136,13 @@ public class QueryDispatcher { private final Map> _serversByQuery; private final FailureDetector _failureDetector; private final Duration _cancelTimeout; + // Injectable for tests that need to manipulate time (e.g. to simulate a timeout without actually waiting). + private final LongSupplier _clock; public QueryDispatcher(MailboxService mailboxService, FailureDetector failureDetector, @Nullable TlsConfig tlsConfig, boolean enableCancellation, Duration cancelTimeout) { this(mailboxService, failureDetector, tlsConfig, enableCancellation, cancelTimeout, - DispatchClient.KeepAliveConfig.DISABLED); + DispatchClient.KeepAliveConfig.DISABLED, System::currentTimeMillis); } /// Overload that accepts gRPC keep-alive settings for broker dispatch channels. A non-positive `keepAliveTimeMs` @@ -132,11 +151,20 @@ public QueryDispatcher(MailboxService mailboxService, FailureDetector failureDet boolean enableCancellation, Duration cancelTimeout, int keepAliveTimeMs, int keepAliveTimeoutMs, boolean keepAliveWithoutCalls) { this(mailboxService, failureDetector, tlsConfig, enableCancellation, cancelTimeout, - new DispatchClient.KeepAliveConfig(keepAliveTimeMs, keepAliveTimeoutMs, keepAliveWithoutCalls)); + new DispatchClient.KeepAliveConfig(keepAliveTimeMs, keepAliveTimeoutMs, keepAliveWithoutCalls), + System::currentTimeMillis); + } + + @VisibleForTesting + QueryDispatcher(MailboxService mailboxService, FailureDetector failureDetector, @Nullable TlsConfig tlsConfig, + boolean enableCancellation, Duration cancelTimeout, LongSupplier clock) { + this(mailboxService, failureDetector, tlsConfig, enableCancellation, cancelTimeout, + DispatchClient.KeepAliveConfig.DISABLED, clock); } private QueryDispatcher(MailboxService mailboxService, FailureDetector failureDetector, @Nullable TlsConfig tlsConfig, - boolean enableCancellation, Duration cancelTimeout, DispatchClient.KeepAliveConfig keepAliveConfig) { + boolean enableCancellation, Duration cancelTimeout, DispatchClient.KeepAliveConfig keepAliveConfig, + LongSupplier clock) { _cancelTimeout = cancelTimeout; _mailboxService = mailboxService; _executorService = Executors.newFixedThreadPool(2 * Runtime.getRuntime().availableProcessors(), @@ -145,6 +173,7 @@ private QueryDispatcher(MailboxService mailboxService, FailureDetector failureDe _clientGrpcSslContext = initClientSslContext(tlsConfig); _keepAliveConfig = keepAliveConfig; _failureDetector = failureDetector; + _clock = clock; if (enableCancellation) { _serversByQuery = new ConcurrentHashMap<>(); @@ -175,22 +204,38 @@ public QueryResult submitAndReduce(RequestContext context, DispatchableSubPlan d /// in-flight request statistics into {@code statsManager} for use by the adaptive query router. /// When {@code statsManager} is non-null: ///

    - ///
  • Each leaf server is registered as having one more in-flight request via - /// {@link ServerRoutingStatsManager#recordStatsForQuerySubmission} after the fan-out begins.
  • - ///
  • After the full fan-out completes (or fails), each server is decremented via - /// {@link ServerRoutingStatsManager#recordStatsUponResponseArrival} with {@code latency = -1} - /// (no latency is recorded at this stage).
  • + ///
  • Each submitted server is registered as having one more in-flight request via + /// {@link ServerRoutingStatsManager#recordStatsForQuerySubmission} after the fan-out succeeds.
  • + ///
  • For servers that reported leaf-stage timing via {@code UPSTREAM_SERVER_RESPONSE_TIMES_MS} stats, + /// {@link ServerRoutingStatsManager#recordStatsUponResponseArrival} is called with measured elapsed time. + ///
  • For all submitted servers not already recorded by the upstream timing extraction + /// (e.g. intermediate-stage servers, or any server when the reduce phase fails), + /// the finally block calls {@code recordStatsUponResponseArrival} with the total wall-clock + /// elapsed time since dispatch.
  • ///
- /// TODO: Replace the coarse end-of-fanout decrement with per-sender arrival once per-sender EOS - /// interception is in place, and record real leaf-stage latency at that point. public QueryResult submitAndReduce(RequestContext context, DispatchableSubPlan dispatchableSubPlan, long timeoutMs, Map queryOptions, @Nullable ServerRoutingStatsManager statsManager) throws Exception { long requestId = context.getRequestId(); Set servers = new HashSet<>(); - // Tracks servers where recordStatsForQuerySubmission was actually called, so the finally block only - // decrements servers that were incremented — guarding against a partial failure in submit(). Set incrementedServers = new HashSet<>(); + long submitTimeMs = _clock.getAsLong(); + QueryResult result = null; + CancelOutcome cancelOutcome = CancelOutcome.NONE; + + AdaptiveRoutingStageClassification classification = null; + if (statsManager != null) { + classification = AdaptiveRoutingStageClassification.classify(dispatchableSubPlan); + dispatchableSubPlan.getQueryStageMap().get(0).getCustomProperties() + .put(AdaptiveRoutingUpstreamTimings.COLLECT_UPSTREAM_TIMING_KEY, "true"); + for (DispatchablePlanFragment fragment : dispatchableSubPlan.getQueryStagesWithoutRoot()) { + int stageId = fragment.getPlanFragment().getFragmentId(); + if (classification._trustedStageIds.contains(stageId)) { + fragment.getCustomProperties().put(AdaptiveRoutingUpstreamTimings.COLLECT_UPSTREAM_TIMING_KEY, "true"); + } + } + } + try { submit(requestId, dispatchableSubPlan, timeoutMs, servers, queryOptions); // The SSE engine increments before `submit`, but here we increment after because `submit` populates @@ -202,22 +247,32 @@ public QueryResult submitAndReduce(RequestContext context, DispatchableSubPlan d incrementedServers.add(server); } } - QueryResult result = runReducer(dispatchableSubPlan, queryOptions, _mailboxService); + result = runReducer(dispatchableSubPlan, queryOptions, _mailboxService); if (result.getProcessingException() != null) { - cancel(requestId); + cancel(requestId, servers); } return result; } catch (Exception ex) { - return tryRecover(context.getRequestId(), servers, ex); + Pair recovered = tryRecover(context.getRequestId(), servers, ex); + result = recovered.getLeft(); + cancelOutcome = recovered.getRight(); + return result; } catch (Throwable e) { // TODO: Consider always cancel when it returns (early terminate) cancel(requestId); throw e; } finally { if (statsManager != null) { - for (QueryServerInstance server : incrementedServers) { - statsManager.recordStatsUponResponseArrival(requestId, server.getInstanceId(), -1); + Map knownTimings = Map.of(); + if (result != null && !incrementedServers.isEmpty() && !classification._trustedStageIds.isEmpty()) { + try { + knownTimings = extractMaxTimingsPerInstance(result, classification, requestId, cancelOutcome); + } catch (Exception e) { + LOGGER.warn("Failed to apply upstream timings for request {}", requestId, e); + } } + recordPerServerLatencies(statsManager, requestId, incrementedServers, classification, knownTimings, + cancelOutcome, _clock.getAsLong() - submitTimeMs); } if (isQueryCancellationEnabled()) { _serversByQuery.remove(requestId); @@ -225,11 +280,35 @@ public QueryResult submitAndReduce(RequestContext context, DispatchableSubPlan d } } + private static void recordPerServerLatencies(ServerRoutingStatsManager statsManager, long requestId, + Set incrementedServers, AdaptiveRoutingStageClassification classification, + Map knownTimings, CancelOutcome cancelOutcome, long elapsedMs) { + for (QueryServerInstance server : incrementedServers) { + String id = server.getInstanceId(); + long latency; + if (knownTimings.containsKey(id)) { + // Tier 1: actual upstream timing extracted — use it for tracked (leaf) servers, -1 otherwise. + latency = classification._trackedServers.contains(id) ? knownTimings.get(id) : -1L; + } else if (classification._trackedServers.contains(id) && !knownTimings.isEmpty()) { + // Tier 2: tracked leaf server whose timing is missing while other servers had data — mark degraded. + latency = elapsedMs; + } else if (cancelOutcome.wasAttempted() && !cancelOutcome._respondingServerIds.contains(id)) { + // Tier 3: cancel was attempted but this server didn't respond — mark degraded. + latency = elapsedMs; + } else { + // Tier 4: no timing data, but server is responsive (or cancel wasn't attempted). + latency = -1L; + } + LOGGER.debug("==[UPSTREAM_TIMING]== request {} recording server {} latency={}ms", requestId, id, latency); + statsManager.recordStatsUponResponseArrival(requestId, id, latency); + } + } + /// Tries to recover from an exception thrown during query dispatching. /// /// [QueryException] and [TimeoutException] are handled by returning a [QueryResult] with the error code and stats, /// while other exceptions are not known, so they are directly rethrown. - private QueryResult tryRecover(long requestId, Set servers, Exception ex) + private Pair tryRecover(long requestId, Set servers, Exception ex) throws Exception { if (servers.isEmpty()) { throw ex; @@ -250,12 +329,12 @@ private QueryResult tryRecover(long requestId, Set servers, // in case of known exceptions (timeout or query exception), we need can build here the erroneous QueryResult // that include the stats. LOGGER.warn("Query failed with a known exception. Trying to cancel the other opchains"); - MultiStageQueryStats stats = cancelWithStats(requestId, servers); - if (stats == null) { + CancelOutcome outcome = cancelWithStats(requestId, servers); + if (outcome._stats == null) { throw ex; } QueryProcessingException processingException = new QueryProcessingException(errorCode, ex.getMessage()); - return new QueryResult(processingException, stats, 0L); + return Pair.of(new QueryResult(processingException, outcome._stats, 0L), outcome); } public List explain(RequestContext context, DispatchablePlanFragment fragment, long timeoutMs, @@ -343,6 +422,73 @@ private boolean isQueryCancellationEnabled() { return _serversByQuery != null; } + + /** + * Extracts the maximum observed latency per instance from consulted stages' stats. + * Takes the maximum when the same server appears in multiple consulted stages. + */ + @VisibleForTesting + static Map extractMaxTimingsPerInstance(QueryResult result, + AdaptiveRoutingStageClassification classification, long requestId, + CancelOutcome cancelOutcome) { + Map maxTimingPerInstance = new HashMap<>(); + List queryStatsList = result.getQueryStats(); + for (int stageIdx = 0; stageIdx < queryStatsList.size(); stageIdx++) { + if (!classification._trustedStageIds.contains(stageIdx)) { + continue; + } + MultiStageQueryStats.StageStats.Closed stageStats = queryStatsList.get(stageIdx); + if (stageStats != null) { + extractTimingsFromStage(stageStats, stageIdx, classification, requestId, maxTimingPerInstance); + } + } + if (cancelOutcome._stats != null) { + MultiStageQueryStats cancelStats = cancelOutcome._stats; + for (int stageId = cancelStats.getCurrentStageId() + 1; stageId <= cancelStats.getMaxStageId(); stageId++) { + if (!classification._trustedStageIds.contains(stageId)) { + continue; + } + MultiStageQueryStats.StageStats.Closed stageStats = cancelStats.getUpstreamStageStats(stageId); + if (stageStats != null) { + extractTimingsFromStage(stageStats, stageId, classification, requestId, maxTimingPerInstance); + } + } + } + return maxTimingPerInstance; + } + + private static void extractTimingsFromStage(MultiStageQueryStats.StageStats.Closed stageStats, int stageIdx, + AdaptiveRoutingStageClassification classification, long requestId, Map maxTimingPerInstance) { + stageStats.forEach((opType, statMap) -> { + if (opType != MultiStageOperator.Type.MAILBOX_RECEIVE) { + return; + } + @SuppressWarnings("unchecked") + StatMap receiveStats = + (StatMap) statMap; + String encoded = + receiveStats.getString(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS); + if (encoded == null) { + LOGGER.debug("==[UPSTREAM_TIMING]== request {} consulted stage {} MAILBOX_RECEIVE has null timing", + requestId, stageIdx); + return; + } + LOGGER.debug("==[UPSTREAM_TIMING]== request {} consulted stage {} encoded timing: {}", + requestId, stageIdx, encoded); + Map timings = AdaptiveRoutingUpstreamTimings.decode(encoded); + for (Map.Entry entry : timings.entrySet()) { + String instanceId = classification._senderKeyToInstanceId.get(entry.getKey()); + if (instanceId != null) { + maxTimingPerInstance.merge(instanceId, entry.getValue(), Math::max); + } else { + LOGGER.debug("==[UPSTREAM_TIMING]== request {} senderKey={} not found in known servers, skipping", + requestId, entry.getKey()); + } + } + }); + } + + private void execute(long requestId, Set stagePlans, long timeoutMs, Map queryOptions, SendRequest sendRequest, Set serverInstancesOut, BiConsumer resultConsumer) @@ -537,9 +683,9 @@ private boolean cancel(long requestId, @Nullable Set server } @Nullable - private MultiStageQueryStats cancelWithStats(long requestId, @Nullable Set servers) { + private CancelOutcome cancelWithStats(long requestId, @Nullable Set servers) { if (servers == null) { - return null; + return CancelOutcome.NONE; } Deadline deadline = Deadline.after(_cancelTimeout.toMillis(), TimeUnit.MILLISECONDS); @@ -547,11 +693,13 @@ private MultiStageQueryStats cancelWithStats(long requestId, @Nullable Set> dispatchCallbacks = dispatch(sendRequest, servers, deadline, serverInstance -> requestId); + Set respondedServerIds = new HashSet<>(); MultiStageQueryStats stats = MultiStageQueryStats.emptyStats(0); StatMap rootStats = new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); stats.getCurrentStats().addLastOperator(MultiStageOperator.Type.MAILBOX_RECEIVE, rootStats); try { processResults(requestId, servers.size(), (response, server) -> { + respondedServerIds.add(server.getInstanceId()); Map statsByStage = response.getStatsByStageMap(); for (Map.Entry entry : statsByStage.entrySet()) { try (InputStream is = entry.getValue().newInput(); DataInputStream dis = new DataInputStream(is)) { @@ -562,12 +710,12 @@ private MultiStageQueryStats cancelWithStats(long requestId, @Nullable Set opChainMetadata = new HashMap<>(queryOptions); + String collectTiming = stagePlan.getCustomProperties() + .get(AdaptiveRoutingUpstreamTimings.COLLECT_UPSTREAM_TIMING_KEY); + if (collectTiming != null) { + opChainMetadata.put(AdaptiveRoutingUpstreamTimings.COLLECT_UPSTREAM_TIMING_KEY, collectTiming); + } OpChainExecutionContext opChainExecutionContext = - OpChainExecutionContext.fromQueryContext(mailboxService, queryOptions, stageMetadata, workerMetadata.get(0), + OpChainExecutionContext.fromQueryContext(mailboxService, opChainMetadata, stageMetadata, workerMetadata.get(0), null, true, true); PairList resultFields = subPlan.getQueryResultFields(); diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperatorTest.java index 6c57ef6ba013..3d8aabe0a7d1 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/MailboxReceiveOperatorTest.java @@ -19,6 +19,7 @@ package org.apache.pinot.query.runtime.operator; import java.io.IOException; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -40,6 +41,7 @@ import org.apache.pinot.query.runtime.operator.MultiStageOperator.Type; import org.apache.pinot.query.runtime.plan.MultiStageQueryStats; import org.apache.pinot.query.runtime.plan.OpChainExecutionContext; +import org.apache.pinot.query.service.dispatch.AdaptiveRoutingUpstreamTimings; import org.apache.pinot.segment.spi.memory.DataBuffer; import org.apache.pinot.spi.exception.QueryErrorCode; import org.mockito.Mock; @@ -55,6 +57,7 @@ import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.openMocks; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; @@ -264,9 +267,58 @@ public void differentUpstreamStatsProduceEmptyStats() } } + /** + * Verifies that {@link BaseMailboxReceiveOperator#copyStatMaps()} includes per-sender timing for + * senders whose EOS arrived before the query was cancelled or timed out, even when {@code onEos()} + * was never called (because a slow sender never completed). + * + *

This is the timeout path fix: fast-server timings are "trapped" in + * {@code _multiConsumer._senderElapsedMs} until {@code onEos()} fires, but on timeout + * {@code onEos()} never fires. {@code copyStatMaps()} must flush the partial map so that + * {@code applyUpstreamTimingsFromStats} can exempt those fast servers from the full elapsed-time + * penalty in the {@code finally} block. + */ + @Test + public void copyStatMapsIncludesPartialTimingWhenSlowSenderNeverCompletes() { + // mailbox 1 (fast sender): returns EOS immediately. + // mailbox 2 (slow sender): never returns data; the operator will time out waiting for it. + when(_mailboxService.getReceivingMailbox(eq(MAILBOX_ID_1))).thenReturn(_mailbox1); + when(_mailbox1.poll()).thenReturn(OperatorTestUtil.eosWithEmptyStats()); + + when(_mailboxService.getReceivingMailbox(eq(MAILBOX_ID_2))).thenReturn(_mailbox2); + // _mailbox2.poll() returns null by default (slow sender never completes). + + // Use a short deadline so the test does not take too long waiting for the slow sender. + long shortDeadlineMs = System.currentTimeMillis() + 500L; + Map metadata = new HashMap<>(); + metadata.put(AdaptiveRoutingUpstreamTimings.COLLECT_UPSTREAM_TIMING_KEY, "true"); + try (MailboxReceiveOperator operator = getOperator(_stageMetadataBoth, RelDistribution.Type.HASH_DISTRIBUTED, + shortDeadlineMs, metadata)) { + MseBlock block = operator.nextBlock(); + // The operator times out waiting for mailbox 2. + assertTrue(block.isError(), "Expected a timeout error block"); + + // onEos() was NOT called (only one of two senders completed), so _statMap has no timing. + // copyStatMaps() must flush partial timing from _multiConsumer._senderElapsedMs. + StatMap stats = operator.copyStatMaps(); + String encoded = stats.getString(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS); + assertNotNull(encoded, "copyStatMaps() should include partial timing for fast sender that already sent EOS"); + // The sender key is derived from the MailboxInfo hostname/port ("localhost", 1234). + String expectedKey = "localhost|1234"; + assertTrue(encoded.contains(expectedKey), + "Encoded timing '" + encoded + "' should contain the sender key '" + expectedKey + "'"); + } + } + private MailboxReceiveOperator getOperator(StageMetadata stageMetadata, RelDistribution.Type distributionType, long deadlineMs) { - OpChainExecutionContext context = OperatorTestUtil.getOpChainContext(_mailboxService, deadlineMs, stageMetadata); + return getOperator(stageMetadata, distributionType, deadlineMs, Map.of()); + } + + private MailboxReceiveOperator getOperator(StageMetadata stageMetadata, RelDistribution.Type distributionType, + long deadlineMs, Map opChainMetadata) { + OpChainExecutionContext context = + OperatorTestUtil.getOpChainContext(_mailboxService, deadlineMs, stageMetadata, opChainMetadata); MailboxReceiveNode node = mock(MailboxReceiveNode.class); when(node.getDistributionType()).thenReturn(distributionType); when(node.getSenderStageId()).thenReturn(1); diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java index 7e77ecaf780f..b6be37684787 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/OperatorTestUtil.java @@ -110,7 +110,12 @@ public static ReceivingMailbox.MseBlockWithStats eosWithStats(List s public static OpChainExecutionContext getOpChainContext(MailboxService mailboxService, long deadlineMs, StageMetadata stageMetadata) { - return new OpChainExecutionContext(mailboxService, 0, "cid", deadlineMs, deadlineMs, "brokerId", Map.of(), + return getOpChainContext(mailboxService, deadlineMs, stageMetadata, Map.of()); + } + + public static OpChainExecutionContext getOpChainContext(MailboxService mailboxService, long deadlineMs, + StageMetadata stageMetadata, Map opChainMetadata) { + return new OpChainExecutionContext(mailboxService, 0, "cid", deadlineMs, deadlineMs, "brokerId", opChainMetadata, stageMetadata, stageMetadata.getWorkerMetadataList().get(0), null, true, true); } diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/utils/BlockingMultiStreamConsumerUpstreamTimingTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/utils/BlockingMultiStreamConsumerUpstreamTimingTest.java new file mode 100644 index 000000000000..0b4644d91a56 --- /dev/null +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/utils/BlockingMultiStreamConsumerUpstreamTimingTest.java @@ -0,0 +1,210 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.utils; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.pinot.query.mailbox.MailboxService; +import org.apache.pinot.query.mailbox.ReceivingMailbox; +import org.apache.pinot.query.routing.StageMetadata; +import org.apache.pinot.query.routing.WorkerMetadata; +import org.apache.pinot.query.runtime.blocks.SuccessMseBlock; +import org.apache.pinot.query.runtime.operator.OperatorTestUtil; +import org.apache.pinot.query.runtime.plan.OpChainExecutionContext; +import org.apache.pinot.query.service.dispatch.AdaptiveRoutingUpstreamTimings; +import org.apache.pinot.spi.query.QueryThreadContext; +import org.testng.annotations.Test; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + + +/** + * Tests per-sender elapsed-time tracking in {@link BlockingMultiStreamConsumer.OfMseBlock}. + */ +public class BlockingMultiStreamConsumerUpstreamTimingTest { + + @SuppressWarnings("unchecked") + private AsyncStream mockStream(String id) { + AsyncStream stream = mock(AsyncStream.class); + when(stream.getId()).thenReturn(id); + doNothing().when(stream).addOnNewDataListener(any()); + return stream; + } + + private static ReceivingMailbox.MseBlockWithStats eos() { + return new ReceivingMailbox.MseBlockWithStats(SuccessMseBlock.INSTANCE, List.of()); + } + + private OpChainExecutionContext createContext() { + MailboxService mailboxService = mock(MailboxService.class); + when(mailboxService.getHostname()).thenReturn("localhost"); + when(mailboxService.getPort()).thenReturn(1234); + StageMetadata stageMetadata = new StageMetadata(0, + List.of(new WorkerMetadata(0, Map.of(), Map.of())), Map.of()); + return OperatorTestUtil.getOpChainContext(mailboxService, Long.MAX_VALUE, stageMetadata); + } + + /** + * Covers: per-sender tracking, max-dedup for duplicate sender keys, and no-op for unmapped streams. + * + *

Three streams: A (fast), A-worker1 (slow, same sender key as A), and C (no sender key mapping). + * Verifies: A gets max of both workers (700ms), C produces no entry. + */ + @Test + public void testPerSenderElapsedTimeWithMaxDedupAndUnmappedStream() { + AtomicLong clock = new AtomicLong(0L); + + AsyncStream streamA0 = mockStream("mailbox-A-w0"); + AsyncStream streamA1 = mockStream("mailbox-A-w1"); + AsyncStream streamC = mockStream("mailbox-C"); + + String keyA = AdaptiveRoutingUpstreamTimings.senderKey("host-a", 8442); + + when(streamA0.poll()).thenAnswer(inv -> { + clock.set(100L); + return eos(); + }); + when(streamA1.poll()).thenAnswer(inv -> { + clock.set(700L); + return eos(); + }); + when(streamC.poll()).thenAnswer(inv -> { + clock.set(500L); + return eos(); + }); + + try (QueryThreadContext ignored = QueryThreadContext.openForMseTest()) { + Map streamIdToSenderKey = new HashMap<>(); + streamIdToSenderKey.put("mailbox-A-w0", keyA); + streamIdToSenderKey.put("mailbox-A-w1", keyA); + // mailbox-C deliberately NOT mapped -> tests null-guard path + + BlockingMultiStreamConsumer.OfMseBlock consumer = new BlockingMultiStreamConsumer.OfMseBlock( + createContext(), + new ArrayList<>(List.of(streamA0, streamA1, streamC)), + /* senderStageId= */ 2, + streamIdToSenderKey, + clock::get); + + consumer.readMseBlockBlocking(); + + Map timings = consumer.getSenderElapsedMs(); + assertEquals(timings.size(), 1, "Only keyA should be recorded (keyC has no mapping)"); + assertEquals((long) timings.get(keyA), 700L, "Should keep max across workers (700 > 100)"); + assertTrue(!timings.containsKey(AdaptiveRoutingUpstreamTimings.senderKey("host-c", 8442)), + "Unmapped stream must not produce a timing entry"); + } + } + + /** + * Full timeout: no senders complete. getSenderElapsedMsIncludingPending() injects elapsed time for all. + * Does not call readMseBlockBlocking() — simulates the state at cancel time when no EOS arrived. + */ + @Test + public void testIncludingPendingFullTimeout() { + AtomicLong clock = new AtomicLong(0L); + String keyA = AdaptiveRoutingUpstreamTimings.senderKey("host-a", 8442); + String keyB = AdaptiveRoutingUpstreamTimings.senderKey("host-b", 8442); + String keyC = AdaptiveRoutingUpstreamTimings.senderKey("host-c", 8442); + + AsyncStream streamA = mockStream("mailbox-A"); + AsyncStream streamB = mockStream("mailbox-B"); + AsyncStream streamC = mockStream("mailbox-C"); + + try (QueryThreadContext ignored = QueryThreadContext.openForMseTest()) { + Map streamIdToSenderKey = new HashMap<>(); + streamIdToSenderKey.put("mailbox-A", keyA); + streamIdToSenderKey.put("mailbox-B", keyB); + streamIdToSenderKey.put("mailbox-C", keyC); + + BlockingMultiStreamConsumer.OfMseBlock consumer = new BlockingMultiStreamConsumer.OfMseBlock( + createContext(), + new ArrayList<>(List.of(streamA, streamB, streamC)), + /* senderStageId= */ 2, + streamIdToSenderKey, + clock::get); + + // No reads performed — no EOS received + assertTrue(consumer.getSenderElapsedMs().isEmpty()); + + // Simulate time passing (as if timeout occurred at 9800ms) + clock.set(9800L); + + Map timings = consumer.getSenderElapsedMsIncludingPending(); + assertEquals(timings.size(), 3); + assertEquals((long) timings.get(keyA), 9800L); + assertEquals((long) timings.get(keyB), 9800L); + assertEquals((long) timings.get(keyC), 9800L); + } + } + + /** + * All senders complete normally. getSenderElapsedMsIncludingPending matches getSenderElapsedMs + * even when called later (putIfAbsent is a no-op for already-recorded senders). + */ + @Test + public void testIncludingPendingAllComplete() { + AtomicLong clock = new AtomicLong(0L); + String keyA = AdaptiveRoutingUpstreamTimings.senderKey("host-a", 8442); + String keyB = AdaptiveRoutingUpstreamTimings.senderKey("host-b", 8442); + + AsyncStream streamA = mockStream("mailbox-A"); + AsyncStream streamB = mockStream("mailbox-B"); + + when(streamA.poll()).thenAnswer(inv -> { + clock.set(50L); + return eos(); + }); + when(streamB.poll()).thenAnswer(inv -> { + clock.set(80L); + return eos(); + }); + + try (QueryThreadContext ignored = QueryThreadContext.openForMseTest()) { + Map streamIdToSenderKey = new HashMap<>(); + streamIdToSenderKey.put("mailbox-A", keyA); + streamIdToSenderKey.put("mailbox-B", keyB); + + BlockingMultiStreamConsumer.OfMseBlock consumer = new BlockingMultiStreamConsumer.OfMseBlock( + createContext(), + new ArrayList<>(List.of(streamA, streamB)), + /* senderStageId= */ 2, + streamIdToSenderKey, + clock::get); + + consumer.readMseBlockBlocking(); + + // Both completed — includingPending should preserve actual latencies even at later clock + clock.set(5000L); + Map base = consumer.getSenderElapsedMs(); + Map withPending = consumer.getSenderElapsedMsIncludingPending(); + assertEquals(withPending, base, "When all senders completed, includingPending equals base"); + assertEquals((long) withPending.get(keyA), 50L); + assertEquals((long) withPending.get(keyB), 80L); + } + } +} diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/AdaptiveRoutingUpstreamTimingsTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/AdaptiveRoutingUpstreamTimingsTest.java new file mode 100644 index 000000000000..8c9c06024179 --- /dev/null +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/AdaptiveRoutingUpstreamTimingsTest.java @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.service.dispatch; + +import java.util.Map; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + + +public class AdaptiveRoutingUpstreamTimingsTest { + + @Test + public void testEncodeMultipleEntriesRoundTrips() { + Map timings = new java.util.LinkedHashMap<>(); + timings.put("host-z|8442", 300L); + timings.put("host-a|8442", 100L); + timings.put("host-m|8442", 200L); + String encoded = AdaptiveRoutingUpstreamTimings.encode(timings); + assertEquals(AdaptiveRoutingUpstreamTimings.decode(encoded), Map.of( + "host-z|8442", 300L, "host-a|8442", 100L, "host-m|8442", 200L)); + } + + @Test + public void testRoundTrip() { + Map original = Map.of("host-a|8442", 100L, "host-b|8442", 200L, "host-c|8442", 300L); + String encoded = AdaptiveRoutingUpstreamTimings.encode(original); + assertEquals(AdaptiveRoutingUpstreamTimings.decode(encoded), original); + } + + @Test + public void testMergeEncodingsTakesMax() { + String enc1 = AdaptiveRoutingUpstreamTimings.encode(Map.of("host-a|8442", 100L, "host-b|8442", 50L)); + String enc2 = AdaptiveRoutingUpstreamTimings.encode(Map.of("host-a|8442", 80L, "host-b|8442", 200L)); + Map merged = AdaptiveRoutingUpstreamTimings.decode( + AdaptiveRoutingUpstreamTimings.mergeEncodings(enc1, enc2)); + assertEquals(merged, Map.of("host-a|8442", 100L, "host-b|8442", 200L)); + } + + @Test + public void testMergeEncodingsDisjointKeys() { + String enc1 = AdaptiveRoutingUpstreamTimings.encode(Map.of("host-a|8442", 100L)); + String enc2 = AdaptiveRoutingUpstreamTimings.encode(Map.of("host-b|8442", 200L)); + Map merged = AdaptiveRoutingUpstreamTimings.decode( + AdaptiveRoutingUpstreamTimings.mergeEncodings(enc1, enc2)); + assertEquals(merged, Map.of("host-a|8442", 100L, "host-b|8442", 200L)); + } +} diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherApplyUpstreamTimingsTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherApplyUpstreamTimingsTest.java new file mode 100644 index 000000000000..b7cac2451b84 --- /dev/null +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherApplyUpstreamTimingsTest.java @@ -0,0 +1,637 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.service.dispatch; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeSet; +import org.apache.pinot.common.datatable.StatMap; +import org.apache.pinot.common.response.broker.ResultTable; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.core.transport.server.routing.stats.ServerRoutingStatsManager; +import org.apache.pinot.query.planner.PlanFragment; +import org.apache.pinot.query.planner.physical.DispatchablePlanFragment; +import org.apache.pinot.query.planner.physical.DispatchableSubPlan; +import org.apache.pinot.query.planner.physical.FragmentType; +import org.apache.pinot.query.planner.plannode.MailboxSendNode; +import org.apache.pinot.query.routing.QueryServerInstance; +import org.apache.pinot.query.runtime.operator.BaseMailboxReceiveOperator; +import org.apache.pinot.query.runtime.operator.MultiStageOperator; +import org.apache.pinot.query.runtime.plan.MultiStageQueryStats; +import org.testng.Assert; +import org.testng.annotations.Test; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + + +/** + * Unit tests for upstream timing extraction via {@link AdaptiveRoutingStageClassification#classify} and + * {@link QueryDispatcher#extractMaxTimingsPerInstance}. + * + *

Each test constructs a {@link QueryDispatcher.QueryResult} with hand-crafted + * {@code UPSTREAM_SERVER_RESPONSE_TIMES_MS} stats and a matching {@link DispatchableSubPlan}, + * then asserts that {@link ServerRoutingStatsManager#recordStatsUponResponseArrival} is (or is not) + * called with the expected arguments. + * + *

Two kinds of stages are consulted: direct pure-leaf receivers ({@code stagesReceivingFromLeaves}, + * including stage 0 for 2-stage queries), and SINGLETON leaf stages themselves + * ({@code singletonLeafStageIds}). All other stages — including receivers of SINGLETON leaf stages — + * are excluded to prevent SINGLETON cascade contamination. + * See {@link AdaptiveRoutingStageClassification} for details. + */ +public class QueryDispatcherApplyUpstreamTimingsTest { + + private static final long REQUEST_ID = 42L; + private static final int MAILBOX_PORT = 8442; + private static final ResultTable EMPTY_RESULT = + new ResultTable(new DataSchema(new String[]{}, new DataSchema.ColumnDataType[]{}), List.of()); + + /** + * Mirrors the 3-step sequence in QueryDispatcher's finally block: classify -> extract -> record. + * Returns the classification so tests can inspect _trackedServers. + */ + private static AdaptiveRoutingStageClassification applyUpstreamTimingsFromStats(QueryDispatcher.QueryResult result, + DispatchableSubPlan plan, ServerRoutingStatsManager statsManager, long requestId, + Set recordedInstanceIds) { + return applyUpstreamTimingsFromStats(result, plan, statsManager, requestId, recordedInstanceIds, + QueryDispatcher.CancelOutcome.NONE); + } + + private static AdaptiveRoutingStageClassification applyUpstreamTimingsFromStats(QueryDispatcher.QueryResult result, + DispatchableSubPlan plan, ServerRoutingStatsManager statsManager, long requestId, + Set recordedInstanceIds, QueryDispatcher.CancelOutcome cancelOutcome) { + AdaptiveRoutingStageClassification classification = AdaptiveRoutingStageClassification.classify(plan); + Map maxTimings = QueryDispatcher.extractMaxTimingsPerInstance( + result, classification, requestId, cancelOutcome); + for (Map.Entry entry : maxTimings.entrySet()) { + if (recordedInstanceIds.add(entry.getKey())) { + statsManager.recordStatsUponResponseArrival(requestId, entry.getKey(), entry.getValue()); + } + } + return classification; + } + + // --------------------------------------------------------------------------- + // Helpers + // --------------------------------------------------------------------------- + + /** + * Builds a QueryResult whose stage-1 stats contain a MAILBOX_RECEIVE operator with + * UPSTREAM_SERVER_RESPONSE_TIMES_MS set to {@code encoded}. + */ + private static QueryDispatcher.QueryResult resultWithStage1Timing(String encoded) { + StatMap receiveStats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + receiveStats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, encoded); + + MultiStageQueryStats.StageStats.Closed stage1 = new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(receiveStats)); + + MultiStageQueryStats mqStats = MultiStageQueryStats.emptyStats(0); + mqStats.mergeUpstream(1, stage1); + return new QueryDispatcher.QueryResult(EMPTY_RESULT, mqStats, 0L); + } + + private static DispatchablePlanFragment leafFragment(int receiverStageId, QueryServerInstance... servers) { + return leafFragmentWithStageId(0, receiverStageId, servers); + } + + private static DispatchablePlanFragment leafFragmentWithStageId(int stageId, int receiverStageId, + QueryServerInstance... servers) { + DispatchablePlanFragment fragment = mock(DispatchablePlanFragment.class); + Map.Entry>[] entries = new Map.Entry[servers.length]; + for (int i = 0; i < servers.length; i++) { + entries[i] = Map.entry(servers[i], List.of(i)); + } + when(fragment.getServerInstanceToWorkerIdMap()).thenReturn(Map.ofEntries(entries)); + when(fragment.getFragmentType()).thenReturn(FragmentType.LEAF); + + MailboxSendNode sendNode = mock(MailboxSendNode.class); + when(sendNode.getReceiverStageIds()).thenReturn(List.of(receiverStageId)); + PlanFragment planFragment = mock(PlanFragment.class); + when(planFragment.getFragmentId()).thenReturn(stageId); + when(planFragment.getFragmentRoot()).thenReturn(sendNode); + when(fragment.getPlanFragment()).thenReturn(planFragment); + return fragment; + } + + private static DispatchablePlanFragment nonLeafFragment(QueryServerInstance... servers) { + return nonLeafFragmentWithReceivers(List.of(), servers); + } + + private static DispatchablePlanFragment nonLeafFragmentWithReceivers(List receiverStageIds, + QueryServerInstance... servers) { + DispatchablePlanFragment fragment = mock(DispatchablePlanFragment.class); + Map.Entry>[] entries = new Map.Entry[servers.length]; + for (int i = 0; i < servers.length; i++) { + entries[i] = Map.entry(servers[i], List.of(i)); + } + when(fragment.getServerInstanceToWorkerIdMap()).thenReturn(Map.ofEntries(entries)); + when(fragment.getFragmentType()).thenReturn(FragmentType.INTERMEDIATE); + + if (!receiverStageIds.isEmpty()) { + MailboxSendNode sendNode = mock(MailboxSendNode.class); + when(sendNode.getReceiverStageIds()).thenReturn(receiverStageIds); + PlanFragment planFragment = mock(PlanFragment.class); + when(planFragment.getFragmentRoot()).thenReturn(sendNode); + when(fragment.getPlanFragment()).thenReturn(planFragment); + } else { + when(fragment.getPlanFragment()).thenReturn(null); + } + return fragment; + } + + /** + * Wraps one or more fragments into a {@link DispatchableSubPlan} mock. + */ + private static DispatchableSubPlan planWith(DispatchablePlanFragment... fragments) { + DispatchableSubPlan plan = mock(DispatchableSubPlan.class); + TreeSet fragmentSet = new TreeSet<>(Comparator.comparingInt(System::identityHashCode)); + fragmentSet.addAll(Arrays.asList(fragments)); + when(plan.getQueryStagesWithoutRoot()).thenReturn(fragmentSet); + return plan; + } + + // --------------------------------------------------------------------------- + // Tests + // --------------------------------------------------------------------------- + + @Test + public void testRecordsRealLatencyForKnownIndirectSender() { + // Stage-1 stats encode that host-a took 600ms. The plan knows host-a as "instance-a" via a + // leaf fragment that sends to Stage 1. + QueryServerInstance server = new QueryServerInstance("instance-a", "host-a", 9000, MAILBOX_PORT); + String encoded = AdaptiveRoutingUpstreamTimings.senderKey("host-a", MAILBOX_PORT) + "=600"; + + ServerRoutingStatsManager stats = mock(ServerRoutingStatsManager.class); + Set recorded = new HashSet<>(); + + applyUpstreamTimingsFromStats( + resultWithStage1Timing(encoded), planWith(leafFragment(1, server)), stats, REQUEST_ID, recorded); + + verify(stats).recordStatsUponResponseArrival(REQUEST_ID, "instance-a", 600L); + Assert.assertTrue(recorded.contains("instance-a"), "instance-a should be added to recordedInstanceIds"); + } + + @Test + public void testSkipsUnknownSenderKey() { + // Stage stats reference a host that isn't in the plan — should not be recorded. + QueryServerInstance server = new QueryServerInstance("instance-a", "host-a", 9000, MAILBOX_PORT); + String encoded = AdaptiveRoutingUpstreamTimings.senderKey("host-unknown", MAILBOX_PORT) + "=300"; + + ServerRoutingStatsManager stats = mock(ServerRoutingStatsManager.class); + + applyUpstreamTimingsFromStats( + resultWithStage1Timing(encoded), planWith(leafFragment(1, server)), stats, REQUEST_ID, new HashSet<>()); + + verify(stats, never()).recordStatsUponResponseArrival(REQUEST_ID, "instance-a", 300L); + } + + + @Test + public void testMultipleStagesAndSendersAllProcessed() { + // Stage 1 receives from leaf-A (host-a); stage 2 receives from leaf-B (host-b). Both recorded. + QueryServerInstance serverA = new QueryServerInstance("instance-a", "host-a", 9000, MAILBOX_PORT); + QueryServerInstance serverB = new QueryServerInstance("instance-b", "host-b", 9000, MAILBOX_PORT); + + StatMap s1Stats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + s1Stats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, + AdaptiveRoutingUpstreamTimings.senderKey("host-a", MAILBOX_PORT) + "=500"); + + StatMap s2Stats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + s2Stats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, + AdaptiveRoutingUpstreamTimings.senderKey("host-b", MAILBOX_PORT) + "=250"); + + MultiStageQueryStats mqStats = MultiStageQueryStats.emptyStats(0); + mqStats.mergeUpstream(1, new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(s1Stats))); + mqStats.mergeUpstream(2, new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(s2Stats))); + + QueryDispatcher.QueryResult result = new QueryDispatcher.QueryResult(EMPTY_RESULT, mqStats, 0L); + ServerRoutingStatsManager stats = mock(ServerRoutingStatsManager.class); + Set recorded = new HashSet<>(); + + // Two leaf fragments: leaf-A sends to Stage 1, leaf-B sends to Stage 2. + applyUpstreamTimingsFromStats(result, + planWith(leafFragment(1, serverA), leafFragment(2, serverB)), stats, REQUEST_ID, recorded); + + verify(stats).recordStatsUponResponseArrival(REQUEST_ID, "instance-a", 500L); + verify(stats).recordStatsUponResponseArrival(REQUEST_ID, "instance-b", 250L); + Assert.assertTrue(recorded.containsAll(List.of("instance-a", "instance-b"))); + } + + @Test + public void testIntermediateStageTimingsIgnored() { + // Stage 1 (leaf receiver) has accurate per-leaf timings for all 3 servers. + // Stage 0 (broker) has no timing data (it receives from the intermediate, not from leaves). + // + // The plan has a leaf fragment (Stage 2) that sends to Stage 1 as its receiver. + // Stage 1 is consulted because it is a pure-leaf receiver. + QueryServerInstance serverA = new QueryServerInstance("instance-a", "host-a", 9000, MAILBOX_PORT); + QueryServerInstance serverB = new QueryServerInstance("instance-b", "host-b", 9000, MAILBOX_PORT); + QueryServerInstance serverC = new QueryServerInstance("instance-c", "host-c", 9000, MAILBOX_PORT); + + // Stage 0 (broker reduce): no timing data — it receives from the intermediate, not leaves. + StatMap stage0ReceiveStats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + + // Stage 1 (leaf receiver): true per-leaf timings: fast servers at 50ms, slow at 600ms + StatMap stage1ReceiveStats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + stage1ReceiveStats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, + AdaptiveRoutingUpstreamTimings.senderKey("host-a", MAILBOX_PORT) + "=50;" + + AdaptiveRoutingUpstreamTimings.senderKey("host-b", MAILBOX_PORT) + "=50;" + + AdaptiveRoutingUpstreamTimings.senderKey("host-c", MAILBOX_PORT) + "=600"); + MultiStageQueryStats.StageStats.Closed stage1Closed = new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(stage1ReceiveStats)); + + MultiStageQueryStats mqStats = MultiStageQueryStats.emptyStats(0); + mqStats.getCurrentStats().addLastOperator(MultiStageOperator.Type.MAILBOX_RECEIVE, stage0ReceiveStats); + mqStats.mergeUpstream(1, stage1Closed); + QueryDispatcher.QueryResult result = new QueryDispatcher.QueryResult(EMPTY_RESULT, mqStats, 0L); + + // Plan: leaf fragment (Stage 2) sends to Stage 1; non-leaf fragment (Stage 1) holds the servers. + // Both fragments list servers A/B/C so senderKeyToInstanceId is fully populated. + ServerRoutingStatsManager stats = mock(ServerRoutingStatsManager.class); + Set recorded = new HashSet<>(); + + applyUpstreamTimingsFromStats(result, + planWith(leafFragment(1, serverA, serverB, serverC), nonLeafFragment(serverA, serverB, serverC)), + stats, REQUEST_ID, recorded); + + // Accurate leaf timings from Stage 1 must be recorded. + verify(stats).recordStatsUponResponseArrival(REQUEST_ID, "instance-a", 50L); + verify(stats).recordStatsUponResponseArrival(REQUEST_ID, "instance-b", 50L); + verify(stats).recordStatsUponResponseArrival(REQUEST_ID, "instance-c", 600L); + } + + @Test + public void testSameServerInMultipleLeafStagesTakesMax() { + // Server A participates in two leaf stages: Stage 3 (50ms) and Stage 5 (300ms). + // Both leaf fragments send to different receivers (Stage 2 and Stage 4 respectively). + // The maximum of the two leaf observations (300ms) should be recorded. + QueryServerInstance serverA = new QueryServerInstance("instance-a", "host-a", 9000, MAILBOX_PORT); + + StatMap stage2Stats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + stage2Stats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, + AdaptiveRoutingUpstreamTimings.senderKey("host-a", MAILBOX_PORT) + "=50"); + + StatMap stage4Stats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + stage4Stats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, + AdaptiveRoutingUpstreamTimings.senderKey("host-a", MAILBOX_PORT) + "=300"); + + MultiStageQueryStats mqStats = MultiStageQueryStats.emptyStats(0); + mqStats.mergeUpstream(2, new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(stage2Stats))); + mqStats.mergeUpstream(4, new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(stage4Stats))); + + QueryDispatcher.QueryResult result = new QueryDispatcher.QueryResult(EMPTY_RESULT, mqStats, 0L); + ServerRoutingStatsManager stats = mock(ServerRoutingStatsManager.class); + Set recorded = new HashSet<>(); + + // Two leaf fragments for server A: one sends to Stage 2, another sends to Stage 4. + applyUpstreamTimingsFromStats(result, + planWith(leafFragment(2, serverA), leafFragment(4, serverA)), stats, REQUEST_ID, recorded); + + // The max of the two leaf observations (300ms) should be the single recorded value. + verify(stats).recordStatsUponResponseArrival(REQUEST_ID, "instance-a", 300L); + } + + private static DispatchablePlanFragment singletonLeafFragment(int leafStageId, boolean timingTrusted, + QueryServerInstance... servers) { + return singletonLeafFragment(leafStageId, timingTrusted, List.of(), servers); + } + + private static DispatchablePlanFragment singletonLeafFragment(int leafStageId, boolean timingTrusted, + List receiverStageIds, QueryServerInstance... servers) { + DispatchablePlanFragment fragment = mock(DispatchablePlanFragment.class); + Map.Entry>[] entries = new Map.Entry[servers.length]; + for (int i = 0; i < servers.length; i++) { + entries[i] = Map.entry(servers[i], List.of(i)); + } + when(fragment.getServerInstanceToWorkerIdMap()).thenReturn(Map.ofEntries(entries)); + when(fragment.getFragmentType()).thenReturn( + timingTrusted ? FragmentType.SINGLETON_LEAF : FragmentType.INTERMEDIATE); + + MailboxSendNode sendNode = mock(MailboxSendNode.class); + when(sendNode.getReceiverStageIds()).thenReturn(receiverStageIds); + PlanFragment planFragment = mock(PlanFragment.class); + when(planFragment.getFragmentId()).thenReturn(leafStageId); + when(planFragment.getFragmentRoot()).thenReturn(sendNode); + when(fragment.getPlanFragment()).thenReturn(planFragment); + return fragment; + } + + @Test + public void testMixedLeafWithSingletonReceiveIsSkipped() { + // A SINGLETON leaf fragment (dim table lookup-join). Its receiver stage 1 is NOT trusted + // (SINGLETON receivers are contaminated). Its servers are NOT tracked. + QueryServerInstance server = new QueryServerInstance("instance-a", "host-a", 9000, MAILBOX_PORT); + String encoded = AdaptiveRoutingUpstreamTimings.senderKey("host-a", MAILBOX_PORT) + "=800"; + + ServerRoutingStatsManager stats = mock(ServerRoutingStatsManager.class); + Set recorded = new HashSet<>(); + + // SINGLETON leaf (stageId=2), not timing-trusted (sender is not a leaf in this test). + AdaptiveRoutingStageClassification classification = applyUpstreamTimingsFromStats( + resultWithStage1Timing(encoded), + planWith(singletonLeafFragment(2, false, server)), + stats, REQUEST_ID, recorded); + + // Stage 1 (receiver) is not consulted — the inflated 800ms timing is never recorded. + verify(stats, never()).recordStatsUponResponseArrival(REQUEST_ID, "instance-a", 800L); + verify(stats, never()).recordStatsUponResponseArrival(REQUEST_ID, "instance-a", -1L); + Assert.assertFalse(classification._trackedServers.contains("instance-a"), + "instance-a must NOT be in trackedServers so the finally block records -1"); + } + + + @Test + public void testPureLeavesStillRecordedWhenMixedLeafAlsoPresent() { + // When a query has both a pure leaf (fact table) and a SINGLETON leaf (dim lookup), + // the pure leaf timing is still recorded correctly. The SINGLETON leaf is silently skipped. + QueryServerInstance server = new QueryServerInstance("instance-a", "host-a", 9000, MAILBOX_PORT); + + // Stage 1 receives from the pure leaf (trusted receiver). + // Stage 2 receives from the SINGLETON leaf (not trusted). + StatMap stage1Stats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + stage1Stats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, + AdaptiveRoutingUpstreamTimings.senderKey("host-a", MAILBOX_PORT) + "=90"); + + StatMap stage2Stats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + stage2Stats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, + AdaptiveRoutingUpstreamTimings.senderKey("host-a", MAILBOX_PORT) + "=950"); + + MultiStageQueryStats mqStats = MultiStageQueryStats.emptyStats(0); + mqStats.mergeUpstream(1, new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(stage1Stats))); + mqStats.mergeUpstream(2, new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(stage2Stats))); + + QueryDispatcher.QueryResult result = new QueryDispatcher.QueryResult(EMPTY_RESULT, mqStats, 0L); + ServerRoutingStatsManager stats = mock(ServerRoutingStatsManager.class); + Set recorded = new HashSet<>(); + + // Pure leaf sends to stage 1 (trusted). SINGLETON leaf (stage 3) has no trusted receivers. + applyUpstreamTimingsFromStats(result, + planWith(leafFragment(1, server), singletonLeafFragment(3, false, server)), + stats, REQUEST_ID, recorded); + + // Pure leaf timing (90ms) recorded, NOT the inflated SINGLETON timing (950ms). + verify(stats).recordStatsUponResponseArrival(REQUEST_ID, "instance-a", 90L); + verify(stats, never()).recordStatsUponResponseArrival(REQUEST_ID, "instance-a", 950L); + } + + @Test + public void testSingletonLeafStageItselfIsConsultedWhenSenderIsLeaf() { + QueryServerInstance impacted = new QueryServerInstance("impacted", "host-impacted", 9000, MAILBOX_PORT); + QueryServerInstance normal = new QueryServerInstance("normal", "host-normal", 9000, MAILBOX_PORT); + + // Stage 4 (the SINGLETON leaf stage) receives from stage 5 (a pure leaf scan stage). + StatMap stage4Stats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + stage4Stats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, + AdaptiveRoutingUpstreamTimings.senderKey("host-impacted", MAILBOX_PORT) + "=911;" + + AdaptiveRoutingUpstreamTimings.senderKey("host-normal", MAILBOX_PORT) + "=2"); + + MultiStageQueryStats mqStats = MultiStageQueryStats.emptyStats(0); + mqStats.mergeUpstream(4, new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(stage4Stats))); + QueryDispatcher.QueryResult result = new QueryDispatcher.QueryResult(EMPTY_RESULT, mqStats, 0L); + + // SINGLETON leaf at stageId=4, marked TIMING_TRUSTED (sender stage 5 is a leaf). + // Stage 5 is a pure leaf running on the same servers. + DispatchablePlanFragment singletonLeaf = singletonLeafFragment(4, true, impacted, normal); + DispatchablePlanFragment senderLeaf = leafFragmentWithStageId(5, 4, impacted, normal); + DispatchablePlanFragment intermediate = nonLeafFragment(impacted, normal); + + ServerRoutingStatsManager stats = mock(ServerRoutingStatsManager.class); + Set recorded = new HashSet<>(); + + applyUpstreamTimingsFromStats( + result, planWith(singletonLeaf, senderLeaf, intermediate), stats, REQUEST_ID, recorded); + + // Stage 4 is TIMING_TRUSTED -> its stats are consulted. + verify(stats).recordStatsUponResponseArrival(REQUEST_ID, "impacted", 911L); + verify(stats).recordStatsUponResponseArrival(REQUEST_ID, "normal", 2L); + Assert.assertTrue(recorded.containsAll(List.of("impacted", "normal"))); + } + + + /** + * Covers the relay/intermediate-server contamination case (e.g. a SINGLETON relay stage-1 server + * like {@code 055f50e85c4876db1} that waits for a slow upstream and therefore has an inflated + * wall-clock elapsed time). The intermediate server must NOT be in {@code _trackedServers} so + * the caller's finally block records it at -1 (not at wall-clock). + * + *

Topology: leaf (stage 2) sends non-SINGLETON to stage 1 (intermediate relay). Stage 1's + * UPSTREAM_SERVER_RESPONSE_TIMES_MS gives accurate leaf timings. The relay server runs only the + * intermediate stage ({@code nonLeafFragment}) and must NOT have a real latency recorded for it. + */ + @Test + public void testIntermediateRelayServerNotInTrackedServers() { + QueryServerInstance leafServer = new QueryServerInstance("leaf-instance", "host-leaf", 9000, MAILBOX_PORT); + QueryServerInstance relayServer = new QueryServerInstance("relay-instance", "host-relay", 9000, MAILBOX_PORT); + + // Stage 1 (intermediate relay) records leaf server timing accurately. + String encoded = AdaptiveRoutingUpstreamTimings.senderKey("host-leaf", MAILBOX_PORT) + "=80"; + + ServerRoutingStatsManager stats = mock(ServerRoutingStatsManager.class); + Set recorded = new HashSet<>(); + + // leaf fragment sends to stage 1; relay server is in a non-leaf intermediate fragment. + AdaptiveRoutingStageClassification classification = applyUpstreamTimingsFromStats( + resultWithStage1Timing(encoded), + planWith(leafFragment(1, leafServer), nonLeafFragment(relayServer)), + stats, REQUEST_ID, recorded); + + // Leaf server gets its real latency recorded. + verify(stats).recordStatsUponResponseArrival(REQUEST_ID, "leaf-instance", 80L); + // Relay server is NOT recorded by applyUpstreamTimingsFromStats — the caller's finally block + // will handle it (not in trackedServers, so it records -1). + verify(stats, never()).recordStatsUponResponseArrival(REQUEST_ID, "relay-instance", -1L); + Assert.assertFalse(recorded.contains("relay-instance"), + "relay-instance must NOT be in recordedInstanceIds (handled by caller)"); + Assert.assertFalse(classification._trackedServers.contains("relay-instance"), + "relay-instance must NOT be in trackedServers so the finally block records -1"); + } + + @Test + public void testSingletonLeafReceivingFromIntermediateNotConsulted() { + QueryServerInstance serverA = new QueryServerInstance("server-a", "host-a", 9000, MAILBOX_PORT); + QueryServerInstance serverB = new QueryServerInstance("server-b", "host-b", 9000, MAILBOX_PORT); + + // Stage 2 (SINGLETON leaf) records timing for its stage-3 upstream senders. + StatMap stage2Stats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + stage2Stats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, + AdaptiveRoutingUpstreamTimings.senderKey("host-a", MAILBOX_PORT) + "=600;" + + AdaptiveRoutingUpstreamTimings.senderKey("host-b", MAILBOX_PORT) + "=10"); + + MultiStageQueryStats mqStats = MultiStageQueryStats.emptyStats(0); + mqStats.mergeUpstream(2, new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(stage2Stats))); + QueryDispatcher.QueryResult result = new QueryDispatcher.QueryResult(EMPTY_RESULT, mqStats, 0L); + + // Stage 2 is a SINGLETON leaf. NOT timing-trusted (sender stage 3 is intermediate, not a leaf). + DispatchablePlanFragment singletonLeaf = singletonLeafFragment(2, false, serverA, serverB); + DispatchablePlanFragment intermediate = nonLeafFragment(serverA, serverB); + + ServerRoutingStatsManager stats = mock(ServerRoutingStatsManager.class); + Set recorded = new HashSet<>(); + + AdaptiveRoutingStageClassification classification = applyUpstreamTimingsFromStats( + result, planWith(singletonLeaf, intermediate), stats, REQUEST_ID, recorded); + + // Stage 2 is NOT consulted (not timing-trusted) -> contaminated timings never recorded. + verify(stats, never()).recordStatsUponResponseArrival(REQUEST_ID, "server-a", 600L); + verify(stats, never()).recordStatsUponResponseArrival(REQUEST_ID, "server-b", 10L); + verify(stats, never()).recordStatsUponResponseArrival(REQUEST_ID, "server-a", -1L); + verify(stats, never()).recordStatsUponResponseArrival(REQUEST_ID, "server-b", -1L); + Assert.assertFalse(classification._trackedServers.contains("server-a")); + Assert.assertFalse(classification._trackedServers.contains("server-b")); + } + + // --------------------------------------------------------------------------- + // Cancel stats tests + // --------------------------------------------------------------------------- + + @Test + public void testCancelStatsProvideTimingsForStagesMissingFromResult() { + // Stage 1 timings come from the normal query result. Stage 2 timings come only from cancelStats + // (the broker didn't see stage 2 during normal reduce because the query was cancelled early). + QueryServerInstance serverA = new QueryServerInstance("instance-a", "host-a", 9000, MAILBOX_PORT); + QueryServerInstance serverB = new QueryServerInstance("instance-b", "host-b", 9000, MAILBOX_PORT); + + // Normal result has stage 1 timing for server A. + StatMap stage1Stats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + stage1Stats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, + AdaptiveRoutingUpstreamTimings.senderKey("host-a", MAILBOX_PORT) + "=100"); + + MultiStageQueryStats mqStats = MultiStageQueryStats.emptyStats(0); + mqStats.mergeUpstream(1, new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(stage1Stats))); + QueryDispatcher.QueryResult result = new QueryDispatcher.QueryResult(EMPTY_RESULT, mqStats, 0L); + + // Cancel stats have stage 2 timing for server B (gathered during cancel). + StatMap cancelStage2Stats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + cancelStage2Stats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, + AdaptiveRoutingUpstreamTimings.senderKey("host-b", MAILBOX_PORT) + "=450"); + MultiStageQueryStats cancelStats = MultiStageQueryStats.emptyStats(0); + cancelStats.mergeUpstream(2, new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(cancelStage2Stats))); + + // Plan: leaf-A sends to stage 1, leaf-B sends to stage 2. + ServerRoutingStatsManager stats = mock(ServerRoutingStatsManager.class); + Set recorded = new HashSet<>(); + + applyUpstreamTimingsFromStats(result, + planWith(leafFragment(1, serverA), leafFragment(2, serverB)), + stats, REQUEST_ID, recorded, new QueryDispatcher.CancelOutcome(cancelStats, Set.of())); + + verify(stats).recordStatsUponResponseArrival(REQUEST_ID, "instance-a", 100L); + verify(stats).recordStatsUponResponseArrival(REQUEST_ID, "instance-b", 450L); + Assert.assertTrue(recorded.containsAll(List.of("instance-a", "instance-b"))); + } + + @Test + public void testCancelStatsMergedWithResultTakingMax() { + // Same server appears in both the normal result (stage 1) and cancelStats (stage 2). + // extractMaxTimingsPerInstance should take the max across both sources. + QueryServerInstance server = new QueryServerInstance("instance-a", "host-a", 9000, MAILBOX_PORT); + + // Normal result: stage 1 reports server at 200ms. + StatMap stage1Stats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + stage1Stats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, + AdaptiveRoutingUpstreamTimings.senderKey("host-a", MAILBOX_PORT) + "=200"); + + MultiStageQueryStats mqStats = MultiStageQueryStats.emptyStats(0); + mqStats.mergeUpstream(1, new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(stage1Stats))); + QueryDispatcher.QueryResult result = new QueryDispatcher.QueryResult(EMPTY_RESULT, mqStats, 0L); + + // Cancel stats: stage 2 reports same server at 700ms. + StatMap cancelStage2Stats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + cancelStage2Stats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, + AdaptiveRoutingUpstreamTimings.senderKey("host-a", MAILBOX_PORT) + "=700"); + MultiStageQueryStats cancelStats = MultiStageQueryStats.emptyStats(0); + cancelStats.mergeUpstream(2, new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(cancelStage2Stats))); + + // Plan: leaf sends to both stage 1 and stage 2. + ServerRoutingStatsManager stats = mock(ServerRoutingStatsManager.class); + Set recorded = new HashSet<>(); + + applyUpstreamTimingsFromStats(result, + planWith(leafFragment(1, server), leafFragment(2, server)), + stats, REQUEST_ID, recorded, new QueryDispatcher.CancelOutcome(cancelStats, Set.of())); + + // Max of 200 and 700 -> 700ms. + verify(stats).recordStatsUponResponseArrival(REQUEST_ID, "instance-a", 700L); + } + + @Test + public void testCancelStatsUntrustedStagesSkipped() { + // Cancel stats contain timings for a non-trusted stage (intermediate). These should be ignored. + QueryServerInstance server = new QueryServerInstance("instance-a", "host-a", 9000, MAILBOX_PORT); + + // Empty normal result (no upstream stage stats). + MultiStageQueryStats mqStats = MultiStageQueryStats.emptyStats(0); + QueryDispatcher.QueryResult result = new QueryDispatcher.QueryResult(EMPTY_RESULT, mqStats, 0L); + + // Cancel stats have stage 1 timing, but stage 1 is NOT a trusted stage (no leaf sends to it). + StatMap cancelStage1Stats = + new StatMap<>(BaseMailboxReceiveOperator.StatKey.class); + cancelStage1Stats.merge(BaseMailboxReceiveOperator.StatKey.UPSTREAM_SERVER_RESPONSE_TIMES_MS, + AdaptiveRoutingUpstreamTimings.senderKey("host-a", MAILBOX_PORT) + "=500"); + MultiStageQueryStats cancelStats = MultiStageQueryStats.emptyStats(0); + cancelStats.mergeUpstream(1, new MultiStageQueryStats.StageStats.Closed( + List.of(MultiStageOperator.Type.MAILBOX_RECEIVE), List.of(cancelStage1Stats))); + + // Plan: only a non-leaf fragment (no leaf sends to stage 1). + ServerRoutingStatsManager stats = mock(ServerRoutingStatsManager.class); + Set recorded = new HashSet<>(); + + applyUpstreamTimingsFromStats(result, + planWith(nonLeafFragment(server)), + stats, REQUEST_ID, recorded, new QueryDispatcher.CancelOutcome(cancelStats, Set.of())); + + verify(stats, never()).recordStatsUponResponseArrival(REQUEST_ID, "instance-a", 500L); + Assert.assertTrue(recorded.isEmpty()); + } +} diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java index 889486043c92..06c9de4e4bf5 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java @@ -30,6 +30,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.LongSupplier; import org.apache.pinot.common.failuredetector.FailureDetector; import org.apache.pinot.common.metrics.BrokerMetrics; import org.apache.pinot.common.proto.Worker; @@ -257,6 +258,46 @@ public void testStatsManagerNotCalledWhenSubmitFails() Mockito.reset(failingQueryServer); } + /** + * When no indirect timing is extracted (e.g. reduce phase fails before stats propagate), the + * finally block must still decrement in-flight for all submitted servers. With no known timings + * and no cancel attempted, all servers fall into Tier 4 (no timing data, server responsive) + * and get -1 (decrement only, no EMA update) rather than the contaminated wall-clock. + */ + @Test + public void testStatsManagerRecordsSubmissionAndArrivalForDispatchedServers() + throws Exception { + // Clock ticks by tickMs on each call: submitTimeMs = 1000, then elapsedMs = 1100 - 1000 = 100. + final long tickMs = 100L; + AtomicLong fakeClockMs = new AtomicLong(1_000L); + QueryDispatcher dispatcher = newDispatcher(() -> fakeClockMs.getAndAdd(tickMs)); + + ServerRoutingStatsManager statsManager = Mockito.mock(ServerRoutingStatsManager.class); + String sql = "SELECT * FROM a"; + long requestId = REQUEST_ID_GEN.getAndIncrement(); + RequestContext context = new DefaultRequestContext(); + context.setRequestId(requestId); + DispatchableSubPlan plan = _queryEnvironment.planQuery(sql); + + Set expectedInstanceIds = getExpectedInstanceIds(plan); + try { + try (QueryThreadContext ignore = QueryThreadContext.openForMseTest()) { + dispatcher.submitAndReduce(context, plan, 10_000L, Map.of(), statsManager); + } catch (NullPointerException e) { + // expected: reduce phase fails with mocked MailboxService + } + for (String instanceId : expectedInstanceIds) { + Mockito.verify(statsManager).recordStatsForQuerySubmission(requestId, instanceId); + // No indirect timing was extracted (mocked MailboxService -> no real stats), so + // knownTimings is empty -> all servers get -1 (Tier 4: no timing data). + Mockito.verify(statsManager).recordStatsUponResponseArrival( + Mockito.eq(requestId), Mockito.eq(instanceId), Mockito.eq(-1L)); + } + } finally { + dispatcher.shutdown(); + } + } + @Test public void testRealStatsManagerInflightReturnsToZero() throws Exception { @@ -288,14 +329,7 @@ public void testRealStatsManagerInflightReturnsToZero() RequestContext context = new DefaultRequestContext(); context.setRequestId(requestId); DispatchableSubPlan plan = _queryEnvironment.planQuery(sql); - - Set expectedInstanceIds = new HashSet<>(); - for (DispatchablePlanFragment fragment : plan.getQueryStagesWithoutRoot()) { - for (QueryServerInstance server : fragment.getServerInstanceToWorkerIdMap().keySet()) { - expectedInstanceIds.add(server.getInstanceId()); - } - } - Assert.assertFalse(expectedInstanceIds.isEmpty()); + Set expectedInstanceIds = getExpectedInstanceIds(plan); try (QueryThreadContext ignore = QueryThreadContext.openForMseTest()) { _queryDispatcher.submitAndReduce(context, plan, 10_000L, Map.of(), statsManager); @@ -321,4 +355,105 @@ public void testRealStatsManagerInflightReturnsToZero() statsManager.shutDown(); } + + @Test + public void testNoStatsRecordedWhenAdaptiveRoutingDisabled() + throws Exception { + // When statsManager is null (adaptive routing disabled), submitAndReduce must not + // dereference it — the null guards in submitAndReduce must prevent any NPE from the + // stats path. Any exception thrown here should come from the mocked MailboxService + // (reduce phase), not from a null statsManager dereference. + String sql = "SELECT * FROM a"; + long requestId = REQUEST_ID_GEN.getAndIncrement(); + RequestContext context = new DefaultRequestContext(); + context.setRequestId(requestId); + DispatchableSubPlan plan = _queryEnvironment.planQuery(sql); + + try (QueryThreadContext ignore = QueryThreadContext.openForMseTest()) { + _queryDispatcher.submitAndReduce(context, plan, 10_000L, Map.of(), null); + } catch (NullPointerException e) { + // Acceptable: reduce phase NPEs because MailboxService is mocked. + // Verify the NPE did NOT originate from the stats-manager null-guard path. + for (StackTraceElement frame : e.getStackTrace()) { + Assert.assertFalse( + frame.getMethodName().contains("recordStats") || frame.getMethodName().contains("applyUpstreamTimings"), + "NPE must not originate from stats recording path, but got: " + frame); + } + } + } + + /** + * When {@code submit()} throws {@link TimeoutException} (one server never ACKs the dispatch), + * {@code submitAndReduce()} must catch it via {@code tryRecover()} and return a failed + * {@code QueryResult} — not propagate the exception. + * + *

Because {@code submit()} threw before the submission-stats loop ran, {@code incrementedServers} + * is empty and the finally block's tiered latency recorder has nothing to update, so + * {@code statsManager} receives zero interactions. + */ + @Test + public void testSubmitAndReduceReturnsResultWhenSubmitTimesOut() + throws Exception { + // All servers hang on submit so processResults() times out after the short deadline. + CountDownLatch neverClosingLatch = new CountDownLatch(1); + List allServers = new ArrayList<>(_queryServerMap.values()); + for (QueryServer server : allServers) { + Mockito.doAnswer(invocationOnMock -> { + neverClosingLatch.await(); + StreamObserver observer = invocationOnMock.getArgument(1); + observer.onCompleted(); + return null; + }).when(server).submit(Mockito.any(), Mockito.any()); + } + + ServerRoutingStatsManager statsManager = Mockito.mock(ServerRoutingStatsManager.class); + String sql = "SELECT * FROM a"; + long requestId = REQUEST_ID_GEN.getAndIncrement(); + RequestContext context = new DefaultRequestContext(); + context.setRequestId(requestId); + DispatchableSubPlan plan = _queryEnvironment.planQuery(sql); + + try { + try (QueryThreadContext ignore = QueryThreadContext.openForMseTest()) { + // submit() times out because all servers never ACK -> tryRecover() handles TimeoutException. + // Depending on whether cancelWithStats succeeds, this either returns a QueryResult with a + // processing exception or throws a RuntimeException wrapping the cancel failure. + QueryDispatcher.QueryResult result = + _queryDispatcher.submitAndReduce(context, plan, 200L, Map.of(), statsManager); + Assert.assertNotNull(result.getProcessingException(), + "Expected a processing exception in the result when submit times out"); + } + } catch (RuntimeException e) { + // Cancel phase may also throw if the hanging servers don't respond to the cancel RPC. + Assert.assertTrue(e.getMessage().contains("Error dispatching query"), + "Expected dispatch error from cancel phase, got: " + e.getMessage()); + } finally { + neverClosingLatch.countDown(); + for (QueryServer server : allServers) { + Mockito.reset(server); + } + } + + // submit() threw before recordStatsForQuerySubmission ran -> incrementedServers is empty. + // The finally block's tiered latency recorder iterates an empty incrementedServers set, + // so statsManager receives no interactions. + Mockito.verifyNoInteractions(statsManager); + } + + /** Creates a local {@link QueryDispatcher} wired to the shared query servers with an injected clock. */ + private QueryDispatcher newDispatcher(LongSupplier clock) { + return new QueryDispatcher(Mockito.mock(MailboxService.class), Mockito.mock(FailureDetector.class), null, true, + Duration.ofSeconds(1), clock); + } + + private Set getExpectedInstanceIds(DispatchableSubPlan plan) { + Set expectedInstanceIds = new HashSet<>(); + for (DispatchablePlanFragment fragment : plan.getQueryStagesWithoutRoot()) { + for (QueryServerInstance server : fragment.getServerInstanceToWorkerIdMap().keySet()) { + expectedInstanceIds.add(server.getInstanceId()); + } + } + Assert.assertFalse(expectedInstanceIds.isEmpty()); + return expectedInstanceIds; + } }