Skip to content

Commit f5ac3b3

Browse files
speedstorm1copybara-github
authored andcommitted
feat: support hyperparameters in distillation tuning
PiperOrigin-RevId: 882708166
1 parent 594f64e commit f5ac3b3

File tree

4 files changed

+49
-2
lines changed

4 files changed

+49
-2
lines changed

Google.GenAI/Tunings.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,11 @@ internal JsonNode CreateTuningJobConfigToVertex(JsonNode fromObject, JsonObject
355355
Common.SetValueByPath(parentObject, new string[] { "supervisedTuningSpec", "tuningMode" },
356356
Common.GetValueByPath(fromObject, new string[] { "tuningMode" }));
357357
}
358+
} else if (discriminatorValueTuningMode == "DISTILLATION") {
359+
if (Common.GetValueByPath(fromObject, new string[] { "tuningMode" }) != null) {
360+
Common.SetValueByPath(parentObject, new string[] { "distillationSpec", "tuningMode" },
361+
Common.GetValueByPath(fromObject, new string[] { "tuningMode" }));
362+
}
358363
}
359364
if (Common.GetValueByPath(fromObject, new string[] { "customBaseModel" }) != null) {
360365
Common.SetValueByPath(
@@ -373,6 +378,12 @@ internal JsonNode CreateTuningJobConfigToVertex(JsonNode fromObject, JsonObject
373378
parentObject, new string[] { "supervisedTuningSpec", "hyperParameters", "batchSize" },
374379
Common.GetValueByPath(fromObject, new string[] { "batchSize" }));
375380
}
381+
} else if (discriminatorValueBatchSize == "DISTILLATION") {
382+
if (Common.GetValueByPath(fromObject, new string[] { "batchSize" }) != null) {
383+
Common.SetValueByPath(parentObject,
384+
new string[] { "distillationSpec", "hyperParameters", "batchSize" },
385+
Common.GetValueByPath(fromObject, new string[] { "batchSize" }));
386+
}
376387
}
377388

378389
JsonNode discriminatorLearningRate =
@@ -387,6 +398,12 @@ internal JsonNode CreateTuningJobConfigToVertex(JsonNode fromObject, JsonObject
387398
new string[] { "supervisedTuningSpec", "hyperParameters", "learningRate" },
388399
Common.GetValueByPath(fromObject, new string[] { "learningRate" }));
389400
}
401+
} else if (discriminatorValueLearningRate == "DISTILLATION") {
402+
if (Common.GetValueByPath(fromObject, new string[] { "learningRate" }) != null) {
403+
Common.SetValueByPath(
404+
parentObject, new string[] { "distillationSpec", "hyperParameters", "learningRate" },
405+
Common.GetValueByPath(fromObject, new string[] { "learningRate" }));
406+
}
390407
}
391408

392409
if (Common.GetValueByPath(fromObject, new string[] { "labels" }) != null) {

Google.GenAI/types/CreateTuningJobConfig.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ public AdapterSize
129129
}
130130

131131
/// <summary>
132-
/// Tuning mode for SFT tuning.
132+
/// Tuning mode for tuning.
133133
/// </summary>
134134
[JsonPropertyName("tuningMode")]
135135
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]

Google.GenAI/types/DistillationHyperParameters.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
namespace Google.GenAI.Types {
2525
/// <summary>
26-
/// Hyperparameters for Distillation. This data type is not supported in Gemini API.
26+
/// Hyperparameters for distillation.
2727
/// </summary>
2828

2929
public record DistillationHyperParameters {
@@ -56,6 +56,26 @@ public double
5656
get; set;
5757
}
5858

59+
/// <summary>
60+
/// The batch size hyperparameter for tuning. This is only supported for OSS models in Vertex.
61+
/// </summary>
62+
[JsonPropertyName("batchSize")]
63+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
64+
public int
65+
? BatchSize {
66+
get; set;
67+
}
68+
69+
/// <summary>
70+
/// The learning rate for tuning. OSS models only.
71+
/// </summary>
72+
[JsonPropertyName("learningRate")]
73+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
74+
public double
75+
? LearningRate {
76+
get; set;
77+
}
78+
5979
/// <summary>
6080
/// Deserializes a JSON string to a DistillationHyperParameters object.
6181
/// </summary>

Google.GenAI/types/DistillationSpec.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,16 @@ public string
111111
get; set;
112112
}
113113

114+
/// <summary>
115+
/// Tuning mode for tuning.
116+
/// </summary>
117+
[JsonPropertyName("tuningMode")]
118+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
119+
public TuningMode
120+
? TuningMode {
121+
get; set;
122+
}
123+
114124
/// <summary>
115125
/// Deserializes a JSON string to a DistillationSpec object.
116126
/// </summary>

0 commit comments

Comments
 (0)