diff --git a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py index 282f6165d..9770bff77 100644 --- a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py +++ b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py @@ -1207,6 +1207,11 @@ def export_keys_values( f"Rank {dist.get_rank()} has accumulated count {accumulated_counts} which is different from expected {local_max_rows}, " f"difference: {accumulated_counts - local_max_rows}" ) + + if len(keys_list) == 0: + return torch.empty(0, dtype=torch.int64, device=device), torch.empty( + 0, 0, device=device + ) return torch.cat(keys_list), torch.cat(values_list, dim=0) def incremental_dump( diff --git a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py index 22eae8737..6aaaa57ef 100644 --- a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py +++ b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py @@ -1143,3 +1143,36 @@ def test_empty_batch(opt_type, opt_params, dim, caching, deterministic, PS): del os.environ["DEMB_DETERMINISM_MODE"] print("all check passed") + + +def test_export_keys_values_empty_table(): + """export_keys_values() on a never-used table must return empty tensors + (not crash on torch.cat([])) -- covers the empty keys_list guard.""" + assert torch.cuda.is_available() + device = torch.device("cuda:0") + + opt = DynamicEmbTableOptions( + dim=8, + init_capacity=1024, + max_capacity=1024, + index_type=torch.int64, + embedding_dtype=torch.float32, + device_id=0, + score_strategy=DynamicEmbScoreStrategy.TIMESTAMP, + caching=False, + local_hbm_for_values=1024**3, + ) + bdebt = BatchedDynamicEmbeddingTablesV2( + table_names=["t0"], + table_options=[opt], + feature_table_map=[0], + pooling_mode=DynamicEmbPoolingMode.SUM, + optimizer=EmbOptimType.SGD, + learning_rate=0.1, + ) + + keys, values = bdebt.export_keys_values("t0", device) + + assert keys.shape == (0,) + assert keys.dtype == torch.int64 + assert values.shape[0] == 0