@@ -276,9 +276,167 @@ class DistributedPlanner {
276276 }
277277
278278 // Case 2: Sharded -- N RemoteScans + UNION ALL
279- const auto & shard_list = shards_.get_shards (table->table_name );
279+ // Optimization (#27): if WHERE contains shard_key = <literal> or
280+ // shard_key IN (<literals>), route to only the relevant shard(s).
281+ const auto & full_shard_list = shards_.get_shards (table->table_name );
282+ std::vector<ShardInfo> pruned = prune_shards (table, where_expr, full_shard_list);
280283 return make_sharded_union (table, where_expr, nullptr , 0 , nullptr , 0 ,
281- nullptr , nullptr , 0 , -1 , false , shard_list);
284+ nullptr , nullptr , 0 , -1 , false , pruned);
285+ }
286+
287+ // Shard pruning (#27): analyze WHERE for shard_key = <literal> or
288+ // shard_key IN (<literal_list>). Returns a subset of shards when pruning
289+ // is possible, otherwise returns the full shard list.
290+ std::vector<ShardInfo> prune_shards (const TableInfo* table,
291+ const sql_parser::AstNode* where_expr,
292+ const std::vector<ShardInfo>& all_shards) {
293+ if (!where_expr || all_shards.empty ()) return all_shards;
294+
295+ sql_parser::StringRef shard_key = shards_.get_shard_key (table->table_name );
296+ if (!shard_key.ptr || shard_key.len == 0 ) return all_shards;
297+
298+ // Try to extract shard key literal values from WHERE expression
299+ std::vector<size_t > target_indices;
300+ extract_shard_targets (where_expr, shard_key, table->table_name ,
301+ all_shards.size (), target_indices);
302+
303+ if (target_indices.empty ()) return all_shards;
304+
305+ // Deduplicate and collect matching shards
306+ std::vector<bool > included (all_shards.size (), false );
307+ for (size_t idx : target_indices) {
308+ if (idx < all_shards.size ()) included[idx] = true ;
309+ }
310+ std::vector<ShardInfo> result;
311+ for (size_t i = 0 ; i < all_shards.size (); ++i) {
312+ if (included[i]) result.push_back (all_shards[i]);
313+ }
314+ return result.empty () ? all_shards : result;
315+ }
316+
317+ // Walk a WHERE expression looking for shard_key = <literal> or
318+ // shard_key IN (<literal>, ...). Populates target_indices with
319+ // the shard index for each matched literal.
320+ void extract_shard_targets (const sql_parser::AstNode* expr,
321+ sql_parser::StringRef shard_key,
322+ sql_parser::StringRef table_name,
323+ size_t num_shards,
324+ std::vector<size_t >& target_indices) {
325+ if (!expr) return ;
326+
327+ // Check for shard_key = <literal>
328+ if (expr->type == sql_parser::NodeType::NODE_BINARY_OP) {
329+ sql_parser::StringRef op = expr->value ();
330+ if (op.len == 1 && op.ptr [0 ] == ' =' ) {
331+ const sql_parser::AstNode* left_node = expr->first_child ;
332+ const sql_parser::AstNode* right_node = left_node ? left_node->next_sibling : nullptr ;
333+ if (left_node && right_node) {
334+ // Check if one side is the shard key column and the other is a literal
335+ const sql_parser::AstNode* col_node = nullptr ;
336+ const sql_parser::AstNode* lit_node = nullptr ;
337+ if (is_shard_key_ref (left_node, shard_key) && is_literal (right_node)) {
338+ col_node = left_node; lit_node = right_node;
339+ } else if (is_shard_key_ref (right_node, shard_key) && is_literal (left_node)) {
340+ col_node = right_node; lit_node = left_node;
341+ }
342+ if (col_node && lit_node) {
343+ size_t idx = literal_to_shard_index (lit_node, table_name, num_shards);
344+ target_indices.push_back (idx);
345+ return ;
346+ }
347+ }
348+ }
349+ // Recurse into AND branches
350+ if (op.len == 3 &&
351+ (op.ptr [0 ] == ' A' || op.ptr [0 ] == ' a' ) &&
352+ (op.ptr [1 ] == ' N' || op.ptr [1 ] == ' n' ) &&
353+ (op.ptr [2 ] == ' D' || op.ptr [2 ] == ' d' )) {
354+ const sql_parser::AstNode* left_node = expr->first_child ;
355+ const sql_parser::AstNode* right_node = left_node ? left_node->next_sibling : nullptr ;
356+ // For AND, either branch matching is sufficient (both must be true,
357+ // so if one constrains the shard key, we can prune).
358+ std::vector<size_t > left_targets, right_targets;
359+ extract_shard_targets (left_node, shard_key, table_name, num_shards, left_targets);
360+ extract_shard_targets (right_node, shard_key, table_name, num_shards, right_targets);
361+ // Use whichever branch found shard targets (prefer the more selective one)
362+ if (!left_targets.empty () && !right_targets.empty ()) {
363+ // Intersect: both constraints must hold
364+ std::vector<bool > lset (num_shards, false ), rset (num_shards, false );
365+ for (auto i : left_targets) if (i < num_shards) lset[i] = true ;
366+ for (auto i : right_targets) if (i < num_shards) rset[i] = true ;
367+ for (size_t i = 0 ; i < num_shards; ++i) {
368+ if (lset[i] && rset[i]) target_indices.push_back (i);
369+ }
370+ } else if (!left_targets.empty ()) {
371+ target_indices.insert (target_indices.end (), left_targets.begin (), left_targets.end ());
372+ } else if (!right_targets.empty ()) {
373+ target_indices.insert (target_indices.end (), right_targets.begin (), right_targets.end ());
374+ }
375+ return ;
376+ }
377+ }
378+
379+ // Check for shard_key IN (literal_list)
380+ if (expr->type == sql_parser::NodeType::NODE_IN_LIST) {
381+ const sql_parser::AstNode* col_expr = expr->first_child ;
382+ if (col_expr && is_shard_key_ref (col_expr, shard_key)) {
383+ for (const sql_parser::AstNode* item = col_expr->next_sibling ; item; item = item->next_sibling ) {
384+ if (is_literal (item)) {
385+ target_indices.push_back (literal_to_shard_index (item, table_name, num_shards));
386+ } else {
387+ // Non-literal in IN list -- can't prune
388+ target_indices.clear ();
389+ return ;
390+ }
391+ }
392+ }
393+ }
394+ }
395+
396+ bool is_shard_key_ref (const sql_parser::AstNode* node, sql_parser::StringRef shard_key) const {
397+ if (!node) return false ;
398+ if (node->type == sql_parser::NodeType::NODE_COLUMN_REF ||
399+ node->type == sql_parser::NodeType::NODE_IDENTIFIER) {
400+ return node->value ().equals_ci (shard_key.ptr , shard_key.len );
401+ }
402+ if (node->type == sql_parser::NodeType::NODE_QUALIFIED_NAME) {
403+ // table.column -- check the column part
404+ const sql_parser::AstNode* c = node->first_child ;
405+ if (c && c->next_sibling ) {
406+ return c->next_sibling ->value ().equals_ci (shard_key.ptr , shard_key.len );
407+ }
408+ }
409+ return false ;
410+ }
411+
412+ static bool is_literal (const sql_parser::AstNode* node) {
413+ if (!node) return false ;
414+ return node->type == sql_parser::NodeType::NODE_LITERAL_INT ||
415+ node->type == sql_parser::NodeType::NODE_LITERAL_FLOAT ||
416+ node->type == sql_parser::NodeType::NODE_LITERAL_STRING;
417+ }
418+
419+ size_t literal_to_shard_index (const sql_parser::AstNode* lit,
420+ sql_parser::StringRef table_name,
421+ size_t num_shards) const {
422+ if (!lit || num_shards == 0 ) return 0 ;
423+ if (lit->type == sql_parser::NodeType::NODE_LITERAL_INT) {
424+ sql_parser::StringRef sv = lit->value ();
425+ int64_t val = 0 ;
426+ if (sv.ptr && sv.len > 0 ) val = std::strtoll (sv.ptr , nullptr , 10 );
427+ return shards_.shard_index_for_int (table_name, val);
428+ }
429+ if (lit->type == sql_parser::NodeType::NODE_LITERAL_STRING) {
430+ sql_parser::StringRef sv = lit->value ();
431+ return shards_.shard_index_for_string (table_name, sv.ptr , sv.len );
432+ }
433+ if (lit->type == sql_parser::NodeType::NODE_LITERAL_FLOAT) {
434+ sql_parser::StringRef sv = lit->value ();
435+ double dv = sv.ptr ? std::strtod (sv.ptr , nullptr ) : 0.0 ;
436+ int64_t iv = static_cast <int64_t >(dv);
437+ return shards_.shard_index_for_int (table_name, iv);
438+ }
439+ return 0 ;
282440 }
283441
284442 // Build N RemoteScans with UNION ALL
0 commit comments