diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py index 92a8b763a..c0d7f249a 100644 --- a/hamilton/function_modifiers/base.py +++ b/hamilton/function_modifiers/base.py @@ -784,10 +784,13 @@ def _add_original_function_to_nodes(fn: Callable, nodes: List[node.Node]) -> Lis out = [] for node_ in nodes: current_originating_functions = node_.originating_functions - new_originating_functions = ( - current_originating_functions if current_originating_functions is not None else () - ) + (fn,) - out.append(node_.copy_with(originating_functions=new_originating_functions)) + if current_originating_functions and fn not in current_originating_functions: + new_originating_functions = ( + current_originating_functions if current_originating_functions is not None else () + ) + (fn,) + out.append(node_.copy_with(originating_functions=new_originating_functions)) + else: + out.append(node_) return out diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index 1b5e726d3..7f6024070 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -265,16 +265,21 @@ def replacement_function( new_input_types[param] = ( val # We just use the standard one, nothing is getting replaced ) + partial_func = functools.partial( + replacement_function, + **{parameter: val.value for parameter, val in literal_dependencies.items()}, + ) nodes.append( node_.copy_with( name=output_node, doc_string=docstring, # TODO -- change docstring - callabl=functools.partial( - replacement_function, - **{parameter: val.value for parameter, val in literal_dependencies.items()}, - ), + callabl=partial_func, input_types=new_input_types, include_refs=False, # Include refs is here as this is earlier than compile time + originating_functions=( + fn, + partial_func, + ), # TODO -- figure out why this isn't getting replaced later... ) )