diff --git a/examples/materialization/using_types/notebook.ipynb b/examples/materialization/using_types/notebook.ipynb
new file mode 100644
index 000000000..e46ba15fd
--- /dev/null
+++ b/examples/materialization/using_types/notebook.ipynb
@@ -0,0 +1,175 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "id": "initial_id",
+ "metadata": {
+ "collapsed": true,
+ "ExecuteTime": {
+ "end_time": "2024-06-25T00:00:13.662458Z",
+ "start_time": "2024-06-25T00:00:06.982077Z"
+ }
+ },
+ "source": "%load_ext hamilton.plugins.jupyter_magic\n",
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
+ "execution_count": 1
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-06-25T00:00:25.003646Z",
+ "start_time": "2024-06-25T00:00:24.322577Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "%%cell_to_module simple_etl --display\n",
+ "import pandas as pd\n",
+ "from sklearn import datasets\n",
+ "\n",
+ "from hamilton.htypes import DataLoaderMetadata, DataSaverMetadata\n",
+ "\n",
+ "\n",
+ "def raw_data() -> tuple[pd.DataFrame, DataLoaderMetadata]:\n",
+ " data = datasets.load_digits()\n",
+ " df = pd.DataFrame(data.data, columns=[f\"feature_{i}\" for i in range(data.data.shape[1])])\n",
+ " return df, DataLoaderMetadata.from_dataframe(df)\n",
+ "\n",
+ "\n",
+ "def transformed_data(raw_data: pd.DataFrame) -> pd.DataFrame:\n",
+ " return raw_data\n",
+ "\n",
+ "\n",
+ "def saved_data(transformed_data: pd.DataFrame, filepath: str) -> DataSaverMetadata:\n",
+ " transformed_data.to_csv(filepath)\n",
+ " return DataSaverMetadata.from_file_and_dataframe(filepath, transformed_data)\n"
+ ],
+ "id": "efd6c1b2417bb9cf",
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": "\n\n\n\n\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "execution_count": 2
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-06-25T00:00:37.889540Z",
+ "start_time": "2024-06-25T00:00:35.994131Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "from hamilton_sdk import adapters\n",
+ "\n",
+ "from hamilton import driver\n",
+ "\n",
+ "tracker = adapters.HamiltonTracker(\n",
+ " project_id=7, # modify this as needed\n",
+ " username=\"elijah@dagworks.io\",\n",
+ " dag_name=\"my_version_of_the_dag\",\n",
+ " tags={\"environment\": \"DEV\", \"team\": \"MY_TEAM\", \"version\": \"X\"},\n",
+ ")\n",
+ "dr = driver.Builder().with_config({}).with_modules(simple_etl).with_adapters(tracker).build()\n",
+ "dr.display_all_functions()"
+ ],
+ "id": "e9252f2a09228330",
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": "\n\n\n\n\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 3
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-06-25T00:00:53.746596Z",
+ "start_time": "2024-06-25T00:00:52.320439Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "dr.execute([\"saved_data\"], inputs={\"filepath\": \"data.csv\"})",
+ "id": "86c0d0f7da9a472b",
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Capturing execution run. Results can be found at http://localhost:8241/dashboard/project/7/runs/25\n",
+ "\n",
+ "\n",
+ "Captured execution run. Results can be found at http://localhost:8241/dashboard/project/7/runs/25\n",
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{'saved_data': }"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 4
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "outputs": [],
+ "execution_count": null,
+ "source": "",
+ "id": "e108601ca3a88aab"
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/materialization/using_types/run.py b/examples/materialization/using_types/run.py
new file mode 100644
index 000000000..eec65a491
--- /dev/null
+++ b/examples/materialization/using_types/run.py
@@ -0,0 +1,28 @@
+import logging
+
+import simple_etl
+from hamilton_sdk import adapters
+
+from hamilton import driver
+from hamilton.log_setup import setup_logging
+
+setup_logging(logging.DEBUG)
+
+tracker = adapters.HamiltonTracker(
+ project_id=15, # modify this as needed
+ username="elijah@dagworks.io",
+ dag_name="my_version_of_the_dag",
+ tags={"environment": "DEV", "team": "MY_TEAM", "version": "X"},
+) # note this slows down execution because there's 60 columns.
+# 30 columns adds about a 1 second.
+# 60 is therefore 2 seconds.
+
+dr = driver.Builder().with_config({}).with_modules(simple_etl).with_adapters(tracker).build()
+dr.display_all_functions("simple_etl.png")
+
+import time
+
+start = time.time()
+print(start)
+dr.execute(["saved_data"], inputs={"filepath": "data.csv"})
+print(time.time() - start)
diff --git a/examples/materialization/using_types/simple_etl.png b/examples/materialization/using_types/simple_etl.png
new file mode 100644
index 000000000..a6595a2ed
Binary files /dev/null and b/examples/materialization/using_types/simple_etl.png differ
diff --git a/examples/materialization/using_types/simple_etl.py b/examples/materialization/using_types/simple_etl.py
new file mode 100644
index 000000000..73ae688a1
--- /dev/null
+++ b/examples/materialization/using_types/simple_etl.py
@@ -0,0 +1,45 @@
+from hamilton.telemetry import disable_telemetry
+
+disable_telemetry()
+import logging
+
+import pandas as pd
+from sklearn import datasets
+
+from hamilton import node
+from hamilton.function_modifiers import loader, saver
+from hamilton.io import utils as io_utils
+from hamilton.log_setup import setup_logging
+
+setup_logging(logging.INFO)
+
+
+@loader()
+def raw_data() -> tuple[pd.DataFrame, dict]:
+ data = datasets.load_digits()
+ df = pd.DataFrame(data.data, columns=[f"feature_{i}" for i in range(data.data.shape[1])])
+ metadata = io_utils.get_dataframe_metadata(df)
+ return df, metadata
+
+
+def transformed_data(raw_data: pd.DataFrame) -> pd.DataFrame:
+ return raw_data
+
+
+@saver()
+def saved_data(transformed_data: pd.DataFrame, filepath: str) -> dict:
+ transformed_data.to_csv(filepath)
+ metadata = io_utils.get_file_and_dataframe_metadata(filepath, transformed_data)
+ return metadata
+
+
+if __name__ == "__main__":
+ import time
+
+ from hamilton_sdk.tracking import runs
+
+ df, metadata = raw_data()
+ t1 = time.time()
+ stats = runs.process_result(df, node.Node.from_fn(raw_data))
+ t2 = time.time()
+ print(t2 - t1)
diff --git a/hamilton/execution/graph_functions.py b/hamilton/execution/graph_functions.py
index dfadbbf90..3a86d8b60 100644
--- a/hamilton/execution/graph_functions.py
+++ b/hamilton/execution/graph_functions.py
@@ -218,7 +218,6 @@ def dfs_traverse(
except Exception as e:
pre_node_execute_errored = True
raise e
-
if adapter.does_method("do_node_execute", is_async=False):
result = adapter.call_lifecycle_method_sync(
"do_node_execute",
diff --git a/hamilton/function_modifiers/__init__.py b/hamilton/function_modifiers/__init__.py
index 1f52edcb2..a1baf34e0 100644
--- a/hamilton/function_modifiers/__init__.py
+++ b/hamilton/function_modifiers/__init__.py
@@ -92,3 +92,5 @@
# materialization stuff
load_from = adapters.load_from
save_to = adapters.save_to
+loader = macros.loader
+saver = macros.saver
diff --git a/hamilton/function_modifiers/adapters.py b/hamilton/function_modifiers/adapters.py
index 0fe712f90..48c4d9480 100644
--- a/hamilton/function_modifiers/adapters.py
+++ b/hamilton/function_modifiers/adapters.py
@@ -265,7 +265,6 @@ def filter_function(_inject_parameter=inject_parameter, **kwargs):
def inject_nodes(
self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable
) -> Tuple[Collection[node.Node], Dict[str, str]]:
- pass
"""Generates two nodes:
1. A node that loads the data from the data source, and returns that + metadata
2. A node that takes the data from the data source, injects it into, and runs, the function.
diff --git a/hamilton/function_modifiers/macros.py b/hamilton/function_modifiers/macros.py
index d855448cf..89e745070 100644
--- a/hamilton/function_modifiers/macros.py
+++ b/hamilton/function_modifiers/macros.py
@@ -5,6 +5,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import pandas as pd
+import typing_inspect
from hamilton import models, node
from hamilton.dev_utils.deprecation import deprecated
@@ -12,6 +13,7 @@
from hamilton.function_modifiers.configuration import ConfigResolver
from hamilton.function_modifiers.delayed import resolve as delayed_resolve
from hamilton.function_modifiers.dependencies import (
+ InvalidDecoratorException,
LiteralDependency,
SingleDependency,
UpstreamDependency,
@@ -870,3 +872,106 @@ def optional_config(self) -> Dict[str, Any]:
#
# def __init__(self, *transforms: Applicable, collapse=False):
# super(flow, self).__init__(*transforms, collapse=collapse, _chain=False)
+
+
+class loader(base.NodeCreator):
+ """Class to capture metadata."""
+
+ # def __init__(self, og_function: Callable):
+ # self.og_function = og_function
+ # super(loader,self).__init__()
+
+ def validate(self, fn: Callable):
+ print("called validate loader")
+ return_annotation = inspect.signature(fn).return_annotation
+ if return_annotation is inspect.Signature.empty:
+ raise InvalidDecoratorException(
+ f"Function: {fn.__qualname__} must have a return annotation."
+ )
+ # check that the type is a tuple[TYPE, dict]:
+ if not typing_inspect.is_tuple_type(return_annotation):
+ raise InvalidDecoratorException(f"Function: {fn.__qualname__} must return a tuple.")
+ if len(typing_inspect.get_args(return_annotation)) != 2:
+ raise InvalidDecoratorException(
+ f"Function: {fn.__qualname__} must return a tuple of length 2."
+ )
+ if not typing_inspect.get_args(return_annotation)[1] == dict:
+ raise InvalidDecoratorException(
+ f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict)."
+ )
+
+ def generate_nodes(self, fn: Callable, config) -> List[node.Node]:
+ """
+ Generates two nodes.
+ The first one is just the fn - with a slightly different name,
+ the second one uses the proper function name, but only returns
+ the first part of the tuple that the first returns.
+ We have to add tags appropriately.
+ :param fn:
+ :param config:
+ :return:
+ """
+ _name = "loader"
+ og_node = node.Node.from_fn(fn, name=_name)
+ new_tags = og_node.tags.copy()
+ new_tags.update(
+ {
+ "hamilton.data_loader": True,
+ "hamilton.data_loader.has_metadata": True,
+ "hamilton.data_loader.source": f"{fn.__name__}",
+ "hamilton.data_loader.classname": f"{fn.__name__}()",
+ "hamilton.data_loader.node": _name,
+ }
+ )
+
+ def filter_function(**kwargs):
+ return kwargs[f"{fn.__name__}.{_name}"][0]
+
+ filter_node = node.Node(
+ name=fn.__name__, # use original function name
+ callabl=filter_function,
+ typ=typing_inspect.get_args(og_node.type)[0],
+ input_types={f"{fn.__name__}.{_name}": og_node.type},
+ tags={
+ "hamilton.data_loader": True,
+ "hamilton.data_loader.has_metadata": False,
+ "hamilton.data_loader.source": f"{fn.__name__}",
+ "hamilton.data_loader.classname": f"{fn.__name__}()",
+ "hamilton.data_loader.node": fn.__name__,
+ },
+ )
+
+ return [og_node.copy_with(tags=new_tags, namespace=(fn.__name__,)), filter_node]
+
+
+class saver(base.NodeCreator):
+ """Class to capture metadata."""
+
+ def validate(self, fn: Callable):
+ print("called validate")
+ return_annotation = inspect.signature(fn).return_annotation
+ if return_annotation is inspect.Signature.empty:
+ raise InvalidDecoratorException(
+ f"Function: {fn.__qualname__} must have a return annotation."
+ )
+ # check that the return type is a dict
+ if return_annotation not in (dict, Dict):
+ raise InvalidDecoratorException(f"Function: {fn.__qualname__} must return a dict.")
+
+ def generate_nodes(self, fn: Callable, config) -> List[node.Node]:
+ """
+ All this does is add tags
+ :param fn:
+ :param config:
+ :return:
+ """
+ og_node = node.Node.from_fn(fn)
+ new_tags = og_node.tags.copy()
+ new_tags.update(
+ {
+ "hamilton.data_saver": True,
+ "hamilton.data_saver.sink": f"{og_node.name}",
+ "hamilton.data_saver.classname": f"{fn.__name__}()",
+ }
+ )
+ return [og_node.copy_with(tags=new_tags)]
diff --git a/hamilton/graph.py b/hamilton/graph.py
index 56edacff7..5847ef6a2 100644
--- a/hamilton/graph.py
+++ b/hamilton/graph.py
@@ -500,7 +500,9 @@ def _get_legend(
node_style.update(**modifier_style)
seen_node_types.add("materializer")
- if n.tags.get("hamilton.data_loader") and "load_data." in n.name:
+ if n.tags.get("hamilton.data_loader") and (
+ "load_data." in n.name or "loader" == n.tags.get("hamilton.data_loader.node")
+ ):
materializer_type = n.tags["hamilton.data_loader.classname"]
label = _get_node_label(n, type_string=materializer_type)
modifier_style = _get_function_modifier_style("materializer")
diff --git a/hamilton/node.py b/hamilton/node.py
index 4d4a4a764..2c6e5c73f 100644
--- a/hamilton/node.py
+++ b/hamilton/node.py
@@ -266,6 +266,9 @@ def from_fn(fn: Callable, name: str = None) -> "Node":
return_type = typing.get_type_hints(fn, **type_hint_kwargs).get("return")
if return_type is None:
raise ValueError(f"Missing type hint for return value in function {fn.__qualname__}.")
+ module = inspect.getmodule(fn).__name__
+ tags = {"module": module}
+
node_source = NodeType.STANDARD
# TODO - extract this into a function + clean up!
if typing_inspect.is_generic_type(return_type):
@@ -277,8 +280,7 @@ def from_fn(fn: Callable, name: str = None) -> "Node":
if typing_inspect.get_origin(hint) == Collect:
node_source = NodeType.COLLECT
break
- module = inspect.getmodule(fn).__name__
- tags = {"module": module}
+
if hasattr(fn, "__config_decorated__"):
tags["hamilton.config"] = ",".join(fn.__config_decorated__)
return Node(
diff --git a/ui/sdk/src/hamilton_sdk/tracking/pandas_col_stats.py b/ui/sdk/src/hamilton_sdk/tracking/pandas_col_stats.py
index c59a92da5..a013b1cad 100644
--- a/ui/sdk/src/hamilton_sdk/tracking/pandas_col_stats.py
+++ b/ui/sdk/src/hamilton_sdk/tracking/pandas_col_stats.py
@@ -1,5 +1,6 @@
from typing import Dict, List, Union
+import numpy as np
import pandas as pd
from hamilton_sdk.tracking import dataframe_stats as dfs
@@ -45,19 +46,8 @@ def quantiles(col: pd.Series, quantile_cuts: List[float]) -> Dict[float, float]:
def histogram(col: pd.Series, num_hist_bins: int = 10) -> Dict[str, int]:
- try:
- hist_counts = (
- col.value_counts(
- bins=num_hist_bins,
- )
- .sort_index()
- .to_dict()
- )
- except ValueError:
- return {}
- except AttributeError:
- return {}
- return {str(interval): interval_value for interval, interval_value in hist_counts.items()}
+ hist, bins = np.histogram(col, bins=num_hist_bins)
+ return {str(interval): interval_value for interval, interval_value in zip(bins, hist)}
def numeric_column_stats(
diff --git a/ui/sdk/src/hamilton_sdk/tracking/pandas_stats.py b/ui/sdk/src/hamilton_sdk/tracking/pandas_stats.py
index f190b8b08..a0763448f 100644
--- a/ui/sdk/src/hamilton_sdk/tracking/pandas_stats.py
+++ b/ui/sdk/src/hamilton_sdk/tracking/pandas_stats.py
@@ -1,3 +1,4 @@
+import logging
from typing import Any, Dict, Union
import pandas as pd
@@ -12,7 +13,11 @@
- for object types we should :shrug:
"""
+
dr = driver.Builder().with_modules(pcs).with_config({"config_key": "config_value"}).build()
+logger = logging.getLogger(__name__)
+
+import time
def _compute_stats(df: pd.DataFrame) -> Dict[str, Dict[str, Any]]:
@@ -44,11 +49,16 @@ def _compute_stats(df: pd.DataFrame) -> Dict[str, Dict[str, Any]]:
def execute_col(
target_output: str, col: pd.Series, name: Union[str, int], position: int
) -> Dict[str, Any]:
- """Get stats on a column."""
+ """Get stats on a column.
+ TODO: profile this and see where we can speed things up.
+ """
try:
+ t1 = time.time()
res = dr.execute(
[target_output], inputs={"col": col, "name": name, "position": position}
)
+
+ logger.info(f"Computed stats for column {name}, time taken was {time.time() - t1}")
res = res[target_output].to_dict()
except Exception:
# minimum that we want -- ideally we have hamilton handle errors and do best effort.
diff --git a/ui/sdk/src/hamilton_sdk/tracking/stats.py b/ui/sdk/src/hamilton_sdk/tracking/stats.py
index 0b968e27f..44f5a5c3e 100644
--- a/ui/sdk/src/hamilton_sdk/tracking/stats.py
+++ b/ui/sdk/src/hamilton_sdk/tracking/stats.py
@@ -100,10 +100,16 @@ def compute_stats_tuple(result: tuple, node_name: str, node_tags: dict) -> Stats
if isinstance(result[1], dict):
try:
# double check that it's JSON serializable
- raw_data = json.dumps(result[1])
+ raw_data = (
+ json.dumps(result[1])
+ if isinstance(result[1], dict)
+ else json.dumps(result[1].to_dict())
+ )
_metadata = json.loads(raw_data)
except Exception:
- _metadata = str(result[1])
+ _metadata = (
+ str(result[1]) if isinstance(result[1], dict) else str(result[1].to_dict())
+ )
if len(_metadata) > 1000:
_metadata = _metadata[:1000] + "..."
else: