-
Notifications
You must be signed in to change notification settings - Fork 281
Expand file tree
/
Copy pathModel.cs
More file actions
130 lines (105 loc) · 4.75 KB
/
Model.cs
File metadata and controls
130 lines (105 loc) · 4.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
// --------------------------------------------------------------------------------------------------------------------
// <copyright company="Microsoft">
// Copyright (c) Microsoft. All rights reserved.
// </copyright>
// --------------------------------------------------------------------------------------------------------------------
namespace Microsoft.AI.Foundry.Local;
using Microsoft.Extensions.Logging;
public class Model : IModel
{
private readonly ILogger _logger;
public List<ModelVariant> Variants { get; internal set; }
public ModelVariant SelectedVariant { get; internal set; } = default!;
public string Alias { get; init; }
public string Id => SelectedVariant.Id;
/// <summary>
/// Is the currently selected variant cached locally?
/// </summary>
public Task<bool> IsCachedAsync(CancellationToken? ct = null) => SelectedVariant.IsCachedAsync(ct);
/// <summary>
/// Is the currently selected variant loaded in memory?
/// </summary>
public Task<bool> IsLoadedAsync(CancellationToken? ct = null) => SelectedVariant.IsLoadedAsync(ct);
internal Model(ModelVariant modelVariant, ILogger logger)
{
_logger = logger;
Alias = modelVariant.Alias;
Variants = new() { modelVariant };
// variants are sorted by Core, so the first one added is the default
SelectedVariant = modelVariant;
}
internal void AddVariant(ModelVariant variant)
{
if (Alias != variant.Alias)
{
// internal error so log
throw new FoundryLocalException($"Variant alias {variant.Alias} does not match model alias {Alias}",
_logger);
}
Variants.Add(variant);
// prefer the highest priority locally cached variant
if (variant.Info.Cached && !SelectedVariant.Info.Cached)
{
SelectedVariant = variant;
}
}
/// <summary>
/// Select a specific model variant from <see cref="Variants"/> to use for <see cref="IModel"/> operations.
/// </summary>
/// <param name="variant">Model variant to select. Must be one of the variants in <see cref="Variants"/>.</param>
/// <exception cref="FoundryLocalException">If variant is not valid for this model.</exception>
public void SelectVariant(ModelVariant variant)
{
_ = Variants.FirstOrDefault(v => v == variant) ??
// user error so don't log
throw new FoundryLocalException($"Model {Alias} does not have a {variant.Id} variant.");
SelectedVariant = variant;
}
/// <summary>
/// Get the latest version of the specified model variant.
/// </summary>
/// <param name="variant">Model variant.</param>
/// <returns>ModelVariant for latest version. Same as `variant` if that is the latest version.</returns>
/// <exception cref="FoundryLocalException">If variant is not valid for this model.</exception>
public ModelVariant GetLatestVersion(ModelVariant variant)
{
// variants are sorted by version, so the first one matching the name is the latest version for that variant.
var latest = Variants.FirstOrDefault(v => v.Info.Name == variant.Info.Name) ??
// user error so don't log
throw new FoundryLocalException($"Model {Alias} does not have a {variant.Id} variant.");
return latest;
}
public async Task<string> GetPathAsync(CancellationToken? ct = null)
{
return await SelectedVariant.GetPathAsync(ct).ConfigureAwait(false);
}
public async Task DownloadAsync(Action<float>? downloadProgress = null,
CancellationToken? ct = null)
{
await SelectedVariant.DownloadAsync(downloadProgress, ct).ConfigureAwait(false);
}
public async Task LoadAsync(CancellationToken? ct = null)
{
await SelectedVariant.LoadAsync(ct).ConfigureAwait(false);
}
public async Task<OpenAIChatClient> GetChatClientAsync(CancellationToken? ct = null)
{
return await SelectedVariant.GetChatClientAsync(ct).ConfigureAwait(false);
}
public async Task<OpenAIAudioClient> GetAudioClientAsync(CancellationToken? ct = null)
{
return await SelectedVariant.GetAudioClientAsync(ct).ConfigureAwait(false);
}
public async Task<OpenAIResponsesClient> GetResponsesClientAsync(CancellationToken? ct = null)
{
return await SelectedVariant.GetResponsesClientAsync(ct).ConfigureAwait(false);
}
public async Task UnloadAsync(CancellationToken? ct = null)
{
await SelectedVariant.UnloadAsync(ct).ConfigureAwait(false);
}
public async Task RemoveFromCacheAsync(CancellationToken? ct = null)
{
await SelectedVariant.RemoveFromCacheAsync(ct).ConfigureAwait(false);
}
}