Skip to content

Commit e28413d

Browse files
speedstorm1copybara-github
authored andcommitted
feat: support hyperparameters in distillation tuning
PiperOrigin-RevId: 882708166
1 parent 150698e commit e28413d

File tree

1 file changed

+62
-13
lines changed

1 file changed

+62
-13
lines changed

GeneratedFirebaseAI/Sources/Types.swift

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19483,9 +19483,16 @@ extension PreferenceOptimizationSpec: Codable {
1948319483
}
1948419484
}
1948519485

19486-
/// Hyperparameters for Distillation. This data type is not supported in Gemini API.
19486+
/// Hyperparameters for distillation.
1948719487
@available(iOS 15.0, macOS 13.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
1948819488
public struct DistillationHyperParameters: Sendable {
19489+
/// The batch size hyperparameter for tuning.
19490+
/// This is only supported for OSS models in Vertex.
19491+
public let batchSize: Int32?
19492+
19493+
/// The learning rate for tuning. OSS models only.
19494+
public let learningRate: Float?
19495+
1948919496
/// Optional. Adapter size for distillation.
1949019497
public let adapterSize: AdapterSize?
1949119498

@@ -19498,10 +19505,14 @@ public struct DistillationHyperParameters: Sendable {
1949819505

1949919506
/// Default initializer.
1950019507
public init(
19508+
batchSize: Int32? = nil,
19509+
learningRate: Float? = nil,
1950119510
adapterSize: AdapterSize? = nil,
1950219511
epochCount: Int64? = nil,
1950319512
learningRateMultiplier: Double? = nil
1950419513
) {
19514+
self.batchSize = batchSize
19515+
self.learningRate = learningRate
1950519516
self.adapterSize = adapterSize
1950619517
self.epochCount = epochCount
1950719518
self.learningRateMultiplier = learningRateMultiplier
@@ -19513,6 +19524,8 @@ extension DistillationHyperParameters: Codable {
1951319524

1951419525
// MARK: - Codable
1951519526
public enum VertexKeys: String, CodingKey {
19527+
case batchSize = "batchSize"
19528+
case learningRate = "learningRate"
1951619529
case adapterSize = "adapterSize"
1951719530
case epochCount = "epochCount"
1951819531
case learningRateMultiplier = "learningRateMultiplier"
@@ -19522,6 +19535,16 @@ extension DistillationHyperParameters: Codable {
1952219535
let configuration: APIClient = try decoder.userInfoOrThrow(.configuration)
1952319536

1952419537
let VertexKeysContainer = try decoder.container(keyedBy: VertexKeys.self)
19538+
batchSize = try VertexKeysContainer.decodeIfPresent(
19539+
Int32.self,
19540+
forKey: .batchSize
19541+
)
19542+
19543+
learningRate = try VertexKeysContainer.decodeIfPresent(
19544+
Float.self,
19545+
forKey: .learningRate
19546+
)
19547+
1952519548
adapterSize = try VertexKeysContainer.decodeIfPresent(
1952619549
AdapterSize.self,
1952719550
forKey: .adapterSize
@@ -19544,6 +19567,16 @@ extension DistillationHyperParameters: Codable {
1954419567
if configuration.isVertexAI() {
1954519568

1954619569
var VertexKeysContainer = encoder.container(keyedBy: VertexKeys.self)
19570+
try VertexKeysContainer.encodeIfPresent(
19571+
batchSize,
19572+
forKey: .batchSize
19573+
)
19574+
19575+
try VertexKeysContainer.encodeIfPresent(
19576+
learningRate,
19577+
forKey: .learningRate
19578+
)
19579+
1954719580
try VertexKeysContainer.encodeIfPresent(
1954819581
adapterSize,
1954919582
forKey: .adapterSize
@@ -19569,14 +19602,17 @@ public struct DistillationSpec: Sendable {
1956919602
/// The GCS URI of the prompt dataset to use during distillation.
1957019603
public let promptDatasetUri: String?
1957119604

19605+
/// Tuning mode for tuning.
19606+
public let tuningMode: TuningMode?
19607+
19608+
/// Optional. Hyperparameters for Distillation.
19609+
public let hyperParameters: DistillationHyperParameters?
19610+
1957219611
/// The base teacher model that is being distilled. See [Supported
1957319612
/// models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-
1957419613
/// reference/tuning#supported_models).
1957519614
public let baseTeacherModel: String?
1957619615

19577-
/// Optional. Hyperparameters for Distillation.
19578-
public let hyperParameters: DistillationHyperParameters?
19579-
1958019616
/// Deprecated. A path in a Cloud Storage bucket, which will be treated as the root
1958119617
/// output directory of the distillation pipeline. It is used by the system to
1958219618
/// generate the paths of output artifacts.
@@ -19601,17 +19637,19 @@ public struct DistillationSpec: Sendable {
1960119637
/// Default initializer.
1960219638
public init(
1960319639
promptDatasetUri: String? = nil,
19604-
baseTeacherModel: String? = nil,
19640+
tuningMode: TuningMode? = nil,
1960519641
hyperParameters: DistillationHyperParameters? = nil,
19642+
baseTeacherModel: String? = nil,
1960619643
pipelineRootDirectory: String? = nil,
1960719644
studentModel: String? = nil,
1960819645
trainingDatasetUri: String? = nil,
1960919646
tunedTeacherModelSource: String? = nil,
1961019647
validationDatasetUri: String? = nil
1961119648
) {
1961219649
self.promptDatasetUri = promptDatasetUri
19613-
self.baseTeacherModel = baseTeacherModel
19650+
self.tuningMode = tuningMode
1961419651
self.hyperParameters = hyperParameters
19652+
self.baseTeacherModel = baseTeacherModel
1961519653
self.pipelineRootDirectory = pipelineRootDirectory
1961619654
self.studentModel = studentModel
1961719655
self.trainingDatasetUri = trainingDatasetUri
@@ -19626,8 +19664,9 @@ extension DistillationSpec: Codable {
1962619664
// MARK: - Codable
1962719665
public enum VertexKeys: String, CodingKey {
1962819666
case promptDatasetUri = "promptDatasetUri"
19629-
case baseTeacherModel = "baseTeacherModel"
19667+
case tuningMode = "tuningMode"
1963019668
case hyperParameters = "hyperParameters"
19669+
case baseTeacherModel = "baseTeacherModel"
1963119670
case pipelineRootDirectory = "pipelineRootDirectory"
1963219671
case studentModel = "studentModel"
1963319672
case trainingDatasetUri = "trainingDatasetUri"
@@ -19644,16 +19683,21 @@ extension DistillationSpec: Codable {
1964419683
forKey: .promptDatasetUri
1964519684
)
1964619685

19647-
baseTeacherModel = try VertexKeysContainer.decodeIfPresent(
19648-
String.self,
19649-
forKey: .baseTeacherModel
19686+
tuningMode = try VertexKeysContainer.decodeIfPresent(
19687+
TuningMode.self,
19688+
forKey: .tuningMode
1965019689
)
1965119690

1965219691
hyperParameters = try VertexKeysContainer.decodeIfPresent(
1965319692
DistillationHyperParameters.self,
1965419693
forKey: .hyperParameters
1965519694
)
1965619695

19696+
baseTeacherModel = try VertexKeysContainer.decodeIfPresent(
19697+
String.self,
19698+
forKey: .baseTeacherModel
19699+
)
19700+
1965719701
pipelineRootDirectory = try VertexKeysContainer.decodeIfPresent(
1965819702
String.self,
1965919703
forKey: .pipelineRootDirectory
@@ -19692,15 +19736,20 @@ extension DistillationSpec: Codable {
1969219736
)
1969319737

1969419738
try VertexKeysContainer.encodeIfPresent(
19695-
baseTeacherModel,
19696-
forKey: .baseTeacherModel
19739+
tuningMode,
19740+
forKey: .tuningMode
1969719741
)
1969819742

1969919743
try VertexKeysContainer.encodeIfPresent(
1970019744
hyperParameters,
1970119745
forKey: .hyperParameters
1970219746
)
1970319747

19748+
try VertexKeysContainer.encodeIfPresent(
19749+
baseTeacherModel,
19750+
forKey: .baseTeacherModel
19751+
)
19752+
1970419753
try VertexKeysContainer.encodeIfPresent(
1970519754
pipelineRootDirectory,
1970619755
forKey: .pipelineRootDirectory
@@ -23809,7 +23858,7 @@ public struct CreateTuningJobConfig: Sendable {
2380923858
/// Adapter size for tuning.
2381023859
public let adapterSize: AdapterSize?
2381123860

23812-
/// Tuning mode for SFT tuning.
23861+
/// Tuning mode for tuning.
2381323862
public let tuningMode: TuningMode?
2381423863

2381523864
/// Custom base model for tuning. This is only supported for OSS models in Vertex.

0 commit comments

Comments
 (0)