Skip to content

Commit 1739d62

Browse files
authored
Merge pull request microsoft#351 from microsoft/zt/custom-models
[DRAFT] Uploading custom models for any sample
2 parents 093a369 + 0743825 commit 1739d62

23 files changed

+748
-188
lines changed

AIDevGallery.SourceGenerator/Models/Model.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,6 @@ internal class Model
2727
public List<AIToolkitAction>? AIToolkitActions { get; init; }
2828
public string? AIToolkitId { get; init; }
2929
public string? AIToolkitFinetuningId { get; init; }
30+
public List<int[]>? InputDimensions { get; set; }
31+
public List<int[]>? OutputDimensions { get; set; }
3032
}

AIDevGallery.SourceGenerator/ModelsSourceGenerator.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,8 @@ private void GenerateModelDetails(StringBuilder sourceBuilder, Dictionary<string
253253
var aiToolkitActions = modelDefinition.AIToolkitActions != null ? string.Join(", ", modelDefinition.AIToolkitActions.Select(action => $"AIToolkitAction.{action}")) : string.Empty;
254254
var aiToolkitId = !string.IsNullOrEmpty(modelDefinition.AIToolkitId) ? $"\"{modelDefinition.AIToolkitId}\"" : "null";
255255
var aiToolkitFinetuningId = !string.IsNullOrEmpty(modelDefinition.AIToolkitFinetuningId) ? $"\"{modelDefinition.AIToolkitFinetuningId}\"" : "null";
256+
var inputDimensions = modelDefinition.InputDimensions != null ? "[ " + string.Join(", ", modelDefinition.InputDimensions.Select(dimension => "[" + string.Join(", ", dimension.Select(d => d.ToString())) + "]")) + "]" : "null";
257+
var outputDimensions = modelDefinition.OutputDimensions != null ? "[ " + string.Join(", ", modelDefinition.OutputDimensions.Select(dimension => "[" + string.Join(", ", dimension.Select(d => d.ToString())) + "]")) + "]" : "null";
256258

257259
sourceBuilder.AppendLine(
258260
$$""""
@@ -274,7 +276,9 @@ private void GenerateModelDetails(StringBuilder sourceBuilder, Dictionary<string
274276
FileFilters = [ {{fileFilters}} ],
275277
AIToolkitActions = [ {{aiToolkitActions}} ],
276278
AIToolkitId = {{aiToolkitId}},
277-
AIToolkitFinetuningId = {{aiToolkitFinetuningId}}
279+
AIToolkitFinetuningId = {{aiToolkitFinetuningId}},
280+
InputDimensions = {{inputDimensions}},
281+
OutputDimensions = {{outputDimensions}}
278282
}
279283
},
280284
"""");

AIDevGallery/AIDevGallery.csproj

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@
8484
<PackageReference Include="NAudio.WinMM" />
8585
<PackageReference Include="System.Numerics.Tensors" />
8686
<PackageReference Include="WinUIEx" />
87-
<PackageReference Include="HtmlAgilityPack" />
88-
<PackageReference Include="Markdig" />
89-
<PackageReference Include="Roman-Numerals" />
87+
<PackageReference Include="HtmlAgilityPack"/>
88+
<PackageReference Include="Markdig"/>
89+
<PackageReference Include="Roman-Numerals"/>
9090
<PackageReference Include="Microsoft.ML.OnnxRuntimeGenAI.Managed" />
9191
<Manifest Include="$(ApplicationManifest)" />
9292
</ItemGroup>

AIDevGallery/App.xaml.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ protected override async void OnLaunched(LaunchActivatedEventArgs args)
4141
{
4242
await LoadSamples();
4343
AppActivationArguments appActivationArguments = AppInstance.GetCurrent().GetActivatedEventArgs();
44-
var activationParam = ActivationHelper.GetActivationParam(appActivationArguments);
44+
var activationParam = await ActivationHelper.GetActivationParam(appActivationArguments);
4545
MainWindow = new MainWindow(activationParam);
4646

4747
MainWindow.Activate();

AIDevGallery/Controls/ModelPicker/AddHFModelView.xaml.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ private async Task SearchModels(string query, CancellationToken cancellationToke
9999

100100
try
101101
{
102-
accelerator = UserAddedModelUtilsTemp.GetHardwareAcceleratorFromConfig(configContents);
102+
accelerator = UserAddedModelUtil.GetHardwareAcceleratorFromConfig(configContents);
103103
}
104104
catch (JsonException)
105105
{

AIDevGallery/Controls/ModelPicker/ModelPickerViews/OnnxPickerView.xaml.cs

Lines changed: 27 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
using System.Linq;
1717
using System.Threading.Tasks;
1818
using Windows.ApplicationModel.DataTransfer;
19-
using Windows.Storage.Pickers;
2019

2120
namespace AIDevGallery.Controls.ModelPickerViews;
2221
internal sealed partial class OnnxPickerView : BaseModelPickerView
2322
{
24-
private List<ModelDetails> Models { get; set; } = new();
25-
private List<ModelType>? ModelTypes { get; set; }
23+
private List<ModelDetails> models = [];
24+
private List<ModelType>? modelTypes;
25+
2626
public ModelDetails? Selected { get; private set; }
2727

2828
private ObservableCollection<AvailableModel> AvailableModels { get; } = [];
@@ -38,48 +38,47 @@ public OnnxPickerView()
3838

3939
public override Task Load(List<ModelType> types)
4040
{
41-
ModelTypes = types;
41+
modelTypes = types;
42+
43+
ResetAndLoadModelList();
4244

4345
if (types.Contains(ModelType.LanguageModels))
4446
{
4547
AddHFModelButton.Visibility = Visibility.Visible;
4648
}
4749

4850
// local models supported for types
49-
// TODO: check which models support it
50-
if (types.Contains(ModelType.LanguageModels) || true)
51+
if (types.Contains(ModelType.LanguageModels) || models.IsModelsDetailsListUploadCompatible())
5152
{
5253
AddLocalModelButton.Visibility = Visibility.Visible;
5354
}
5455

55-
ResetAndLoadModelList();
56-
5756
return Task.CompletedTask;
5857
}
5958

6059
private void ResetAndLoadModelList()
6160
{
62-
Models.Clear();
61+
models.Clear();
6362
AvailableModels.Clear();
6463
DownloadableModels.Clear();
6564
UnavailableModels.Clear();
6665

67-
if (ModelTypes == null || ModelTypes.Count == 0)
66+
if (modelTypes == null || modelTypes.Count == 0)
6867
{
6968
return;
7069
}
7170

72-
foreach (ModelType type in ModelTypes)
71+
foreach (ModelType type in modelTypes)
7372
{
74-
Models.AddRange(ModelDetailsHelper.GetModelDetailsForModelType(type));
73+
models.AddRange(ModelDetailsHelper.GetModelDetailsForModelType(type));
7574
}
7675

77-
if (Models == null || Models.Count == 0)
76+
if (models == null || models.Count == 0)
7877
{
7978
return;
8079
}
8180

82-
foreach (var model in Models)
81+
foreach (var model in models)
8382
{
8483
if (!model.IsOnnxModel())
8584
{
@@ -155,7 +154,7 @@ private void ResetAndLoadModelList()
155154

156155
private void CacheStore_ModelsChanged(ModelCacheStore sender)
157156
{
158-
ResetAndLoadModelList();
157+
DispatcherQueue.TryEnqueue(ResetAndLoadModelList);
159158
}
160159

161160
private void ModelSelectionView_SelectionChanged(object sender, SelectionChangedEventArgs e)
@@ -269,6 +268,7 @@ private async void DeleteModel_Click(object sender, RoutedEventArgs e)
269268
if (result == ContentDialogResult.Primary)
270269
{
271270
await App.ModelCache.DeleteModelFromCache(details.Url);
271+
ResetAndLoadModelList();
272272
}
273273
}
274274
}
@@ -375,129 +375,20 @@ private void AddHFModelView_CloseRequested(object sender)
375375

376376
private async void AddLocalModelButton_Click(object sender, RoutedEventArgs e)
377377
{
378-
var hwnd = WinRT.Interop.WindowNative.GetWindowHandle(App.MainWindow);
379-
var picker = new FolderPicker();
380-
picker.FileTypeFilter.Add("*");
381-
WinRT.Interop.InitializeWithWindow.Initialize(picker, hwnd);
382-
var folder = await picker.PickSingleFolderAsync();
383-
384-
if (folder != null)
378+
if (modelTypes == null)
385379
{
386-
var files = Directory.GetFiles(folder.Path);
387-
var config = files.Where(r => Path.GetFileName(r) == "genai_config.json").FirstOrDefault();
388-
389-
if (string.IsNullOrEmpty(config) || App.ModelCache.Models.Any(m => m.Path == folder.Path))
390-
{
391-
var message = string.IsNullOrEmpty(config) ?
392-
"The folder does not contain a model you can add. Ensure \"genai_config.json\" is present in the selected directory" :
393-
"This model is already added";
394-
395-
ContentDialog confirmFolderDialog = new()
396-
{
397-
Title = "Can't add model",
398-
Content = message,
399-
XamlRoot = this.Content.XamlRoot,
400-
CloseButtonText = "OK"
401-
};
402-
403-
await confirmFolderDialog.ShowAsync();
404-
return;
405-
}
406-
407-
HardwareAccelerator accelerator = HardwareAccelerator.CPU;
408-
409-
try
410-
{
411-
string configContents = string.Empty;
412-
configContents = await File.ReadAllTextAsync(config);
413-
accelerator = UserAddedModelUtilsTemp.GetHardwareAcceleratorFromConfig(configContents);
414-
}
415-
catch (Exception ex)
416-
{
417-
ContentDialog confirmFolderDialog = new()
418-
{
419-
Title = "Can't read genai_config.json",
420-
Content = ex.Message,
421-
XamlRoot = this.Content.XamlRoot,
422-
CloseButtonText = "OK"
423-
};
424-
425-
await confirmFolderDialog.ShowAsync();
426-
return;
427-
}
428-
429-
var nameTextBox = new TextBox()
430-
{
431-
Text = Path.GetFileName(folder.Path),
432-
Width = 300,
433-
HorizontalAlignment = HorizontalAlignment.Left,
434-
Margin = new Thickness(0, 0, 0, 10),
435-
Header = "Model name"
436-
};
437-
438-
ContentDialog nameModelDialog = new()
439-
{
440-
Title = "Add model",
441-
Content = new StackPanel()
442-
{
443-
Orientation = Orientation.Vertical,
444-
Spacing = 8,
445-
Children =
446-
{
447-
new TextBlock()
448-
{
449-
Text = $"Adding ONNX model from \n \"{folder.Path}\"",
450-
TextWrapping = TextWrapping.WrapWholeWords
451-
},
452-
nameTextBox
453-
}
454-
},
455-
XamlRoot = this.Content.XamlRoot,
456-
CloseButtonText = "Cancel",
457-
PrimaryButtonText = "Add",
458-
DefaultButton = ContentDialogButton.Primary,
459-
Style = Application.Current.Resources["DefaultContentDialogStyle"] as Style
460-
};
461-
462-
string modelName = nameTextBox.Text;
463-
464-
nameTextBox.TextChanged += (s, e) =>
465-
{
466-
if (string.IsNullOrEmpty(nameTextBox.Text))
467-
{
468-
nameModelDialog.IsPrimaryButtonEnabled = false;
469-
}
470-
else
471-
{
472-
modelName = nameTextBox.Text;
473-
nameModelDialog.IsPrimaryButtonEnabled = true;
474-
}
475-
};
476-
477-
var result = await nameModelDialog.ShowAsync();
478-
if (result != ContentDialogResult.Primary)
479-
{
480-
return;
481-
}
482-
483-
DirectoryInfo dirInfo = new DirectoryInfo(folder.Path);
484-
long dirSize = await Task.Run(() => dirInfo.EnumerateFiles("*", SearchOption.AllDirectories).Sum(file => file.Length));
485-
486-
var details = new ModelDetails()
487-
{
488-
Id = "useradded-local-languagemodel-" + Guid.NewGuid().ToString(),
489-
Name = modelName,
490-
Url = $"local-file:///{folder.Path}",
491-
Description = "Localy added GenAI Model",
492-
HardwareAccelerators = [accelerator],
493-
IsUserAdded = true,
494-
PromptTemplate = ModelDetailsHelper.GetTemplateFromName(folder.Path),
495-
Size = dirSize,
496-
ReadmeUrl = null,
497-
License = "unknown"
498-
};
380+
return;
381+
}
499382

500-
await App.ModelCache.AddLocalModelToCache(details, folder.Path);
383+
if (modelTypes.Contains(ModelType.LanguageModels))
384+
{
385+
await UserAddedModelUtil.OpenAddLanguageModelFlow(Content.XamlRoot);
386+
}
387+
else
388+
{
389+
await UserAddedModelUtil.OpenAddModelFlow(Content.XamlRoot, modelTypes);
501390
}
391+
392+
ResetAndLoadModelList();
502393
}
503394
}

AIDevGallery/Controls/ModelSelectionControl.xaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,11 @@
612612
</DataTemplate>
613613
</ItemsRepeater.ItemTemplate>
614614
</ItemsRepeater>
615+
<!--<Button
616+
x:Name="AddLocalModelButton"
617+
Click="AddLocalModelButton_Click"
618+
HorizontalAlignment="Center"
619+
Content="Add Local Model"/>-->
615620
</StackPanel>
616621
</ScrollViewer>
617622
<ContentDialog

AIDevGallery/Controls/ModelSelectionControl.xaml.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace AIDevGallery.Controls;
2222
internal partial class ModelSelectionControl : UserControl
2323
{
2424
public List<ModelDetails>? Models { get; private set; }
25+
public Scenario? Scenario { get; set; }
2526
public ModelDetails? Selected { get; private set; }
2627

2728
public static readonly DependencyProperty DownloadableModelsTitleProperty = DependencyProperty.Register(nameof(DownloadableModelsTitle), typeof(string), typeof(ModelSelectionControl), new PropertyMetadata(defaultValue: null));
@@ -85,7 +86,7 @@ public void SetModels(List<ModelDetails>? models, ModelDetails? initialSelectedM
8586

8687
private void CacheStore_ModelsChanged(ModelCacheStore sender)
8788
{
88-
PopulateModelDetailsLists();
89+
DispatcherQueue.TryEnqueue(PopulateModelDetailsLists);
8990
}
9091

9192
private void ResetAndLoadModelList(ModelDetails? selectedModel = null)

0 commit comments

Comments
 (0)