From df2dc6230ea7f8ddcadf53fad45b90a723ecb138 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Fri, 28 Feb 2025 10:04:51 -0800 Subject: [PATCH] WIP fix for serde + extract_fields --- hamilton/function_modifiers/expanders.py | 80 +++++++++++++++++------- 1 file changed, 58 insertions(+), 22 deletions(-) diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index 62610f898..d791ff39a 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -731,6 +731,36 @@ def _validate_extract_fields(fields: dict): ) +async def dict_generator_async( + *args, + fn, + fill_with, + fields, + **kwargs, +): + dict_generated = await fn(*args, **kwargs) + if fill_with is not None: + for field in fields: + if field not in dict_generated: + dict_generated[field] = fill_with + return dict_generated + + +async def dict_generator( + *args, + fn, + fill_with, + fields, + **kwargs, +): + dict_generated = fn(*args, **kwargs) + if fill_with is not None: + for field in fields: + if field not in dict_generated: + dict_generated[field] = fill_with + return dict_generated + + class extract_fields(base.SingleNodeNodeTransformer): """Extracts fields from a dictionary of output.""" @@ -804,29 +834,35 @@ def transform_node( """ fn = node_.callable base_doc = node_.documentation - + dict_generator_fn = ( + functools.partial(dict_generator, fn=fn, fill_with=self.fill_with, fields=self.fields) + if not (inspect.iscoroutinefunction(fn)) + else functools.partial( + dict_generator_async, fn=fn, fill_with=self.fill_with, fields=self.fields + ) + ) # if fn is async - if inspect.iscoroutinefunction(fn): - - async def dict_generator(*args, **kwargs): - dict_generated = await fn(*args, **kwargs) - if self.fill_with is not None: - for field in self.fields: - if field not in dict_generated: - dict_generated[field] = self.fill_with - return dict_generated - - else: - - def dict_generator(*args, **kwargs): - dict_generated = fn(*args, **kwargs) - if self.fill_with is not None: - for field in self.fields: - if field not in dict_generated: - dict_generated[field] = self.fill_with - return dict_generated - - output_nodes = [node_.copy_with(callabl=dict_generator)] + # if inspect.iscoroutinefunction(fn): + # + # async def dict_generator(*args, **kwargs): + # dict_generated = await fn(*args, **kwargs) + # if self.fill_with is not None: + # for field in self.fields: + # if field not in dict_generated: + # dict_generated[field] = self.fill_with + # return dict_generated + # + # else: + # + # def dict_generator(*args, **kwargs): + # dict_generated = fn(*args, **kwargs) + # if self.fill_with is not None: + # for field in self.fields: + # if field not in dict_generated: + # dict_generated[field] = self.fill_with + # return dict_generated + + output_nodes = [node_.copy_with(callabl=dict_generator_fn)] for field, field_type in self.fields.items(): doc_string = base_doc # default doc string of base function.