diff --git a/src/Cellm.Models/Prompts/PromptBuilder.cs b/src/Cellm.Models/Prompts/PromptBuilder.cs index 07da4a73..89f8c40b 100644 --- a/src/Cellm.Models/Prompts/PromptBuilder.cs +++ b/src/Cellm.Models/Prompts/PromptBuilder.cs @@ -4,8 +4,8 @@ namespace Cellm.Models.Prompts; public class PromptBuilder { - private List _messages = new(); - private ChatOptions _options = new(); + private readonly List _messages = []; + private readonly ChatOptions _options = new(); public PromptBuilder() { @@ -30,6 +30,12 @@ public PromptBuilder SetTemperature(double temperature) return this; } + public PromptBuilder SetMaxOutputTokens(int maxOutputTokens) + { + _options.MaxOutputTokens = maxOutputTokens; + return this; + } + public PromptBuilder AddSystemMessage(string content) { _messages.Add(new ChatMessage(ChatRole.System, content)); diff --git a/src/Cellm.Models/Providers/Anthropic/AnthropicRequestHandler.cs b/src/Cellm.Models/Providers/Anthropic/AnthropicRequestHandler.cs index 771ee211..2d23d9c0 100644 --- a/src/Cellm.Models/Providers/Anthropic/AnthropicRequestHandler.cs +++ b/src/Cellm.Models/Providers/Anthropic/AnthropicRequestHandler.cs @@ -1,20 +1,15 @@ using Cellm.Models.Prompts; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Options; namespace Cellm.Models.Providers.Anthropic; internal class AnthropicRequestHandler( - [FromKeyedServices(Provider.Anthropic)] IChatClient chatClient, - IOptionsMonitor providerConfiguration) + [FromKeyedServices(Provider.Anthropic)] IChatClient chatClient) : IModelRequestHandler { public async Task Handle(AnthropicRequest request, CancellationToken cancellationToken) { - // Required by Anthropic API - request.Prompt.Options.MaxOutputTokens ??= providerConfiguration.CurrentValue.MaxOutputTokens; - var chatResponse = await chatClient.GetResponseAsync( request.Prompt.Messages, request.Prompt.Options, diff --git a/src/Cellm.Models/ServiceCollectionExtensions.cs b/src/Cellm.Models/ServiceCollectionExtensions.cs index 9a21c246..e6a606c6 100644 --- a/src/Cellm.Models/ServiceCollectionExtensions.cs +++ b/src/Cellm.Models/ServiceCollectionExtensions.cs @@ -41,6 +41,7 @@ using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Options; +using Mistral.SDK; using OpenAI; using Polly; using Polly.CircuitBreaker; @@ -164,27 +165,6 @@ public static IServiceCollection AddCellmChatClient(this IServiceCollection serv return services; } - public static IServiceCollection AddOllamaChatClient(this IServiceCollection services) - { - services - .AddKeyedChatClient(Provider.Ollama, serviceProvider => - { - var account = ServiceLocator.ServiceProvider.GetRequiredService(); - account.RequireEntitlement(Entitlement.EnableOllamaProvider); - - var ollamaConfiguration = serviceProvider.GetRequiredService>(); - var resilientHttpClient = serviceProvider.GetKeyedService("ResilientHttpClient") ?? throw new NullReferenceException("ResilientHttpClient"); - - return new OllamaChatClient( - ollamaConfiguration.CurrentValue.BaseAddress, - ollamaConfiguration.CurrentValue.DefaultModel, - resilientHttpClient); - }, ServiceLifetime.Transient) - .UseFunctionInvocation(); - - return services; - } - public static IServiceCollection AddDeepSeekChatClient(this IServiceCollection services) { services @@ -222,15 +202,31 @@ public static IServiceCollection AddMistralChatClient(this IServiceCollection se var mistralConfiguration = serviceProvider.GetRequiredService>(); var resilientHttpClient = serviceProvider.GetKeyedService("ResilientHttpClient") ?? throw new NullReferenceException("ResilientHttpClient"); - var openAiClient = new OpenAIClient( - new ApiKeyCredential(mistralConfiguration.CurrentValue.ApiKey), - new OpenAIClientOptions - { - Transport = new HttpClientPipelineTransport(resilientHttpClient), - Endpoint = mistralConfiguration.CurrentValue.BaseAddress - }); + return new MistralClient(mistralConfiguration.CurrentValue.ApiKey, resilientHttpClient) + .Completions + .AsBuilder() + .Build(); + }, ServiceLifetime.Transient) + .UseFunctionInvocation(); + + return services; + } + + public static IServiceCollection AddOllamaChatClient(this IServiceCollection services) + { + services + .AddKeyedChatClient(Provider.Ollama, serviceProvider => + { + var account = ServiceLocator.ServiceProvider.GetRequiredService(); + account.RequireEntitlement(Entitlement.EnableOllamaProvider); + + var ollamaConfiguration = serviceProvider.GetRequiredService>(); + var resilientHttpClient = serviceProvider.GetKeyedService("ResilientHttpClient") ?? throw new NullReferenceException("ResilientHttpClient"); - return openAiClient.GetChatClient(mistralConfiguration.CurrentValue.DefaultModel).AsIChatClient(); + return new OllamaChatClient( + ollamaConfiguration.CurrentValue.BaseAddress, + ollamaConfiguration.CurrentValue.DefaultModel, + resilientHttpClient); }, ServiceLifetime.Transient) .UseFunctionInvocation(); diff --git a/src/Cellm/AddIn/ArgumentParser.cs b/src/Cellm/AddIn/ArgumentParser.cs index c3d275db..06a7ab0b 100644 --- a/src/Cellm/AddIn/ArgumentParser.cs +++ b/src/Cellm/AddIn/ArgumentParser.cs @@ -6,8 +6,6 @@ namespace Cellm.AddIn; -public record Arguments(Provider Provider, string Model, string Context, string Instructions, double Temperature); - public class ArgumentParser { private string? _provider; @@ -231,7 +229,7 @@ private static string RenderInstructions(string instructions) .ToString(); } - private double ParseTemperature(double temperature) + private static double ParseTemperature(double temperature) { if (temperature < 0 || temperature > 1) { diff --git a/src/Cellm/AddIn/Arguments.cs b/src/Cellm/AddIn/Arguments.cs new file mode 100644 index 00000000..550ba373 --- /dev/null +++ b/src/Cellm/AddIn/Arguments.cs @@ -0,0 +1,5 @@ +using Cellm.Models.Providers; + +namespace Cellm.AddIn; + +public record Arguments(Provider Provider, string Model, string Context, string Instructions, double Temperature); diff --git a/src/Cellm/AddIn/ExcelAddin.cs b/src/Cellm/AddIn/ExcelAddin.cs index d0d1b0b8..5eafaabb 100644 --- a/src/Cellm/AddIn/ExcelAddin.cs +++ b/src/Cellm/AddIn/ExcelAddin.cs @@ -9,9 +9,9 @@ public void AutoOpen() { ExcelIntegration.RegisterUnhandledExceptionHandler(obj => { - var ex = (Exception)obj; - SentrySdk.CaptureException(ex); - return ex.Message; + var e = (Exception)obj; + SentrySdk.CaptureException(e); + return e.Message; }); _ = ServiceLocator.ServiceProvider; diff --git a/src/Cellm/AddIn/ExcelFunctions.cs b/src/Cellm/AddIn/ExcelFunctions.cs index 701ddd23..a4ca5ded 100644 --- a/src/Cellm/AddIn/ExcelFunctions.cs +++ b/src/Cellm/AddIn/ExcelFunctions.cs @@ -8,6 +8,7 @@ using ExcelDna.Integration; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; namespace Cellm.AddIn; @@ -74,7 +75,10 @@ public static object PromptWith( { try { - var arguments = ServiceLocator.ServiceProvider.GetRequiredService() + var argumentParser = ServiceLocator.ServiceProvider.GetRequiredService(); + var providerConfiguration = ServiceLocator.ServiceProvider.GetRequiredService>(); + + var arguments = argumentParser .AddProvider(providerAndModel) .AddModel(providerAndModel) .AddInstructionsOrContext(instructionsOrContext) @@ -90,6 +94,7 @@ public static object PromptWith( var prompt = new PromptBuilder() .SetModel(arguments.Model) .SetTemperature(arguments.Temperature) + .SetMaxOutputTokens(providerConfiguration.CurrentValue.MaxOutputTokens) .AddSystemMessage(SystemMessages.SystemMessage) .AddUserMessage(userMessage) .Build(); @@ -101,11 +106,11 @@ public static object PromptWith( }); } - catch (CellmException ex) + catch (CellmException e) { - SentrySdk.CaptureException(ex); - Debug.WriteLine(ex); - return ex.Message; + SentrySdk.CaptureException(e); + Debug.WriteLine(e); + return e.Message; } } diff --git a/src/Cellm/Cellm.csproj b/src/Cellm/Cellm.csproj index c1f5e629..2c06cdde 100644 --- a/src/Cellm/Cellm.csproj +++ b/src/Cellm/Cellm.csproj @@ -40,6 +40,7 @@ + diff --git a/src/Cellm/packages.lock.json b/src/Cellm/packages.lock.json index 91be4a0a..d2a9e5db 100644 --- a/src/Cellm/packages.lock.json +++ b/src/Cellm/packages.lock.json @@ -223,6 +223,16 @@ "Microsoft.Extensions.Primitives": "9.0.4" } }, + "Mistral.SDK": { + "type": "Direct", + "requested": "[2.1.1, )", + "resolved": "2.1.1", + "contentHash": "dBTLqmtfj7C62meCEB9l7VKDtRDDFQgYbx8a5+8uTLtU9bUmw9xYMmyTixHcfOGAQ8LlrXOWNOS6dEFMEgFHhQ==", + "dependencies": { + "Microsoft.Bcl.AsyncInterfaces": "8.0.0", + "Microsoft.Extensions.AI.Abstractions": "9.3.0-preview.1.25161.3" + } + }, "ModelContextProtocol": { "type": "Direct", "requested": "[0.1.0-preview.7, )",