Skip to content

Commit eb4a7cf

Browse files
speedstorm1copybara-github
authored andcommitted
feat: support hyperparameters in distillation tuning
FUTURE_COPYBARA_INTEGRATE_REVIEW=#232 from googleapis:release-please--branches--main--components--Google.GenAI b8eaed6 PiperOrigin-RevId: 882708166
1 parent c736cae commit eb4a7cf

File tree

4 files changed

+55
-8
lines changed

4 files changed

+55
-8
lines changed

Google.GenAI/Tunings.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,11 @@ internal JsonNode CreateTuningJobConfigToVertex(JsonNode fromObject, JsonObject
355355
Common.SetValueByPath(parentObject, new string[] { "supervisedTuningSpec", "tuningMode" },
356356
Common.GetValueByPath(fromObject, new string[] { "tuningMode" }));
357357
}
358+
} else if (discriminatorValueTuningMode == "DISTILLATION") {
359+
if (Common.GetValueByPath(fromObject, new string[] { "tuningMode" }) != null) {
360+
Common.SetValueByPath(parentObject, new string[] { "distillationSpec", "tuningMode" },
361+
Common.GetValueByPath(fromObject, new string[] { "tuningMode" }));
362+
}
358363
}
359364
if (Common.GetValueByPath(fromObject, new string[] { "customBaseModel" }) != null) {
360365
Common.SetValueByPath(
@@ -373,6 +378,12 @@ internal JsonNode CreateTuningJobConfigToVertex(JsonNode fromObject, JsonObject
373378
parentObject, new string[] { "supervisedTuningSpec", "hyperParameters", "batchSize" },
374379
Common.GetValueByPath(fromObject, new string[] { "batchSize" }));
375380
}
381+
} else if (discriminatorValueBatchSize == "DISTILLATION") {
382+
if (Common.GetValueByPath(fromObject, new string[] { "batchSize" }) != null) {
383+
Common.SetValueByPath(parentObject,
384+
new string[] { "distillationSpec", "hyperParameters", "batchSize" },
385+
Common.GetValueByPath(fromObject, new string[] { "batchSize" }));
386+
}
376387
}
377388

378389
JsonNode discriminatorLearningRate =
@@ -387,6 +398,12 @@ internal JsonNode CreateTuningJobConfigToVertex(JsonNode fromObject, JsonObject
387398
new string[] { "supervisedTuningSpec", "hyperParameters", "learningRate" },
388399
Common.GetValueByPath(fromObject, new string[] { "learningRate" }));
389400
}
401+
} else if (discriminatorValueLearningRate == "DISTILLATION") {
402+
if (Common.GetValueByPath(fromObject, new string[] { "learningRate" }) != null) {
403+
Common.SetValueByPath(
404+
parentObject, new string[] { "distillationSpec", "hyperParameters", "learningRate" },
405+
Common.GetValueByPath(fromObject, new string[] { "learningRate" }));
406+
}
390407
}
391408

392409
if (Common.GetValueByPath(fromObject, new string[] { "labels" }) != null) {

Google.GenAI/types/CreateTuningJobConfig.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ public AdapterSize
129129
}
130130

131131
/// <summary>
132-
/// Tuning mode for SFT tuning.
132+
/// Tuning mode for tuning.
133133
/// </summary>
134134
[JsonPropertyName("tuningMode")]
135135
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]

Google.GenAI/types/DistillationHyperParameters.cs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,36 @@
2323

2424
namespace Google.GenAI.Types {
2525
/// <summary>
26-
/// Hyperparameters for Distillation. This data type is not supported in Gemini API.
26+
/// Hyperparameters for distillation.
2727
/// </summary>
2828

2929
public record DistillationHyperParameters {
30+
/// <summary>
31+
/// The batch size hyperparameter for tuning. This is only supported for OSS models in Vertex.
32+
/// </summary>
33+
[JsonPropertyName("batchSize")]
34+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
35+
public int ? BatchSize { get; set; }
36+
37+
/// <summary>
38+
/// The learning rate for tuning. OSS models only.
39+
/// </summary>
40+
[JsonPropertyName("learningRate")]
41+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
42+
public double
43+
? LearningRate {
44+
get; set;
45+
}
46+
3047
/// <summary>
3148
/// Optional. Adapter size for distillation.
3249
/// </summary>
3350
[JsonPropertyName("adapterSize")]
3451
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
35-
public AdapterSize ? AdapterSize { get; set; }
52+
public AdapterSize
53+
? AdapterSize {
54+
get; set;
55+
}
3656

3757
/// <summary>
3858
/// Optional. Number of complete passes the model makes over the entire training dataset during

Google.GenAI/types/DistillationSpec.cs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,12 @@ public record DistillationSpec {
3535
public string ? PromptDatasetUri { get; set; }
3636

3737
/// <summary>
38-
/// The base teacher model that is being distilled. See Supported models
39-
/// (https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models).
38+
/// Tuning mode for tuning.
4039
/// </summary>
41-
[JsonPropertyName("baseTeacherModel")]
40+
[JsonPropertyName("tuningMode")]
4241
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
43-
public string
44-
? BaseTeacherModel {
42+
public TuningMode
43+
? TuningMode {
4544
get; set;
4645
}
4746

@@ -55,6 +54,17 @@ public DistillationHyperParameters
5554
get; set;
5655
}
5756

57+
/// <summary>
58+
/// The base teacher model that is being distilled. See Supported models
59+
/// (https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models).
60+
/// </summary>
61+
[JsonPropertyName("baseTeacherModel")]
62+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
63+
public string
64+
? BaseTeacherModel {
65+
get; set;
66+
}
67+
5868
/// <summary>
5969
/// Deprecated. A path in a Cloud Storage bucket, which will be treated as the root output
6070
/// directory of the distillation pipeline. It is used by the system to generate the paths of

0 commit comments

Comments
 (0)