From 764319abbe7023b70a0a4a29b838fa6d0f736190 Mon Sep 17 00:00:00 2001 From: Octopus Date: Mon, 23 Mar 2026 11:00:53 -0500 Subject: [PATCH] feat: add MiniMax as alternative LLM provider for RAG dataflows Add MiniMax M2.7 as an alternative LLM provider alongside OpenAI in both faiss_rag and conversational_rag dataflows using Hamilton's @config.when pattern. Changes: - Use @config.when_not(provider="minimax") for OpenAI (backward-compatible default) - Use @config.when(provider="minimax") for MiniMax via OpenAI-compatible API - Update valid_configs.jsonl with minimax configuration - Update tags.json with minimax tag - Update README.md with MiniMax usage documentation - Add 35 unit tests + 6 integration tests MiniMax M2.7 features: - 1M token context window - OpenAI-compatible API at https://api.minimax.io/v1 - Configurable via MINIMAX_API_KEY environment variable --- .../dagworks/conversational_rag/README.md | 48 ++- .../dagworks/conversational_rag/__init__.py | 75 ++++- .../dagworks/conversational_rag/tags.json | 2 +- .../test_conversational_rag.py | 296 ++++++++++++++++++ .../conversational_rag/valid_configs.jsonl | 3 +- .../contrib/dagworks/faiss_rag/README.md | 48 ++- .../contrib/dagworks/faiss_rag/__init__.py | 55 +++- .../contrib/dagworks/faiss_rag/tags.json | 2 +- .../dagworks/faiss_rag/test_faiss_rag.py | 265 ++++++++++++++++ .../dagworks/faiss_rag/valid_configs.jsonl | 3 +- 10 files changed, 771 insertions(+), 26 deletions(-) create mode 100644 contrib/hamilton/contrib/dagworks/conversational_rag/test_conversational_rag.py create mode 100644 contrib/hamilton/contrib/dagworks/faiss_rag/test_faiss_rag.py diff --git a/contrib/hamilton/contrib/dagworks/conversational_rag/README.md b/contrib/hamilton/contrib/dagworks/conversational_rag/README.md index e0b092115..56801aa9e 100644 --- a/contrib/hamilton/contrib/dagworks/conversational_rag/README.md +++ b/contrib/hamilton/contrib/dagworks/conversational_rag/README.md @@ -23,7 +23,9 @@ This module shows a conversational retrieval augmented generation (RAG) example Apache Hamilton. It shows you how you might structure your code with Apache Hamilton to create a RAG pipeline that takes into account conversation. -This example uses [FAISS](https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/) + and in memory vector store and the OpenAI LLM provider. +This example uses [FAISS](https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/) + an in memory vector store with multi-provider LLM support. +It supports **OpenAI** (default) and **[MiniMax](https://www.minimax.io/)** as LLM providers, +switchable via Hamilton's `@config.when` pattern. The implementation of the FAISS vector store uses the LangChain wrapper around it. That's because this was the simplest way to get this example up without requiring someone having to host and manage a proper vector store. @@ -57,6 +59,7 @@ Here we just ask for the final result, but if you wanted to, you could ask for o you can then introspect or log for debugging/evaluation purposes. Note if you want more platform integrations, you can add adapters that will do this automatically for you, e.g. like we have the `PrintLn` adapter here. +**Using OpenAI (default):** ```python # import the module from hamilton import driver @@ -64,7 +67,7 @@ from hamilton import lifecycle dr = ( driver.Builder() .with_modules(conversational_rag) - .with_config({}) + .with_config({}) # defaults to OpenAI # this prints the inputs and outputs of each step. .with_adapters(lifecycle.PrintLn(verbosity=2)) .build() @@ -102,6 +105,34 @@ result = dr.execute( print(result) ``` +**Using MiniMax:** + +Set `MINIMAX_API_KEY` in your environment, then pass `{"provider": "minimax"}` in the config: +```python +from hamilton import driver, lifecycle +dr = ( + driver.Builder() + .with_modules(conversational_rag) + .with_config({"provider": "minimax"}) + .with_adapters(lifecycle.PrintLn(verbosity=2)) + .build() +) +result = dr.execute( + ["conversational_rag_response"], + inputs={ + "input_texts": [ + "harrison worked at kensho", + "stefan worked at Stitch Fix", + ], + "question": "where did stefan work?", + "chat_history": [] + }, +) +print(result) +``` +MiniMax uses the [MiniMax-M2.7](https://www.minimax.io/) model with a 1M token context window +via an OpenAI-compatible API endpoint. + # How to extend this module What you'd most likely want to do is: @@ -112,16 +143,21 @@ What you'd most likely want to do is: With (1) you can import any vector store/library that you want. You should draw out the process you would like, and that should then map to Apache Hamilton functions. With (2) you can import any LLM provider that you want, just use `@config.when` if you -want to switch between multiple providers. +want to switch between multiple providers. OpenAI and MiniMax are already supported. With (3) you can add more functions that create parts of the prompt. # Configuration Options -There is no configuration needed for this module. + +| Config Key | Values | Description | +|-----------|--------|-------------| +| `provider` | `"minimax"` | Use MiniMax M2.7 as the LLM. Requires `MINIMAX_API_KEY` env var. | +| *(empty)* | | Default: uses OpenAI. Requires `OPENAI_API_KEY` env var. | # Limitations -You need to have the OPENAI_API_KEY in your environment. -It should be accessible from your code by doing `os.environ["OPENAI_API_KEY"]`. +You need to have the appropriate API key in your environment: +- **OpenAI** (default): `OPENAI_API_KEY` +- **MiniMax**: `MINIMAX_API_KEY` The code does not check the context length, so it may fail if the context passed is too long for the LLM you send it to. diff --git a/contrib/hamilton/contrib/dagworks/conversational_rag/__init__.py b/contrib/hamilton/contrib/dagworks/conversational_rag/__init__.py index 350133032..2a3039f1c 100644 --- a/contrib/hamilton/contrib/dagworks/conversational_rag/__init__.py +++ b/contrib/hamilton/contrib/dagworks/conversational_rag/__init__.py @@ -16,10 +16,12 @@ # under the License. import logging +import os logger = logging.getLogger(__name__) from hamilton import contrib +from hamilton.function_modifiers import config with contrib.catch_import_errors(__name__, __file__, logger): import openai @@ -53,8 +55,9 @@ def standalone_question_prompt(chat_history: list[str], question: str) -> str: ).format(chat_history=chat_history_str, question=question) -def standalone_question(standalone_question_prompt: str, llm_client: openai.OpenAI) -> str: - """Asks the LLM to create a standalone question from the prompt. +@config.when_not(provider="minimax") +def standalone_question__openai(standalone_question_prompt: str, llm_client: openai.OpenAI) -> str: + """Asks OpenAI to create a standalone question from the prompt. :param standalone_question_prompt: the prompt with context. :param llm_client: the llm client to use. @@ -67,6 +70,21 @@ def standalone_question(standalone_question_prompt: str, llm_client: openai.Open return response.choices[0].message.content +@config.when(provider="minimax") +def standalone_question__minimax(standalone_question_prompt: str, llm_client: openai.OpenAI) -> str: + """Asks MiniMax to create a standalone question from the prompt. + + :param standalone_question_prompt: the prompt with context. + :param llm_client: the llm client to use. + :return: the standalone question. + """ + response = llm_client.chat.completions.create( + model="MiniMax-M2.7", + messages=[{"role": "user", "content": standalone_question_prompt}], + ) + return response.choices[0].message.content + + def vector_store(input_texts: list[str]) -> VectorStoreRetriever: """A Vector store. This function populates and creates one for querying. @@ -112,13 +130,31 @@ def answer_prompt(context: str, standalone_question: str) -> str: return template.format(context=context, question=standalone_question) -def llm_client() -> openai.OpenAI: - """The LLM client to use for the RAG model.""" +@config.when_not(provider="minimax") +def llm_client__openai() -> openai.OpenAI: + """The OpenAI LLM client (default). + + Uses the OPENAI_API_KEY environment variable for authentication. + """ return openai.OpenAI() -def conversational_rag_response(answer_prompt: str, llm_client: openai.OpenAI) -> str: - """Creates the RAG response from the LLM model for the given prompt. +@config.when(provider="minimax") +def llm_client__minimax() -> openai.OpenAI: + """The MiniMax LLM client via OpenAI-compatible API. + + Uses the MINIMAX_API_KEY environment variable for authentication. + MiniMax provides an OpenAI-compatible endpoint at https://api.minimax.io/v1. + """ + return openai.OpenAI( + base_url="https://api.minimax.io/v1", + api_key=os.environ.get("MINIMAX_API_KEY"), + ) + + +@config.when_not(provider="minimax") +def conversational_rag_response__openai(answer_prompt: str, llm_client: openai.OpenAI) -> str: + """Creates the RAG response using OpenAI. :param answer_prompt: the prompt to send to the LLM. :param llm_client: the LLM client to use. @@ -131,11 +167,29 @@ def conversational_rag_response(answer_prompt: str, llm_client: openai.OpenAI) - return response.choices[0].message.content +@config.when(provider="minimax") +def conversational_rag_response__minimax(answer_prompt: str, llm_client: openai.OpenAI) -> str: + """Creates the RAG response using MiniMax M2.7. + + MiniMax M2.7 is a high-performance model with 1M token context window. + + :param answer_prompt: the prompt to send to the LLM. + :param llm_client: the LLM client to use. + :return: the response from the LLM. + """ + response = llm_client.chat.completions.create( + model="MiniMax-M2.7", + messages=[{"role": "user", "content": answer_prompt}], + ) + return response.choices[0].message.content + + if __name__ == "__main__": import __init__ as conversational_rag from hamilton import driver, lifecycle + # Default: uses OpenAI (config={} or config={"provider": "openai"}) dr = ( driver.Builder() .with_modules(conversational_rag) @@ -176,3 +230,12 @@ def conversational_rag_response(answer_prompt: str, llm_client: openai.OpenAI) - }, ) ) + + # To use MiniMax instead, set MINIMAX_API_KEY and use: + # dr = ( + # driver.Builder() + # .with_modules(conversational_rag) + # .with_config({"provider": "minimax"}) + # .with_adapters(lifecycle.PrintLn(verbosity=2)) + # .build() + # ) diff --git a/contrib/hamilton/contrib/dagworks/conversational_rag/tags.json b/contrib/hamilton/contrib/dagworks/conversational_rag/tags.json index 87ee694c0..990f7b10b 100644 --- a/contrib/hamilton/contrib/dagworks/conversational_rag/tags.json +++ b/contrib/hamilton/contrib/dagworks/conversational_rag/tags.json @@ -1,6 +1,6 @@ { "schema": "1.0", - "use_case_tags": ["LLM", "openai", "RAG", "retrieval augmented generation", "FAISS"], + "use_case_tags": ["LLM", "openai", "minimax", "RAG", "retrieval augmented generation", "FAISS"], "secondary_tags": { "language": "English" } diff --git a/contrib/hamilton/contrib/dagworks/conversational_rag/test_conversational_rag.py b/contrib/hamilton/contrib/dagworks/conversational_rag/test_conversational_rag.py new file mode 100644 index 000000000..0a315c5eb --- /dev/null +++ b/contrib/hamilton/contrib/dagworks/conversational_rag/test_conversational_rag.py @@ -0,0 +1,296 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for the conversational_rag dataflow with multi-provider support.""" + +import os +import sys +from unittest.mock import MagicMock, patch + +import pytest + +sys.path.insert( + 0, + os.path.dirname(os.path.abspath(__file__)), +) + +import openai + +from hamilton import driver + +from hamilton.contrib.dagworks import conversational_rag + + +# ──────────────────────────── Unit Tests ──────────────────────────── + + +class TestStandaloneQuestionPrompt: + """Tests for standalone_question_prompt (provider-independent).""" + + def test_includes_chat_history(self): + result = conversational_rag.standalone_question_prompt( + chat_history=["Human: Hi", "AI: Hello"], + question="Where did he work?", + ) + assert "Human: Hi" in result + assert "AI: Hello" in result + + def test_includes_question(self): + result = conversational_rag.standalone_question_prompt( + chat_history=[], + question="What is Hamilton?", + ) + assert "What is Hamilton?" in result + + def test_empty_chat_history(self): + result = conversational_rag.standalone_question_prompt( + chat_history=[], + question="test question", + ) + assert "test question" in result + + +class TestAnswerPrompt: + """Tests for answer_prompt (provider-independent).""" + + def test_includes_context_and_question(self): + result = conversational_rag.answer_prompt( + context="Hamilton builds DAGs", + standalone_question="What is Hamilton?", + ) + assert "Hamilton builds DAGs" in result + assert "What is Hamilton?" in result + + +class TestOpenAIProvider: + """Tests for OpenAI provider configuration.""" + + def test_llm_client_openai_returns_client(self): + with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}): + client = conversational_rag.llm_client__openai() + assert isinstance(client, openai.OpenAI) + + def test_standalone_question_openai_calls_correct_model(self): + mock_client = MagicMock(spec=openai.OpenAI) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "standalone question" + mock_client.chat.completions.create.return_value = mock_response + + result = conversational_rag.standalone_question__openai( + standalone_question_prompt="test prompt", + llm_client=mock_client, + ) + assert result == "standalone question" + mock_client.chat.completions.create.assert_called_once_with( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "test prompt"}], + ) + + def test_rag_response_openai_calls_correct_model(self): + mock_client = MagicMock(spec=openai.OpenAI) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Stitch Fix" + mock_client.chat.completions.create.return_value = mock_response + + result = conversational_rag.conversational_rag_response__openai( + answer_prompt="Where did stefan work?", + llm_client=mock_client, + ) + assert result == "Stitch Fix" + mock_client.chat.completions.create.assert_called_once_with( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Where did stefan work?"}], + ) + + +class TestMiniMaxProvider: + """Tests for MiniMax provider configuration.""" + + def test_llm_client_minimax_base_url(self): + with patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}): + client = conversational_rag.llm_client__minimax() + assert isinstance(client, openai.OpenAI) + assert str(client.base_url).rstrip("/") == "https://api.minimax.io/v1" + + def test_llm_client_minimax_api_key(self): + with patch.dict(os.environ, {"MINIMAX_API_KEY": "my-key"}): + client = conversational_rag.llm_client__minimax() + assert client.api_key == "my-key" + + def test_standalone_question_minimax_calls_correct_model(self): + mock_client = MagicMock(spec=openai.OpenAI) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "standalone question" + mock_client.chat.completions.create.return_value = mock_response + + result = conversational_rag.standalone_question__minimax( + standalone_question_prompt="test prompt", + llm_client=mock_client, + ) + assert result == "standalone question" + call_args = mock_client.chat.completions.create.call_args + assert call_args.kwargs["model"] == "MiniMax-M2.7" + + def test_rag_response_minimax_calls_correct_model(self): + mock_client = MagicMock(spec=openai.OpenAI) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "MiniMax answer" + mock_client.chat.completions.create.return_value = mock_response + + result = conversational_rag.conversational_rag_response__minimax( + answer_prompt="Where did stefan work?", + llm_client=mock_client, + ) + assert result == "MiniMax answer" + call_args = mock_client.chat.completions.create.call_args + assert call_args.kwargs["model"] == "MiniMax-M2.7" + + +class TestHamiltonDriverConfig: + """Tests for Hamilton driver configuration.""" + + def test_driver_builds_with_default_config(self): + dr = driver.Builder().with_modules(conversational_rag).with_config({}).build() + assert dr is not None + + def test_driver_builds_with_openai_config(self): + dr = ( + driver.Builder() + .with_modules(conversational_rag) + .with_config({"provider": "openai"}) + .build() + ) + assert dr is not None + + def test_driver_builds_with_minimax_config(self): + dr = ( + driver.Builder() + .with_modules(conversational_rag) + .with_config({"provider": "minimax"}) + .build() + ) + assert dr is not None + + def test_default_config_has_required_nodes(self): + dr = driver.Builder().with_modules(conversational_rag).with_config({}).build() + graph_nodes = {n.name for n in dr.graph.get_nodes()} + assert "llm_client" in graph_nodes + assert "standalone_question" in graph_nodes + assert "conversational_rag_response" in graph_nodes + + def test_minimax_config_has_required_nodes(self): + dr = ( + driver.Builder() + .with_modules(conversational_rag) + .with_config({"provider": "minimax"}) + .build() + ) + graph_nodes = {n.name for n in dr.graph.get_nodes()} + assert "llm_client" in graph_nodes + assert "standalone_question" in graph_nodes + assert "conversational_rag_response" in graph_nodes + + def test_default_config_end_to_end_mocked(self): + """Test end-to-end with mocked OpenAI client.""" + mock_client = MagicMock(spec=openai.OpenAI) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Stitch Fix" + mock_client.chat.completions.create.return_value = mock_response + + dr = driver.Builder().with_modules(conversational_rag).with_config({}).build() + result = dr.execute( + ["conversational_rag_response"], + overrides={ + "llm_client": mock_client, + "standalone_question": "Where did Stefan work?", + "answer_prompt": "Context: Stefan worked at Stitch Fix\n\nQuestion: Where did Stefan work?", + }, + ) + assert result["conversational_rag_response"] == "Stitch Fix" + + def test_minimax_config_end_to_end_mocked(self): + """Test end-to-end with mocked MiniMax client.""" + mock_client = MagicMock(spec=openai.OpenAI) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "MiniMax answer" + mock_client.chat.completions.create.return_value = mock_response + + dr = ( + driver.Builder() + .with_modules(conversational_rag) + .with_config({"provider": "minimax"}) + .build() + ) + result = dr.execute( + ["conversational_rag_response"], + overrides={ + "llm_client": mock_client, + "standalone_question": "Where did Stefan work?", + "answer_prompt": "Context: Stefan worked at Stitch Fix\n\nQuestion: Where did Stefan work?", + }, + ) + assert result["conversational_rag_response"] == "MiniMax answer" + + +# ──────────────────────── Integration Tests ───────────────────────── + + +class TestMiniMaxIntegration: + """Integration tests that call the real MiniMax API.""" + + @pytest.fixture + def minimax_api_key(self): + key = os.environ.get("MINIMAX_API_KEY") + if not key: + pytest.skip("MINIMAX_API_KEY not set") + return key + + def test_minimax_client_creation(self, minimax_api_key): + with patch.dict(os.environ, {"MINIMAX_API_KEY": minimax_api_key}): + client = conversational_rag.llm_client__minimax() + assert isinstance(client, openai.OpenAI) + + def test_minimax_standalone_question_real_api(self, minimax_api_key): + with patch.dict(os.environ, {"MINIMAX_API_KEY": minimax_api_key}): + client = conversational_rag.llm_client__minimax() + result = conversational_rag.standalone_question__minimax( + standalone_question_prompt="Given the following conversation:\n" + "Human: Who wrote this example?\nAI: Stefan\n" + "Follow Up Input: Where did he work?\n" + "Standalone question:", + llm_client=client, + ) + assert isinstance(result, str) + assert len(result) > 0 + + def test_minimax_conversational_rag_response_real_api(self, minimax_api_key): + with patch.dict(os.environ, {"MINIMAX_API_KEY": minimax_api_key}): + client = conversational_rag.llm_client__minimax() + result = conversational_rag.conversational_rag_response__minimax( + answer_prompt="Answer the question based only on the following context:\n" + "Stefan worked at Stitch Fix.\n\n" + "Question: Where did Stefan work?", + llm_client=client, + ) + assert isinstance(result, str) + assert len(result) > 0 diff --git a/contrib/hamilton/contrib/dagworks/conversational_rag/valid_configs.jsonl b/contrib/hamilton/contrib/dagworks/conversational_rag/valid_configs.jsonl index b8a6704f8..70ffd523f 100644 --- a/contrib/hamilton/contrib/dagworks/conversational_rag/valid_configs.jsonl +++ b/contrib/hamilton/contrib/dagworks/conversational_rag/valid_configs.jsonl @@ -1 +1,2 @@ -{"description": "Default", "name": "default", "config": {}} +{"description": "Default (OpenAI)", "name": "default", "config": {}} +{"description": "MiniMax", "name": "minimax", "config": {"provider": "minimax"}} diff --git a/contrib/hamilton/contrib/dagworks/faiss_rag/README.md b/contrib/hamilton/contrib/dagworks/faiss_rag/README.md index 18be1478d..a326b8b65 100644 --- a/contrib/hamilton/contrib/dagworks/faiss_rag/README.md +++ b/contrib/hamilton/contrib/dagworks/faiss_rag/README.md @@ -23,7 +23,9 @@ This module shows a simple retrieval augmented generation (RAG) example using Apache Hamilton. It shows you how you might structure your code with Apache Hamilton to create a simple RAG pipeline. -This example uses [FAISS](https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/) + and in memory vector store and the OpenAI LLM provider. +This example uses [FAISS](https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/) + an in memory vector store with multi-provider LLM support. +It supports **OpenAI** (default) and **[MiniMax](https://www.minimax.io/)** as LLM providers, +switchable via Hamilton's `@config.when` pattern. The implementation of the FAISS vector store uses the LangChain wrapper around it. That's because this was the simplest way to get this example up without requiring someone having to host and manage a proper vector store. @@ -49,6 +51,8 @@ You can ask to get back any result of an intermediate function by providing the Here we just ask for the final result, but if you wanted to, you could ask for outputs of any of the functions, which you can then introspect or log for debugging/evaluation purposes. Note if you want more platform integrations, you can add adapters that will do this automatically for you, e.g. like we have the `PrintLn` adapter here. + +**Using OpenAI (default):** ```python # import the module from hamilton import driver @@ -56,7 +60,7 @@ from hamilton import lifecycle dr = ( driver.Builder() .with_modules(faiss_rag) - .with_config({}) + .with_config({}) # defaults to OpenAI # this prints the inputs and outputs of each step. .with_adapters(lifecycle.PrintLn(verbosity=2)) .build() @@ -74,6 +78,33 @@ result = dr.execute( print(result) ``` +**Using MiniMax:** + +Set `MINIMAX_API_KEY` in your environment, then pass `{"provider": "minimax"}` in the config: +```python +from hamilton import driver, lifecycle +dr = ( + driver.Builder() + .with_modules(faiss_rag) + .with_config({"provider": "minimax"}) + .with_adapters(lifecycle.PrintLn(verbosity=2)) + .build() +) +result = dr.execute( + ["rag_response"], + inputs={ + "input_texts": [ + "harrison worked at kensho", + "stefan worked at Stitch Fix", + ], + "question": "where did stefan work?", + }, +) +print(result) +``` +MiniMax uses the [MiniMax-M2.7](https://www.minimax.io/) model with a 1M token context window +via an OpenAI-compatible API endpoint. + # How to extend this module What you'd most likely want to do is: @@ -84,16 +115,21 @@ What you'd most likely want to do is: With (1) you can import any vector store/library that you want. You should draw out the process you would like, and that should then map to Apache Hamilton functions. With (2) you can import any LLM provider that you want, just use `@config.when` if you -want to switch between multiple providers. +want to switch between multiple providers. OpenAI and MiniMax are already supported. With (3) you can add more functions that create parts of the prompt. # Configuration Options -There is no configuration needed for this module. + +| Config Key | Values | Description | +|-----------|--------|-------------| +| `provider` | `"minimax"` | Use MiniMax M2.7 as the LLM. Requires `MINIMAX_API_KEY` env var. | +| *(empty)* | | Default: uses OpenAI. Requires `OPENAI_API_KEY` env var. | # Limitations -You need to have the OPENAI_API_KEY in your environment. -It should be accessible from your code by doing `os.environ["OPENAI_API_KEY"]`. +You need to have the appropriate API key in your environment: +- **OpenAI** (default): `OPENAI_API_KEY` +- **MiniMax**: `MINIMAX_API_KEY` The code does not check the context length, so it may fail if the context passed is too long for the LLM you send it to. diff --git a/contrib/hamilton/contrib/dagworks/faiss_rag/__init__.py b/contrib/hamilton/contrib/dagworks/faiss_rag/__init__.py index 0d85e0a60..f4b72ec3e 100644 --- a/contrib/hamilton/contrib/dagworks/faiss_rag/__init__.py +++ b/contrib/hamilton/contrib/dagworks/faiss_rag/__init__.py @@ -16,10 +16,12 @@ # under the License. import logging +import os logger = logging.getLogger(__name__) from hamilton import contrib +from hamilton.function_modifiers import config with contrib.catch_import_errors(__name__, __file__, logger): import openai @@ -75,13 +77,31 @@ def rag_prompt(context: str, question: str) -> str: return template.format(context=context, question=question) -def llm_client() -> openai.OpenAI: - """The LLM client to use for the RAG model.""" +@config.when_not(provider="minimax") +def llm_client__openai() -> openai.OpenAI: + """The OpenAI LLM client (default). + + Uses the OPENAI_API_KEY environment variable for authentication. + """ return openai.OpenAI() -def rag_response(rag_prompt: str, llm_client: openai.OpenAI) -> str: - """Creates the RAG response from the LLM model for the given prompt. +@config.when(provider="minimax") +def llm_client__minimax() -> openai.OpenAI: + """The MiniMax LLM client via OpenAI-compatible API. + + Uses the MINIMAX_API_KEY environment variable for authentication. + MiniMax provides an OpenAI-compatible endpoint at https://api.minimax.io/v1. + """ + return openai.OpenAI( + base_url="https://api.minimax.io/v1", + api_key=os.environ.get("MINIMAX_API_KEY"), + ) + + +@config.when_not(provider="minimax") +def rag_response__openai(rag_prompt: str, llm_client: openai.OpenAI) -> str: + """Creates the RAG response using OpenAI. :param rag_prompt: the prompt to send to the LLM. :param llm_client: the LLM client to use. @@ -94,11 +114,29 @@ def rag_response(rag_prompt: str, llm_client: openai.OpenAI) -> str: return response.choices[0].message.content +@config.when(provider="minimax") +def rag_response__minimax(rag_prompt: str, llm_client: openai.OpenAI) -> str: + """Creates the RAG response using MiniMax M2.7. + + MiniMax M2.7 is a high-performance model with 1M token context window. + + :param rag_prompt: the prompt to send to the LLM. + :param llm_client: the LLM client to use. + :return: the response from the LLM. + """ + response = llm_client.chat.completions.create( + model="MiniMax-M2.7", + messages=[{"role": "user", "content": rag_prompt}], + ) + return response.choices[0].message.content + + if __name__ == "__main__": import __init__ as hamilton_faiss_rag from hamilton import driver, lifecycle + # Default: uses OpenAI (config={} or config={"provider": "openai"}) dr = ( driver.Builder() .with_modules(hamilton_faiss_rag) @@ -120,3 +158,12 @@ def rag_response(rag_prompt: str, llm_client: openai.OpenAI) -> str: }, ) ) + + # To use MiniMax instead, set MINIMAX_API_KEY and use: + # dr = ( + # driver.Builder() + # .with_modules(hamilton_faiss_rag) + # .with_config({"provider": "minimax"}) + # .with_adapters(lifecycle.PrintLn(verbosity=2)) + # .build() + # ) diff --git a/contrib/hamilton/contrib/dagworks/faiss_rag/tags.json b/contrib/hamilton/contrib/dagworks/faiss_rag/tags.json index 87ee694c0..990f7b10b 100644 --- a/contrib/hamilton/contrib/dagworks/faiss_rag/tags.json +++ b/contrib/hamilton/contrib/dagworks/faiss_rag/tags.json @@ -1,6 +1,6 @@ { "schema": "1.0", - "use_case_tags": ["LLM", "openai", "RAG", "retrieval augmented generation", "FAISS"], + "use_case_tags": ["LLM", "openai", "minimax", "RAG", "retrieval augmented generation", "FAISS"], "secondary_tags": { "language": "English" } diff --git a/contrib/hamilton/contrib/dagworks/faiss_rag/test_faiss_rag.py b/contrib/hamilton/contrib/dagworks/faiss_rag/test_faiss_rag.py new file mode 100644 index 000000000..1412d275a --- /dev/null +++ b/contrib/hamilton/contrib/dagworks/faiss_rag/test_faiss_rag.py @@ -0,0 +1,265 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for the faiss_rag dataflow with multi-provider support.""" + +import os +import sys +from unittest.mock import MagicMock, patch + +import pytest + +# Add the parent directory to allow importing the module +sys.path.insert( + 0, + os.path.dirname(os.path.abspath(__file__)), +) + +import openai + +from hamilton import driver + +# Import the module under test +from hamilton.contrib.dagworks import faiss_rag + + +# ──────────────────────────── Unit Tests ──────────────────────────── + + +class TestRagPrompt: + """Tests for the rag_prompt function (provider-independent).""" + + def test_rag_prompt_includes_context(self): + result = faiss_rag.rag_prompt(context="Hamilton is a DAG framework", question="What is Hamilton?") + assert "Hamilton is a DAG framework" in result + + def test_rag_prompt_includes_question(self): + result = faiss_rag.rag_prompt(context="some context", question="What is Hamilton?") + assert "What is Hamilton?" in result + + def test_rag_prompt_format(self): + result = faiss_rag.rag_prompt(context="ctx", question="q") + assert "Answer the question" in result + assert "ctx" in result + assert "q" in result + + +class TestOpenAIProvider: + """Tests for OpenAI provider configuration.""" + + def test_llm_client_openai_returns_openai_client(self): + """Test that the OpenAI client is correctly created.""" + with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}): + client = faiss_rag.llm_client__openai() + assert isinstance(client, openai.OpenAI) + + def test_rag_response_openai_calls_chat_completions(self): + """Test that the OpenAI rag_response calls the correct model.""" + mock_client = MagicMock(spec=openai.OpenAI) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Stitch Fix" + mock_client.chat.completions.create.return_value = mock_response + + result = faiss_rag.rag_response__openai( + rag_prompt="Where did stefan work?", + llm_client=mock_client, + ) + + assert result == "Stitch Fix" + mock_client.chat.completions.create.assert_called_once_with( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Where did stefan work?"}], + ) + + +class TestMiniMaxProvider: + """Tests for MiniMax provider configuration.""" + + def test_llm_client_minimax_returns_openai_client_with_minimax_base_url(self): + """Test that the MiniMax client uses the correct base_url.""" + with patch.dict(os.environ, {"MINIMAX_API_KEY": "test-minimax-key"}): + client = faiss_rag.llm_client__minimax() + assert isinstance(client, openai.OpenAI) + assert "minimax" in str(client.base_url).lower() + + def test_llm_client_minimax_uses_env_api_key(self): + """Test that the MiniMax client reads MINIMAX_API_KEY from env.""" + with patch.dict(os.environ, {"MINIMAX_API_KEY": "my-secret-key"}): + client = faiss_rag.llm_client__minimax() + assert client.api_key == "my-secret-key" + + def test_rag_response_minimax_calls_correct_model(self): + """Test that the MiniMax rag_response calls MiniMax-M2.7.""" + mock_client = MagicMock(spec=openai.OpenAI) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Stitch Fix" + mock_client.chat.completions.create.return_value = mock_response + + result = faiss_rag.rag_response__minimax( + rag_prompt="Where did stefan work?", + llm_client=mock_client, + ) + + assert result == "Stitch Fix" + mock_client.chat.completions.create.assert_called_once_with( + model="MiniMax-M2.7", + messages=[{"role": "user", "content": "Where did stefan work?"}], + ) + + +class TestHamiltonDriverConfig: + """Tests for Hamilton driver configuration with providers.""" + + def test_driver_builds_with_default_config(self): + """Test that the driver builds successfully with default (empty) config.""" + dr = driver.Builder().with_modules(faiss_rag).with_config({}).build() + assert dr is not None + + def test_driver_builds_with_openai_config(self): + """Test that the driver builds successfully with explicit OpenAI config.""" + dr = driver.Builder().with_modules(faiss_rag).with_config({"provider": "openai"}).build() + assert dr is not None + + def test_driver_builds_with_minimax_config(self): + """Test that the driver builds successfully with MiniMax config.""" + dr = driver.Builder().with_modules(faiss_rag).with_config({"provider": "minimax"}).build() + assert dr is not None + + def test_default_config_includes_openai_functions(self): + """Test that default config resolves to OpenAI provider functions.""" + dr = driver.Builder().with_modules(faiss_rag).with_config({}).build() + graph_nodes = {n.name for n in dr.graph.get_nodes()} + assert "llm_client" in graph_nodes + assert "rag_response" in graph_nodes + + def test_minimax_config_includes_minimax_functions(self): + """Test that minimax config resolves to MiniMax provider functions.""" + dr = driver.Builder().with_modules(faiss_rag).with_config({"provider": "minimax"}).build() + graph_nodes = {n.name for n in dr.graph.get_nodes()} + assert "llm_client" in graph_nodes + assert "rag_response" in graph_nodes + + def test_default_config_executes_openai_rag(self): + """Test end-to-end execution with OpenAI config using mocked client.""" + mock_client = MagicMock(spec=openai.OpenAI) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Stitch Fix" + mock_client.chat.completions.create.return_value = mock_response + + dr = driver.Builder().with_modules(faiss_rag).with_config({}).build() + result = dr.execute( + ["rag_response"], + inputs={"rag_prompt": "Where did stefan work?"}, + overrides={"llm_client": mock_client, "rag_prompt": "Where did stefan work?"}, + ) + assert result["rag_response"] == "Stitch Fix" + + def test_minimax_config_executes_minimax_rag(self): + """Test end-to-end execution with MiniMax config using mocked client.""" + mock_client = MagicMock(spec=openai.OpenAI) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "MiniMax response" + mock_client.chat.completions.create.return_value = mock_response + + dr = driver.Builder().with_modules(faiss_rag).with_config({"provider": "minimax"}).build() + result = dr.execute( + ["rag_response"], + inputs={"rag_prompt": "Where did stefan work?"}, + overrides={"llm_client": mock_client, "rag_prompt": "Where did stefan work?"}, + ) + assert result["rag_response"] == "MiniMax response" + + +class TestMiniMaxModelConstants: + """Tests for MiniMax model configuration constants.""" + + def test_minimax_base_url(self): + """Test that MiniMax base URL is correct.""" + with patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}): + client = faiss_rag.llm_client__minimax() + assert str(client.base_url).rstrip("/") == "https://api.minimax.io/v1" + + def test_minimax_model_name_is_m27(self): + """Test that MiniMax response uses M2.7 model.""" + mock_client = MagicMock(spec=openai.OpenAI) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "test" + mock_client.chat.completions.create.return_value = mock_response + + faiss_rag.rag_response__minimax(rag_prompt="test", llm_client=mock_client) + + call_args = mock_client.chat.completions.create.call_args + assert call_args.kwargs["model"] == "MiniMax-M2.7" + + +# ──────────────────────── Integration Tests ───────────────────────── + + +class TestMiniMaxIntegration: + """Integration tests that call the real MiniMax API.""" + + @pytest.fixture + def minimax_api_key(self): + """Get MiniMax API key from environment.""" + key = os.environ.get("MINIMAX_API_KEY") + if not key: + pytest.skip("MINIMAX_API_KEY not set") + return key + + def test_minimax_client_creation(self, minimax_api_key): + """Test creating a real MiniMax client.""" + with patch.dict(os.environ, {"MINIMAX_API_KEY": minimax_api_key}): + client = faiss_rag.llm_client__minimax() + assert isinstance(client, openai.OpenAI) + assert client.api_key == minimax_api_key + + def test_minimax_rag_response_real_api(self, minimax_api_key): + """Test a real RAG response from MiniMax API.""" + with patch.dict(os.environ, {"MINIMAX_API_KEY": minimax_api_key}): + client = faiss_rag.llm_client__minimax() + result = faiss_rag.rag_response__minimax( + rag_prompt="Answer the question based only on the following context:\n" + "Stefan worked at Stitch Fix.\n\n" + "Question: Where did Stefan work?", + llm_client=client, + ) + assert isinstance(result, str) + assert len(result) > 0 + + def test_minimax_driver_execution(self, minimax_api_key): + """Test Hamilton driver execution with MiniMax config.""" + with patch.dict(os.environ, {"MINIMAX_API_KEY": minimax_api_key}): + dr = driver.Builder().with_modules(faiss_rag).with_config({"provider": "minimax"}).build() + + client = faiss_rag.llm_client__minimax() + result = dr.execute( + ["rag_response"], + overrides={ + "llm_client": client, + "rag_prompt": "Answer the question based only on the following context:\n" + "Hamilton is a Python library for DAGs.\n\n" + "Question: What is Hamilton?", + }, + ) + assert "rag_response" in result + assert isinstance(result["rag_response"], str) + assert len(result["rag_response"]) > 0 diff --git a/contrib/hamilton/contrib/dagworks/faiss_rag/valid_configs.jsonl b/contrib/hamilton/contrib/dagworks/faiss_rag/valid_configs.jsonl index b8a6704f8..70ffd523f 100644 --- a/contrib/hamilton/contrib/dagworks/faiss_rag/valid_configs.jsonl +++ b/contrib/hamilton/contrib/dagworks/faiss_rag/valid_configs.jsonl @@ -1 +1,2 @@ -{"description": "Default", "name": "default", "config": {}} +{"description": "Default (OpenAI)", "name": "default", "config": {}} +{"description": "MiniMax", "name": "minimax", "config": {"provider": "minimax"}}