diff --git a/.github/workflows/hamilton-lsp.yml b/.github/workflows/hamilton-lsp.yml index 847436636..5a0e50add 100644 --- a/.github/workflows/hamilton-lsp.yml +++ b/.github/workflows/hamilton-lsp.yml @@ -27,7 +27,8 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - - uses: astral-sh/setup-uv@v6 + - name: Install uv + run: curl -LsSf https://astral.sh/uv/install.sh | sh - name: Install dependencies run: | uv pip install --system -e ${{ github.workspace }} -e . pytest diff --git a/.github/workflows/hamilton-main.yml b/.github/workflows/hamilton-main.yml index 2274860da..657c14a22 100644 --- a/.github/workflows/hamilton-main.yml +++ b/.github/workflows/hamilton-main.yml @@ -23,13 +23,13 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - - name: Install uv and set the python version - uses: astral-sh/setup-uv@v6 + - name: Install uv + run: curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Set up Python + uses: actions/setup-python@v5 with: python-version: '3.10' - enable-cache: true - cache-dependency-glob: "uv.lock" - activate-environment: true - name: Check linting with prek run: | @@ -73,13 +73,13 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - - name: Install uv and set the python version - uses: astral-sh/setup-uv@v6 + - name: Install uv + run: curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Set up Python + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - enable-cache: true - cache-dependency-glob: "uv.lock" - activate-environment: true - name: Test hamilton main package run: | diff --git a/hamilton/graph_utils.py b/hamilton/graph_utils.py index 1913ef3c0..b41c15f17 100644 --- a/hamilton/graph_utils.py +++ b/hamilton/graph_utils.py @@ -21,7 +21,9 @@ def is_submodule(child: ModuleType, parent: ModuleType): - return parent.__name__ in child.__name__ + if child is None: + return False + return child.__name__ == parent.__name__ or child.__name__.startswith(parent.__name__ + ".") def find_functions(function_module: ModuleType) -> list[tuple[str, Callable]]: diff --git a/tests/test_graph.py b/tests/test_graph.py index 1d802b02f..bbf473d79 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -20,6 +20,7 @@ import sys import uuid from itertools import permutations +from types import ModuleType import pandas as pd import pytest @@ -54,6 +55,38 @@ import tests.resources.typing_vs_not_typing +@pytest.mark.parametrize( + ("child_name", "parent_name", "expected"), + [ + ("foo", "foo", True), # same module + ("foo.bar", "foo", True), # direct child + ("foo.bar.baz", "foo", True), # nested child + ("foo.bar.baz", "foo.bar", True), # nested child of subpackage + ("foobar", "foo", False), # not a submodule, just a prefix without dot separator + ("hamilton.function_modifiers", "modifiers", False), # substring match, not a submodule + ("hamilton.function_modifiers.dependencies", "modifiers", False), # substring deeper + ("x.foo.y", "foo", False), # parent name in the middle, not a prefix + ("bar", "foo", False), # completely unrelated + ], + ids=[ + "same_module", + "direct_child", + "nested_child", + "nested_child_of_subpackage", + "prefix_without_dot", + "substring_not_submodule", + "substring_deeper", + "parent_in_middle", + "unrelated", + ], +) +def test_is_submodule(child_name, parent_name, expected): + """Tests that is_submodule correctly checks module hierarchy using prefix matching.""" + child = ModuleType(child_name) + parent = ModuleType(parent_name) + assert hamilton.graph_utils.is_submodule(child, parent) == expected + + def test_find_functions(): """Tests that we filter out _ functions when passed a module and don't pull in anything from the imports.""" expected = [ @@ -66,6 +99,39 @@ def test_find_functions(): assert actual == expected +def test_find_functions_excludes_imports_with_substring_module_name(): + """Regression test: imported functions should not be included when the user module's + name is a substring of the imported function's module path. + + Previously, is_submodule used `parent.__name__ in child.__name__` (substring match), + which caused e.g. a module named 'modifiers' to pull in functions from + 'hamilton.function_modifiers'. + """ + # Create a fake module named "modifiers" with one real function and two imports + mod = ModuleType("modifiers") + + def my_func(x: int) -> int: + return x * 2 + + # Assign the function to the module so inspect.getmodule can resolve it + my_func.__module__ = "modifiers" + mod.my_func = my_func + mod.source = fm.source + mod.value = fm.value + + # Register in sys.modules so inspect.getmodule can find it + sys.modules["modifiers"] = mod + try: + actual = hamilton.graph_utils.find_functions(mod) + actual_names = [name for name, _ in actual] + assert actual_names == ["my_func"], ( + f"Expected only ['my_func'] but got {actual_names}. " + "Imported functions from hamilton.function_modifiers should not be included." + ) + finally: + del sys.modules["modifiers"] + + def test_find_functions_from_temporary_function_module(): """Tests that we handle the TemporaryFunctionModule object correctly.""" expected = [