From 76c3f3f122f9731b89bb2f02d6950d60d337adb3 Mon Sep 17 00:00:00 2001 From: Joshua Catt Date: Fri, 3 Apr 2026 14:14:25 -0400 Subject: [PATCH] added llm_complete --- README.md | 17 +++++++ src/datacustomcode/__init__.py | 2 + src/datacustomcode/ai/__init__.py | 35 +++++++++++++ src/datacustomcode/ai/llm.py | 83 +++++++++++++++++++++++++++++++ tests/ai/__init__.py | 14 ++++++ tests/ai/test_llm.py | 51 +++++++++++++++++++ 6 files changed, 202 insertions(+) create mode 100644 src/datacustomcode/ai/__init__.py create mode 100644 src/datacustomcode/ai/llm.py create mode 100644 tests/ai/__init__.py create mode 100644 tests/ai/test_llm.py diff --git a/README.md b/README.md index d3d601a..c5d4e39 100644 --- a/README.md +++ b/README.md @@ -170,6 +170,23 @@ client.write_to_dlo('output_DLO') > [!WARNING] > Currently we only support reading from DMOs and writing to DMOs or reading from DLOs and writing to DLOs, but they cannot mix. +## LLM Capabilities + +* `llm_complete(prompt_col, model_id, max_tokens)` – Generate AI completions from a prompt column + +For example: +```python +from datacustomcode.ai import llm_complete +from pyspark.sql.functions import concat_ws, lit, col + +# Generate summaries +prompt = concat_ws(" ", + lit("Summarize:"), + col("Name__c"), + col("Description__c") +) +df = df.withColumn("summary", llm_complete(prompt)) +``` ## CLI diff --git a/src/datacustomcode/__init__.py b/src/datacustomcode/__init__.py index 2662e74..831d3f4 100644 --- a/src/datacustomcode/__init__.py +++ b/src/datacustomcode/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from datacustomcode.ai import llm_complete from datacustomcode.client import Client from datacustomcode.credentials import AuthType, Credentials from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader @@ -28,4 +29,5 @@ "LocalProxyClientProvider", "PrintDataCloudWriter", "QueryAPIDataCloudReader", + "llm_complete", ] diff --git a/src/datacustomcode/ai/__init__.py b/src/datacustomcode/ai/__init__.py new file mode 100644 index 0000000..d81d34b --- /dev/null +++ b/src/datacustomcode/ai/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LLM capabilities for Data Cloud Custom Code. + +This module provides LLM functions + +Available functions: + llm_complete: Generate completions from a prompt column + +Example: + from datacustomcode.ai import llm_complete + from pyspark.sql.functions import col + + df = spark.read.table("Account_std__dll") + df = df.withColumn("summary", llm_complete("Name__c")) +""" + +from datacustomcode.ai.llm import llm_complete + +__all__ = [ + "llm_complete", +] diff --git a/src/datacustomcode/ai/llm.py b/src/datacustomcode/ai/llm.py new file mode 100644 index 0000000..6df631c --- /dev/null +++ b/src/datacustomcode/ai/llm.py @@ -0,0 +1,83 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LLM completion functions for Data Cloud Custom Code. +""" + +from typing import Union + +from pyspark.sql import Column +from pyspark.sql.functions import call_function, lit + +# Default values for llm_complete function +# TODO: Validate these defaults +_DEFAULT_MODEL_ID = "sfdc_ai__DefaultGPT4Omni" +_DEFAULT_MAX_TOKENS = 200 +_LLM_GATEWAY_UDF_NAME = "llm_gateway_generate" + + +def llm_complete( + prompt_col: Union[Column, str], + *, + model_id: str = _DEFAULT_MODEL_ID, + max_tokens: int = _DEFAULT_MAX_TOKENS, +) -> Column: + """Returns the AI-generated text response as a string column. + + Args: + prompt_col: Column or column name containing the prompt text. + The prompt should be a string value (max 32KB recommended). + Use string functions like concat_ws(), format_string(), etc. + to construct complex prompts from multiple columns. + model_id: Defaults to "sfdc_ai__DefaultGPT4Omni". + Available models depend on your org's configuration. + max_tokens: Maximum tokens in the response. Defaults to 200. + Higher values allow longer responses but increase latency and cost. + + Returns: + Column of StringType with AI-generated response. + Returns null if the input prompt is null. + + Raises: + TypeError: If prompt_col is not a Column or string. + ValueError: If max_tokens is not positive. + """ + # Input validation + if not isinstance(prompt_col, (Column, str)): + raise TypeError( + f"prompt_col must be a Column or str, got {type(prompt_col).__name__}" + ) + + if not isinstance(max_tokens, int) or max_tokens <= 0: + raise ValueError(f"max_tokens must be a positive integer, got {max_tokens}") + + # Convert string column name to Column + if isinstance(prompt_col, str): + from pyspark.sql.functions import col + + prompt_col = col(prompt_col) + + from pyspark.sql.functions import named_struct + + template = "{prompt}" + values_struct = named_struct(lit("prompt"), prompt_col) + + return call_function( + _LLM_GATEWAY_UDF_NAME, + lit(template), + values_struct, + lit(model_id), + lit(max_tokens), + ) diff --git a/tests/ai/__init__.py b/tests/ai/__init__.py new file mode 100644 index 0000000..93988ff --- /dev/null +++ b/tests/ai/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/ai/test_llm.py b/tests/ai/test_llm.py new file mode 100644 index 0000000..d24221c --- /dev/null +++ b/tests/ai/test_llm.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for datacustomcode.ai.llm module.""" + +import pytest + +from datacustomcode.ai import llm_complete + + +class TestLlmComplete: + """Tests for llm_complete function.""" + + def test_invalid_prompt_col_type_int(self): + """Test that invalid prompt_col type raises TypeError.""" + with pytest.raises(TypeError, match="prompt_col must be a Column or str"): + llm_complete(123) + + def test_invalid_max_tokens_type_string(self): + """Test that string max_tokens raises ValueError.""" + with pytest.raises(ValueError, match="max_tokens must be a positive integer"): + llm_complete("test_col", max_tokens="invalid") + + def test_invalid_max_tokens_type_float(self): + """Test that float max_tokens raises ValueError.""" + with pytest.raises(ValueError, match="max_tokens must be a positive integer"): + llm_complete("test_col", max_tokens=100.5) + + def test_negative_max_tokens(self): + """Test that negative max_tokens raises ValueError.""" + with pytest.raises(ValueError, match="max_tokens must be a positive integer"): + llm_complete("test_col", max_tokens=-1) + + def test_zero_max_tokens(self): + """Test that zero max_tokens raises ValueError.""" + with pytest.raises(ValueError, match="max_tokens must be a positive integer"): + llm_complete("test_col", max_tokens=0) + +