Skip to content

Commit

Permalink
Update PromptService, ConfigViewModels, and SemanticFunctionViewModel…
Browse files Browse the repository at this point in the history
… for DashScope integration and configuration improvements (#25)

* stage: more llm config

* feat: add openai llm

* add DashScope
  • Loading branch information
xbotter authored Sep 1, 2023
1 parent 0c6fd07 commit 9ca7ca6
Show file tree
Hide file tree
Showing 14 changed files with 270 additions and 60 deletions.
31 changes: 31 additions & 0 deletions PromptPlayground/Converters/StringToBooleanConverter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Avalonia.Data.Converters;
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace PromptPlayground.Converters
{
public class StringToBooleanConverter : IValueConverter
{
public object Convert(object value, Type targetType, object parameter, CultureInfo culture)
{
if (value == null || parameter == null || !(value is string) || !(parameter is string))
{
return false;
}

var strValue = (string)value;
var strParameter = (string)parameter;

return strValue.Equals(strParameter, StringComparison.OrdinalIgnoreCase);
}

public object ConvertBack(object value, Type targetType, object parameter, CultureInfo culture)
{
throw new NotSupportedException();
}
}
}
8 changes: 4 additions & 4 deletions PromptPlayground/PromptPlayground.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@
<PackageReference Include="Avalonia.Xaml.Interactions" Version="11.0.2" />
<PackageReference Include="Bogus" Version="34.0.2" />
<PackageReference Include="CommunityToolkit.Mvvm" Version="8.2.0" />
<PackageReference Include="DashScope.SemanticKernel" Version="0.1.0-preview" />
<PackageReference Include="Humanizer" Version="2.14.1" />
<PackageReference Include="Microsoft.SemanticKernel.Connectors.Memory.Qdrant" Version="0.18.230725.3-preview" />
<PackageReference Include="Moq" Version="4.18.4" />
<PackageReference Include="Microsoft.SemanticKernel.Connectors.Memory.Qdrant" Version="0.21.230828.2-preview" />
<PackageReference Include="Projektanker.Icons.Avalonia.MaterialDesign" Version="8.1.0" />
<PackageReference Include="Semi.Avalonia" Version="$(AvaloniaVersion)" />

<!--Condition below is needed to remove Avalonia.Diagnostics package from build output in Release configuration.-->
<PackageReference Condition="'$(Configuration)' == 'Debug'" Include="Avalonia.Diagnostics" Version="$(AvaloniaVersion)" />
<PackageReference Include="ERNIE-Bot.SemanticKernel" Version="0.3.2-preview" />
<PackageReference Include="ERNIE-Bot.SemanticKernel" Version="0.5.1-preview" />
<PackageReference Include="MessageBox.Avalonia" Version="3.0.0" />
<PackageReference Include="Microsoft.SemanticKernel" Version="0.18.230725.3-preview" />
<PackageReference Include="Microsoft.SemanticKernel" Version="0.21.230828.2-preview" />
</ItemGroup>

<ItemGroup>
Expand Down
10 changes: 5 additions & 5 deletions PromptPlayground/Services/PromptService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ public async Task<GenerateResult> RunAsync(string prompt, PromptTemplateConfig c

if (provider.GetVectorDb() is IVectorDbConfigViewModel vectorDb)
{
context[TextMemorySkill.CollectionParam] = vectorDb.Collection;
context[TextMemorySkill.LimitParam] = vectorDb.Limit.ToString();
context[TextMemorySkill.RelevanceParam] = vectorDb.Relevance.ToString();
context.Variables[TextMemorySkill.CollectionParam] = vectorDb.Collection;
context.Variables[TextMemorySkill.LimitParam] = vectorDb.Limit.ToString();
context.Variables[TextMemorySkill.RelevanceParam] = vectorDb.Relevance.ToString();
}

var result = await func.InvokeAsync(context);
Expand All @@ -71,7 +71,7 @@ public async Task<GenerateResult> RunAsync(string prompt, PromptTemplateConfig c
{
Text = result.Result,
Elapsed = sw.Elapsed,
Error = result.LastErrorDescription,
Error = result.LastException?.Message,
TokenUsage = usage
};
}
Expand All @@ -81,7 +81,7 @@ public async Task<GenerateResult> RunAsync(string prompt, PromptTemplateConfig c
{
Text = result.Result,
Elapsed = sw.Elapsed,
Error = result.LastErrorDescription,
Error = result.LastException!.Message,
};
}
}
Expand Down
5 changes: 2 additions & 3 deletions PromptPlayground/Services/TemplateEngine/Blocks/CodeBlock.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.SkillDefinition;
using PromptPlayground.Services.TemplateEngine;
Expand Down Expand Up @@ -127,9 +128,7 @@ private async Task<string> RenderFunctionCallAsync(FunctionIdBlock fBlock, SKCon
{
if (context.Skills == null)
{
throw new KernelException(
KernelException.ErrorCodes.SkillCollectionNotSet,
"Skill collection not found in the context");
throw new SKException("Skill collection not found in the context");
}

if (!GetFunctionFromSkillCollection(context.Skills!, fBlock, out ISKFunction? function))
Expand Down
10 changes: 8 additions & 2 deletions PromptPlayground/ViewModels/ConfigViewModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@ public partial class ConfigViewModel : ViewModelBase, IConfigAttributesProvider,
private string[] RequiredAttributes = new string[]
{
#region LLM Config
ConfigAttribute.AzureDeployment,
ConfigAttribute.AzureDeployment,
ConfigAttribute.AzureEndpoint,
ConfigAttribute.AzureSecret,
ConfigAttribute.BaiduClientId,
ConfigAttribute.BaiduSecret,
ConfigAttribute.BaiduModel,
ConfigAttribute.OpenAIApiKey,
ConfigAttribute.OpenAIModel,
ConfigAttribute.DashScopeApiKey,
ConfigAttribute.DashScopeModel,
#endregion
#region Vector DB
ConfigAttribute.VectorSize,
Expand Down Expand Up @@ -164,8 +169,9 @@ public ConfigViewModel()
this.AllAttributes = CheckAttributes(this.AllAttributes);

LLMs.Add(new AzureOpenAIConfigViewModel(this));
LLMs.Add(new BaiduTurboConfigViewModel(this));
LLMs.Add(new BaiduConfigViewModel(this));
LLMs.Add(new OpenAIConfigViewModel(this));
LLMs.Add(new DashScopeConfigViewModel(this));

Embeddings.Add(new AzureOpenAIEmbeddingConfigViewModel(this));

Expand Down
98 changes: 98 additions & 0 deletions PromptPlayground/ViewModels/ConfigViewModels/ConfigAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
using CommunityToolkit.Mvvm.ComponentModel;
using DashScope;
using Humanizer;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text.Json.Serialization;

namespace PromptPlayground.ViewModels.ConfigViewModels
{
public class ConfigAttribute : ObservableObject
{
static ConfigAttribute()
{
var consts = typeof(ConfigAttribute).GetFields(BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy).Where(fi => fi.IsLiteral && !fi.IsInitOnly).ToList();
var ct = typeof(ConfigTypeAttribute);
_fields = consts.Select(c =>
{
var value = c.GetValue(null) as string;

if (c.IsDefined(ct))
{
return (value, c.GetCustomAttribute(ct) as ConfigTypeAttribute);
}
else
{
return (value, new ConfigTypeAttribute());
}
}).ToDictionary(_ => _.Item1!, _ => _.Item2!);
}
static readonly Dictionary<string, ConfigTypeAttribute> _fields;
public ConfigAttribute(string name)
{
Name = name;
this.Type = _fields[name].Type;
this.SelectValues = _fields[name].SelectValues?.ToList() ?? new List<string>();
}
private string _value = string.Empty;
[JsonIgnore]
public string HumanizeName => Name.Humanize();
public string Type { get; set; } = "string";
public List<string> SelectValues { get; set; }
public string Name { get; set; } = string.Empty;
public string Value { get => _value; set => SetProperty(ref _value, value, nameof(Value)); }

#region Constants
public const string AzureDeployment = nameof(AzureDeployment);
public const string AzureEndpoint = nameof(AzureEndpoint);
public const string AzureSecret = nameof(AzureSecret);
public const string AzureEmbeddingDeployment = nameof(AzureEmbeddingDeployment);

public const string BaiduClientId = nameof(BaiduClientId);
public const string BaiduSecret = nameof(BaiduSecret);
[ConfigType("select", "Ernie-Bot", "Ernie-Bot-turbo", "BLOOMZ_7B")]
public const string BaiduModel = nameof(BaiduModel);

public const string OpenAIApiKey = nameof(OpenAIApiKey);
public const string OpenAIModel = nameof(OpenAIModel);

public const string DashScopeApiKey = nameof(DashScopeApiKey);
[ConfigType("select", DashScopeModels.QWenV1, DashScopeModels.QWenPlusV1)]
public const string DashScopeModel = nameof(DashScopeModel);

public const string QdrantEndpoint = nameof(QdrantEndpoint);
public const string QdrantApiKey = nameof(QdrantApiKey);

public const string VectorSize = nameof(VectorSize);
#endregion
}

[AttributeUsage(AttributeTargets.Field, AllowMultiple = false)]
public class ConfigTypeAttribute : Attribute
{
/// <summary>
/// `string` or `select`, use select must set SelectValues
/// </summary>
/// <param name="type"></param>
public ConfigTypeAttribute(string type = "string", params string[] selectValues)
{
Type = type;
if (type == "select")
{
if (selectValues == null || selectValues.Length == 0)
{
throw new ArgumentException("selectValues must not be null or empty when type is select");
}

SelectValues = selectValues;
}
}

public string Type { get; }
public string[]? SelectValues { get; }
}


}
36 changes: 1 addition & 35 deletions PromptPlayground/ViewModels/ConfigViewModels/IConfigViewModel.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using CommunityToolkit.Mvvm.ComponentModel;
using Humanizer;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Orchestration;
using PromptPlayground.Services;
using PromptPlayground.ViewModels.ConfigViewModels.Embedding;
Expand All @@ -10,7 +8,6 @@
using System.Collections.ObjectModel;
using System.ComponentModel.DataAnnotations;
using System.Linq;
using System.Text.Json.Serialization;

namespace PromptPlayground.ViewModels.ConfigViewModels
{
Expand Down Expand Up @@ -57,35 +54,4 @@ public interface IConfigAttributesProvider
IVectorDbConfigViewModel? GetVectorDb();
IEmbeddingConfigViewModel? GetEmbedding();
}

public class ConfigAttribute : ObservableObject
{
public ConfigAttribute(string name)
{
Name = name;
}
private string _value = string.Empty;
[JsonIgnore]
public string HumanizeName => Name.Humanize();
public string Type { get; set; } = "string";
public string Name { get; set; } = string.Empty;
public string Value { get => _value; set => SetProperty(ref _value, value, nameof(Value)); }

#region Constants
public const string AzureDeployment = nameof(AzureDeployment);
public const string AzureEndpoint = nameof(AzureEndpoint);
public const string AzureSecret = nameof(AzureSecret);
public const string AzureEmbeddingDeployment = nameof(AzureEmbeddingDeployment);

public const string BaiduClientId = nameof(BaiduClientId);
public const string BaiduSecret = nameof(BaiduSecret);

public const string QdrantEndpoint = nameof(QdrantEndpoint);
public const string QdrantApiKey = nameof(QdrantApiKey);

public const string VectorSize = nameof(VectorSize);
#endregion
}


}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ERNIE_Bot.SDK.Models;
using ERNIE_Bot.SDK;
using ERNIE_Bot.SDK.Models;
using Microsoft;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Orchestration;
Expand All @@ -15,22 +16,27 @@ public class BaiduConfigViewModel : ConfigViewModelBase, ILLMConfigViewModel

public BaiduConfigViewModel(IConfigAttributesProvider provider) : base(provider)
{
RequireAttribute(ConfigAttribute.BaiduModel);
RequireAttribute(ClientId);
RequireAttribute(Secret);
}
public KernelBuilder CreateKernelBuilder()
{
Requires.NotNullOrWhiteSpace(ClientId, nameof(ClientId));
Requires.NotNullOrWhiteSpace(ClientId, nameof(ClientId));

return Kernel.Builder
.WithERNIEBotChatCompletionService(GetAttribute(ClientId), GetAttribute(Secret))
;
.WithERNIEBotChatCompletionService(GetAttribute(ClientId), GetAttribute(Secret), modelEndpoint: ModelEndpoint);
}
public ResultTokenUsage? GetUsage(ModelResult result)
{
var completions = result.GetResult<ChatResponse>();
return new ResultTokenUsage(completions.Usage.TotalTokens, completions.Usage.PromptTokens, completions.Usage.CompletionTokens);
}

private string ModelEndpoint =>
GetAttribute(ConfigAttribute.BaiduModel) switch
{
"BLOOMZ_7B" => ModelEndpoints.BLOOMZ_7B,
"Ernie-Bot-turbo" => ModelEndpoints.ERNIE_Bot_Turbo,
_ => ModelEndpoints.ERNIE_Bot
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Orchestration;
using PromptPlayground.Services;
using System;

namespace PromptPlayground.ViewModels.ConfigViewModels.LLM
{
[Obsolete("Use BaiduConfigViewModel with model")]
public class BaiduTurboConfigViewModel : ConfigViewModelBase, ILLMConfigViewModel
{
const string ClientId = ConfigAttribute.BaiduClientId;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using Azure.AI.OpenAI;
using DashScope;
using DashScope.Models;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Orchestration;
using PromptPlayground.Services;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace PromptPlayground.ViewModels.ConfigViewModels.LLM
{
internal class DashScopeConfigViewModel : ConfigViewModelBase, ILLMConfigViewModel
{
public override string Name => "DashScope";

public DashScopeConfigViewModel(IConfigAttributesProvider provider) : base(provider)
{
RequireAttribute(ConfigAttribute.DashScopeApiKey);
RequireAttribute(ConfigAttribute.DashScopeModel);
}

public KernelBuilder CreateKernelBuilder()
{
var apiKey = GetAttribute(ConfigAttribute.DashScopeApiKey);
var model = GetAttribute(ConfigAttribute.DashScopeModel);

return Kernel.Builder.WithDashScopeCompletionService(apiKey, model);
}

public ResultTokenUsage? GetUsage(ModelResult resultModel)
{
var usage = resultModel.GetResult<CompletionResponse>().Usage;
return new ResultTokenUsage(usage.InputTokens + usage.OutputTokens, usage.InputTokens, usage.OutputTokens);
}
}
}
Loading

0 comments on commit 9ca7ca6

Please sign in to comment.