Skip to content

Commit

Permalink
Merge pull request #233 from microsoft/alzollin/refactorSamplesNuPkg
Browse files Browse the repository at this point in the history
Refactor NugetPackage logic and SharedCode logic to be shared between sample container and project generator.
  • Loading branch information
nmetulev authored Feb 22, 2025
2 parents cf82d60 + 1038a0c commit 3a4624e
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 159 deletions.
10 changes: 5 additions & 5 deletions AIDevGallery.UnitTests/ProjectGeneratorUnitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ private class SampleUIData : INotifyPropertyChanged

public required string SampleName { get; init; }
public required Sample Sample { get; init; }
public required Dictionary<ModelType, (string CachedModelDirectoryPath, string ModelUrl, HardwareAccelerator HardwareAccelerator)> CachedModelsToGenerator { get; init; }
public required Dictionary<ModelType, ExpandedModelDetails> CachedModelsToGenerator { get; init; }
public Brush? StatusColor
{
get => statusColor;
Expand Down Expand Up @@ -177,17 +177,17 @@ private static IEnumerable<SampleUIData> GetAllForSample(Sample s)
};
}

static Dictionary<ModelType, (string CachedModelDirectoryPath, string ModelUrl, HardwareAccelerator HardwareAccelerator)> GetModelsToGenerator(Sample s, List<Dictionary<ModelType, List<ModelDetails>>> modelsDetails, KeyValuePair<ModelType, List<ModelDetails>> keyValuePair)
static Dictionary<ModelType, ExpandedModelDetails> GetModelsToGenerator(Sample s, List<Dictionary<ModelType, List<ModelDetails>>> modelsDetails, KeyValuePair<ModelType, List<ModelDetails>> keyValuePair)
{
Dictionary<ModelType, (string CachedModelDirectoryPath, string ModelUrl, HardwareAccelerator HardwareAccelerator)> cachedModelsToGenerator = new();
Dictionary<ModelType, ExpandedModelDetails> cachedModelsToGenerator = new();

ModelDetails modelDetails1 = keyValuePair.Value.First();
cachedModelsToGenerator[keyValuePair.Key] = (modelDetails1.Url, modelDetails1.Url, modelDetails1.HardwareAccelerators.First());
cachedModelsToGenerator[keyValuePair.Key] = new(modelDetails1.Id, modelDetails1.Url, modelDetails1.Url, 0, modelDetails1.HardwareAccelerators.First());

if (s.Model2Types != null && modelsDetails.Count > 1)
{
ModelDetails modelDetails2 = modelsDetails[1].Values.First().First();
cachedModelsToGenerator[s.Model2Types.First()] = (modelDetails2.Url, modelDetails2.Url, modelDetails2.HardwareAccelerators.First());
cachedModelsToGenerator[s.Model2Types.First()] = new(modelDetails2.Id, modelDetails2.Url, modelDetails2.Url, 0, modelDetails2.HardwareAccelerators.First());
}

return cachedModelsToGenerator;
Expand Down
2 changes: 1 addition & 1 deletion AIDevGallery/Controls/SampleContainer.xaml
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
Text="Required NuGet packages" />
</Expander.Header>
<Expander.Content>
<ItemsRepeater Margin="-12,-8,-12,-8" ItemsSource="{x:Bind _sampleCache.NugetPackageReferences}">
<ItemsRepeater Margin="-12,-8,-12,-8" ItemsSource="{x:Bind NugetPackageReferences, Mode=OneWay}">
<ItemsRepeater.ItemTemplate>
<DataTemplate x:DataType="x:String">
<HyperlinkButton
Expand Down
27 changes: 24 additions & 3 deletions AIDevGallery/Controls/SampleContainer.xaml.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using AIDevGallery.Helpers;
using AIDevGallery.Models;
using AIDevGallery.Samples.SharedCode;
using AIDevGallery.Telemetry.Events;
Expand Down Expand Up @@ -30,7 +31,17 @@ public HorizontalAlignment DisclaimerHorizontalAlignment
set => SetValue(DisclaimerHorizontalAlignmentProperty, value);
}

public List<string> NugetPackageReferences
{
get { return (List<string>)GetValue(NugetPackageReferencesProperty); }
set { SetValue(NugetPackageReferencesProperty, value); }
}

public static readonly DependencyProperty NugetPackageReferencesProperty =
DependencyProperty.Register("NugetPackageReferences", typeof(List<string>), typeof(SampleContainer), new PropertyMetadata(null));

private Sample? _sampleCache;
private Dictionary<ModelType, ExpandedModelDetails>? _cachedModels;
private List<ModelDetails>? _modelsCache;
private CancellationTokenSource? _sampleLoadingCts;
private TaskCompletionSource? _sampleLoadedCompletionSource;
Expand Down Expand Up @@ -272,7 +283,17 @@ private bool LoadSampleMetadata(Sample sample, List<ModelDetails>? models)
}

_sampleCache = sample;
_modelsCache = models;

if (models != null)
{
_cachedModels = sample.GetCacheModelDetailsDictionary(models.ToArray());

if (_cachedModels != null)
{
NugetPackageReferences = sample.GetAllNugetPackageReferences(_cachedModels);
_modelsCache = models;
}
}

if (sample == null)
{
Expand Down Expand Up @@ -308,9 +329,9 @@ private void RenderCode(bool force = false)
CodePivot.Items.Add(CreateCodeBlock(codeFormatter, "Sample.xaml", _sampleCache.XAMLCode, Languages.FindById("xaml")));
}

if (_sampleCache.SharedCode != null && _sampleCache.SharedCode.Count != 0)
if (_cachedModels != null)
{
foreach (var sharedCodeEnum in _sampleCache.SharedCode)
foreach (var sharedCodeEnum in _sampleCache.GetAllSharedCode(_cachedModels))
{
string sharedCodeName = Samples.SharedCodeHelpers.GetName(sharedCodeEnum);
string sharedCodeContent = Samples.SharedCodeHelpers.GetSource(sharedCodeEnum);
Expand Down
155 changes: 155 additions & 0 deletions AIDevGallery/Helpers/SamplesHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using AIDevGallery.Models;
using AIDevGallery.Samples;
using System;
using System.Collections.Generic;
using System.Linq;

namespace AIDevGallery.Helpers;

internal static class SamplesHelper
{
public static List<SharedCodeEnum> GetAllSharedCode(this Sample sample, Dictionary<ModelType, ExpandedModelDetails> models)
{
var sharedCode = sample.SharedCode.ToList();

bool isLanguageModel = ModelDetailsHelper.EqualOrParent(models.Keys.First(), ModelType.LanguageModels);

if (isLanguageModel)
{
AddUnique(SharedCodeEnum.GenAIModel);
}

if (sharedCode.Contains(SharedCodeEnum.GenAIModel))
{
AddUnique(SharedCodeEnum.LlmPromptTemplate);
}

if (models.Any(m => ModelDetailsHelper.EqualOrParent(m.Key, ModelType.PhiSilica)))
{
AddUnique(SharedCodeEnum.PhiSilicaClient);
}

if (sharedCode.Contains(SharedCodeEnum.DeviceUtils))
{
AddUnique(SharedCodeEnum.NativeMethods);
}

return sharedCode;

void AddUnique(SharedCodeEnum sharedCodeEnumToAdd)
{
if (!sharedCode.Contains(sharedCodeEnumToAdd))
{
sharedCode.Add(sharedCodeEnumToAdd);
}
}
}

public static List<string> GetAllNugetPackageReferences(this Sample sample, Dictionary<ModelType, ExpandedModelDetails> models)
{
var packageReferences = sample.NugetPackageReferences.ToList();

var modelTypes = sample.Model1Types.Concat(sample.Model2Types ?? Enumerable.Empty<ModelType>())
.Where(models.ContainsKey);

bool isLanguageModel = modelTypes.Any(modelType => ModelDetailsHelper.EqualOrParent(modelType, ModelType.LanguageModels));

if (isLanguageModel)
{
AddUnique("Microsoft.ML.OnnxRuntimeGenAI.DirectML");
}

var sharedCode = sample.GetAllSharedCode(models);

if (sharedCode.Contains(SharedCodeEnum.NativeMethods))
{
AddUnique("Microsoft.Windows.CsWin32");
}

return packageReferences;

void AddUnique(string packageNameToAdd)
{
if (!packageReferences.Any(packageName => packageName == packageNameToAdd))
{
packageReferences.Add(packageNameToAdd);
}
}
}

public static Dictionary<ModelType, ExpandedModelDetails>? GetCacheModelDetailsDictionary(this Sample sample, ModelDetails?[] modelDetails)
{
if (modelDetails.Length == 0 || modelDetails.Length > 2)
{
throw new ArgumentException(modelDetails.Length == 0 ? "No model details provided" : "More than 2 model details provided");
}

var selectedModelDetails = modelDetails[0];
var selectedModelDetails2 = modelDetails.Length > 1 ? modelDetails[1] : null;

if (selectedModelDetails == null)
{
return null;
}

Dictionary<ModelType, ExpandedModelDetails> cachedModels = [];

ExpandedModelDetails cachedModel;

if (selectedModelDetails.Size == 0)
{
cachedModel = new(selectedModelDetails.Id, selectedModelDetails.Url, selectedModelDetails.Url, 0, selectedModelDetails.HardwareAccelerators.FirstOrDefault());
}
else
{
var realCachedModel = App.ModelCache.GetCachedModel(selectedModelDetails.Url);
if (realCachedModel == null)
{
return null;
}

cachedModel = new(selectedModelDetails.Id, realCachedModel.Path, realCachedModel.Url, realCachedModel.ModelSize, selectedModelDetails.HardwareAccelerators.FirstOrDefault());
}

var cachedSampleItem = App.FindSampleItemById(cachedModel.Id);

var model1Type = sample.Model1Types.Any(cachedSampleItem.Contains)
? sample.Model1Types.First(cachedSampleItem.Contains)
: sample.Model1Types.First();
cachedModels.Add(model1Type, cachedModel);

if (sample.Model2Types != null)
{
if (selectedModelDetails2 == null)
{
return null;
}

if (selectedModelDetails2.Size == 0)
{
cachedModel = new(selectedModelDetails2.Id, selectedModelDetails2.Url, selectedModelDetails2.Url, 0, selectedModelDetails2.HardwareAccelerators.FirstOrDefault());
}
else
{
var realCachedModel = App.ModelCache.GetCachedModel(selectedModelDetails2.Url);
if (realCachedModel == null)
{
return null;
}

cachedModel = new(selectedModelDetails2.Id, realCachedModel.Path, realCachedModel.Url, realCachedModel.ModelSize, selectedModelDetails2.HardwareAccelerators.FirstOrDefault());
}

var model2Type = sample.Model2Types.Any(cachedSampleItem.Contains)
? sample.Model2Types.First(cachedSampleItem.Contains)
: sample.Model2Types.First();

cachedModels.Add(model2Type, cachedModel);
}

return cachedModels;
}
}
6 changes: 6 additions & 0 deletions AIDevGallery/Models/ExpandedModelDetails.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

namespace AIDevGallery.Models;

internal record class ExpandedModelDetails(string Id, string Path, string Url, long ModelSize, HardwareAccelerator HardwareAccelerator);
60 changes: 4 additions & 56 deletions AIDevGallery/Pages/Scenarios/ScenarioPage.xaml.cs
Original file line number Diff line number Diff line change
Expand Up @@ -248,59 +248,11 @@ private async void ExportSampleToggle_Click(object sender, RoutedEventArgs e)
return;
}

Dictionary<ModelType, (string Id, string Path, string Url, long ModelSize, HardwareAccelerator HardwareAccelerator)> cachedModels = [];
var cachedModels = sample.GetCacheModelDetailsDictionary([selectedModelDetails, selectedModelDetails2]);

(string Id, string Path, string Url, long ModelSize, HardwareAccelerator HardwareAccelerator) cachedModel;

if (selectedModelDetails.Size == 0)
{
cachedModel = (selectedModelDetails.Id, selectedModelDetails.Url, selectedModelDetails.Url, 0, selectedModelDetails.HardwareAccelerators.FirstOrDefault());
}
else
{
var realCachedModel = App.ModelCache.GetCachedModel(selectedModelDetails.Url);
if (realCachedModel == null)
{
return;
}

cachedModel = (selectedModelDetails.Id, realCachedModel.Path, realCachedModel.Url, realCachedModel.ModelSize, selectedModelDetails.HardwareAccelerators.FirstOrDefault());
}

var cachedSampleItem = App.FindSampleItemById(cachedModel.Id);

var model1Type = sample.Model1Types.Any(cachedSampleItem.Contains)
? sample.Model1Types.First(cachedSampleItem.Contains)
: sample.Model1Types.First();
cachedModels.Add(model1Type, cachedModel);

if (sample.Model2Types != null)
if (cachedModels == null)
{
if (selectedModelDetails2 == null)
{
return;
}

if (selectedModelDetails2.Size == 0)
{
cachedModel = (selectedModelDetails2.Id, selectedModelDetails2.Url, selectedModelDetails2.Url, 0, selectedModelDetails2.HardwareAccelerators.FirstOrDefault());
}
else
{
var realCachedModel = App.ModelCache.GetCachedModel(selectedModelDetails2.Url);
if (realCachedModel == null)
{
return;
}

cachedModel = (selectedModelDetails2.Id, realCachedModel.Path, realCachedModel.Url, realCachedModel.ModelSize, selectedModelDetails2.HardwareAccelerators.FirstOrDefault());
}

var model2Type = sample.Model2Types.Any(cachedSampleItem.Contains)
? sample.Model2Types.First(cachedSampleItem.Contains)
: sample.Model2Types.First();

cachedModels.Add(model2Type, cachedModel);
return;
}

ContentDialog? dialog = null;
Expand Down Expand Up @@ -338,13 +290,9 @@ private async void ExportSampleToggle_Click(object sender, RoutedEventArgs e)
};
_ = dialog.ShowAsync();

Dictionary<ModelType, (string CachedModelDirectoryPath, string ModelUrl, HardwareAccelerator HardwareAccelerator)> cachedModelsToGenerator = cachedModels
.Select(cm => (cm.Key, (cm.Value.Path, cm.Value.Url, cm.Value.HardwareAccelerator)))
.ToDictionary(x => x.Key, x => (x.Item2.Path, x.Item2.Url, x.Item2.HardwareAccelerator));

var projectPath = await generator.GenerateAsync(
sample,
cachedModelsToGenerator,
cachedModels,
copyRadioButton.IsChecked == true && copyRadioButtons.Visibility == Visibility.Visible,
folder.Path,
CancellationToken.None);
Expand Down
Loading

0 comments on commit 3a4624e

Please sign in to comment.