Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions src/google/adk/utils/instructions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from ..sessions.state import State

__all__ = [
'inject_session_state',
"inject_session_state",
]

logger = logging.getLogger('google_adk.' + __name__)
logger = logging.getLogger("google_adk." + __name__)


async def inject_session_state(
Expand Down Expand Up @@ -76,18 +76,29 @@ async def _async_sub(pattern, repl_async_fn, string) -> str:
result.append(replacement)
last_end = match.end()
result.append(string[last_end:])
return ''.join(result)
return "".join(result)

async def _replace_match(match) -> str:
var_name = match.group().lstrip('{').rstrip('}').strip()
matched_text = match.group()

# Check for exactly double braces (escaping)
if (
matched_text.startswith("{{")
and matched_text.endswith("}}")
and not matched_text.startswith("{{{")
and not matched_text.endswith("}}}")
):
return matched_text[1:-1]

var_name = matched_text.lstrip("{").rstrip("}").strip()
optional = False
if var_name.endswith('?'):
if var_name.endswith("?"):
optional = True
var_name = var_name.removesuffix('?')
if var_name.startswith('artifact.'):
var_name = var_name.removeprefix('artifact.')
var_name = var_name.removesuffix("?")
if var_name.startswith("artifact."):
var_name = var_name.removeprefix("artifact.")
if invocation_context.artifact_service is None:
raise ValueError('Artifact service is not initialized.')
raise ValueError("Artifact service is not initialized.")
artifact = await invocation_context.artifact_service.load_artifact(
app_name=invocation_context.session.app_name,
user_id=invocation_context.session.user_id,
Expand All @@ -97,31 +108,31 @@ async def _replace_match(match) -> str:
if artifact is None:
if optional:
logger.debug(
'Artifact %s not found, replacing with empty string', var_name
"Artifact %s not found, replacing with empty string", var_name
)
return ''
return ""
else:
raise KeyError(f'Artifact {var_name} not found.')
raise KeyError(f"Artifact {var_name} not found.")
return str(artifact)
else:
if not _is_valid_state_name(var_name):
return match.group()
if var_name in invocation_context.session.state:
value = invocation_context.session.state[var_name]
if value is None:
return ''
return ""
return str(value)
else:
if optional:
logger.debug(
'Context variable %s not found, replacing with empty string',
"Context variable %s not found, replacing with empty string",
var_name,
)
return ''
return ""
else:
raise KeyError(f'Context variable not found: `{var_name}`.')
raise KeyError(f"Context variable not found: `{var_name}`.")

return await _async_sub(r'{+[^{}]*}+', _replace_match, template)
return await _async_sub(r"{+[^{}]*}+", _replace_match, template)


def _is_valid_state_name(var_name):
Expand All @@ -138,12 +149,12 @@ def _is_valid_state_name(var_name):
Returns:
True if the variable name is a valid state name, False otherwise.
"""
parts = var_name.split(':')
parts = var_name.split(":")
if len(parts) == 1:
return var_name.isidentifier()

if len(parts) == 2:
prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX]
if (parts[0] + ':') in prefixes:
if (parts[0] + ":") in prefixes:
return parts[1].isidentifier()
return False
Loading