diff --git a/AIDevGallery.UnitTests/ProjectGeneratorUnitTests.cs b/AIDevGallery.UnitTests/ProjectGeneratorUnitTests.cs index fb0cf3e..b833d15 100644 --- a/AIDevGallery.UnitTests/ProjectGeneratorUnitTests.cs +++ b/AIDevGallery.UnitTests/ProjectGeneratorUnitTests.cs @@ -65,7 +65,7 @@ private class SampleUIData : INotifyPropertyChanged public required string SampleName { get; init; } public required Sample Sample { get; init; } - public required Dictionary CachedModelsToGenerator { get; init; } + public required Dictionary CachedModelsToGenerator { get; init; } public Brush? StatusColor { get => statusColor; @@ -177,17 +177,17 @@ private static IEnumerable GetAllForSample(Sample s) }; } - static Dictionary GetModelsToGenerator(Sample s, List>> modelsDetails, KeyValuePair> keyValuePair) + static Dictionary GetModelsToGenerator(Sample s, List>> modelsDetails, KeyValuePair> keyValuePair) { - Dictionary cachedModelsToGenerator = new(); + Dictionary cachedModelsToGenerator = new(); ModelDetails modelDetails1 = keyValuePair.Value.First(); - cachedModelsToGenerator[keyValuePair.Key] = (modelDetails1.Url, modelDetails1.Url, modelDetails1.HardwareAccelerators.First()); + cachedModelsToGenerator[keyValuePair.Key] = new(modelDetails1.Id, modelDetails1.Url, modelDetails1.Url, 0, modelDetails1.HardwareAccelerators.First()); if (s.Model2Types != null && modelsDetails.Count > 1) { ModelDetails modelDetails2 = modelsDetails[1].Values.First().First(); - cachedModelsToGenerator[s.Model2Types.First()] = (modelDetails2.Url, modelDetails2.Url, modelDetails2.HardwareAccelerators.First()); + cachedModelsToGenerator[s.Model2Types.First()] = new(modelDetails2.Id, modelDetails2.Url, modelDetails2.Url, 0, modelDetails2.HardwareAccelerators.First()); } return cachedModelsToGenerator; diff --git a/AIDevGallery/Controls/SampleContainer.xaml b/AIDevGallery/Controls/SampleContainer.xaml index 81baa83..86b3419 100644 --- a/AIDevGallery/Controls/SampleContainer.xaml +++ b/AIDevGallery/Controls/SampleContainer.xaml @@ -201,7 +201,7 @@ Text="Required NuGet packages" /> - + SetValue(DisclaimerHorizontalAlignmentProperty, value); } + public List NugetPackageReferences + { + get { return (List)GetValue(NugetPackageReferencesProperty); } + set { SetValue(NugetPackageReferencesProperty, value); } + } + + public static readonly DependencyProperty NugetPackageReferencesProperty = + DependencyProperty.Register("NugetPackageReferences", typeof(List), typeof(SampleContainer), new PropertyMetadata(null)); + private Sample? _sampleCache; + private Dictionary? _cachedModels; private List? _modelsCache; private CancellationTokenSource? _sampleLoadingCts; private TaskCompletionSource? _sampleLoadedCompletionSource; @@ -272,7 +283,17 @@ private bool LoadSampleMetadata(Sample sample, List? models) } _sampleCache = sample; - _modelsCache = models; + + if (models != null) + { + _cachedModels = sample.GetCacheModelDetailsDictionary(models.ToArray()); + + if (_cachedModels != null) + { + NugetPackageReferences = sample.GetAllNugetPackageReferences(_cachedModels); + _modelsCache = models; + } + } if (sample == null) { @@ -308,9 +329,9 @@ private void RenderCode(bool force = false) CodePivot.Items.Add(CreateCodeBlock(codeFormatter, "Sample.xaml", _sampleCache.XAMLCode, Languages.FindById("xaml"))); } - if (_sampleCache.SharedCode != null && _sampleCache.SharedCode.Count != 0) + if (_cachedModels != null) { - foreach (var sharedCodeEnum in _sampleCache.SharedCode) + foreach (var sharedCodeEnum in _sampleCache.GetAllSharedCode(_cachedModels)) { string sharedCodeName = Samples.SharedCodeHelpers.GetName(sharedCodeEnum); string sharedCodeContent = Samples.SharedCodeHelpers.GetSource(sharedCodeEnum); diff --git a/AIDevGallery/Helpers/SamplesHelper.cs b/AIDevGallery/Helpers/SamplesHelper.cs new file mode 100644 index 0000000..c045a9f --- /dev/null +++ b/AIDevGallery/Helpers/SamplesHelper.cs @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using AIDevGallery.Models; +using AIDevGallery.Samples; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace AIDevGallery.Helpers; + +internal static class SamplesHelper +{ + public static List GetAllSharedCode(this Sample sample, Dictionary models) + { + var sharedCode = sample.SharedCode.ToList(); + + bool isLanguageModel = ModelDetailsHelper.EqualOrParent(models.Keys.First(), ModelType.LanguageModels); + + if (isLanguageModel) + { + AddUnique(SharedCodeEnum.GenAIModel); + } + + if (sharedCode.Contains(SharedCodeEnum.GenAIModel)) + { + AddUnique(SharedCodeEnum.LlmPromptTemplate); + } + + if (models.Any(m => ModelDetailsHelper.EqualOrParent(m.Key, ModelType.PhiSilica))) + { + AddUnique(SharedCodeEnum.PhiSilicaClient); + } + + if (sharedCode.Contains(SharedCodeEnum.DeviceUtils)) + { + AddUnique(SharedCodeEnum.NativeMethods); + } + + return sharedCode; + + void AddUnique(SharedCodeEnum sharedCodeEnumToAdd) + { + if (!sharedCode.Contains(sharedCodeEnumToAdd)) + { + sharedCode.Add(sharedCodeEnumToAdd); + } + } + } + + public static List GetAllNugetPackageReferences(this Sample sample, Dictionary models) + { + var packageReferences = sample.NugetPackageReferences.ToList(); + + var modelTypes = sample.Model1Types.Concat(sample.Model2Types ?? Enumerable.Empty()) + .Where(models.ContainsKey); + + bool isLanguageModel = modelTypes.Any(modelType => ModelDetailsHelper.EqualOrParent(modelType, ModelType.LanguageModels)); + + if (isLanguageModel) + { + AddUnique("Microsoft.ML.OnnxRuntimeGenAI.DirectML"); + } + + var sharedCode = sample.GetAllSharedCode(models); + + if (sharedCode.Contains(SharedCodeEnum.NativeMethods)) + { + AddUnique("Microsoft.Windows.CsWin32"); + } + + return packageReferences; + + void AddUnique(string packageNameToAdd) + { + if (!packageReferences.Any(packageName => packageName == packageNameToAdd)) + { + packageReferences.Add(packageNameToAdd); + } + } + } + + public static Dictionary? GetCacheModelDetailsDictionary(this Sample sample, ModelDetails?[] modelDetails) + { + if (modelDetails.Length == 0 || modelDetails.Length > 2) + { + throw new ArgumentException(modelDetails.Length == 0 ? "No model details provided" : "More than 2 model details provided"); + } + + var selectedModelDetails = modelDetails[0]; + var selectedModelDetails2 = modelDetails.Length > 1 ? modelDetails[1] : null; + + if (selectedModelDetails == null) + { + return null; + } + + Dictionary cachedModels = []; + + ExpandedModelDetails cachedModel; + + if (selectedModelDetails.Size == 0) + { + cachedModel = new(selectedModelDetails.Id, selectedModelDetails.Url, selectedModelDetails.Url, 0, selectedModelDetails.HardwareAccelerators.FirstOrDefault()); + } + else + { + var realCachedModel = App.ModelCache.GetCachedModel(selectedModelDetails.Url); + if (realCachedModel == null) + { + return null; + } + + cachedModel = new(selectedModelDetails.Id, realCachedModel.Path, realCachedModel.Url, realCachedModel.ModelSize, selectedModelDetails.HardwareAccelerators.FirstOrDefault()); + } + + var cachedSampleItem = App.FindSampleItemById(cachedModel.Id); + + var model1Type = sample.Model1Types.Any(cachedSampleItem.Contains) + ? sample.Model1Types.First(cachedSampleItem.Contains) + : sample.Model1Types.First(); + cachedModels.Add(model1Type, cachedModel); + + if (sample.Model2Types != null) + { + if (selectedModelDetails2 == null) + { + return null; + } + + if (selectedModelDetails2.Size == 0) + { + cachedModel = new(selectedModelDetails2.Id, selectedModelDetails2.Url, selectedModelDetails2.Url, 0, selectedModelDetails2.HardwareAccelerators.FirstOrDefault()); + } + else + { + var realCachedModel = App.ModelCache.GetCachedModel(selectedModelDetails2.Url); + if (realCachedModel == null) + { + return null; + } + + cachedModel = new(selectedModelDetails2.Id, realCachedModel.Path, realCachedModel.Url, realCachedModel.ModelSize, selectedModelDetails2.HardwareAccelerators.FirstOrDefault()); + } + + var model2Type = sample.Model2Types.Any(cachedSampleItem.Contains) + ? sample.Model2Types.First(cachedSampleItem.Contains) + : sample.Model2Types.First(); + + cachedModels.Add(model2Type, cachedModel); + } + + return cachedModels; + } +} \ No newline at end of file diff --git a/AIDevGallery/Models/ExpandedModelDetails.cs b/AIDevGallery/Models/ExpandedModelDetails.cs new file mode 100644 index 0000000..8c07930 --- /dev/null +++ b/AIDevGallery/Models/ExpandedModelDetails.cs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace AIDevGallery.Models; + +internal record class ExpandedModelDetails(string Id, string Path, string Url, long ModelSize, HardwareAccelerator HardwareAccelerator); \ No newline at end of file diff --git a/AIDevGallery/Pages/Scenarios/ScenarioPage.xaml.cs b/AIDevGallery/Pages/Scenarios/ScenarioPage.xaml.cs index f9d878f..34fe598 100644 --- a/AIDevGallery/Pages/Scenarios/ScenarioPage.xaml.cs +++ b/AIDevGallery/Pages/Scenarios/ScenarioPage.xaml.cs @@ -248,59 +248,11 @@ private async void ExportSampleToggle_Click(object sender, RoutedEventArgs e) return; } - Dictionary cachedModels = []; + var cachedModels = sample.GetCacheModelDetailsDictionary([selectedModelDetails, selectedModelDetails2]); - (string Id, string Path, string Url, long ModelSize, HardwareAccelerator HardwareAccelerator) cachedModel; - - if (selectedModelDetails.Size == 0) - { - cachedModel = (selectedModelDetails.Id, selectedModelDetails.Url, selectedModelDetails.Url, 0, selectedModelDetails.HardwareAccelerators.FirstOrDefault()); - } - else - { - var realCachedModel = App.ModelCache.GetCachedModel(selectedModelDetails.Url); - if (realCachedModel == null) - { - return; - } - - cachedModel = (selectedModelDetails.Id, realCachedModel.Path, realCachedModel.Url, realCachedModel.ModelSize, selectedModelDetails.HardwareAccelerators.FirstOrDefault()); - } - - var cachedSampleItem = App.FindSampleItemById(cachedModel.Id); - - var model1Type = sample.Model1Types.Any(cachedSampleItem.Contains) - ? sample.Model1Types.First(cachedSampleItem.Contains) - : sample.Model1Types.First(); - cachedModels.Add(model1Type, cachedModel); - - if (sample.Model2Types != null) + if (cachedModels == null) { - if (selectedModelDetails2 == null) - { - return; - } - - if (selectedModelDetails2.Size == 0) - { - cachedModel = (selectedModelDetails2.Id, selectedModelDetails2.Url, selectedModelDetails2.Url, 0, selectedModelDetails2.HardwareAccelerators.FirstOrDefault()); - } - else - { - var realCachedModel = App.ModelCache.GetCachedModel(selectedModelDetails2.Url); - if (realCachedModel == null) - { - return; - } - - cachedModel = (selectedModelDetails2.Id, realCachedModel.Path, realCachedModel.Url, realCachedModel.ModelSize, selectedModelDetails2.HardwareAccelerators.FirstOrDefault()); - } - - var model2Type = sample.Model2Types.Any(cachedSampleItem.Contains) - ? sample.Model2Types.First(cachedSampleItem.Contains) - : sample.Model2Types.First(); - - cachedModels.Add(model2Type, cachedModel); + return; } ContentDialog? dialog = null; @@ -338,13 +290,9 @@ private async void ExportSampleToggle_Click(object sender, RoutedEventArgs e) }; _ = dialog.ShowAsync(); - Dictionary cachedModelsToGenerator = cachedModels - .Select(cm => (cm.Key, (cm.Value.Path, cm.Value.Url, cm.Value.HardwareAccelerator))) - .ToDictionary(x => x.Key, x => (x.Item2.Path, x.Item2.Url, x.Item2.HardwareAccelerator)); - var projectPath = await generator.GenerateAsync( sample, - cachedModelsToGenerator, + cachedModels, copyRadioButton.IsChecked == true && copyRadioButtons.Visibility == Visibility.Visible, folder.Path, CancellationToken.None); diff --git a/AIDevGallery/ProjectGenerator/Generator.cs b/AIDevGallery/ProjectGenerator/Generator.cs index e343965..ad6397c 100644 --- a/AIDevGallery/ProjectGenerator/Generator.cs +++ b/AIDevGallery/ProjectGenerator/Generator.cs @@ -46,25 +46,9 @@ private static string ToSafeVariableName(string input) return safeName; } - internal Task GenerateAsync(Sample sample, Dictionary models, bool copyModelLocally, string outputPath, CancellationToken cancellationToken) - { - var packageReferences = new List<(string PackageName, string? Version)> - { - ("Microsoft.WindowsAppSDK", null), - ("Microsoft.Windows.SDK.BuildTools", null), - }; - - foreach (var nugetPackageReference in sample.NugetPackageReferences) - { - packageReferences.Add(new(nugetPackageReference, null)); - } - - return GenerateAsyncInternal(sample, models, copyModelLocally, packageReferences, outputPath, cancellationToken); - } - internal const string DotNetVersion = "net9.0"; - private async Task GenerateAsyncInternal(Sample sample, Dictionary models, bool copyModelLocally, List<(string PackageName, string? Version)> packageReferences, string outputPath, CancellationToken cancellationToken) + internal async Task GenerateAsync(Sample sample, Dictionary models, bool copyModelLocally, string outputPath, CancellationToken cancellationToken) { var projectName = $"{sample.Name}Sample"; string safeProjectName = ToSafeVariableName(projectName); @@ -94,20 +78,20 @@ private async Task GenerateAsyncInternal(Sample sample, Dictionary new FileInfo(f).Length); + sumTotalSize += Directory.GetFiles(modelInfo.Path, "*", SearchOption.AllDirectories).Sum(f => new FileInfo(f).Length); } else { - sumTotalSize += new FileInfo(modelInfo.CachedModelDirectoryPath).Length; + sumTotalSize += new FileInfo(modelInfo.Path).Length; } } @@ -120,7 +104,7 @@ private async Task GenerateAsyncInternal(Sample sample, Dictionary modelInfos = []; + Dictionary modelInfos = []; string model1Id = string.Empty; string model2Id = string.Empty; foreach (var modelType in modelTypes) @@ -133,14 +117,13 @@ private async Task GenerateAsyncInternal(Sample sample, Dictionary mf.Value.Url == modelInfo.ModelUrl) is var modelDetails2 && modelDetails2.Value != null) + else if (ModelTypeHelpers.ModelDetails.FirstOrDefault(mf => mf.Value.Url == modelInfo.Url) is var modelDetails2 && modelDetails2.Value != null) { modelPromptTemplate = modelDetails2.Value.PromptTemplate; modelId = modelDetails2.Value.Id; @@ -152,19 +135,19 @@ private async Task GenerateAsyncInternal(Sample sample, Dictionary GenerateAsyncInternal(Sample sample, Dictionary GenerateAsyncInternal(Sample sample, Dictionary GenerateAsyncInternal(Sample sample, Dictionary packageReferences.PackageName == genAiPackage)) - { - packageReferences.Add((genAiPackage, null)); - } - } } SampleProjectGeneratedEvent.Log(sample.Id, model1Id, model2Id, copyModelLocally); @@ -232,7 +206,7 @@ private async Task GenerateAsyncInternal(Sample sample, Dictionary GenerateAsyncInternal(Sample sample, Dictionary packageReferences = sample.GetAllNugetPackageReferences(models); + packageReferences.Add("Microsoft.WindowsAppSDK"); + packageReferences.Add("Microsoft.Windows.SDK.BuildTools"); + // Add NuGet references if (packageReferences.Count > 0 || copyModelLocally) { var project = ProjectRootElement.Open(csproj); var itemGroup = project.AddItemGroup(); - static void AddPackageReference(ProjectItemGroupElement itemGroup, string packageName, string? version) + static void AddPackageReference(ProjectItemGroupElement itemGroup, string packageName) { if (itemGroup.Items.Any(i => i.ItemType == "PackageReference" && i.Include == packageName)) { @@ -324,7 +302,7 @@ static void AddPackageReference(ProjectItemGroupElement itemGroup, string packag packageReferenceItem.Condition = "$(Platform) == 'ARM64'"; } - var versionStr = version ?? PackageVersionHelpers.PackageVersions[packageName]; + var versionStr = PackageVersionHelpers.PackageVersions[packageName]; packageReferenceItem.AddMetadata("Version", versionStr, true); if (packageName == "Microsoft.ML.OnnxRuntimeGenAI") @@ -340,22 +318,20 @@ static void AddPackageReference(ProjectItemGroupElement itemGroup, string packag } } - foreach (var packageReference in packageReferences) + foreach (var packageName in packageReferences) { - var packageName = packageReference.PackageName; - var version = packageReference.Version; if (packageName == "Microsoft.ML.OnnxRuntime.DirectML") { - AddPackageReference(itemGroup, "Microsoft.ML.OnnxRuntime.Qnn", null); + AddPackageReference(itemGroup, "Microsoft.ML.OnnxRuntime.Qnn"); } else if (packageName == "Microsoft.ML.OnnxRuntimeGenAI.DirectML") { - AddPackageReference(itemGroup, "Microsoft.ML.OnnxRuntime.Qnn", null); - AddPackageReference(itemGroup, "Microsoft.ML.OnnxRuntimeGenAI", null); - AddPackageReference(itemGroup, "Microsoft.ML.OnnxRuntimeGenAI.Managed", null); + AddPackageReference(itemGroup, "Microsoft.ML.OnnxRuntime.Qnn"); + AddPackageReference(itemGroup, "Microsoft.ML.OnnxRuntimeGenAI"); + AddPackageReference(itemGroup, "Microsoft.ML.OnnxRuntimeGenAI.Managed"); } - AddPackageReference(itemGroup, packageName, version); + AddPackageReference(itemGroup, packageName); } if (copyModelLocally) @@ -363,18 +339,18 @@ static void AddPackageReference(ProjectItemGroupElement itemGroup, string packag var modelContentItemGroup = project.AddItemGroup(); foreach (var modelInfo in modelInfos) { - if (modelInfo.Value.CachedModelDirectoryPath.Contains("file://", StringComparison.OrdinalIgnoreCase)) + if (modelInfo.Value.ExpandedModelDetails.Path.Contains("file://", StringComparison.OrdinalIgnoreCase)) { continue; } if (modelInfo.Value.IsSingleFile) { - modelContentItemGroup.AddItem("Content", @$"Models\{modelInfo.Value.CachedModelDirectoryPath}"); + modelContentItemGroup.AddItem("Content", @$"Models\{modelInfo.Value.ExpandedModelDetails.Path}"); } else { - modelContentItemGroup.AddItem("Content", @$"Models\{modelInfo.Value.CachedModelDirectoryPath}\**"); + modelContentItemGroup.AddItem("Content", @$"Models\{modelInfo.Value.ExpandedModelDetails.Path}\**"); } } } @@ -524,45 +500,12 @@ private string GetPromptTemplateString(PromptTemplate? promptTemplate, int space private async Task AddFilesFromSampleAsync( Sample sample, - List<(string PackageName, string? Version)> packageReferences, string baseNamespace, string outputPath, - Dictionary modelInfos, + Dictionary modelInfos, CancellationToken cancellationToken) { - List sharedCode = sample.SharedCode.ToList(); - bool isLanguageModel = ModelDetailsHelper.EqualOrParent(modelInfos.Keys.First(), ModelType.LanguageModels); - - if (isLanguageModel && !sharedCode.Contains(SharedCodeEnum.GenAIModel)) - { - sharedCode.Add(SharedCodeEnum.GenAIModel); - } - - if (sharedCode.Contains(SharedCodeEnum.GenAIModel)) - { - if (!sharedCode.Contains(SharedCodeEnum.LlmPromptTemplate)) - { - sharedCode.Add(SharedCodeEnum.LlmPromptTemplate); - } - } - - if (modelInfos.Values.Any(mi => mi.IsPhiSilica)) - { - if (!sharedCode.Contains(SharedCodeEnum.PhiSilicaClient)) - { - sharedCode.Add(SharedCodeEnum.PhiSilicaClient); - } - } - - if (sharedCode.Contains(SharedCodeEnum.DeviceUtils) && !sharedCode.Contains(SharedCodeEnum.NativeMethods)) - { - sharedCode.Add(SharedCodeEnum.NativeMethods); - var csWin32 = "Microsoft.Windows.CsWin32"; - if (!packageReferences.Any(packageReferences => packageReferences.PackageName == csWin32)) - { - packageReferences.Add((csWin32, null)); - } - } + List sharedCode = sample.GetAllSharedCode(modelInfos.ToDictionary(m => m.Key, m => m.Value.ExpandedModelDetails)); foreach (var sharedCodeEnum in sharedCode) { @@ -624,7 +567,7 @@ private async Task AddFilesFromSampleAsync( int i = 0; foreach (var modelInfo in modelInfos) { - cleanCsSource = cleanCsSource.Replace($"sampleParams.HardwareAccelerators[{i}]", $"HardwareAccelerator.{modelInfo.Value.HardwareAccelerator}"); + cleanCsSource = cleanCsSource.Replace($"sampleParams.HardwareAccelerators[{i}]", $"HardwareAccelerator.{modelInfo.Value.ExpandedModelDetails.HardwareAccelerator}"); cleanCsSource = cleanCsSource.Replace($"sampleParams.ModelPaths[{i}]", modelInfo.Value.ModelPathStr); i++; } @@ -634,7 +577,7 @@ private async Task AddFilesFromSampleAsync( else { var modelInfo = modelInfos.Values.First(); - cleanCsSource = cleanCsSource.Replace("sampleParams.HardwareAccelerator", $"HardwareAccelerator.{modelInfo.HardwareAccelerator}"); + cleanCsSource = cleanCsSource.Replace("sampleParams.HardwareAccelerator", $"HardwareAccelerator.{modelInfo.ExpandedModelDetails.HardwareAccelerator}"); cleanCsSource = cleanCsSource.Replace("sampleParams.ModelPath", modelInfo.ModelPathStr); modelPath = modelInfo.ModelPathStr; } @@ -651,7 +594,7 @@ private async Task AddFilesFromSampleAsync( var spaceCount = subStr.Length - subStrWithoutSpaces.Length; var modelInfo = modelInfos.Values.First(); var promptTemplate = GetPromptTemplateString(modelInfo.ModelPromptTemplate, spaceCount); - var chatClientLoader = GetChatClientLoaderString(sharedCode, modelPath, promptTemplate, modelInfo.IsPhiSilica, modelInfos.Keys.First()); + var chatClientLoader = GetChatClientLoaderString(sharedCode, modelPath, promptTemplate, modelInfos.Any(m => ModelDetailsHelper.EqualOrParent(m.Key, ModelType.PhiSilica)), modelInfos.Keys.First()); if (chatClientLoader != null) { cleanCsSource = cleanCsSource.Replace(search, chatClientLoader);