Skip to content

Commit 9ead0c4

Browse files
feat: Add row, column, and table structured outputs (#206)
fix: Revert Google client to OpenAI, Semantic Kernel's Google connector was too alpha
1 parent d97786c commit 9ead0c4

File tree

12 files changed

+298
-85
lines changed

12 files changed

+298
-85
lines changed

src/Cellm/AddIn/CellmAddInConfiguration.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
namespace Cellm.AddIn;
1+
using Cellm.Models.Prompts;
2+
3+
namespace Cellm.AddIn;
24

35
public class CellmAddInConfiguration
46
{
@@ -19,4 +21,6 @@ public class CellmAddInConfiguration
1921
public int CacheTimeoutInSeconds { get; init; } = 3600;
2022

2123
public List<string> Models { get; init; } = [];
24+
25+
public StructuredOutputShape StructuredOutputShape { get; init; } = StructuredOutputShape.Table;
2226
}

src/Cellm/AddIn/CellmFunctions.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ internal static async Task<object> GetResponseAsync(Arguments arguments, Stopwat
161161
.SetModel(arguments.Model)
162162
.SetTemperature(arguments.Temperature)
163163
.SetMaxOutputTokens(cellmAddInConfiguration.CurrentValue.MaxOutputTokens)
164+
.SetOutputShape(cellmAddInConfiguration.CurrentValue.StructuredOutputShape)
164165
.AddSystemMessage(SystemMessages.SystemMessage)
165166
.AddUserMessage(userMessage)
166167
.Build();
@@ -172,8 +173,16 @@ internal static async Task<object> GetResponseAsync(Arguments arguments, Stopwat
172173
var response = await client.GetResponseAsync(prompt, arguments.Provider, cancellationToken).ConfigureAwait(false);
173174
var assistantMessage = response.Messages.LastOrDefault()?.Text ?? throw new InvalidOperationException("No text response");
174175

176+
// Check for cancellation before returning response
177+
cancellationToken.ThrowIfCancellationRequested();
178+
175179
logger.LogInformation("Sending prompt to {}/{} ({}) ... Done (elapsed time: {}ms, request time: {}ms)", arguments.Provider, arguments.Model, callerCoordinates, wallClock.ElapsedMilliseconds, requestClock.ElapsedMilliseconds);
176180

181+
if (StructuredOutput.TryParse(assistantMessage, response.OutputShape, out var array2d) && array2d is not null)
182+
{
183+
return array2d;
184+
}
185+
177186
return assistantMessage;
178187
}
179188
// Short-circuit if any cells were found to be #GETTING_DATA or contain other errors during cell parsing.

src/Cellm/AddIn/UserInterface/Ribbon/RibbonModelGroup.cs

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System.Text;
22
using Cellm.AddIn.UserInterface.Forms;
3+
using Cellm.Models.Prompts;
34
using Cellm.Models.Providers;
45
using Cellm.Models.Providers.Anthropic;
56
using Cellm.Models.Providers.Aws;
@@ -25,6 +26,7 @@ public partial class RibbonMain
2526
private enum ModelGroupControlIds
2627
{
2728
VerticalContainer,
29+
HorizontalContainer,
2830

2931
ModelProviderGroup,
3032
ProviderModelBox,
@@ -34,6 +36,12 @@ private enum ModelGroupControlIds
3436

3537
ModelComboBox,
3638
TemperatureComboBox,
39+
40+
OutputCell,
41+
OutputRow,
42+
OutputTable,
43+
OutputColumn,
44+
3745
CacheToggleButton,
3846

3947
ProviderSettingsButton
@@ -155,11 +163,41 @@ public string ModelGroup()
155163
getItemLabel="{nameof(GetTemperatureItemLabel)}"
156164
/>
157165
</box>
166+
<box id="{nameof(ModelGroupControlIds.HorizontalContainer)}" boxStyle="horizontal">
167+
<buttonGroup id="SelectionButtonGroup">
168+
<toggleButton id="{nameof(ModelGroupControlIds.OutputCell)}"
169+
imageMso="TableSelectCell"
170+
getPressed="{nameof(GetOutputCellPressed)}"
171+
onAction="{nameof(OnOutputCellClicked)}"
172+
screentip="Output response in a single cell (default)" />
173+
<toggleButton id="{nameof(ModelGroupControlIds.OutputRow)}"
174+
imageMso="TableRowSelect"
175+
getPressed="{nameof(GetOutputRowPressed)}"
176+
onAction="{nameof(OnOutputRowClicked)}"
177+
screentip="Respond with row"
178+
supertip="Spill multiple response values (if any) across cells to the right." />
179+
<toggleButton id="{nameof(ModelGroupControlIds.OutputTable)}"
180+
imageMso="TableSelect"
181+
getPressed="{nameof(GetOutputTablePressed)}"
182+
onAction="{nameof(OnOutputTableClicked)}"
183+
screentip="Respond with table"
184+
supertip="Let model decide how to output multiple values (as single cell, row, column, or table, just tell it what you want)" />
185+
<toggleButton id="{nameof(ModelGroupControlIds.OutputColumn)}"
186+
imageMso="TableColumnSelect"
187+
getPressed="{nameof(GetOutputColumnPressed)}"
188+
onAction="{nameof(OnOutputColumnClicked)}"
189+
screentip="Respond with column"
190+
supertip="Spill multiple response values (if any) across cells below" />
191+
</buttonGroup>
192+
</box>
158193
</box>
159194
<separator id="cacheSeparator" />
160-
<toggleButton id="{nameof(ModelGroupControlIds.CacheToggleButton)}" label="Cache" size="large" imageMso="SourceControlRefreshStatus"
195+
<toggleButton id="{nameof(ModelGroupControlIds.CacheToggleButton)}"
196+
label="Memory On" size="large"
197+
imageMso="SourceControlRefreshStatus"
161198
screentip="Enable/disable local caching of model responses. Enabled: Return cached responses for identical prompts. Disabled: Always request new responses. Disabling cache will clear entries."
162-
onAction="{nameof(OnCacheToggled)}" getPressed="{nameof(GetCachePressed)}" />
199+
onAction="{nameof(OnCacheToggled)}"
200+
getPressed="{nameof(GetCachePressed)}" />
163201
</group>
164202
""";
165203
}
@@ -809,4 +847,81 @@ public string GetTemperatureItemLabel(IRibbonControl control, int index)
809847
_logger.LogWarning("Invalid index {index} requested for GetTemperatureItemLabel", index);
810848
return string.Empty;
811849
}
850+
851+
public void OnOutputCellClicked(IRibbonControl control, bool isPressed)
852+
{
853+
// Default, cannot toggle off via this button
854+
SetValue($"{nameof(CellmAddInConfiguration)}:{nameof(StructuredOutputShape)}", StructuredOutputShape.None.ToString());
855+
InvalidateOutputToggleButtons();
856+
}
857+
858+
public void OnOutputRowClicked(IRibbonControl control, bool isPressed)
859+
{
860+
if (isPressed)
861+
{
862+
SetValue($"{nameof(CellmAddInConfiguration)}:{nameof(StructuredOutputShape)}", StructuredOutputShape.Row.ToString());
863+
InvalidateOutputToggleButtons();
864+
}
865+
else
866+
{
867+
SetValue($"{nameof(CellmAddInConfiguration)}:{nameof(StructuredOutputShape)}", StructuredOutputShape.None.ToString());
868+
InvalidateOutputToggleButtons();
869+
}
870+
}
871+
872+
public void OnOutputTableClicked(IRibbonControl control, bool isPressed)
873+
{
874+
if (isPressed)
875+
{
876+
SetValue($"{nameof(CellmAddInConfiguration)}:{nameof(StructuredOutputShape)}", StructuredOutputShape.Table.ToString());
877+
InvalidateOutputToggleButtons();
878+
}
879+
else
880+
{
881+
SetValue($"{nameof(CellmAddInConfiguration)}:{nameof(StructuredOutputShape)}", StructuredOutputShape.None.ToString());
882+
InvalidateOutputToggleButtons();
883+
}
884+
}
885+
886+
public void OnOutputColumnClicked(IRibbonControl control, bool isPressed)
887+
{
888+
if (isPressed)
889+
{
890+
SetValue($"{nameof(CellmAddInConfiguration)}:{nameof(StructuredOutputShape)}", StructuredOutputShape.Column.ToString());
891+
InvalidateOutputToggleButtons();
892+
}
893+
else
894+
{
895+
SetValue($"{nameof(CellmAddInConfiguration)}:{nameof(StructuredOutputShape)}", StructuredOutputShape.None.ToString());
896+
InvalidateOutputToggleButtons();
897+
}
898+
}
899+
900+
public bool GetOutputCellPressed(IRibbonControl control)
901+
{
902+
return GetValue($"{nameof(CellmAddInConfiguration)}:{nameof(CellmAddInConfiguration.StructuredOutputShape)}") == StructuredOutputShape.None.ToString();
903+
}
904+
905+
public bool GetOutputRowPressed(IRibbonControl control)
906+
{
907+
return GetValue($"{nameof(CellmAddInConfiguration)}:{nameof(CellmAddInConfiguration.StructuredOutputShape)}") == StructuredOutputShape.Row.ToString();
908+
}
909+
910+
public bool GetOutputTablePressed(IRibbonControl control)
911+
{
912+
return GetValue($"{nameof(CellmAddInConfiguration)}:{nameof(CellmAddInConfiguration.StructuredOutputShape)}") == StructuredOutputShape.Table.ToString();
913+
}
914+
915+
public bool GetOutputColumnPressed(IRibbonControl control)
916+
{
917+
return GetValue($"{nameof(CellmAddInConfiguration)}:{nameof(CellmAddInConfiguration.StructuredOutputShape)}") == StructuredOutputShape.Column.ToString();
918+
}
919+
920+
private void InvalidateOutputToggleButtons()
921+
{
922+
_ribbonUi?.InvalidateControl(nameof(ModelGroupControlIds.OutputCell));
923+
_ribbonUi?.InvalidateControl(nameof(ModelGroupControlIds.OutputRow));
924+
_ribbonUi?.InvalidateControl(nameof(ModelGroupControlIds.OutputTable));
925+
_ribbonUi?.InvalidateControl(nameof(ModelGroupControlIds.OutputColumn));
926+
}
812927
}

src/Cellm/Cellm.csproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
<PackageReference Include="Microsoft.Extensions.Logging.Debug" Version="9.0.5" />
5757
<PackageReference Include="Microsoft.Extensions.Options" Version="9.0.5" />
5858
<PackageReference Include="Microsoft.Extensions.Options.ConfigurationExtensions" Version="9.0.5" />
59-
<PackageReference Include="Microsoft.SemanticKernel.Connectors.Google" Version="1.56.0-alpha" />
6059
<PackageReference Include="Mistral.SDK" Version="2.2.0" />
6160
<PackageReference Include="ModelContextProtocol" Version="0.2.0-preview.3" />
6261
<PackageReference Include="OllamaSharp" Version="5.2.2" />

src/Cellm/Models/Prompts/Prompt.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
namespace Cellm.Models.Prompts;
44

5-
public record Prompt(IList<ChatMessage> Messages, ChatOptions Options);
5+
internal record Prompt(IList<ChatMessage> Messages, ChatOptions Options, StructuredOutputShape OutputShape);

src/Cellm/Models/Prompts/PromptBuilder.cs

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,78 +2,86 @@
22

33
namespace Cellm.Models.Prompts;
44

5-
public class PromptBuilder
5+
internal class PromptBuilder
66
{
77
private readonly List<ChatMessage> _messages = [];
88
private readonly ChatOptions _options = new();
9+
private StructuredOutputShape _outputShape = StructuredOutputShape.None;
910

10-
public PromptBuilder()
11+
internal PromptBuilder()
1112
{
1213
}
1314

14-
public PromptBuilder(Prompt prompt)
15+
internal PromptBuilder(Prompt prompt)
1516
{
1617
// Do not mutate prompt
1718
_messages = new List<ChatMessage>(prompt.Messages);
1819
_options = prompt.Options.Clone();
20+
_outputShape = prompt.OutputShape;
1921
}
2022

21-
public PromptBuilder SetModel(string model)
23+
internal PromptBuilder SetModel(string model)
2224
{
2325
_options.ModelId = model;
2426
return this;
2527
}
2628

27-
public PromptBuilder SetTemperature(double temperature)
29+
internal PromptBuilder SetTemperature(double temperature)
2830
{
2931
_options.Temperature = (float)temperature;
3032
return this;
3133
}
3234

33-
public PromptBuilder SetMaxOutputTokens(int maxOutputTokens)
35+
internal PromptBuilder SetMaxOutputTokens(int maxOutputTokens)
3436
{
3537
_options.MaxOutputTokens = maxOutputTokens;
3638
return this;
3739
}
3840

39-
public PromptBuilder AddSystemMessage(string content)
41+
internal PromptBuilder SetOutputShape(StructuredOutputShape outputShape)
42+
{
43+
_outputShape = outputShape;
44+
return this;
45+
}
46+
47+
internal PromptBuilder AddSystemMessage(string content)
4048
{
4149
_messages.Add(new ChatMessage(ChatRole.System, content));
4250
return this;
4351
}
4452

45-
public PromptBuilder AddUserMessage(string content)
53+
internal PromptBuilder AddUserMessage(string content)
4654
{
4755
_messages.Add(new ChatMessage(ChatRole.User, content));
4856
return this;
4957
}
5058

51-
public PromptBuilder AddAssistantMessage(string content)
59+
internal PromptBuilder AddAssistantMessage(string content)
5260
{
5361
_messages.Add(new ChatMessage(ChatRole.User, content));
5462
return this;
5563
}
5664

57-
public PromptBuilder AddMessage(ChatMessage message)
65+
internal PromptBuilder AddMessage(ChatMessage message)
5866
{
5967
_messages.Add(message);
6068
return this;
6169
}
6270

63-
public PromptBuilder AddMessages(IList<ChatMessage> messages)
71+
internal PromptBuilder AddMessages(IList<ChatMessage> messages)
6472
{
6573
_messages.AddRange(messages);
6674
return this;
6775
}
6876

69-
public PromptBuilder SetTools(IList<AITool> tools)
77+
internal PromptBuilder SetTools(IList<AITool> tools)
7078
{
7179
_options.Tools = tools;
7280
return this;
7381
}
7482

75-
public Prompt Build()
83+
internal Prompt Build()
7684
{
77-
return new Prompt(_messages, _options);
85+
return new Prompt(_messages, _options, _outputShape);
7886
}
7987
}

0 commit comments

Comments
 (0)