Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 105 additions & 7 deletions dotnet/src/Microsoft.Agents.AI.Workflows/Checkpointing/TypeId.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.IO;
using System.Reflection;
using System.Text.Json.Serialization;
using System.Text.RegularExpressions;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Agents.AI.Workflows.Checkpointing;
Expand All @@ -11,7 +14,7 @@ namespace Microsoft.Agents.AI.Workflows.Checkpointing;
/// </summary>
public sealed class TypeId : IEquatable<TypeId>
{
/// <inheritdoc cref="System.Reflection.Assembly.FullName"/>
/// <inheritdoc cref="Assembly.FullName"/>
public string AssemblyName { get; }

/// <inheritdoc cref="Type.FullName"/>
Expand Down Expand Up @@ -46,6 +49,11 @@ public override bool Equals(object? obj)
=> this.Equals(obj as TypeId);

/// <inheritdoc />
/// <remarks>
/// Compares the type full name and the simple assembly name. Version, culture, and public key
/// token are ignored both in <see cref="AssemblyName"/> and in any assembly-qualified generic
/// arguments embedded in <see cref="TypeName"/>.
/// </remarks>
public bool Equals(TypeId? other)
{
if (other is null)
Expand All @@ -58,11 +66,27 @@ public bool Equals(TypeId? other)
return true;
}

return this.AssemblyName == other.AssemblyName && this.TypeName == other.TypeName;
if (this.NormalizedTypeName != other.NormalizedTypeName)
{
return false;
}

if (string.Equals(this.AssemblyName, other.AssemblyName, StringComparison.Ordinal))
{
return true;
}

string? thisSimpleName = this.SimpleAssemblyName;
string? otherSimpleName = other.SimpleAssemblyName;

return thisSimpleName is not null
&& string.Equals(thisSimpleName, otherSimpleName, StringComparison.Ordinal);
}

/// <inheritdoc />
public override int GetHashCode() => HashCode.Combine(this.AssemblyName, this.TypeName);
/// <remarks>Hashes the normalized type name and the simple assembly name.</remarks>
public override int GetHashCode()
=> HashCode.Combine(this.SimpleAssemblyName, this.NormalizedTypeName);

/// <inheritdoc />
public static bool operator ==(TypeId? left, TypeId? right) => left is null ? right is null : left.Equals(right);
Expand All @@ -73,13 +97,27 @@ public bool Equals(TypeId? other)
/// <summary>
/// Determines whether the specified type matches both the assembly name and type name represented by this instance.
/// </summary>
/// <remarks>
/// Compares the type full name and the simple assembly name. Version, culture, and public key
/// token are ignored both in <see cref="AssemblyName"/> and in any assembly-qualified generic
/// arguments embedded in <see cref="TypeName"/>.
/// </remarks>
/// <param name="type">The type to compare against the stored assembly and type names. Cannot be null.</param>
/// <returns>true if the specified type's assembly and type names are equal to those stored in this instance; otherwise,
/// false.</returns>
/// <returns>true if the specified type's assembly simple name and normalized type full name are equal to those stored
/// in this instance; otherwise, false.</returns>
public bool IsMatch(Type type)
{
return this.AssemblyName == type.Assembly.FullName
&& this.TypeName == type.FullName;
string? runtimeNormalizedTypeName = type.FullName is null ? null : NormalizeTypeName(type.FullName);
if (this.NormalizedTypeName != runtimeNormalizedTypeName)
{
return false;
}

string? storedSimpleName = this.SimpleAssemblyName;
string? runtimeSimpleName = type.Assembly.GetName().Name;

return storedSimpleName is not null
&& string.Equals(storedSimpleName, runtimeSimpleName, StringComparison.Ordinal);
}

/// <summary>
Expand Down Expand Up @@ -113,4 +151,64 @@ public bool IsMatchPolymorphic(Type type)

/// <inheritdoc/>
public override string ToString() => $"{this.TypeName}, {this.AssemblyName}";

/// <summary>
/// The simple assembly name parsed from <see cref="AssemblyName"/>, lazily computed and cached.
/// </summary>
internal string? SimpleAssemblyName
=> field ??= GetSimpleAssemblyName(this.AssemblyName);

/// <summary>
/// The type full name with embedded assembly-qualified generic arguments stripped of
/// version, culture, and public key token. Lazily computed and cached.
/// </summary>
internal string NormalizedTypeName
=> field ??= NormalizeTypeName(this.TypeName);

private static readonly Regex s_assemblyQualifierPattern = new(
@", Version=[^,\]]+, Culture=[^,\]]+, PublicKeyToken=[^,\]]+",
RegexOptions.Compiled | RegexOptions.CultureInvariant);

/// <summary>
/// Removes <c>, Version=...</c>, <c>, Culture=...</c>, and <c>, PublicKeyToken=...</c> triplets
/// from <paramref name="typeName"/>. Returns the input unchanged when no triplet is present.
/// </summary>
internal static string NormalizeTypeName(string typeName)
{
if (typeName.IndexOf("Version=", StringComparison.Ordinal) < 0)
{
return typeName;
}

return s_assemblyQualifierPattern.Replace(typeName, string.Empty);
}

/// <summary>
/// Returns the simple assembly name parsed from an <see cref="Assembly.FullName"/>-style string,
/// or <see langword="null"/> when both parsing and the substring fallback fail.
/// </summary>
internal static string? GetSimpleAssemblyName(string assemblyFullName)
{
if (string.IsNullOrEmpty(assemblyFullName))
{
return null;
}

try
{
string? parsed = new AssemblyName(assemblyFullName).Name;
if (!string.IsNullOrEmpty(parsed))
{
return parsed;
}
}
catch (Exception ex) when (ex is FileLoadException or ArgumentException)
{
// Fall through to substring fallback.
}

int comma = assemblyFullName.IndexOf(',');
string fallback = (comma < 0 ? assemblyFullName : assemblyFullName.Substring(0, comma)).Trim();
return fallback.Length == 0 ? null : fallback;
}
}
19 changes: 18 additions & 1 deletion dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowSession.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
Expand Down Expand Up @@ -297,12 +298,14 @@ private async ValueTask<ResumeDispatchInfo> SendMessagesWithResponseConversionAs
/// interface assignment succeeds.
/// </summary>
[UnconditionalSuppressMessage("Trimming", "IL2057:Unrecognized value passed to the parameter of method", Justification = "Higher-layer envelope types are explicitly preserved by the package that defines them.")]
[UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access", Justification = "Higher-layer envelope types are explicitly preserved by the package that defines them.")]
[UnconditionalSuppressMessage("Trimming", "IL2073:Members annotated with 'DynamicallyAccessedMembersAttribute' require dynamic access", Justification = "Higher-layer envelope types are explicitly preserved by the package that defines them.")]
private static bool TryGetRequestEnvelope(ExternalRequest request, [NotNullWhen(true)] out IExternalRequestEnvelope? envelope)
{
envelope = null;

TypeId requestType = request.PortInfo.RequestType;
Type? concreteType = Type.GetType($"{requestType.TypeName}, {requestType.AssemblyName}", throwOnError: false);
Type? concreteType = ResolveTypeLenient(requestType);
if (concreteType is null || !typeof(IExternalRequestEnvelope).IsAssignableFrom(concreteType))
{
return false;
Expand All @@ -317,6 +320,20 @@ private static bool TryGetRequestEnvelope(ExternalRequest request, [NotNullWhen(
return true;
}

/// <summary>Caches <see cref="ResolveTypeLenient"/> results keyed by <see cref="TypeId"/>.</summary>
private static readonly ConcurrentDictionary<TypeId, Type?> s_envelopeTypeCache = new();

/// <summary>
/// Resolves a <see cref="TypeId"/> to a loaded <see cref="Type"/> using partial-name binding,
/// which matches any loaded assembly with the same simple name regardless of version. Results
/// are cached.
/// </summary>
[UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access", Justification = "Higher-layer envelope types are explicitly preserved by the package that defines them.")]
[UnconditionalSuppressMessage("Trimming", "IL2057:Unrecognized value passed to the parameter of method", Justification = "Higher-layer envelope types are explicitly preserved by the package that defines them.")]
internal static Type? ResolveTypeLenient(TypeId typeId)
=> s_envelopeTypeCache.GetOrAdd(typeId, static id =>
Type.GetType($"{id.NormalizedTypeName}, {id.SimpleAssemblyName}", throwOnError: false));

Comment thread
peibekwe marked this conversation as resolved.
/// <summary>
/// Creates the workflow-facing request content surfaced in response updates.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.Agents.AI.Workflows.Checkpointing;
using Microsoft.Agents.AI.Workflows.InProc;

namespace Microsoft.Agents.AI.Workflows.UnitTests;

/// <summary>
/// Verifies that a checkpoint serialized through <see cref="JsonCheckpointStore"/> can be restored
/// after every <c>Version=X.Y.Z.W</c> substring in the persisted JSON is rewritten to a different value.
/// </summary>
public class CheckpointVersionToleranceTests
{
private sealed class EchoExecutor() : Executor("Echo")
{
protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder)
=> protocolBuilder.ConfigureRoutes(routeBuilder =>
routeBuilder.AddHandler<string>((msg, ctx) => ctx.SendMessageAsync(msg)));
}

[Theory]
[InlineData(ExecutionEnvironment.InProcess_OffThread)]
[InlineData(ExecutionEnvironment.InProcess_Lockstep)]
internal async Task Test_Checkpoint_Resumes_AfterAssemblyVersionRewriteAsync(ExecutionEnvironment environment)
{
Comment thread
peibekwe marked this conversation as resolved.
// Arrange
RequestPort<string, string> requestPort = RequestPort.Create<string, string>("TestPort");
EchoExecutor echo = new();

Workflow workflow = new WorkflowBuilder(requestPort)
.AddEdge(requestPort, echo)
.Build();

VersionMutatingJsonStore store = new();
CheckpointManager checkpointManager = CheckpointManager.CreateJson(store);
InProcessExecutionEnvironment env = environment.ToWorkflowExecutionEnvironment();

// Run the workflow and capture a checkpoint.
CheckpointInfo? checkpoint = null;
await using (StreamingRun firstRun = await env.WithCheckpointing(checkpointManager)
.RunStreamingAsync(workflow, "Hello"))
{
await foreach (WorkflowEvent evt in firstRun.WatchStreamAsync(blockOnPendingRequest: false))
{
if (evt is SuperStepCompletedEvent step && step.CompletionInfo?.Checkpoint is { } cp)
{
checkpoint = cp;
}
}
}

checkpoint.Should().NotBeNull();
store.MutationApplied.Should().BeFalse();

// Resume against the mutated store, which rewrites every Version=X.Y.Z.W in the persisted JSON.
Func<Task> resume = async () =>
{
await using StreamingRun resumed = await env.WithCheckpointing(checkpointManager)
.ResumeStreamingAsync(workflow, checkpoint!);
using CancellationTokenSource cts = new(TimeSpan.FromSeconds(10));
await foreach (WorkflowEvent _ in resumed.WatchStreamAsync(blockOnPendingRequest: false, cts.Token))
{
}
};

await resume.Should().NotThrowAsync("resume must succeed when persisted assembly versions differ from loaded ones");
store.MutationApplied.Should().BeTrue();
}

/// <summary>
/// JSON checkpoint store that rewrites every <c>Version=N.N.N.N</c> token in the persisted
/// payload at retrieval time.
/// </summary>
private sealed class VersionMutatingJsonStore : JsonCheckpointStore
{
private static readonly Regex s_versionPattern = new(@"Version=\d+\.\d+\.\d+\.\d+", RegexOptions.Compiled);

private readonly Dictionary<string, Dictionary<string, JsonElement>> _store = [];

public string ReplacementVersion { get; init; } = "99.0.0.0";

public bool MutationApplied { get; private set; }

public override ValueTask<CheckpointInfo> CreateCheckpointAsync(string sessionId, JsonElement value, CheckpointInfo? parent = null)
{
if (!this._store.TryGetValue(sessionId, out Dictionary<string, JsonElement>? sessionStore))
{
sessionStore = this._store[sessionId] = [];
}

CheckpointInfo info = new(sessionId);
sessionStore[info.CheckpointId] = value.Clone();
return new ValueTask<CheckpointInfo>(info);
}

public override ValueTask<JsonElement> RetrieveCheckpointAsync(string sessionId, CheckpointInfo key)
{
if (!this._store.TryGetValue(sessionId, out Dictionary<string, JsonElement>? sessionStore)
|| !sessionStore.TryGetValue(key.CheckpointId, out JsonElement raw))
{
throw new KeyNotFoundException($"Could not retrieve checkpoint with id {key.CheckpointId} for session {sessionId}");
}

string rawText = raw.GetRawText();
string mutatedText = s_versionPattern.Replace(rawText, $"Version={this.ReplacementVersion}");

if (!ReferenceEquals(rawText, mutatedText) && rawText != mutatedText)
{
this.MutationApplied = true;
}

using JsonDocument doc = JsonDocument.Parse(mutatedText);
return new ValueTask<JsonElement>(doc.RootElement.Clone());
}

public override ValueTask<IEnumerable<CheckpointInfo>> RetrieveIndexAsync(string sessionId, CheckpointInfo? withParent = null)
{
if (!this._store.TryGetValue(sessionId, out Dictionary<string, JsonElement>? sessionStore))
{
return new ValueTask<IEnumerable<CheckpointInfo>>(Array.Empty<CheckpointInfo>());
}

IEnumerable<CheckpointInfo> infos = sessionStore.Keys.Select(id => new CheckpointInfo(sessionId, id));
return new ValueTask<IEnumerable<CheckpointInfo>>(infos);
}
}
}
Loading
Loading