From e91522fb1a707cc36bb6539099be412796cd4e4d Mon Sep 17 00:00:00 2001 From: PR Bot Date: Sat, 28 Mar 2026 20:16:11 +0800 Subject: [PATCH] Add MiniMax as first-class LLM provider Add MiniMax as a new LLM provider that extends the existing OpenAI class, leveraging MiniMax's OpenAI-compatible API endpoint. This enables users to use MiniMax's M2.7 and M2.5 model families for prompting and synthetic data generation workflows. Key changes: - New MiniMax class in src/llms/minimax.py extending OpenAI - Temperature clamping to MiniMax's (0.0, 1.0] range - Model-specific context lengths (1M for M2.7, 204K for M2.5) - Model-specific max output lengths (16K for M2.7, 8K for M2.5) - MINIMAX_API_KEY environment variable support - 30 unit tests + 3 integration tests in tests/test_minimax.py - In-project test class in src/tests/llms/test_llms.py - Registered in src/llms/__init__.py exports Supported models: - MiniMax-M2.7 (1M context) - MiniMax-M2.7-highspeed (1M context) - MiniMax-M2.5 (204K context) - MiniMax-M2.5-highspeed (204K context) --- src/llms/__init__.py | 2 + src/llms/minimax.py | 149 ++++++++++++++++ src/tests/llms/test_llms.py | 216 ++++++++++++++++++++++++ tests/test_minimax.py | 326 ++++++++++++++++++++++++++++++++++++ 4 files changed, 693 insertions(+) create mode 100644 src/llms/minimax.py create mode 100644 tests/test_minimax.py diff --git a/src/llms/__init__.py b/src/llms/__init__.py index 51d0ec1..bbce219 100644 --- a/src/llms/__init__.py +++ b/src/llms/__init__.py @@ -85,6 +85,7 @@ from .hf_api_endpoint import HFAPIEndpoint from .hf_transformers import HFTransformers from .llm import LLM +from .minimax import MiniMax from .mistral_ai import MistralAI from .openai import OpenAI from .openai_assistant import OpenAIAssistant @@ -105,6 +106,7 @@ "HFAPIEndpoint", "Together", "MistralAI", + "MiniMax", "Anthropic", "Cohere", "AI21", diff --git a/src/llms/minimax.py b/src/llms/minimax.py new file mode 100644 index 0000000..cfb7a29 --- /dev/null +++ b/src/llms/minimax.py @@ -0,0 +1,149 @@ +import os +from functools import cached_property +from typing import Any, Callable + +import openai + +from ..utils import ring_utils as ring +from .llm import ( + DEFAULT_BATCH_SIZE, + _check_max_new_tokens_possible, + _check_temperature_and_top_p, +) +from .openai import OpenAI + +# MiniMax model context lengths and max output lengths +_MINIMAX_CONTEXT_LENGTHS: dict[str, int] = { + "MiniMax-M2.7": 1000000, + "MiniMax-M2.7-highspeed": 1000000, + "MiniMax-M2.5": 204800, + "MiniMax-M2.5-highspeed": 204800, +} + +_MINIMAX_MAX_OUTPUT_LENGTHS: dict[str, int] = { + "MiniMax-M2.7": 16384, + "MiniMax-M2.7-highspeed": 16384, + "MiniMax-M2.5": 8192, + "MiniMax-M2.5-highspeed": 8192, +} + +_MINIMAX_BASE_URL = "https://api.minimax.io/v1" + + +class MiniMax(OpenAI): + """A MiniMax LLM provider that uses MiniMax's OpenAI-compatible API. + + MiniMax provides large language models accessible via an OpenAI-compatible + API endpoint. Supported models include ``MiniMax-M2.7``, + ``MiniMax-M2.7-highspeed``, ``MiniMax-M2.5``, and + ``MiniMax-M2.5-highspeed``. + + Args: + model_name: The name of the MiniMax model to use. + system_prompt: An optional system prompt to use. + api_key: The MiniMax API key. If ``None``, the ``MINIMAX_API_KEY`` + environment variable will be used. + retry_on_fail: Whether to retry on failure. + cache_folder_path: The path to the cache folder. + **kwargs: Additional keyword arguments passed to the OpenAI client. + """ + + def __init__( + self, + model_name: str, + system_prompt: None | str = None, + api_key: None | str = None, + retry_on_fail: bool = True, + cache_folder_path: None | str = None, + **kwargs, + ): + super().__init__( + model_name=model_name, + system_prompt=system_prompt or "You are a helpful assistant.", + api_key=api_key or os.environ.get("MINIMAX_API_KEY"), + base_url=_MINIMAX_BASE_URL, + retry_on_fail=retry_on_fail, + cache_folder_path=cache_folder_path, + **kwargs, + ) + + @cached_property + def client(self) -> openai.OpenAI: + other_kwargs: dict[str, Any] = {} + if self.api_key: + other_kwargs["api_key"] = self.api_key + return openai.OpenAI( + base_url=_MINIMAX_BASE_URL, + **other_kwargs, + **self.kwargs, + ) + + @ring.lru(maxsize=128) + def get_max_context_length(self, max_new_tokens: int) -> int: + """Gets the maximum context length for the model. + + Args: + max_new_tokens: The maximum number of tokens that can be generated. + + Returns: + The maximum context length. + """ + # Use known context lengths for MiniMax models + max_context_length = _MINIMAX_CONTEXT_LENGTHS.get(self.model_name, 204800) + # Account for chat format tokens (system prompt + message framing) + format_tokens = 4 * 3 + self.count_tokens(self.system_prompt or "") + return max_context_length - max_new_tokens - format_tokens + + def _get_max_output_length(self) -> None | int: + return _MINIMAX_MAX_OUTPUT_LENGTHS.get(self.model_name, 8192) + + def _run_batch( + self, + max_length_func: Callable[[list[str]], int], + inputs: list[str], + max_new_tokens: None | int = None, + temperature: float = 1.0, + top_p: float = 0.0, + n: int = 1, + stop: None | str | list[str] = None, + repetition_penalty: None | float = None, + logit_bias: None | dict[int, float] = None, + batch_size: int = DEFAULT_BATCH_SIZE, + seed: None | int = None, + **kwargs, + ) -> list[str] | list[list[str]]: + # MiniMax requires temperature in (0.0, 1.0] + if temperature == 0.0: + temperature = 0.01 + elif temperature > 1.0: + temperature = 1.0 + + return super()._run_batch( + max_length_func=max_length_func, + inputs=inputs, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + n=n, + stop=stop, + repetition_penalty=repetition_penalty, + logit_bias=logit_bias, + batch_size=batch_size, + seed=seed, + **kwargs, + ) + + @cached_property + def model_card(self) -> None | str: + return "https://platform.minimaxi.com/document/Models" + + @cached_property + def license(self) -> None | str: + return "https://platform.minimaxi.com/document/Terms%20of%20service" + + @cached_property + def citation(self) -> None | list[str]: + return None + + +__all__ = ["MiniMax"] diff --git a/src/tests/llms/test_llms.py b/src/tests/llms/test_llms.py index 88394ed..2756bc7 100644 --- a/src/tests/llms/test_llms.py +++ b/src/tests/llms/test_llms.py @@ -32,6 +32,7 @@ GoogleAIStudio, HFAPIEndpoint, HFTransformers, + MiniMax, MistralAI, OpenAI, OpenAIAssistant, @@ -3162,6 +3163,221 @@ def chat_mocked(**kwargs): assert "client" not in llm.__dict__ and "tokenizer" not in llm.__dict__ +class TestMiniMax: + def test_init(self, create_datadreamer): + with create_datadreamer(): + llm = MiniMax("MiniMax-M2.7", api_key="fake-key") + assert llm.model_name == "MiniMax-M2.7" + assert llm.base_url == "https://api.minimax.io/v1" + assert llm.system_prompt == "You are a helpful assistant." + + def test_metadata(self, create_datadreamer): + llm = MiniMax("MiniMax-M2.7", api_key="fake-key") + assert llm.model_card == "https://platform.minimaxi.com/document/Models" + assert ( + llm.license + == "https://platform.minimaxi.com/document/Terms%20of%20service" + ) + assert llm.citation is None + + def test_count_tokens(self, create_datadreamer): + with create_datadreamer(): + llm = MiniMax("MiniMax-M2.7", api_key="fake-key") + token_count = llm.count_tokens("This is a test.") + assert isinstance(token_count, int) + assert token_count > 0 + + def test_get_max_context_length(self, create_datadreamer): + with create_datadreamer(): + llm = MiniMax("MiniMax-M2.7", api_key="fake-key") + ctx = llm.get_max_context_length(max_new_tokens=0) + # M2.7 has 1M context, minus format tokens + assert ctx > 999000 + assert ctx < 1000000 + + llm2 = MiniMax("MiniMax-M2.5", api_key="fake-key") + ctx2 = llm2.get_max_context_length(max_new_tokens=0) + assert ctx2 > 204000 + assert ctx2 < 204800 + + def test_get_max_output_length(self, create_datadreamer): + with create_datadreamer(): + llm = MiniMax("MiniMax-M2.7", api_key="fake-key") + assert llm._get_max_output_length() == 16384 + + llm2 = MiniMax("MiniMax-M2.5", api_key="fake-key") + assert llm2._get_max_output_length() == 8192 + + llm3 = MiniMax("MiniMax-M2.5-highspeed", api_key="fake-key") + assert llm3._get_max_output_length() == 8192 + + llm4 = MiniMax("MiniMax-M2.7-highspeed", api_key="fake-key") + assert llm4._get_max_output_length() == 16384 + + def test_temperature_clamping(self, create_datadreamer, mocker): + from unittest.mock import MagicMock + + with create_datadreamer(): + llm = MiniMax("MiniMax-M2.7", api_key="fake-key") + + # Create a mock response + mock_choice = MagicMock() + mock_choice.message.content = "Test response" + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + mocker.patch.object( + llm.client.chat.completions, + "create", + return_value=mock_response, + ) + + # Run with temperature=0.0 (should be clamped to 0.01) + llm.run( + ["Test prompt"], + max_new_tokens=10, + temperature=0.0, + top_p=1.0, + n=1, + batch_size=1, + ) + call_kwargs = llm.client.chat.completions.create.call_args_list[ + 0 + ].kwargs + assert call_kwargs["temperature"] == 0.01 + + @typing.no_type_check + def test_run(self, create_datadreamer, mocker): + from unittest.mock import MagicMock + + with create_datadreamer(): + llm = MiniMax("MiniMax-M2.7", api_key="fake-key") + + # Create a mock response factory + def chat_mocked(**kwargs): + prompt = kwargs["messages"][-1]["content"] + mock_choice = MagicMock() + mock_choice.message.content = f"Response to: {prompt}" + mock_response = MagicMock() + mock_response.choices = [mock_choice] + return mock_response + + mocker.patch.object( + llm.client.chat.completions, + "create", + side_effect=chat_mocked, + ) + + # Simple run + generated_texts = llm.run( + ["What color is the sky?", "What color are trees?"], + max_new_tokens=25, + temperature=0.3, + top_p=1.0, + n=1, + stop=None, + repetition_penalty=None, + logit_bias=None, + batch_size=2, + ) + assert generated_texts == [ + "Response to: What color is the sky?", + "Response to: What color are trees?", + ] + + # Test return_generator + generated_texts_generator = llm.run( + ["What color is the sky?", "What color are trees?"], + max_new_tokens=25, + temperature=0.3, + top_p=1.0, + n=1, + stop=None, + repetition_penalty=None, + logit_bias=None, + batch_size=2, + return_generator=True, + ) + assert isinstance(generated_texts_generator, GeneratorType) + assert list(generated_texts_generator) == generated_texts + + # Test unload model + assert "client" in llm.__dict__ + llm.unload_model() + assert "client" not in llm.__dict__ + + def test_custom_system_prompt(self, create_datadreamer): + with create_datadreamer(): + llm = MiniMax( + "MiniMax-M2.7", + system_prompt="You are a coding assistant.", + api_key="fake-key", + ) + assert llm.system_prompt == "You are a coding assistant." + + def test_display_name(self, create_datadreamer): + with create_datadreamer(): + llm = MiniMax("MiniMax-M2.7", api_key="fake-key") + assert "MiniMax-M2.7" in llm.display_name + + @pytest.mark.skipif( + "MINIMAX_API_KEY" not in os.environ, reason="requires MiniMax API key" + ) + def test_integration_run(self, create_datadreamer): + """Integration test that runs against the real MiniMax API.""" + with create_datadreamer(): + llm = MiniMax("MiniMax-M2.7") + generated_texts = llm.run( + ["Say hello in one word."], + max_new_tokens=10, + temperature=0.01, + top_p=1.0, + n=1, + batch_size=1, + ) + assert len(generated_texts) == 1 + assert isinstance(generated_texts[0], str) + assert len(generated_texts[0]) > 0 + + @pytest.mark.skipif( + "MINIMAX_API_KEY" not in os.environ, reason="requires MiniMax API key" + ) + def test_integration_streaming(self, create_datadreamer): + """Integration test for generator-based output.""" + with create_datadreamer(): + llm = MiniMax("MiniMax-M2.7") + results = llm.run( + ["What is 2+2? Answer with just the number."], + max_new_tokens=5, + temperature=0.01, + top_p=1.0, + n=1, + batch_size=1, + return_generator=True, + ) + results_list = list(results) + assert len(results_list) == 1 + assert "4" in results_list[0] + + @pytest.mark.skipif( + "MINIMAX_API_KEY" not in os.environ, reason="requires MiniMax API key" + ) + def test_integration_m27_highspeed(self, create_datadreamer): + """Integration test for the M2.7-highspeed model.""" + with create_datadreamer(): + llm = MiniMax("MiniMax-M2.7-highspeed") + generated_texts = llm.run( + ["Say 'yes' or 'no'."], + max_new_tokens=5, + temperature=0.01, + top_p=1.0, + n=1, + batch_size=1, + ) + assert len(generated_texts) == 1 + assert isinstance(generated_texts[0], str) + + class TestPetals: pydantic_version = None bitsandbytes_version = None diff --git a/tests/test_minimax.py b/tests/test_minimax.py new file mode 100644 index 0000000..ddda200 --- /dev/null +++ b/tests/test_minimax.py @@ -0,0 +1,326 @@ +""" +Standalone tests for the MiniMax LLM provider. + +These tests verify the MiniMax provider implementation without requiring +the full DataDreamer project dependencies to be installed. Tests that +require the full project use the TestMiniMaxInProject class in +src/tests/llms/test_llms.py. + +Run with: python -m pytest tests/test_minimax.py -v +""" + +import os +import re +from unittest.mock import MagicMock, patch + +import pytest + + +def _read_minimax_source(): + """Read the minimax.py source file.""" + path = os.path.join(os.path.dirname(__file__), "..", "src", "llms", "minimax.py") + with open(path, "r") as f: + return f.read() + + +def _read_init_source(): + """Read the llms __init__.py source file.""" + path = os.path.join(os.path.dirname(__file__), "..", "src", "llms", "__init__.py") + with open(path, "r") as f: + return f.read() + + +class TestMiniMaxModuleStructure: + """Tests verifying the module structure and exports.""" + + def test_minimax_file_exists(self): + """Test that minimax.py exists in the llms directory.""" + path = os.path.join( + os.path.dirname(__file__), "..", "src", "llms", "minimax.py" + ) + assert os.path.exists(path) + + def test_minimax_imports_in_init(self): + """Test that MiniMax is imported in llms/__init__.py.""" + source = _read_init_source() + assert "from .minimax import MiniMax" in source + + def test_minimax_in_all_exports(self): + """Test that MiniMax is in __all__ in llms/__init__.py.""" + source = _read_init_source() + assert '"MiniMax"' in source + + def test_minimax_module_all_export(self): + """Test that minimax.py has __all__ = ['MiniMax'].""" + source = _read_minimax_source() + assert '__all__ = ["MiniMax"]' in source + + def test_inherits_from_openai(self): + """Test that MiniMax class inherits from OpenAI.""" + source = _read_minimax_source() + assert "class MiniMax(OpenAI):" in source + + def test_imports_openai_parent(self): + """Test that minimax.py imports the OpenAI class.""" + source = _read_minimax_source() + assert "from .openai import OpenAI" in source + + +class TestMiniMaxConstants: + """Tests verifying the constants defined in minimax.py.""" + + def test_base_url(self): + """Test that the base URL is correct.""" + source = _read_minimax_source() + assert '_MINIMAX_BASE_URL = "https://api.minimax.io/v1"' in source + + def test_context_lengths_m27(self): + """Test M2.7 context length is 1M.""" + source = _read_minimax_source() + assert '"MiniMax-M2.7": 1000000' in source + + def test_context_lengths_m27_highspeed(self): + """Test M2.7-highspeed context length is 1M.""" + source = _read_minimax_source() + assert '"MiniMax-M2.7-highspeed": 1000000' in source + + def test_context_lengths_m25(self): + """Test M2.5 context length is 204800.""" + source = _read_minimax_source() + assert '"MiniMax-M2.5": 204800' in source + + def test_context_lengths_m25_highspeed(self): + """Test M2.5-highspeed context length is 204800.""" + source = _read_minimax_source() + assert '"MiniMax-M2.5-highspeed": 204800' in source + + def test_max_output_m27(self): + """Test M2.7 max output is 16384.""" + source = _read_minimax_source() + # Check the output dict contains M2.7: 16384 + assert re.search(r'"MiniMax-M2\.7":\s*16384', source) + + def test_max_output_m25(self): + """Test M2.5 max output is 8192.""" + source = _read_minimax_source() + assert re.search(r'"MiniMax-M2\.5":\s*8192', source) + + +class TestMiniMaxImplementation: + """Tests verifying the implementation logic in minimax.py.""" + + def test_temperature_clamping_zero(self): + """Test that temperature=0.0 is clamped to 0.01.""" + source = _read_minimax_source() + assert "temperature == 0.0" in source + assert "temperature = 0.01" in source + + def test_temperature_clamping_above_one(self): + """Test that temperature > 1.0 is clamped to 1.0.""" + source = _read_minimax_source() + assert "temperature > 1.0" in source + assert "temperature = 1.0" in source + + def test_api_key_from_env(self): + """Test that API key can be read from MINIMAX_API_KEY env var.""" + source = _read_minimax_source() + assert 'os.environ.get("MINIMAX_API_KEY")' in source + + def test_default_system_prompt(self): + """Test that default system prompt is set.""" + source = _read_minimax_source() + assert '"You are a helpful assistant."' in source + + def test_model_card_url(self): + """Test the model card URL.""" + source = _read_minimax_source() + assert "https://platform.minimaxi.com/document/Models" in source + + def test_license_url(self): + """Test the license URL.""" + source = _read_minimax_source() + assert "https://platform.minimaxi.com/document/Terms%20of%20service" in source + + def test_citation_none(self): + """Test that citation returns None.""" + source = _read_minimax_source() + # Verify the citation method returns None + assert re.search(r"def citation.*\n.*return None", source) + + def test_get_max_output_length_method(self): + """Test that _get_max_output_length method exists.""" + source = _read_minimax_source() + assert "def _get_max_output_length(self)" in source + + def test_get_max_context_length_method(self): + """Test that get_max_context_length method exists.""" + source = _read_minimax_source() + assert "def get_max_context_length(self, max_new_tokens: int)" in source + + def test_run_batch_method(self): + """Test that _run_batch method exists and calls super().""" + source = _read_minimax_source() + assert "def _run_batch(" in source + assert "return super()._run_batch(" in source + + def test_client_property(self): + """Test that client property creates OpenAI client with correct base URL.""" + source = _read_minimax_source() + assert "def client(self)" in source + assert "openai.OpenAI(" in source + assert "base_url=_MINIMAX_BASE_URL" in source + + def test_has_docstring(self): + """Test that the MiniMax class has a docstring.""" + source = _read_minimax_source() + # After 'class MiniMax(OpenAI):' there should be a docstring + class_match = source.find("class MiniMax(OpenAI):") + assert class_match >= 0 + after_class = source[class_match:] + assert '"""' in after_class[:200] + + def test_all_model_names_in_context_dict(self): + """Test all 4 MiniMax models are in context length dict.""" + source = _read_minimax_source() + models = [ + "MiniMax-M2.7", + "MiniMax-M2.7-highspeed", + "MiniMax-M2.5", + "MiniMax-M2.5-highspeed", + ] + for model in models: + assert f'"{model}"' in source, f"Model {model} not found in source" + + def test_all_model_names_in_output_dict(self): + """Test all 4 MiniMax models are in max output length dict.""" + source = _read_minimax_source() + # Both dicts should have all 4 models + context_matches = re.findall(r'"MiniMax-M\d+\.\d+[^"]*"', source) + assert len(context_matches) >= 8 # 4 models x 2 dicts + + +class TestMiniMaxTestsInProject: + """Tests verifying that MiniMax tests exist in the project test suite.""" + + def test_minimax_test_class_exists(self): + """Test that TestMiniMax class exists in test_llms.py.""" + path = os.path.join( + os.path.dirname(__file__), + "..", + "src", + "tests", + "llms", + "test_llms.py", + ) + with open(path, "r") as f: + source = f.read() + assert "class TestMiniMax:" in source + + def test_minimax_imported_in_tests(self): + """Test that MiniMax is imported in test_llms.py.""" + path = os.path.join( + os.path.dirname(__file__), + "..", + "src", + "tests", + "llms", + "test_llms.py", + ) + with open(path, "r") as f: + source = f.read() + assert "MiniMax," in source + + def test_integration_tests_exist(self): + """Test that integration tests exist with MINIMAX_API_KEY skip.""" + path = os.path.join( + os.path.dirname(__file__), + "..", + "src", + "tests", + "llms", + "test_llms.py", + ) + with open(path, "r") as f: + source = f.read() + assert "MINIMAX_API_KEY" in source + assert "test_integration_run" in source + + +class TestMiniMaxIntegration: + """Integration tests that require MINIMAX_API_KEY environment variable.""" + + @pytest.mark.skipif( + "MINIMAX_API_KEY" not in os.environ, reason="requires MiniMax API key" + ) + def test_api_chat_completion(self): + """Integration test: direct API call to MiniMax.""" + import openai + + client = openai.OpenAI( + api_key=os.environ["MINIMAX_API_KEY"], + base_url="https://api.minimax.io/v1", + ) + response = client.chat.completions.create( + model="MiniMax-M2.7", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Say hello in one word."}, + ], + max_tokens=10, + temperature=0.01, + ) + assert len(response.choices) > 0 + assert isinstance(response.choices[0].message.content, str) + assert len(response.choices[0].message.content) > 0 + + @pytest.mark.skipif( + "MINIMAX_API_KEY" not in os.environ, reason="requires MiniMax API key" + ) + def test_api_m27_highspeed(self): + """Integration test: M2.7-highspeed model.""" + import openai + + client = openai.OpenAI( + api_key=os.environ["MINIMAX_API_KEY"], + base_url="https://api.minimax.io/v1", + ) + response = client.chat.completions.create( + model="MiniMax-M2.7-highspeed", + messages=[ + {"role": "user", "content": "Say 'yes' or 'no'."}, + ], + max_tokens=5, + temperature=0.01, + ) + assert len(response.choices) > 0 + assert isinstance(response.choices[0].message.content, str) + + @pytest.mark.skipif( + "MINIMAX_API_KEY" not in os.environ, reason="requires MiniMax API key" + ) + def test_api_temperature_edge(self): + """Integration test: verify low temperature works (near 0).""" + import openai + + client = openai.OpenAI( + api_key=os.environ["MINIMAX_API_KEY"], + base_url="https://api.minimax.io/v1", + ) + response = client.chat.completions.create( + model="MiniMax-M2.7", + messages=[ + {"role": "user", "content": "What is 2+2? Just the number."}, + ], + max_tokens=200, + temperature=0.01, + ) + content = response.choices[0].message.content + # M2.7 may include tags; strip them before checking + import re + + clean = re.sub(r".*?", "", content, flags=re.DOTALL).strip() + assert "4" in clean + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])