@@ -88,6 +88,40 @@ class Database:
8888 def __init__ (self , executor : ApiExecutor ) -> None :
8989 self ._executor = executor
9090
91+ def _get_doc_serializer (
92+ self ,
93+ doc_serializer : Optional [Serializer [T ]] = None ,
94+ ) -> Serializer [T ]:
95+ """Figure out the document serializer, defaulting to `Json`.
96+
97+ Args:
98+ doc_serializer (Serializer | None): Optional serializer.
99+
100+ Returns:
101+ Serializer: Either the passed serializer or the default one.
102+ """
103+ if doc_serializer is None :
104+ return cast (Serializer [T ], self .serializer )
105+ else :
106+ return doc_serializer
107+
108+ def _get_doc_deserializer (
109+ self ,
110+ doc_deserializer : Optional [Deserializer [U , V ]] = None ,
111+ ) -> Deserializer [U , V ]:
112+ """Figure out the document deserializer, defaulting to `Json`.
113+
114+ Args:
115+ doc_deserializer (Deserializer | None): Optional deserializer.
116+
117+ Returns:
118+ Deserializer: Either the passed deserializer or the default one.
119+ """
120+ if doc_deserializer is None :
121+ return cast (Deserializer [U , V ], self .deserializer )
122+ else :
123+ return doc_deserializer
124+
91125 @property
92126 def connection (self ) -> Connection :
93127 """Return the HTTP connection."""
@@ -390,17 +424,11 @@ def collection(
390424 Returns:
391425 StandardCollection: Collection API wrapper.
392426 """
393- if doc_serializer is None :
394- serializer = cast (Serializer [T ], self .serializer )
395- else :
396- serializer = doc_serializer
397- if doc_deserializer is None :
398- deserializer = cast (Deserializer [U , V ], self .deserializer )
399- else :
400- deserializer = doc_deserializer
401-
402427 return StandardCollection [T , U , V ](
403- self ._executor , name , serializer , deserializer
428+ self ._executor ,
429+ name ,
430+ self ._get_doc_serializer (doc_serializer ),
431+ self ._get_doc_deserializer (doc_deserializer ),
404432 )
405433
406434 async def collections (
@@ -604,16 +632,11 @@ async def create_collection(
604632 def response_handler (resp : Response ) -> StandardCollection [T , U , V ]:
605633 if not resp .is_success :
606634 raise CollectionCreateError (resp , request )
607- if doc_serializer is None :
608- serializer = cast (Serializer [T ], self .serializer )
609- else :
610- serializer = doc_serializer
611- if doc_deserializer is None :
612- deserializer = cast (Deserializer [U , V ], self .deserializer )
613- else :
614- deserializer = doc_deserializer
615635 return StandardCollection [T , U , V ](
616- self ._executor , name , serializer , deserializer
636+ self ._executor ,
637+ name ,
638+ self ._get_doc_serializer (doc_serializer ),
639+ self ._get_doc_deserializer (doc_deserializer ),
617640 )
618641
619642 return await self ._executor .execute (request , response_handler )
@@ -661,16 +684,30 @@ def response_handler(resp: Response) -> bool:
661684
662685 return await self ._executor .execute (request , response_handler )
663686
664- def graph (self , name : str ) -> Graph :
687+ def graph (
688+ self ,
689+ name : str ,
690+ doc_serializer : Optional [Serializer [T ]] = None ,
691+ doc_deserializer : Optional [Deserializer [U , V ]] = None ,
692+ ) -> Graph [T , U , V ]:
665693 """Return the graph API wrapper.
666694
667695 Args:
668696 name (str): Graph name.
697+ doc_serializer (Serializer): Custom document serializer.
698+ This will be used only for document operations.
699+ doc_deserializer (Deserializer): Custom document deserializer.
700+ This will be used only for document operations.
669701
670702 Returns:
671703 Graph: Graph API wrapper.
672704 """
673- return Graph (self ._executor , name )
705+ return Graph [T , U , V ](
706+ self ._executor ,
707+ name ,
708+ self ._get_doc_serializer (doc_serializer ),
709+ self ._get_doc_deserializer (doc_deserializer ),
710+ )
674711
675712 async def has_graph (self , name : str ) -> Result [bool ]:
676713 """Check if a graph exists in the database.
@@ -720,17 +757,23 @@ def response_handler(resp: Response) -> List[GraphProperties]:
720757 async def create_graph (
721758 self ,
722759 name : str ,
760+ doc_serializer : Optional [Serializer [T ]] = None ,
761+ doc_deserializer : Optional [Deserializer [U , V ]] = None ,
723762 edge_definitions : Optional [Sequence [Json ]] = None ,
724763 is_disjoint : Optional [bool ] = None ,
725764 is_smart : Optional [bool ] = None ,
726765 options : Optional [GraphOptions | Json ] = None ,
727766 orphan_collections : Optional [Sequence [str ]] = None ,
728767 wait_for_sync : Optional [bool ] = None ,
729- ) -> Result [Graph ]:
768+ ) -> Result [Graph [ T , U , V ] ]:
730769 """Create a new graph.
731770
732771 Args:
733772 name (str): Graph name.
773+ doc_serializer (Serializer): Custom document serializer.
774+ This will be used only for document operations.
775+ doc_deserializer (Deserializer): Custom document deserializer.
776+ This will be used only for document operations.
734777 edge_definitions (list | None): List of edge definitions, where each edge
735778 definition entry is a dictionary with fields "collection" (name of the
736779 edge collection), "from" (list of vertex collection names) and "to"
@@ -782,10 +825,15 @@ async def create_graph(
782825 params = params ,
783826 )
784827
785- def response_handler (resp : Response ) -> Graph :
786- if resp .is_success :
787- return Graph (self ._executor , name )
788- raise GraphCreateError (resp , request )
828+ def response_handler (resp : Response ) -> Graph [T , U , V ]:
829+ if not resp .is_success :
830+ raise GraphCreateError (resp , request )
831+ return Graph [T , U , V ](
832+ self ._executor ,
833+ name ,
834+ self ._get_doc_serializer (doc_serializer ),
835+ self ._get_doc_deserializer (doc_deserializer ),
836+ )
789837
790838 return await self ._executor .execute (request , response_handler )
791839
0 commit comments