Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions contrib/hamilton/contrib/dagworks/conversational_rag/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ This module shows a conversational retrieval augmented generation (RAG) example
Apache Hamilton. It shows you how you might structure your code with Apache Hamilton to
create a RAG pipeline that takes into account conversation.

This example uses [FAISS](https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/) + and in memory vector store and the OpenAI LLM provider.
This example uses [FAISS](https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/) + an in memory vector store with multi-provider LLM support.
It supports **OpenAI** (default) and **[MiniMax](https://www.minimax.io/)** as LLM providers,
switchable via Hamilton's `@config.when` pattern.
The implementation of the FAISS vector store uses the LangChain wrapper around it.
That's because this was the simplest way to get this example up without requiring
someone having to host and manage a proper vector store.
Expand Down Expand Up @@ -57,14 +59,15 @@ Here we just ask for the final result, but if you wanted to, you could ask for o
you can then introspect or log for debugging/evaluation purposes. Note if you want more platform integrations,
you can add adapters that will do this automatically for you, e.g. like we have the `PrintLn` adapter here.

**Using OpenAI (default):**
```python
# import the module
from hamilton import driver
from hamilton import lifecycle
dr = (
driver.Builder()
.with_modules(conversational_rag)
.with_config({})
.with_config({}) # defaults to OpenAI
# this prints the inputs and outputs of each step.
.with_adapters(lifecycle.PrintLn(verbosity=2))
.build()
Expand Down Expand Up @@ -102,6 +105,34 @@ result = dr.execute(
print(result)
```

**Using MiniMax:**

Set `MINIMAX_API_KEY` in your environment, then pass `{"provider": "minimax"}` in the config:
```python
from hamilton import driver, lifecycle
dr = (
driver.Builder()
.with_modules(conversational_rag)
.with_config({"provider": "minimax"})
.with_adapters(lifecycle.PrintLn(verbosity=2))
.build()
)
result = dr.execute(
["conversational_rag_response"],
inputs={
"input_texts": [
"harrison worked at kensho",
"stefan worked at Stitch Fix",
],
"question": "where did stefan work?",
"chat_history": []
},
)
print(result)
```
MiniMax uses the [MiniMax-M2.7](https://www.minimax.io/) model with a 1M token context window
via an OpenAI-compatible API endpoint.

# How to extend this module
What you'd most likely want to do is:

Expand All @@ -112,16 +143,21 @@ What you'd most likely want to do is:
With (1) you can import any vector store/library that you want. You should draw out
the process you would like, and that should then map to Apache Hamilton functions.
With (2) you can import any LLM provider that you want, just use `@config.when` if you
want to switch between multiple providers.
want to switch between multiple providers. OpenAI and MiniMax are already supported.
With (3) you can add more functions that create parts of the prompt.

# Configuration Options
There is no configuration needed for this module.

| Config Key | Values | Description |
|-----------|--------|-------------|
| `provider` | `"minimax"` | Use MiniMax M2.7 as the LLM. Requires `MINIMAX_API_KEY` env var. |
| *(empty)* | | Default: uses OpenAI. Requires `OPENAI_API_KEY` env var. |

# Limitations

You need to have the OPENAI_API_KEY in your environment.
It should be accessible from your code by doing `os.environ["OPENAI_API_KEY"]`.
You need to have the appropriate API key in your environment:
- **OpenAI** (default): `OPENAI_API_KEY`
- **MiniMax**: `MINIMAX_API_KEY`

The code does not check the context length, so it may fail if the context passed is too long
for the LLM you send it to.
75 changes: 69 additions & 6 deletions contrib/hamilton/contrib/dagworks/conversational_rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
# under the License.

import logging
import os

logger = logging.getLogger(__name__)

from hamilton import contrib
from hamilton.function_modifiers import config

with contrib.catch_import_errors(__name__, __file__, logger):
import openai
Expand Down Expand Up @@ -53,8 +55,9 @@ def standalone_question_prompt(chat_history: list[str], question: str) -> str:
).format(chat_history=chat_history_str, question=question)


def standalone_question(standalone_question_prompt: str, llm_client: openai.OpenAI) -> str:
"""Asks the LLM to create a standalone question from the prompt.
@config.when_not(provider="minimax")
def standalone_question__openai(standalone_question_prompt: str, llm_client: openai.OpenAI) -> str:
"""Asks OpenAI to create a standalone question from the prompt.

:param standalone_question_prompt: the prompt with context.
:param llm_client: the llm client to use.
Expand All @@ -67,6 +70,21 @@ def standalone_question(standalone_question_prompt: str, llm_client: openai.Open
return response.choices[0].message.content


@config.when(provider="minimax")
def standalone_question__minimax(standalone_question_prompt: str, llm_client: openai.OpenAI) -> str:
"""Asks MiniMax to create a standalone question from the prompt.

:param standalone_question_prompt: the prompt with context.
:param llm_client: the llm client to use.
:return: the standalone question.
"""
response = llm_client.chat.completions.create(
model="MiniMax-M2.7",
messages=[{"role": "user", "content": standalone_question_prompt}],
)
return response.choices[0].message.content


def vector_store(input_texts: list[str]) -> VectorStoreRetriever:
"""A Vector store. This function populates and creates one for querying.

Expand Down Expand Up @@ -112,13 +130,31 @@ def answer_prompt(context: str, standalone_question: str) -> str:
return template.format(context=context, question=standalone_question)


def llm_client() -> openai.OpenAI:
"""The LLM client to use for the RAG model."""
@config.when_not(provider="minimax")
def llm_client__openai() -> openai.OpenAI:
"""The OpenAI LLM client (default).

Uses the OPENAI_API_KEY environment variable for authentication.
"""
return openai.OpenAI()


def conversational_rag_response(answer_prompt: str, llm_client: openai.OpenAI) -> str:
"""Creates the RAG response from the LLM model for the given prompt.
@config.when(provider="minimax")
def llm_client__minimax() -> openai.OpenAI:
"""The MiniMax LLM client via OpenAI-compatible API.

Uses the MINIMAX_API_KEY environment variable for authentication.
MiniMax provides an OpenAI-compatible endpoint at https://api.minimax.io/v1.
"""
return openai.OpenAI(
base_url="https://api.minimax.io/v1",
api_key=os.environ.get("MINIMAX_API_KEY"),
)


@config.when_not(provider="minimax")
def conversational_rag_response__openai(answer_prompt: str, llm_client: openai.OpenAI) -> str:
"""Creates the RAG response using OpenAI.

:param answer_prompt: the prompt to send to the LLM.
:param llm_client: the LLM client to use.
Expand All @@ -131,11 +167,29 @@ def conversational_rag_response(answer_prompt: str, llm_client: openai.OpenAI) -
return response.choices[0].message.content


@config.when(provider="minimax")
def conversational_rag_response__minimax(answer_prompt: str, llm_client: openai.OpenAI) -> str:
"""Creates the RAG response using MiniMax M2.7.

MiniMax M2.7 is a high-performance model with 1M token context window.

:param answer_prompt: the prompt to send to the LLM.
:param llm_client: the LLM client to use.
:return: the response from the LLM.
"""
response = llm_client.chat.completions.create(
model="MiniMax-M2.7",
messages=[{"role": "user", "content": answer_prompt}],
)
return response.choices[0].message.content


if __name__ == "__main__":
import __init__ as conversational_rag

from hamilton import driver, lifecycle

# Default: uses OpenAI (config={} or config={"provider": "openai"})
dr = (
driver.Builder()
.with_modules(conversational_rag)
Expand Down Expand Up @@ -176,3 +230,12 @@ def conversational_rag_response(answer_prompt: str, llm_client: openai.OpenAI) -
},
)
)

# To use MiniMax instead, set MINIMAX_API_KEY and use:
# dr = (
# driver.Builder()
# .with_modules(conversational_rag)
# .with_config({"provider": "minimax"})
# .with_adapters(lifecycle.PrintLn(verbosity=2))
# .build()
# )
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"schema": "1.0",
"use_case_tags": ["LLM", "openai", "RAG", "retrieval augmented generation", "FAISS"],
"use_case_tags": ["LLM", "openai", "minimax", "RAG", "retrieval augmented generation", "FAISS"],
"secondary_tags": {
"language": "English"
}
Expand Down
Loading