diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml deleted file mode 100644 index 65b3c8a..0000000 --- a/.github/workflows/codeql.yml +++ /dev/null @@ -1,73 +0,0 @@ -name: CodeQL - -on: - push: - branches-ignore: [ master ] - paths: [ 'src/**' ] - pull_request: - branches: [ master, develop] - paths: [ 'src/**' ] - workflow_dispatch: - schedule: - - cron: '0 0 1,15 * *' - -env: - NET_VERSION: '9.x' - -jobs: - analyze: - name: Analyze - runs-on: ubuntu-latest - permissions: - actions: read - contents: read - security-events: write - - strategy: - fail-fast: false - matrix: - language: [ 'csharp' ] - # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] - # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support - - steps: - - name: Checkout repository - uses: actions/checkout@v3 - with: - fetch-depth: 0 # avoid shallow clone so nbgv can do its work. - - - name: Setup .NET SDK ${{ env.NET_VERSION }} - uses: actions/setup-dotnet@v3 - with: - dotnet-version: ${{ env.NET_VERSION }} - dotnet-quality: 'ga' - - # Initializes the CodeQL tools for scanning. - - name: Initialize CodeQL - uses: github/codeql-action/init@v3 - with: - languages: ${{ matrix.language }} - # If you wish to specify custom queries, you can do so here or in a config file. - # By default, queries listed here will override any specified in a config file. - # Prefix the list here with "+" to use these queries and those in the config file. - - # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs - # queries: security-extended,security-and-quality - - # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). - # If this step fails, then you should remove it and run the build manually (see below) - - name: Autobuild - uses: github/codeql-action/autobuild@v3 - - # ?? Command-line programs to run using the OS shell. - # ?? See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun - - # If the Autobuild fails above, remove it and uncomment the following three lines. - # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. - - # - run: | - # echo "Run, Build Application using script" - # ./location_of_script_within_repo/buildscript.sh - - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 diff --git a/samples/DatabaseGpt.Web/DatabaseGpt.Web.csproj b/samples/DatabaseGpt.Web/DatabaseGpt.Web.csproj index 57ca36e..60b95ce 100644 --- a/samples/DatabaseGpt.Web/DatabaseGpt.Web.csproj +++ b/samples/DatabaseGpt.Web/DatabaseGpt.Web.csproj @@ -1,21 +1,21 @@  - net8.0 + net9.0 enable build$([System.DateTime]::UtcNow.ToString("yyyyMMddHHmmss")) Marco Minerva - - - - - - - - + + + + + + + + diff --git a/samples/DatabaseGptConsole/DatabaseGptConsole.csproj b/samples/DatabaseGptConsole/DatabaseGptConsole.csproj index e5c08e2..366fd1a 100644 --- a/samples/DatabaseGptConsole/DatabaseGptConsole.csproj +++ b/samples/DatabaseGptConsole/DatabaseGptConsole.csproj @@ -2,7 +2,7 @@ Exe - net8.0 + net9.0 enable enable 1.4 @@ -10,8 +10,11 @@ - - + + + + + diff --git a/samples/DatabaseGptConsole/Program.cs b/samples/DatabaseGptConsole/Program.cs index 1b76db5..221a90a 100644 --- a/samples/DatabaseGptConsole/Program.cs +++ b/samples/DatabaseGptConsole/Program.cs @@ -1,6 +1,8 @@ -using ChatGptNet; +using Azure; +using Azure.AI.OpenAI; using DatabaseGpt; using DatabaseGptConsole; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; @@ -29,6 +31,17 @@ static void ConfigureServices(HostBuilderContext context, IServiceCollection ser { services.AddSingleton(); + services.AddHybridCache(); + + var apiKey = context.Configuration.GetValue("ChatGPT:ApiKey")!; + var deploymentName = context.Configuration.GetValue("ChatGPT:DeploymentName")!; + var endpoint = context.Configuration.GetValue("ChatGpt:Endpoint")!; + + var azureOpenAIClient = new AzureOpenAIClient(new(endpoint), new AzureKeyCredential(apiKey)); + var chatClient = azureOpenAIClient.GetChatClient(deploymentName).AsIChatClient(); + + services.AddChatClient(chatClient); + services.AddDatabaseGpt(database => { // For SQL Server. @@ -42,9 +55,5 @@ static void ConfigureServices(HostBuilderContext context, IServiceCollection ser // For SQLite. //database.UseConfiguration(context.Configuration) // .UseSqlite(context.Configuration.GetConnectionString("SqliteConnection")); - }, - chatGpt => - { - chatGpt.UseConfiguration(context.Configuration); }); } diff --git a/samples/DatabaseGptConsole/appsettings.json b/samples/DatabaseGptConsole/appsettings.json index ca07d3d..e38b412 100644 --- a/samples/DatabaseGptConsole/appsettings.json +++ b/samples/DatabaseGptConsole/appsettings.json @@ -3,12 +3,9 @@ "SqlConnection": "" }, "ChatGPT": { - "Provider": "OpenAI", "ApiKey": "", - "Organization": "", - "ResourceName": "", - "AuthenticationType": "ApiKey", - "DefaultModel": "gpt-3.5-turbo-16k", + "Endpoint": "", + "DeploymentName": "", "MessageLimit": 20, "MessageExpiration": "00:30:00" }, diff --git a/src/DatabaseGpt.Abstractions/DatabaseGpt.Abstractions.csproj b/src/DatabaseGpt.Abstractions/DatabaseGpt.Abstractions.csproj index 9cfe87d..c0b991b 100644 --- a/src/DatabaseGpt.Abstractions/DatabaseGpt.Abstractions.csproj +++ b/src/DatabaseGpt.Abstractions/DatabaseGpt.Abstractions.csproj @@ -1,13 +1,13 @@ - net8.0 + net9.0 enable enable - + diff --git a/src/DatabaseGpt.Npgsql/DatabaseGpt.Npgsql.csproj b/src/DatabaseGpt.Npgsql/DatabaseGpt.Npgsql.csproj index fa783b6..6dc0cb4 100644 --- a/src/DatabaseGpt.Npgsql/DatabaseGpt.Npgsql.csproj +++ b/src/DatabaseGpt.Npgsql/DatabaseGpt.Npgsql.csproj @@ -1,14 +1,14 @@ - net8.0 + net9.0 enable enable - - + + diff --git a/src/DatabaseGpt.SqlServer/DatabaseGpt.SqlServer.csproj b/src/DatabaseGpt.SqlServer/DatabaseGpt.SqlServer.csproj index 5e295f9..e41c6eb 100644 --- a/src/DatabaseGpt.SqlServer/DatabaseGpt.SqlServer.csproj +++ b/src/DatabaseGpt.SqlServer/DatabaseGpt.SqlServer.csproj @@ -1,14 +1,14 @@ - net8.0 + net9.0 enable enable - - + + diff --git a/src/DatabaseGpt.SqlServer/Models/ColumnEntity.cs b/src/DatabaseGpt.SqlServer/Models/ColumnEntity.cs deleted file mode 100644 index 4ade9bf..0000000 --- a/src/DatabaseGpt.SqlServer/Models/ColumnEntity.cs +++ /dev/null @@ -1,12 +0,0 @@ -namespace DatabaseGpt.SqlServer.Models; - -internal class ColumnEntity -{ - public string Schema { get; set; } = null!; - - public string Table { get; set; } = null!; - - public string Column { get; set; } = null!; - - public string Description { get; set; } = null!; -} diff --git a/src/DatabaseGpt.SqlServer/SqlServerDatabaseGptProvider.cs b/src/DatabaseGpt.SqlServer/SqlServerDatabaseGptProvider.cs index 134a6ea..a57ade5 100644 --- a/src/DatabaseGpt.SqlServer/SqlServerDatabaseGptProvider.cs +++ b/src/DatabaseGpt.SqlServer/SqlServerDatabaseGptProvider.cs @@ -4,7 +4,6 @@ using Dapper; using DatabaseGpt.Abstractions; using DatabaseGpt.Abstractions.Exceptions; -using DatabaseGpt.SqlServer.Models; using Microsoft.Data.SqlClient; namespace DatabaseGpt.SqlServer; @@ -68,7 +67,7 @@ FROM INFORMATION_SCHEMA.COLUMNS AND TABLE_SCHEMA + '.' + TABLE_NAME + '.' + COLUMN_NAME NOT IN @{nameof(excludedColumns)}; """; - var columns = await connection.QueryAsync(query, new { schema = table.Schema, table = table.Name, excludedColumns }); + var columns = await connection.QueryAsync(query, new { schema = table.Schema, table = table.Name, excludedColumns }); result.AppendLine($"CREATE TABLE [{table.Schema}].[{table.Name}] ({string.Join(',', columns)});"); } diff --git a/src/DatabaseGpt.Sqlite/DatabaseGpt.Sqlite.csproj b/src/DatabaseGpt.Sqlite/DatabaseGpt.Sqlite.csproj index efe00c9..5724868 100644 --- a/src/DatabaseGpt.Sqlite/DatabaseGpt.Sqlite.csproj +++ b/src/DatabaseGpt.Sqlite/DatabaseGpt.Sqlite.csproj @@ -1,14 +1,14 @@  - net8.0 + net9.0 enable enable - - + + diff --git a/src/DatabaseGpt/DatabaseGpt.csproj b/src/DatabaseGpt/DatabaseGpt.csproj index 7d813e2..4587581 100644 --- a/src/DatabaseGpt/DatabaseGpt.csproj +++ b/src/DatabaseGpt/DatabaseGpt.csproj @@ -1,7 +1,7 @@  - net8.0 + net9.0 enable enable @@ -11,8 +11,9 @@ - - + + + diff --git a/src/DatabaseGpt/DatabaseGptClient.cs b/src/DatabaseGpt/DatabaseGptClient.cs index 890dd41..d26b919 100644 --- a/src/DatabaseGpt/DatabaseGptClient.cs +++ b/src/DatabaseGpt/DatabaseGptClient.cs @@ -1,15 +1,17 @@ -using ChatGptNet; -using ChatGptNet.Extensions; -using DatabaseGpt.Abstractions; +using DatabaseGpt.Abstractions; +using DatabaseGpt.Abstractions.Exceptions; using DatabaseGpt.Exceptions; using DatabaseGpt.Models; using DatabaseGpt.Settings; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Hybrid; using Polly; using Polly.Registry; +using ChatHistory = System.Collections.Generic.List; namespace DatabaseGpt; -internal class DatabaseGptClient(IChatGptClient chatGptClient, ResiliencePipelineProvider pipelineProvider, IServiceProvider serviceProvider, DatabaseGptSettings databaseGptSettings) : IDatabaseGptClient +internal class DatabaseGptClient(IChatClient chatGptClient, HybridCache cache, ResiliencePipelineProvider pipelineProvider, IServiceProvider serviceProvider, DatabaseGptSettings databaseGptSettings) : IDatabaseGptClient { private readonly IDatabaseGptProvider provider = databaseGptSettings.CreateProvider(); private readonly ResiliencePipeline pipeline = pipelineProvider.GetPipeline(nameof(DatabaseGptClient)); @@ -39,8 +41,21 @@ private async Task ExecuteNaturalLanguageQueryInternalAs var tables = await GetTablesAsync(sessionId, question, options, cancellationToken); var query = await GetQueryAsync(sessionId, question, tables, options, cancellationToken); - var reader = await provider.ExecuteQueryAsync(query, cancellationToken); - return (query, reader); + try + { + var reader = await provider.ExecuteQueryAsync(query, cancellationToken); + return (query, reader); + } + catch (DatabaseGptException ex) + { + // If there is an exception while executing the query, we log it in the chat history, so the assistant can learn from it. + var chat = await GetChatHistoryAsync(sessionId, cancellationToken); + chat.Add(new(ChatRole.Assistant, ex.ToString())); + await UpdateCacheAsync(sessionId, chat, cancellationToken); + + //Rethrow the exception, so it will be handled by the pipeline. + throw; + } }, cancellationToken); return new(query, reader); @@ -48,8 +63,10 @@ private async Task ExecuteNaturalLanguageQueryInternalAs private async Task CreateSessionAsync(Guid sessionId, CancellationToken cancellationToken) { - var conversationExists = await chatGptClient.ConversationExistsAsync(sessionId, cancellationToken); - if (!conversationExists) + sessionId = sessionId == default ? Guid.CreateVersion7() : sessionId; + var history = await GetChatHistoryAsync(sessionId, cancellationToken); + + if (history.Count == 0) { var tables = await provider.GetTablesAsync(databaseGptSettings.IncludedTables, databaseGptSettings.ExcludedTables, cancellationToken); @@ -67,7 +84,8 @@ private async Task CreateSessionAsync(Guid sessionId, CancellationToken ca """; } - sessionId = await chatGptClient.SetupAsync(sessionId, systemMessage, cancellationToken); + history.Add(new(ChatRole.System, systemMessage)); + await UpdateCacheAsync(sessionId, history, cancellationToken); } return sessionId; @@ -90,9 +108,15 @@ The selected tables should be returned in a comma separated list. Your response await options.OnStarting.Invoke(serviceProvider); } - var response = await chatGptClient.AskAsync(sessionId, request, cancellationToken: cancellationToken); + var chat = await GetChatHistoryAsync(sessionId, cancellationToken); + chat.Add(new(ChatRole.User, request)); - var candidateTables = response.GetContent()!.Trim('\''); + var response = await chatGptClient.GetResponseAsync(chat, cancellationToken: cancellationToken); + + chat.Add(new(ChatRole.Assistant, response.Text)); + await UpdateCacheAsync(sessionId, chat, cancellationToken); + + var candidateTables = response.Text.Trim('\''); if (candidateTables == "NONE") { throw new NoTableFoundException($"No available information in the provided tables can be useful for the question '{question}'."); @@ -134,9 +158,15 @@ CREATE TABLE Table2 (Column3 VARCHAR(255), Column4 VARCHAR(255)) request += $"{Environment.NewLine}{queryHints}"; } - var response = await chatGptClient.AskAsync(sessionId, request, cancellationToken: cancellationToken); + var chat = await GetChatHistoryAsync(sessionId, cancellationToken); + chat.Add(new(ChatRole.User, request)); + + var response = await chatGptClient.GetResponseAsync(chat, cancellationToken: cancellationToken); - var query = response.GetContent()!; + chat.Add(new(ChatRole.Assistant, response.Text)); + await UpdateCacheAsync(sessionId, chat, cancellationToken); + + var query = response.Text; if (query == "NONE") { throw new InvalidSqlException($"The question '{question}' requires an INSERT, UPDATE or DELETE command, that isn't supported."); @@ -155,6 +185,27 @@ CREATE TABLE Table2 (Column3 VARCHAR(255), Column4 VARCHAR(255)) return query; } + private async Task UpdateCacheAsync(Guid conversationId, ChatHistory chat, CancellationToken cancellationToken) + { + if (chat.Count > databaseGptSettings.MessageLimit) + { + chat.RemoveRange(0, chat.Count - databaseGptSettings.MessageLimit); + } + + await cache.SetAsync(conversationId.ToString(), chat, cancellationToken: cancellationToken); + } + + private async Task GetChatHistoryAsync(Guid conversationId, CancellationToken cancellationToken) + { + var historyCache = await cache.GetOrCreateAsync(conversationId.ToString(), (cancellationToken) => + { + return ValueTask.FromResult([]); + }, cancellationToken: cancellationToken); + + var chat = new ChatHistory(historyCache); + return chat; + } + protected virtual void Dispose(bool disposing) { if (!disposedValue) diff --git a/src/DatabaseGpt/DatabaseGptServiceCollectionExtensions.cs b/src/DatabaseGpt/DatabaseGptServiceCollectionExtensions.cs index 47e0e59..d4f24d0 100644 --- a/src/DatabaseGpt/DatabaseGptServiceCollectionExtensions.cs +++ b/src/DatabaseGpt/DatabaseGptServiceCollectionExtensions.cs @@ -1,5 +1,4 @@ -using ChatGptNet; -using DatabaseGpt.Abstractions.Exceptions; +using DatabaseGpt.Abstractions.Exceptions; using DatabaseGpt.Settings; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; @@ -11,28 +10,10 @@ namespace DatabaseGpt; public static class DatabaseGptServiceCollectionExtensions { - public static IServiceCollection AddDatabaseGpt(this IServiceCollection services, Action configureDatabaseGpt, Action configureChatGpt, ServiceLifetime lifetime = ServiceLifetime.Scoped) + public static IServiceCollection AddDatabaseGpt(this IServiceCollection services, Action configureDatabaseGpt, int maxQueryGenerationRetries = 3) { ArgumentNullException.ThrowIfNull(services); ArgumentNullException.ThrowIfNull(configureDatabaseGpt); - ArgumentNullException.ThrowIfNull(configureChatGpt); - - var settings = new DatabaseGptSettings(); - configureDatabaseGpt(settings); - services.AddSingleton(settings); - - services.Add(new ServiceDescriptor(typeof(IDatabaseGptClient), typeof(DatabaseGptClient), lifetime)); - services.AddChatGpt(configureChatGpt, httpClient => ConfigureHttpClientResiliency(httpClient)); - services.AddResiliencePipeline(settings.MaxRetries); - - return services; - } - - public static IServiceCollection AddDatabaseGpt(this IServiceCollection services, Action configureDatabaseGpt, Action configureChatGpt, int maxQueryGenerationRetries = 3) - { - ArgumentNullException.ThrowIfNull(services); - ArgumentNullException.ThrowIfNull(configureDatabaseGpt); - ArgumentNullException.ThrowIfNull(configureChatGpt); services.AddScoped(provider => { @@ -42,49 +23,26 @@ public static IServiceCollection AddDatabaseGpt(this IServiceCollection services }); services.AddScoped(); - services.AddChatGpt(configureChatGpt, httpClient => ConfigureHttpClientResiliency(httpClient)); services.AddResiliencePipeline(maxQueryGenerationRetries); return services; } - public static IServiceCollection AddDatabaseGpt(this IServiceCollection services, Action configureDatabaseGpt, Action configureChatGpt, ServiceLifetime lifetime = ServiceLifetime.Scoped) + public static IServiceCollection AddDatabaseGpt(this IServiceCollection services, Action configureDatabaseGpt, ServiceLifetime lifetime = ServiceLifetime.Scoped) { ArgumentNullException.ThrowIfNull(services); ArgumentNullException.ThrowIfNull(configureDatabaseGpt); - ArgumentNullException.ThrowIfNull(configureChatGpt); var settings = new DatabaseGptSettings(); configureDatabaseGpt(settings); services.AddSingleton(settings); services.Add(new ServiceDescriptor(typeof(IDatabaseGptClient), typeof(DatabaseGptClient), lifetime)); - services.AddChatGpt(configureChatGpt, httpClient => ConfigureHttpClientResiliency(httpClient)); services.AddResiliencePipeline(settings.MaxRetries); return services; } - public static IServiceCollection AddDatabaseGpt(this IServiceCollection services, Action configureDatabaseGpt, Action configureChatGpt, int maxQueryGenerationRetries = 3) - { - ArgumentNullException.ThrowIfNull(services); - ArgumentNullException.ThrowIfNull(configureDatabaseGpt); - ArgumentNullException.ThrowIfNull(configureChatGpt); - - services.AddScoped(provider => - { - var settings = new DatabaseGptSettings(); - configureDatabaseGpt(provider, settings); - return settings; - }); - - services.AddScoped(); - services.AddChatGpt(configureChatGpt, httpClient => ConfigureHttpClientResiliency(httpClient)); - services.AddResiliencePipeline(maxQueryGenerationRetries); - - return services; - } - public static DatabaseGptSettings UseConfiguration(this DatabaseGptSettings settings, IConfiguration configuration, string databaseGptSettingsSectionName = "DatabaseGptSettings") { ArgumentNullException.ThrowIfNull(settings); diff --git a/src/DatabaseGpt/Settings/DatabaseGptSettings.cs b/src/DatabaseGpt/Settings/DatabaseGptSettings.cs index c10071b..4a4414d 100644 --- a/src/DatabaseGpt/Settings/DatabaseGptSettings.cs +++ b/src/DatabaseGpt/Settings/DatabaseGptSettings.cs @@ -16,6 +16,8 @@ public class DatabaseGptSettings : IDatabaseGptSettings public int MaxRetries { get; set; } = 3; + public int MessageLimit { get; set; } = 10; + public void SetDatabaseGptProviderFactory(Func providerFactory) { ArgumentNullException.ThrowIfNull(providerFactory);