From baadff4aa6c82e78166a6e7aa43af119367a59a7 Mon Sep 17 00:00:00 2001 From: damaozi <1811866786@qq.com> Date: Sat, 14 Feb 2026 02:05:00 +0800 Subject: [PATCH] feat: add return_fields parameter to search methods (#955) Add optional return_fields parameter to search_by_embedding, search_by_keywords_like, search_by_keywords_tfidf, and search_by_fulltext methods across all graph DB backends (neo4j, neo4j_community, polardb). When return_fields is specified (e.g., ['memory', 'status', 'tags']), the requested fields are included in each result dict alongside 'id' and 'score', eliminating the need for N+1 get_node() calls. Default is None, preserving full backward compatibility. Changes: - base.py: Updated docstring for search_by_embedding - neo4j.py: Added return_fields to search_by_embedding, modified Cypher RETURN clause and record construction - neo4j_community.py: Added return_fields to search_by_embedding, added _fetch_return_fields helper for direct vec_db path - polardb.py: Added return_fields to all 4 search methods, added _extract_fields_from_properties helper for JSON property extraction Closes #955 fix: add field name validation to prevent query injection in return_fields - Add _validate_return_fields() to BaseGraphDB base class with regex validation - Apply validation in neo4j.py, neo4j_community.py, polardb.py before field name concatenation - Add return_fields parameter to base class abstract method signature - Revert unrelated .get(node_id) change back to .get(node_id, None) - Add TestFieldNameValidation and TestNeo4jCommunitySearchReturnFields test classes (7 new tests) fix: resolve ruff lint and format issues for CI compliance --- src/memos/api/middleware/__init__.py | 9 +- src/memos/api/utils/api_keys.py | 2 +- src/memos/graph_dbs/base.py | 32 +- src/memos/graph_dbs/neo4j.py | 27 +- src/memos/graph_dbs/neo4j_community.py | 87 ++++- src/memos/graph_dbs/polardb.py | 104 +++++- .../read_skill_memory/process_skill_memory.py | 4 +- tests/graph_dbs/test_search_return_fields.py | 306 ++++++++++++++++++ 8 files changed, 541 insertions(+), 30 deletions(-) create mode 100644 tests/graph_dbs/test_search_return_fields.py diff --git a/src/memos/api/middleware/__init__.py b/src/memos/api/middleware/__init__.py index 64cbc5c60..fd39252f5 100644 --- a/src/memos/api/middleware/__init__.py +++ b/src/memos/api/middleware/__init__.py @@ -1,13 +1,14 @@ """Krolik middleware extensions for MemOS.""" -from .auth import verify_api_key, require_scope, require_admin, require_read, require_write +from .auth import require_admin, require_read, require_scope, require_write, verify_api_key from .rate_limit import RateLimitMiddleware + __all__ = [ - "verify_api_key", - "require_scope", + "RateLimitMiddleware", "require_admin", "require_read", + "require_scope", "require_write", - "RateLimitMiddleware", + "verify_api_key", ] diff --git a/src/memos/api/utils/api_keys.py b/src/memos/api/utils/api_keys.py index 559ddd355..29b493fd0 100644 --- a/src/memos/api/utils/api_keys.py +++ b/src/memos/api/utils/api_keys.py @@ -5,8 +5,8 @@ """ import hashlib -import os import secrets + from dataclasses import dataclass from datetime import datetime, timedelta diff --git a/src/memos/graph_dbs/base.py b/src/memos/graph_dbs/base.py index 130b66a3d..0bc4a54f8 100644 --- a/src/memos/graph_dbs/base.py +++ b/src/memos/graph_dbs/base.py @@ -1,12 +1,35 @@ +import re + from abc import ABC, abstractmethod from typing import Any, Literal +# Pattern for valid field names: alphanumeric and underscores, must start with letter or underscore +_VALID_FIELD_NAME_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + + class BaseGraphDB(ABC): """ Abstract base class for a graph database interface used in a memory-augmented RAG system. """ + @staticmethod + def _validate_return_fields(return_fields: list[str] | None) -> list[str]: + """Validate and sanitize return_fields to prevent query injection. + + Only allows alphanumeric characters and underscores in field names. + Silently drops invalid field names. + + Args: + return_fields: List of field names to validate. + + Returns: + List of valid field names. + """ + if not return_fields: + return [] + return [f for f in return_fields if _VALID_FIELD_NAME_RE.match(f)] + # Node (Memory) Management @abstractmethod def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: @@ -144,16 +167,23 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: # Search / recall operations @abstractmethod - def search_by_embedding(self, vector: list[float], top_k: int = 5, **kwargs) -> list[dict]: + def search_by_embedding( + self, vector: list[float], top_k: int = 5, return_fields: list[str] | None = None, **kwargs + ) -> list[dict]: """ Retrieve node IDs based on vector similarity. Args: vector (list[float]): The embedding vector representing query semantics. top_k (int): Number of top similar nodes to retrieve. + return_fields (list[str], optional): Additional node fields to include in results + (e.g., ["memory", "status", "tags"]). When provided, each result dict will + contain these fields in addition to 'id' and 'score'. + Defaults to None (only 'id' and 'score' are returned). Returns: list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. + If return_fields is specified, each dict also includes the requested fields. Notes: - This method may internally call a VecDB (e.g., Qdrant) or store embeddings in the graph DB itself. diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 23ce2408b..079b1c1b8 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -818,6 +818,7 @@ def search_by_embedding( user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: """ @@ -832,9 +833,14 @@ def search_by_embedding( threshold (float, optional): Minimum similarity score threshold (0 ~ 1). search_filter (dict, optional): Additional metadata filters for search results. Keys should match node properties, values are the expected values. + return_fields (list[str], optional): Additional node fields to include in results + (e.g., ["memory", "status", "tags"]). When provided, each result + dict will contain these fields in addition to 'id' and 'score'. + Defaults to None (only 'id' and 'score' are returned). Returns: list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. + If return_fields is specified, each dict also includes the requested fields. Notes: - This method uses Neo4j native vector indexing to search for similar nodes. @@ -886,11 +892,20 @@ def search_by_embedding( if where_clauses: where_clause = "WHERE " + " AND ".join(where_clauses) + return_clause = "RETURN node.id AS id, score" + if return_fields: + validated_fields = self._validate_return_fields(return_fields) + extra_fields = ", ".join( + f"node.{field} AS {field}" for field in validated_fields if field != "id" + ) + if extra_fields: + return_clause = f"RETURN node.id AS id, score, {extra_fields}" + query = f""" CALL db.index.vector.queryNodes('memory_vector_index', $k, $embedding) YIELD node, score {where_clause} - RETURN node.id AS id, score + {return_clause} """ parameters = {"embedding": vector, "k": top_k} @@ -920,7 +935,15 @@ def search_by_embedding( print(f"[search_by_embedding] query: {query},parameters: {parameters}") with self.driver.session(database=self.db_name) as session: result = session.run(query, parameters) - records = [{"id": record["id"], "score": record["score"]} for record in result] + records = [] + for record in result: + item = {"id": record["id"], "score": record["score"]} + if return_fields: + record_keys = record.keys() + for field in return_fields: + if field != "id" and field in record_keys: + item[field] = record[field] + records.append(item) # Threshold filtering after retrieval if threshold is not None: diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index e34313fa2..2dbef11b8 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -246,6 +246,39 @@ def get_children_with_embeddings( return child_nodes + def _fetch_return_fields( + self, + ids: list[str], + score_map: dict[str, float], + return_fields: list[str], + ) -> list[dict]: + """Fetch additional fields from Neo4j for given node IDs.""" + validated_fields = self._validate_return_fields(return_fields) + extra_fields = ", ".join( + f"n.{field} AS {field}" for field in validated_fields if field != "id" + ) + return_clause = "RETURN n.id AS id" + if extra_fields: + return_clause = f"RETURN n.id AS id, {extra_fields}" + + query = f""" + MATCH (n:Memory) + WHERE n.id IN $ids + {return_clause} + """ + with self.driver.session(database=self.db_name) as session: + neo4j_results = session.run(query, {"ids": ids}) + results = [] + for record in neo4j_results: + node_id = record["id"] + item = {"id": node_id, "score": score_map.get(node_id)} + record_keys = record.keys() + for field in return_fields: + if field != "id" and field in record_keys: + item[field] = record[field] + results.append(item) + return results + # Search / recall operations def search_by_embedding( self, @@ -258,6 +291,7 @@ def search_by_embedding( user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: """ @@ -273,9 +307,14 @@ def search_by_embedding( filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results. Example: {"and": [{"id": "xxx"}, {"A": "yyy"}]} or {"or": [{"id": "xxx"}, {"A": "yyy"}]} knowledgebase_ids (list[str], optional): List of knowledgebase IDs to filter by. + return_fields (list[str], optional): Additional node fields to include in results + (e.g., ["memory", "status", "tags"]). When provided, each result dict will + contain these fields in addition to 'id' and 'score'. + Defaults to None (only 'id' and 'score' are returned). Returns: list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. + If return_fields is specified, each dict also includes the requested fields. Notes: - This method uses an external vector database (not Neo4j) to perform the search. @@ -320,7 +359,14 @@ def search_by_embedding( # If no filter or knowledgebase_ids provided, return vector search results directly if not filter and not knowledgebase_ids: - return [{"id": r.id, "score": r.score} for r in vec_results] + if not return_fields: + return [{"id": r.id, "score": r.score} for r in vec_results] + # Need to fetch additional fields from Neo4j + vec_ids = [r.id for r in vec_results] + if not vec_ids: + return [] + score_map = {r.id: r.score for r in vec_results} + return self._fetch_return_fields(vec_ids, score_map, return_fields) # Extract IDs from vector search results vec_ids = [r.id for r in vec_results] @@ -363,22 +409,49 @@ def search_by_embedding( if filter_params: params.update(filter_params) + # Build RETURN clause with optional extra fields + return_clause = "RETURN n.id AS id" + if return_fields: + validated_fields = self._validate_return_fields(return_fields) + extra_fields = ", ".join( + f"n.{field} AS {field}" for field in validated_fields if field != "id" + ) + if extra_fields: + return_clause = f"RETURN n.id AS id, {extra_fields}" + # Query Neo4j to filter results query = f""" MATCH (n:Memory) {where_clause} - RETURN n.id AS id + {return_clause} """ logger.info(f"[search_by_embedding] query: {query}, params: {params}") with self.driver.session(database=self.db_name) as session: neo4j_results = session.run(query, params) - filtered_ids = {record["id"] for record in neo4j_results} + if return_fields: + # Build a map of id -> extra fields from Neo4j results + neo4j_data = {} + for record in neo4j_results: + node_id = record["id"] + record_keys = record.keys() + neo4j_data[node_id] = { + field: record[field] + for field in return_fields + if field != "id" and field in record_keys + } + filtered_ids = set(neo4j_data.keys()) + else: + filtered_ids = {record["id"] for record in neo4j_results} # Filter vector results by Neo4j filtered IDs and return with scores - filtered_results = [ - {"id": r.id, "score": r.score} for r in vec_results if r.id in filtered_ids - ] + filtered_results = [] + for r in vec_results: + if r.id in filtered_ids: + item = {"id": r.id, "score": r.score} + if return_fields and r.id in neo4j_data: + item.update(neo4j_data[r.id]) + filtered_results.append(item) return filtered_results @@ -1102,7 +1175,7 @@ def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]] # Merge embeddings into parsed nodes for parsed_node in parsed_nodes: node_id = parsed_node["id"] - parsed_node["metadata"]["embedding"] = vec_items_map.get(node_id, None) + parsed_node["metadata"]["embedding"] = vec_items_map.get(node_id) return parsed_nodes diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index f0a23e39b..5044564c3 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1690,6 +1690,36 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" raise NotImplementedError + def _extract_fields_from_properties( + self, properties: Any, return_fields: list[str] + ) -> dict[str, Any]: + """Extract requested fields from a PolarDB properties agtype/JSON value. + + Args: + properties: The raw properties value from a PolarDB row (agtype or JSON string). + return_fields: List of field names to extract. + + Returns: + dict with field_name -> value for each requested field found in properties. + """ + result = {} + return_fields = self._validate_return_fields(return_fields) + if not properties or not return_fields: + return result + try: + if isinstance(properties, str): + props = json.loads(properties) + elif isinstance(properties, dict): + props = properties + else: + props = json.loads(str(properties)) + except (json.JSONDecodeError, TypeError, ValueError): + return result + for field in return_fields: + if field != "id" and field in props: + result[field] = props[field] + return result + @timed def search_by_keywords_like( self, @@ -1700,6 +1730,7 @@ def search_by_keywords_like( user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: where_clauses = [] @@ -1751,10 +1782,14 @@ def search_by_keywords_like( where_clauses.append("""(properties -> '"memory"')::text LIKE %s""") where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - query = f""" - SELECT + select_clause = """SELECT ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, - agtype_object_field_text(properties, 'memory') as memory_text + agtype_object_field_text(properties, 'memory') as memory_text""" + if return_fields: + select_clause += ", properties" + + query = f""" + {select_clause} FROM "{self.db_name}_graph"."Memory" {where_clause} """ @@ -1775,7 +1810,11 @@ def search_by_keywords_like( id_val = str(oldid) if id_val.startswith('"') and id_val.endswith('"'): id_val = id_val[1:-1] - output.append({"id": id_val}) + item = {"id": id_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) logger.info( f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" ) @@ -1795,6 +1834,7 @@ def search_by_keywords_tfidf( knowledgebase_ids: list[str] | None = None, tsvector_field: str = "properties_tsvector_zh", tsquery_config: str = "jiebaqry", + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: where_clauses = [] @@ -1850,10 +1890,14 @@ def search_by_keywords_tfidf( where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" # Build fulltext search query - query = f""" - SELECT + select_clause = """SELECT ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, - agtype_object_field_text(properties, 'memory') as memory_text + agtype_object_field_text(properties, 'memory') as memory_text""" + if return_fields: + select_clause += ", properties" + + query = f""" + {select_clause} FROM "{self.db_name}_graph"."Memory" {where_clause} """ @@ -1874,7 +1918,11 @@ def search_by_keywords_tfidf( id_val = str(oldid) if id_val.startswith('"') and id_val.endswith('"'): id_val = id_val[1:-1] - output.append({"id": id_val}) + item = {"id": id_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) logger.info( f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" @@ -1897,6 +1945,7 @@ def search_by_fulltext( knowledgebase_ids: list[str] | None = None, tsvector_field: str = "properties_tsvector_zh", tsquery_config: str = "jiebacfg", + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: """ @@ -1914,10 +1963,12 @@ def search_by_fulltext( filter: filter conditions with 'and' or 'or' logic for search results. tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1 tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation) + return_fields: additional node fields to include in results **kwargs: other parameters (e.g. cube_name) Returns: - list[dict]: result list containing id and score + list[dict]: result list containing id and score. + If return_fields is specified, each dict also includes the requested fields. """ logger.info( f"[search_by_fulltext] query_words: {query_words},top_k:{top_k},scope:{scope},status:{status},threshold:{threshold},search_filter:{search_filter},user_name:{user_name},knowledgebase_ids:{knowledgebase_ids},filter:{filter}" @@ -1982,11 +2033,15 @@ def search_by_fulltext( logger.info(f"[search_by_fulltext] where_clause: {where_clause}") # Build fulltext search query - query = f""" - SELECT + select_clause = f"""SELECT ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, agtype_object_field_text(properties, 'memory') as memory_text, - ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank + ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank""" + if return_fields: + select_clause += ", properties" + + query = f""" + {select_clause} FROM "{self.db_name}_graph"."Memory" {where_clause} ORDER BY rank DESC @@ -2013,7 +2068,15 @@ def search_by_fulltext( # Apply threshold filter if specified if threshold is None or score_val >= threshold: - output.append({"id": id_val, "score": score_val}) + item = {"id": id_val, "score": score_val} + if return_fields: + properties = row[ + 3 + ] # properties column (after old_id, memory_text, rank) + item.update( + self._extract_fields_from_properties(properties, return_fields) + ) + output.append(item) elapsed_time = time.time() - start_time logger.info( f" polardb [search_by_fulltext] query completed time in {elapsed_time:.2f}s" @@ -2034,10 +2097,17 @@ def search_by_embedding( user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: """ Retrieve node IDs based on vector similarity using PostgreSQL vector operations. + + Args: + return_fields (list[str], optional): Additional node fields to include in results + (e.g., ["memory", "status", "tags"]). When provided, each result dict will + contain these fields in addition to 'id' and 'score'. + Defaults to None (only 'id' and 'score' are returned). """ # Build WHERE clause dynamically like nebular.py logger.info( @@ -2178,7 +2248,13 @@ def search_by_embedding( score_val = float(score) score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score if threshold is None or score_val >= threshold: - output.append({"id": id_val, "score": score_val}) + item = {"id": id_val, "score": score_val} + if return_fields: + properties = row[1] # properties column + item.update( + self._extract_fields_from_properties(properties, return_fields) + ) + output.append(item) return output[:top_k] finally: self._return_connection(conn) diff --git a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py index d39955ac2..a9a727b08 100644 --- a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py +++ b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py @@ -1019,7 +1019,9 @@ def process_skill_memory_fine( **kwargs, ) -> list[TextualMemoryItem]: skills_repo_backend = _get_skill_file_storage_location() - oss_client, missing_keys, flag = _skill_init(skills_repo_backend, oss_config, skills_dir_config) + oss_client, _missing_keys, flag = _skill_init( + skills_repo_backend, oss_config, skills_dir_config + ) if not flag: return [] diff --git a/tests/graph_dbs/test_search_return_fields.py b/tests/graph_dbs/test_search_return_fields.py new file mode 100644 index 000000000..82a50308b --- /dev/null +++ b/tests/graph_dbs/test_search_return_fields.py @@ -0,0 +1,306 @@ +""" +Regression tests for issue #955: search methods support specifying return fields. + +Tests that search_by_embedding (and other search methods) accept a `return_fields` +parameter and include the requested fields in the result dicts, eliminating the +need for N+1 get_node() calls. +""" + +import uuid + +from unittest.mock import MagicMock, patch + +import pytest + +from memos.configs.graph_db import Neo4jGraphDBConfig + + +@pytest.fixture +def neo4j_config(): + return Neo4jGraphDBConfig( + uri="bolt://localhost:7687", + user="neo4j", + password="test", + db_name="test_memory_db", + auto_create=False, + embedding_dimension=3, + ) + + +@pytest.fixture +def neo4j_db(neo4j_config): + with patch("neo4j.GraphDatabase") as mock_gd: + mock_driver = MagicMock() + mock_gd.driver.return_value = mock_driver + from memos.graph_dbs.neo4j import Neo4jGraphDB + + db = Neo4jGraphDB(neo4j_config) + db.driver = mock_driver + yield db + + +class TestNeo4jSearchReturnFields: + """Tests for Neo4jGraphDB.search_by_embedding with return_fields.""" + + def test_return_fields_included_in_results(self, neo4j_db): + """return_fields values are present in each result dict.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + node_id = str(uuid.uuid4()) + session_mock.run.return_value = [ + {"id": node_id, "score": 0.95, "memory": "hello", "status": "activated"}, + ] + + results = neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + return_fields=["memory", "status"], + ) + + assert len(results) == 1 + assert results[0]["id"] == node_id + assert results[0]["score"] == 0.95 + assert results[0]["memory"] == "hello" + assert results[0]["status"] == "activated" + + def test_backward_compatible_without_return_fields(self, neo4j_db): + """Without return_fields, only id and score are returned (old behavior).""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [ + {"id": str(uuid.uuid4()), "score": 0.9}, + ] + + results = neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + ) + + assert len(results) == 1 + assert set(results[0].keys()) == {"id", "score"} + + def test_cypher_return_clause_includes_fields(self, neo4j_db): + """Cypher RETURN clause contains the requested fields.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + return_fields=["memory", "tags"], + ) + + query = session_mock.run.call_args[0][0] + assert "node.memory AS memory" in query + assert "node.tags AS tags" in query + + def test_cypher_return_clause_default(self, neo4j_db): + """Without return_fields, RETURN clause only has id and score.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + ) + + query = session_mock.run.call_args[0][0] + assert "RETURN node.id AS id, score" in query + assert "node.memory" not in query + + def test_return_fields_skips_id_field(self, neo4j_db): + """Passing 'id' in return_fields does not duplicate it in RETURN clause.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + return_fields=["id", "memory"], + ) + + query = session_mock.run.call_args[0][0] + # 'id' should appear only once (as node.id AS id), not duplicated + assert query.count("node.id AS id") == 1 + assert "node.memory AS memory" in query + + def test_threshold_filtering_still_works_with_return_fields(self, neo4j_db): + """Threshold filtering works correctly when return_fields is specified.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [ + {"id": str(uuid.uuid4()), "score": 0.9, "memory": "high score"}, + {"id": str(uuid.uuid4()), "score": 0.3, "memory": "low score"}, + ] + + results = neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + threshold=0.5, + return_fields=["memory"], + ) + + assert len(results) == 1 + assert results[0]["memory"] == "high score" + + +class TestPolarDBExtractFieldsFromProperties: + """Tests for PolarDBGraphDB._extract_fields_from_properties helper.""" + + @pytest.fixture + def polardb_instance(self): + """Create a minimal PolarDB instance for testing the helper method.""" + with patch("memos.graph_dbs.polardb.PolarDBGraphDB.__init__", return_value=None): + from memos.graph_dbs.polardb import PolarDBGraphDB + + db = PolarDBGraphDB.__new__(PolarDBGraphDB) + yield db + + def test_extract_from_json_string(self, polardb_instance): + """Extract fields from a JSON string properties value.""" + props = '{"id": "abc", "memory": "hello", "status": "activated", "tags": ["a"]}' + result = polardb_instance._extract_fields_from_properties( + props, ["memory", "status", "tags"] + ) + assert result == {"memory": "hello", "status": "activated", "tags": ["a"]} + + def test_extract_from_dict(self, polardb_instance): + """Extract fields from a dict properties value.""" + props = {"id": "abc", "memory": "hello", "status": "activated"} + result = polardb_instance._extract_fields_from_properties(props, ["memory", "status"]) + assert result == {"memory": "hello", "status": "activated"} + + def test_extract_skips_id(self, polardb_instance): + """'id' field is skipped even if requested.""" + props = '{"id": "abc", "memory": "hello"}' + result = polardb_instance._extract_fields_from_properties(props, ["id", "memory"]) + assert result == {"memory": "hello"} + + def test_extract_missing_fields(self, polardb_instance): + """Missing fields are silently skipped.""" + props = '{"id": "abc", "memory": "hello"}' + result = polardb_instance._extract_fields_from_properties(props, ["memory", "nonexistent"]) + assert result == {"memory": "hello"} + + def test_extract_empty_properties(self, polardb_instance): + """Empty/None properties return empty dict.""" + assert polardb_instance._extract_fields_from_properties(None, ["memory"]) == {} + assert polardb_instance._extract_fields_from_properties("", ["memory"]) == {} + + def test_extract_invalid_json(self, polardb_instance): + """Invalid JSON returns empty dict without raising.""" + result = polardb_instance._extract_fields_from_properties("not-json", ["memory"]) + assert result == {} + + +class TestFieldNameValidation: + """Tests for _validate_return_fields injection prevention.""" + + def test_valid_field_names_pass(self): + from memos.graph_dbs.base import BaseGraphDB + + result = BaseGraphDB._validate_return_fields(["memory", "status", "tags", "user_name"]) + assert result == ["memory", "status", "tags", "user_name"] + + def test_invalid_field_names_rejected(self): + from memos.graph_dbs.base import BaseGraphDB + + # Cypher injection attempts + result = BaseGraphDB._validate_return_fields( + [ + "memory} RETURN n //", + "status; DROP", + "valid_field", + "a.b", + "field name", + "", + ] + ) + assert result == ["valid_field"] + + def test_none_returns_empty(self): + from memos.graph_dbs.base import BaseGraphDB + + assert BaseGraphDB._validate_return_fields(None) == [] + + def test_empty_list_returns_empty(self): + from memos.graph_dbs.base import BaseGraphDB + + assert BaseGraphDB._validate_return_fields([]) == [] + + def test_injection_in_cypher_query_prevented(self, neo4j_db): + """Malicious field names should not appear in the Cypher query.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + return_fields=["memory} RETURN n //", "valid_field"], + ) + + query = session_mock.run.call_args[0][0] + # Injection attempt should NOT appear in query + assert "memory}" not in query + assert "RETURN n //" not in query + # Valid field should appear + assert "node.valid_field AS valid_field" in query + + +class TestNeo4jCommunitySearchReturnFields: + """Tests for Neo4jCommunityGraphDB._fetch_return_fields with return_fields.""" + + @pytest.fixture + def neo4j_community_db(self): + """Create a minimal Neo4jCommunityGraphDB instance by patching __init__.""" + with patch( + "memos.graph_dbs.neo4j_community.Neo4jCommunityGraphDB.__init__", return_value=None + ): + from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB + + db = Neo4jCommunityGraphDB.__new__(Neo4jCommunityGraphDB) + db.driver = MagicMock() + db.db_name = "test_memory_db" + yield db + + def test_fetch_return_fields_queries_neo4j(self, neo4j_community_db): + """_fetch_return_fields builds correct Cypher and returns fields.""" + session_mock = neo4j_community_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [ + {"id": "node-1", "memory": "hello", "status": "activated"}, + ] + + results = neo4j_community_db._fetch_return_fields( + ids=["node-1"], + score_map={"node-1": 0.95}, + return_fields=["memory", "status"], + ) + + assert len(results) == 1 + assert results[0]["id"] == "node-1" + assert results[0]["score"] == 0.95 + assert results[0]["memory"] == "hello" + assert results[0]["status"] == "activated" + + query = session_mock.run.call_args[0][0] + assert "n.memory AS memory" in query + assert "n.status AS status" in query + + def test_fetch_return_fields_validates_names(self, neo4j_community_db): + """_fetch_return_fields rejects invalid field names.""" + session_mock = neo4j_community_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_community_db._fetch_return_fields( + ids=["node-1"], + score_map={"node-1": 0.95}, + return_fields=["memory} RETURN n //", "valid_field"], + ) + + query = session_mock.run.call_args[0][0] + assert "memory}" not in query + assert "n.valid_field AS valid_field" in query