Skip to content

Commit 551b436

Browse files
speedstorm1copybara-github
authored andcommitted
feat: support hyperparameters in distillation tuning
PiperOrigin-RevId: 882708166
1 parent ea49f9a commit 551b436

File tree

4 files changed

+64
-2
lines changed

4 files changed

+64
-2
lines changed

replay_sanitizer.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,36 @@ func sanitizeMapWithSourceType(t *testing.T, sourceType reflect.Type, m any) {
6767
sanitizeMapByPath(m.(map[string]any), path, stdBase64Handler, false)
6868
}
6969
}
70+
71+
int32Paths := make([]string, 0)
72+
visitedTypesInt32 := make(map[string]bool)
73+
if err := getFieldPath(st, reflect.TypeOf(int32(0)), &int32Paths, "", visitedTypesInt32, false); err != nil {
74+
t.Fatal(err)
75+
}
76+
77+
numericStringHandler := func(data any, path string) any {
78+
s, ok := data.(string)
79+
if !ok {
80+
return data
81+
}
82+
f, err := strconv.ParseFloat(s, 64)
83+
if err != nil {
84+
t.Errorf("invalid numeric string %s at path %s", s, path)
85+
return data
86+
}
87+
return f
88+
}
89+
90+
for _, path := range int32Paths {
91+
if sourceType.Kind() == reflect.Slice {
92+
data := m.([]any)
93+
for i := 0; i < len(data); i++ {
94+
sanitizeMapByPath(data[i], path, numericStringHandler, false)
95+
}
96+
} else {
97+
sanitizeMapByPath(m.(map[string]any), path, numericStringHandler, false)
98+
}
99+
}
70100
}
71101

72102
// sanitizeMapByPath sanitizes a value within a nested map structure based on the given path.

table_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,16 @@ func TestTable(t *testing.T) {
330330
}
331331
} else {
332332
method := extractMethod(t, &testTableFile, client)
333+
334+
// Sanitize replay response body segments based on the method's return type.
335+
// This is needed because some numeric fields in the replay files might be
336+
// represented as strings, which causes unmarshalling errors in the Go SDK.
337+
for _, interaction := range replayClient.ReplayFile.Interactions {
338+
for _, bodySegment := range interaction.Response.BodySegments {
339+
sanitizeMapWithSourceType(t, method.Type().Out(0), bodySegment)
340+
}
341+
}
342+
333343
args := extractArgs(ctx, t, method, &testTableFile, testTableItem)
334344

335345
// Inject unknown fields to the replay file to simulate the case where the SDK adds

tunings.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

types.go

Lines changed: 9 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)