Skip to content

Commit 9dc1353

Browse files
committed
Refactoring serializers
1 parent ba45028 commit 9dc1353

File tree

2 files changed

+170
-29
lines changed

2 files changed

+170
-29
lines changed

arangoasync/database.py

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,40 @@ class Database:
8888
def __init__(self, executor: ApiExecutor) -> None:
8989
self._executor = executor
9090

91+
def _get_doc_serializer(
92+
self,
93+
doc_serializer: Optional[Serializer[T]] = None,
94+
) -> Serializer[T]:
95+
"""Figure out the document serializer, defaulting to `Json`.
96+
97+
Args:
98+
doc_serializer (Serializer | None): Optional serializer.
99+
100+
Returns:
101+
Serializer: Either the passed serializer or the default one.
102+
"""
103+
if doc_serializer is None:
104+
return cast(Serializer[T], self.serializer)
105+
else:
106+
return doc_serializer
107+
108+
def _get_doc_deserializer(
109+
self,
110+
doc_deserializer: Optional[Deserializer[U, V]] = None,
111+
) -> Deserializer[U, V]:
112+
"""Figure out the document deserializer, defaulting to `Json`.
113+
114+
Args:
115+
doc_deserializer (Deserializer | None): Optional deserializer.
116+
117+
Returns:
118+
Deserializer: Either the passed deserializer or the default one.
119+
"""
120+
if doc_deserializer is None:
121+
return cast(Deserializer[U, V], self.deserializer)
122+
else:
123+
return doc_deserializer
124+
91125
@property
92126
def connection(self) -> Connection:
93127
"""Return the HTTP connection."""
@@ -390,17 +424,11 @@ def collection(
390424
Returns:
391425
StandardCollection: Collection API wrapper.
392426
"""
393-
if doc_serializer is None:
394-
serializer = cast(Serializer[T], self.serializer)
395-
else:
396-
serializer = doc_serializer
397-
if doc_deserializer is None:
398-
deserializer = cast(Deserializer[U, V], self.deserializer)
399-
else:
400-
deserializer = doc_deserializer
401-
402427
return StandardCollection[T, U, V](
403-
self._executor, name, serializer, deserializer
428+
self._executor,
429+
name,
430+
self._get_doc_serializer(doc_serializer),
431+
self._get_doc_deserializer(doc_deserializer),
404432
)
405433

406434
async def collections(
@@ -604,16 +632,11 @@ async def create_collection(
604632
def response_handler(resp: Response) -> StandardCollection[T, U, V]:
605633
if not resp.is_success:
606634
raise CollectionCreateError(resp, request)
607-
if doc_serializer is None:
608-
serializer = cast(Serializer[T], self.serializer)
609-
else:
610-
serializer = doc_serializer
611-
if doc_deserializer is None:
612-
deserializer = cast(Deserializer[U, V], self.deserializer)
613-
else:
614-
deserializer = doc_deserializer
615635
return StandardCollection[T, U, V](
616-
self._executor, name, serializer, deserializer
636+
self._executor,
637+
name,
638+
self._get_doc_serializer(doc_serializer),
639+
self._get_doc_deserializer(doc_deserializer),
617640
)
618641

619642
return await self._executor.execute(request, response_handler)
@@ -661,16 +684,30 @@ def response_handler(resp: Response) -> bool:
661684

662685
return await self._executor.execute(request, response_handler)
663686

664-
def graph(self, name: str) -> Graph:
687+
def graph(
688+
self,
689+
name: str,
690+
doc_serializer: Optional[Serializer[T]] = None,
691+
doc_deserializer: Optional[Deserializer[U, V]] = None,
692+
) -> Graph[T, U, V]:
665693
"""Return the graph API wrapper.
666694
667695
Args:
668696
name (str): Graph name.
697+
doc_serializer (Serializer): Custom document serializer.
698+
This will be used only for document operations.
699+
doc_deserializer (Deserializer): Custom document deserializer.
700+
This will be used only for document operations.
669701
670702
Returns:
671703
Graph: Graph API wrapper.
672704
"""
673-
return Graph(self._executor, name)
705+
return Graph[T, U, V](
706+
self._executor,
707+
name,
708+
self._get_doc_serializer(doc_serializer),
709+
self._get_doc_deserializer(doc_deserializer),
710+
)
674711

675712
async def has_graph(self, name: str) -> Result[bool]:
676713
"""Check if a graph exists in the database.
@@ -720,17 +757,23 @@ def response_handler(resp: Response) -> List[GraphProperties]:
720757
async def create_graph(
721758
self,
722759
name: str,
760+
doc_serializer: Optional[Serializer[T]] = None,
761+
doc_deserializer: Optional[Deserializer[U, V]] = None,
723762
edge_definitions: Optional[Sequence[Json]] = None,
724763
is_disjoint: Optional[bool] = None,
725764
is_smart: Optional[bool] = None,
726765
options: Optional[GraphOptions | Json] = None,
727766
orphan_collections: Optional[Sequence[str]] = None,
728767
wait_for_sync: Optional[bool] = None,
729-
) -> Result[Graph]:
768+
) -> Result[Graph[T, U, V]]:
730769
"""Create a new graph.
731770
732771
Args:
733772
name (str): Graph name.
773+
doc_serializer (Serializer): Custom document serializer.
774+
This will be used only for document operations.
775+
doc_deserializer (Deserializer): Custom document deserializer.
776+
This will be used only for document operations.
734777
edge_definitions (list | None): List of edge definitions, where each edge
735778
definition entry is a dictionary with fields "collection" (name of the
736779
edge collection), "from" (list of vertex collection names) and "to"
@@ -782,10 +825,15 @@ async def create_graph(
782825
params=params,
783826
)
784827

785-
def response_handler(resp: Response) -> Graph:
786-
if resp.is_success:
787-
return Graph(self._executor, name)
788-
raise GraphCreateError(resp, request)
828+
def response_handler(resp: Response) -> Graph[T, U, V]:
829+
if not resp.is_success:
830+
raise GraphCreateError(resp, request)
831+
return Graph[T, U, V](
832+
self._executor,
833+
name,
834+
self._get_doc_serializer(doc_serializer),
835+
self._get_doc_deserializer(doc_deserializer),
836+
)
789837

790838
return await self._executor.execute(request, response_handler)
791839

arangoasync/graph.py

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,43 @@
1+
__all__ = ["Graph"]
2+
3+
4+
from typing import Generic, TypeVar
5+
6+
from arangoasync.collection import EdgeCollection, VertexCollection
7+
from arangoasync.exceptions import GraphListError
18
from arangoasync.executor import ApiExecutor
9+
from arangoasync.request import Method, Request
10+
from arangoasync.response import Response
11+
from arangoasync.result import Result
12+
from arangoasync.serialization import Deserializer, Serializer
13+
from arangoasync.typings import GraphProperties, Json, Jsons
214

15+
T = TypeVar("T") # Serializer type
16+
U = TypeVar("U") # Deserializer loads
17+
V = TypeVar("V") # Deserializer loads_many
318

4-
class Graph:
19+
20+
class Graph(Generic[T, U, V]):
521
"""Graph API wrapper, representing a graph in ArangoDB.
622
723
Args:
8-
executor: API executor. Required to execute the API requests.
24+
executor (APIExecutor): Required to execute the API requests.
25+
name (str): Graph name.
26+
doc_serializer (Serializer): Document serializer.
27+
doc_deserializer (Deserializer): Document deserializer.
928
"""
1029

11-
def __init__(self, executor: ApiExecutor, name: str) -> None:
30+
def __init__(
31+
self,
32+
executor: ApiExecutor,
33+
name: str,
34+
doc_serializer: Serializer[T],
35+
doc_deserializer: Deserializer[U, V],
36+
) -> None:
1237
self._executor = executor
1338
self._name = name
39+
self._doc_serializer = doc_serializer
40+
self._doc_deserializer = doc_deserializer
1441

1542
def __repr__(self) -> str:
1643
return f"<Graph {self._name}>"
@@ -19,3 +46,69 @@ def __repr__(self) -> str:
1946
def name(self) -> str:
2047
"""Name of the graph."""
2148
return self._name
49+
50+
@property
51+
def serializer(self) -> Serializer[Json]:
52+
"""Return the serializer."""
53+
return self._executor.serializer
54+
55+
@property
56+
def deserializer(self) -> Deserializer[Json, Jsons]:
57+
"""Return the deserializer."""
58+
return self._executor.deserializer
59+
60+
async def properties(self) -> Result[GraphProperties]:
61+
"""Get the properties of the graph.
62+
63+
Returns:
64+
GraphProperties: Properties of the graph.
65+
66+
Raises:
67+
GraphListError: If the operation fails.
68+
69+
References:
70+
- `get-a-graph <https://docs.arangodb.com/3.12/develop/http-api/graphs/named-graphs/#get-a-graph>`__
71+
""" # noqa: E501
72+
request = Request(method=Method.GET, endpoint=f"/_api/gharial/{self._name}")
73+
74+
def response_handler(resp: Response) -> GraphProperties:
75+
if not resp.is_success:
76+
raise GraphListError(resp, request)
77+
body = self.deserializer.loads(resp.raw_body)
78+
return GraphProperties(body["graph"])
79+
80+
return await self._executor.execute(request, response_handler)
81+
82+
def vertex_collection(self, name: str) -> VertexCollection[T, U, V]:
83+
"""Returns the vertex collection API wrapper.
84+
85+
Args:
86+
name (str): Vertex collection name.
87+
88+
Returns:
89+
VertexCollection: Vertex collection API wrapper.
90+
"""
91+
return VertexCollection[T, U, V](
92+
executor=self._executor,
93+
graph=self._name,
94+
name=name,
95+
doc_serializer=self._doc_serializer,
96+
doc_deserializer=self._doc_deserializer,
97+
)
98+
99+
def edge_collection(self, name: str) -> EdgeCollection[T, U, V]:
100+
"""Returns the edge collection API wrapper.
101+
102+
Args:
103+
name (str): Edge collection name.
104+
105+
Returns:
106+
EdgeCollection: Edge collection API wrapper.
107+
"""
108+
return EdgeCollection[T, U, V](
109+
executor=self._executor,
110+
graph=self._name,
111+
name=name,
112+
doc_serializer=self._doc_serializer,
113+
doc_deserializer=self._doc_deserializer,
114+
)

0 commit comments

Comments
 (0)