diff --git a/pinot-common/src/main/java/org/apache/pinot/common/assignment/InstanceAssignmentConfigUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/assignment/InstanceAssignmentConfigUtils.java
index d798c5b2b8bc..d63b8810b12f 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/assignment/InstanceAssignmentConfigUtils.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/assignment/InstanceAssignmentConfigUtils.java
@@ -42,13 +42,19 @@ private InstanceAssignmentConfigUtils() {
/**
* Returns whether the COMPLETED segments should be relocated (offloaded from CONSUMING instances to COMPLETED
* instances) for a LLC real-time table based on the given table config.
- *
COMPLETED segments should be relocated iff the COMPLETED instance assignment is configured or (for
- * backward-compatibility) COMPLETED server tag is overridden to be different from the CONSUMING server tag.
+ *
instanceAssignmentConfigMap = tableConfig.getInstanceAssignmentConfigMap();
return (instanceAssignmentConfigMap != null
&& instanceAssignmentConfigMap.get(InstancePartitionsType.COMPLETED.toString()) != null)
+ || InstancePartitionsUtils.hasPreConfiguredInstancePartitions(tableConfig, InstancePartitionsType.COMPLETED)
|| TagNameUtils.isRelocateCompletedSegments(tableConfig.getTenantConfig());
}
diff --git a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeReplicaGroupSegmentAssignmentTest.java b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeReplicaGroupSegmentAssignmentTest.java
index 6cc9640fdb25..fabb750057d2 100644
--- a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeReplicaGroupSegmentAssignmentTest.java
+++ b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeReplicaGroupSegmentAssignmentTest.java
@@ -20,6 +20,7 @@
import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -27,6 +28,7 @@
import org.apache.helix.HelixManager;
import org.apache.helix.store.zk.ZkHelixPropertyStore;
import org.apache.helix.zookeeper.datamodel.ZNRecord;
+import org.apache.pinot.common.assignment.InstanceAssignmentConfigUtils;
import org.apache.pinot.common.assignment.InstancePartitions;
import org.apache.pinot.common.restlet.resources.RebalanceConfig;
import org.apache.pinot.common.utils.LLCSegmentName;
@@ -622,6 +624,232 @@ public void testSubsetPartitionAssignment() {
"Partitions 0 and 7 should map to different instance sets to avoid hotspots");
}
+ /**
+ * Regression for imported COMPLETED instance partitions during rebalance.
+ *
+ * Prod scenario: Table B imports Table A's CONSUMING and COMPLETED instance partitions via
+ * {@code instancePartitionsMap} (no {@code instanceAssignmentConfigMap} for COMPLETED). During rebalance,
+ * {@link org.apache.pinot.controller.helix.core.rebalance.TableRebalancer#getInstancePartitionsMap} only
+ * loads COMPLETED IPs when {@link InstanceAssignmentConfigUtils#shouldRelocateCompletedSegments} returns true.
+ *
+ *
Bug: {@code shouldRelocateCompletedSegments} ignored {@code instancePartitionsMap}, so rebalance
+ * passed only CONSUMING IPs to segment assignment and completed segments stayed on one server per partition.
+ *
+ *
This test mirrors {@code TableRebalancer}: the rebalance {@code instancePartitionsMap} is built from
+ * {@code shouldRelocateCompletedSegments(tableConfig)}, not by always passing COMPLETED IPs directly.
+ */
+ @Test
+ public void testImportedInstancePartitionsWithMultipleServersPerPartition() {
+ int numReplicas = 2;
+ int numReplicaGroups = numReplicas;
+ int numServers = 64;
+ int numInstancePartitions = 8;
+ int numServersPerPartitionPerRG = numServers / numReplicaGroups / numInstancePartitions; // 4
+ int numStreamPartitions = 8; // exact prod case: 1:1 mapping with instance partitions
+ int numSegmentsPerPartition = 20; // 19 completed + 1 consuming per stream partition
+ String serverPrefix = "Server_";
+
+ List allServers = SegmentAssignmentTestUtils.getNameList(serverPrefix, numServers);
+
+ // Table B: imports CONSUMING/COMPLETED IPs from Table A; no instanceAssignmentConfigMap for COMPLETED.
+ Map streamConfigs = FakeStreamConfigUtils.getDefaultLowLevelStreamConfigs().getStreamConfigsMap();
+ Map importedInstancePartitions = new HashMap<>();
+ importedInstancePartitions.put(InstancePartitionsType.CONSUMING, "sourceTable_CONSUMING");
+ importedInstancePartitions.put(InstancePartitionsType.COMPLETED, "sourceTable_COMPLETED");
+ TableConfig tableConfig =
+ new TableConfigBuilder(TableType.REALTIME).setTableName(RAW_TABLE_NAME).setNumReplicas(numReplicas)
+ .setStreamConfigs(streamConfigs)
+ .setInstancePartitionsMap(importedInstancePartitions)
+ .setSegmentAssignmentConfigMap(Collections.singletonMap(InstancePartitionsType.COMPLETED.toString(),
+ new SegmentAssignmentConfig(AssignmentStrategy.REPLICA_GROUP_SEGMENT_ASSIGNMENT_STRATEGY)))
+ .build();
+ assertTrue(InstanceAssignmentConfigUtils.shouldRelocateCompletedSegments(tableConfig),
+ "Imported COMPLETED instance partitions must trigger completed-segment relocation during rebalance");
+ SegmentAssignment segmentAssignment =
+ SegmentAssignmentFactory.getSegmentAssignment(createHelixManager(), tableConfig, null);
+
+ // Both CONSUMING and COMPLETED instance partitions use the same 64 servers
+ // (imported from the same source table). 8 explicit partitions, 2 RGs, 4 servers per partition.
+ InstancePartitions consumingInstancePartitions = new InstancePartitions(CONSUMING_INSTANCE_PARTITIONS_NAME);
+ InstancePartitions completedInstancePartitions = new InstancePartitions(COMPLETED_INSTANCE_PARTITIONS_NAME);
+ for (int replicaGroupId = 0; replicaGroupId < numReplicaGroups; replicaGroupId++) {
+ for (int partitionId = 0; partitionId < numInstancePartitions; partitionId++) {
+ List serversInPartition = new ArrayList<>(numServersPerPartitionPerRG);
+ int baseIndex = replicaGroupId * (numServers / numReplicaGroups)
+ + partitionId * numServersPerPartitionPerRG;
+ for (int i = 0; i < numServersPerPartitionPerRG; i++) {
+ serversInPartition.add(allServers.get(baseIndex + i));
+ }
+ consumingInstancePartitions.setInstances(partitionId, replicaGroupId, serversInPartition);
+ completedInstancePartitions.setInstances(partitionId, replicaGroupId,
+ new ArrayList<>(serversInPartition));
+ }
+ }
+
+ // Create segments: 8 stream partitions × 20 segments each = 160 total
+ List segments = new ArrayList<>();
+ for (int partitionId = 0; partitionId < numStreamPartitions; partitionId++) {
+ for (int seqNum = 0; seqNum < numSegmentsPerPartition; seqNum++) {
+ segments.add(new LLCSegmentName(RAW_TABLE_NAME, partitionId, seqNum,
+ System.currentTimeMillis()).getSegmentName());
+ }
+ }
+
+ // Build currentAssignment: all segments initially pinned to first server in each partition
+ // (simulates the state produced by the buggy path where COMPLETED IPs were not loaded).
+ Map> currentAssignment = new TreeMap<>();
+ for (int partitionId = 0; partitionId < numStreamPartitions; partitionId++) {
+ for (int seqNum = 0; seqNum < numSegmentsPerPartition; seqNum++) {
+ String segmentName = segments.get(partitionId * numSegmentsPerPartition + seqNum);
+ int instancePartitionId = partitionId % numInstancePartitions;
+ List instancesForSegment = new ArrayList<>(numReplicaGroups);
+ for (int rg = 0; rg < numReplicaGroups; rg++) {
+ instancesForSegment.add(
+ consumingInstancePartitions.getInstances(instancePartitionId, rg).get(0));
+ }
+ boolean isConsuming = (seqNum == numSegmentsPerPartition - 1);
+ String state = isConsuming ? SegmentStateModel.CONSUMING : SegmentStateModel.ONLINE;
+ currentAssignment.put(segmentName,
+ SegmentAssignmentUtils.getInstanceStateMap(instancesForSegment, state));
+ }
+ }
+
+ // Mirror TableRebalancer.getInstancePartitionsMap(): CONSUMING always; COMPLETED only when relocation applies.
+ Map instancePartitionsMap =
+ buildRebalanceInstancePartitionsMap(tableConfig, consumingInstancePartitions, completedInstancePartitions);
+
+ // Simulate pre-fix rebalance: only CONSUMING IPs in the map (COMPLETED omitted even though import is configured).
+ // Completed segments stay pinned to one server per stream partition.
+ Map consumingOnlyPartitionsMap = new TreeMap<>();
+ consumingOnlyPartitionsMap.put(InstancePartitionsType.CONSUMING, consumingInstancePartitions);
+ RebalanceConfig rebalanceConfig = new RebalanceConfig();
+ rebalanceConfig.setIncludeConsuming(true);
+ Map> consumingOnlyAssignment =
+ segmentAssignment.rebalanceTable(currentAssignment, consumingOnlyPartitionsMap, null, null,
+ rebalanceConfig);
+ int expectedCompletedSegmentServerCountWhenCompletedIpsOmitted = numStreamPartitions * numReplicaGroups;
+ assertEquals(countServersWithCompletedSegments(consumingOnlyAssignment),
+ expectedCompletedSegmentServerCountWhenCompletedIpsOmitted,
+ "Without COMPLETED instance partitions (buggy rebalance path), completed segments should stay on "
+ + expectedCompletedSegmentServerCountWhenCompletedIpsOmitted + " servers");
+
+ assertTrue(instancePartitionsMap.containsKey(InstancePartitionsType.COMPLETED),
+ "Rebalance must include COMPLETED instance partitions when imported via instancePartitionsMap");
+ Map> newAssignment =
+ segmentAssignment.rebalanceTable(currentAssignment, instancePartitionsMap, null, null, rebalanceConfig);
+
+ // COMPLETED segments should be distributed across all 64 servers
+ HashSet completedServers = collectServersWithCompletedSegments(newAssignment);
+ assertEquals(completedServers.size(), numServers,
+ "All " + numServers + " servers should have COMPLETED segments, but only "
+ + completedServers.size() + " were used.");
+
+ // Verify per-partition spread for COMPLETED segments
+ for (int replicaGroupId = 0; replicaGroupId < numReplicaGroups; replicaGroupId++) {
+ for (int instPartId = 0; instPartId < numInstancePartitions; instPartId++) {
+ List partitionServers = completedInstancePartitions.getInstances(instPartId, replicaGroupId);
+ HashSet usedInPartition = new HashSet<>();
+ for (String server : partitionServers) {
+ if (completedServers.contains(server)) {
+ usedInPartition.add(server);
+ }
+ }
+ assertEquals(usedInPartition.size(), numServersPerPartitionPerRG,
+ "COMPLETED: instance partition " + instPartId + " in RG " + replicaGroupId + " should use all "
+ + numServersPerPartitionPerRG + " servers, but only " + usedInPartition.size() + " were used");
+ }
+ }
+
+ // Verify per-stream-partition coverage: for each stream partition, the set of servers across all
+ // of its completed segments should equal the full set of servers in that instance partition (per RG).
+ for (int partitionId = 0; partitionId < numStreamPartitions; partitionId++) {
+ int instancePartitionId = partitionId % numInstancePartitions;
+ for (int replicaGroupId = 0; replicaGroupId < numReplicaGroups; replicaGroupId++) {
+ List expectedServers = completedInstancePartitions.getInstances(instancePartitionId, replicaGroupId);
+ HashSet actualServers = new HashSet<>();
+ for (int seqNum = 0; seqNum < numSegmentsPerPartition - 1; seqNum++) {
+ String segmentName = segments.get(partitionId * numSegmentsPerPartition + seqNum);
+ Map instanceStateMap = newAssignment.get(segmentName);
+ for (String server : instanceStateMap.keySet()) {
+ if (expectedServers.contains(server)) {
+ actualServers.add(server);
+ }
+ }
+ }
+ assertEquals(actualServers, new HashSet<>(expectedServers),
+ "Stream partition " + partitionId + " in RG " + replicaGroupId
+ + " should have completed segments on all " + numServersPerPartitionPerRG
+ + " servers in instance partition " + instancePartitionId
+ + ", but only used: " + actualServers);
+ }
+ }
+
+ // --- Bootstrap variant ---
+ RebalanceConfig bootstrapConfig = new RebalanceConfig();
+ bootstrapConfig.setIncludeConsuming(true);
+ bootstrapConfig.setBootstrap(true);
+ Map> bootstrapAssignment =
+ segmentAssignment.rebalanceTable(currentAssignment, instancePartitionsMap, null, null, bootstrapConfig);
+
+ HashSet bootstrapCompletedServers = collectServersWithCompletedSegments(bootstrapAssignment);
+ assertEquals(bootstrapCompletedServers.size(), numServers,
+ "Bootstrap: all " + numServers + " servers should have COMPLETED segments, but only "
+ + bootstrapCompletedServers.size() + " were used.");
+
+ // Same per-stream-partition check for bootstrap
+ for (int partitionId = 0; partitionId < numStreamPartitions; partitionId++) {
+ int instancePartitionId = partitionId % numInstancePartitions;
+ for (int replicaGroupId = 0; replicaGroupId < numReplicaGroups; replicaGroupId++) {
+ List expectedServers = completedInstancePartitions.getInstances(instancePartitionId, replicaGroupId);
+ HashSet actualServers = new HashSet<>();
+ for (int seqNum = 0; seqNum < numSegmentsPerPartition - 1; seqNum++) {
+ String segmentName = segments.get(partitionId * numSegmentsPerPartition + seqNum);
+ Map instanceStateMap = bootstrapAssignment.get(segmentName);
+ for (String server : instanceStateMap.keySet()) {
+ if (expectedServers.contains(server)) {
+ actualServers.add(server);
+ }
+ }
+ }
+ assertEquals(actualServers, new HashSet<>(expectedServers),
+ "Bootstrap: stream partition " + partitionId + " in RG " + replicaGroupId
+ + " should have completed segments on all " + numServersPerPartitionPerRG
+ + " servers in instance partition " + instancePartitionId
+ + ", but only used: " + actualServers);
+ }
+ }
+ }
+
+ /**
+ * Builds the instance-partitions map the same way as
+ * {@link org.apache.pinot.controller.helix.core.rebalance.TableRebalancer#getInstancePartitionsMap}.
+ */
+ private static Map buildRebalanceInstancePartitionsMap(
+ TableConfig tableConfig, InstancePartitions consumingInstancePartitions,
+ InstancePartitions completedInstancePartitions) {
+ Map instancePartitionsMap = new TreeMap<>();
+ instancePartitionsMap.put(InstancePartitionsType.CONSUMING, consumingInstancePartitions);
+ if (InstanceAssignmentConfigUtils.shouldRelocateCompletedSegments(tableConfig)) {
+ instancePartitionsMap.put(InstancePartitionsType.COMPLETED, completedInstancePartitions);
+ }
+ return instancePartitionsMap;
+ }
+
+ private static HashSet collectServersWithCompletedSegments(
+ Map> assignment) {
+ HashSet completedServers = new HashSet<>();
+ for (Map.Entry> entry : assignment.entrySet()) {
+ if (entry.getValue().containsValue(SegmentStateModel.ONLINE)) {
+ completedServers.addAll(entry.getValue().keySet());
+ }
+ }
+ return completedServers;
+ }
+
+ private static int countServersWithCompletedSegments(Map> assignment) {
+ return collectServersWithCompletedSegments(assignment).size();
+ }
+
private HelixManager createHelixManager() {
HelixManager helixManager = mock(HelixManager.class);
ZkHelixPropertyStore propertyStore = mock(ZkHelixPropertyStore.class);