|
4 | 4 | package main |
5 | 5 |
|
6 | 6 | import ( |
| 7 | + "bytes" |
7 | 8 | "context" |
8 | 9 | "encoding/json" |
9 | 10 | "fmt" |
@@ -118,7 +119,7 @@ func handleAIJobPage(ctx context.Context, w http.ResponseWriter, r *http.Request |
118 | 119 | default: |
119 | 120 | job.Correct = spanner.NullBool{} |
120 | 121 | } |
121 | | - if err := aidb.UpdateJob(ctx, job); err != nil { |
| 122 | + if err := aiJobUpdate(ctx, job); err != nil { |
122 | 123 | return err |
123 | 124 | } |
124 | 125 | } |
@@ -284,10 +285,92 @@ func apiAIJobDone(ctx context.Context, req *dashapi.AIJobDoneReq) (any, error) { |
284 | 285 | if len(req.Results) != 0 { |
285 | 286 | job.Results = spanner.NullJSON{Value: req.Results, Valid: true} |
286 | 287 | } |
287 | | - err = aidb.UpdateJob(ctx, job) |
| 288 | + err = aiJobUpdate(ctx, job) |
288 | 289 | return nil, err |
289 | 290 | } |
290 | 291 |
|
| 292 | +func aiJobUpdate(ctx context.Context, job *aidb.Job) error { |
| 293 | + if err := aidb.UpdateJob(ctx, job); err != nil { |
| 294 | + return err |
| 295 | + } |
| 296 | + if !job.BugID.Valid || !job.Finished.Valid || job.Error != "" { |
| 297 | + return nil |
| 298 | + } |
| 299 | + bug, err := loadBug(ctx, job.BugID.StringVal) |
| 300 | + if err != nil { |
| 301 | + return err |
| 302 | + } |
| 303 | + labelType, labelValue, labelAdd, err := aiBugLabel(job) |
| 304 | + if err != nil || labelType == EmptyLabel { |
| 305 | + return err |
| 306 | + } |
| 307 | + label := BugLabel{ |
| 308 | + Label: labelType, |
| 309 | + Value: labelValue, |
| 310 | + Link: job.ID, |
| 311 | + } |
| 312 | + labelSet := makeLabelSet(ctx, bug) |
| 313 | + return updateSingleBug(ctx, bug.key(ctx), func(bug *Bug) error { |
| 314 | + if bug.HasUserLabel(labelType) { |
| 315 | + return nil |
| 316 | + } |
| 317 | + if labelAdd { |
| 318 | + return bug.SetLabels(labelSet, []BugLabel{label}) |
| 319 | + } |
| 320 | + bug.UnsetLabels(labelType) |
| 321 | + return nil |
| 322 | + }) |
| 323 | +} |
| 324 | + |
| 325 | +func aiBugLabel(job *aidb.Job) (typ BugLabelType, value string, set bool, err0 error) { |
| 326 | + switch job.Type { |
| 327 | + case ai.WorkflowAssessmentKCSAN: |
| 328 | + // For now we require a manual correctness check, |
| 329 | + // later we may apply some labels w/o the manual check. |
| 330 | + if !job.Correct.Valid { |
| 331 | + return |
| 332 | + } |
| 333 | + if !job.Correct.Bool { |
| 334 | + return RaceLabel, "", false, nil |
| 335 | + } |
| 336 | + res, err := castJobResults[ai.AssessmentKCSANOutputs](job) |
| 337 | + if err != nil { |
| 338 | + err0 = err |
| 339 | + return |
| 340 | + } |
| 341 | + if !res.Confident { |
| 342 | + return |
| 343 | + } |
| 344 | + if res.Benign { |
| 345 | + return RaceLabel, BenignRace, true, nil |
| 346 | + } |
| 347 | + return RaceLabel, HarmfulRace, true, nil |
| 348 | + } |
| 349 | + return |
| 350 | +} |
| 351 | + |
| 352 | +func castJobResults[T any](job *aidb.Job) (T, error) { |
| 353 | + var res T |
| 354 | + raw, ok := job.Results.Value.(map[string]any) |
| 355 | + if !ok || !job.Results.Valid { |
| 356 | + return res, fmt.Errorf("finished job %v %v does not have results", job.Type, job.ID) |
| 357 | + } |
| 358 | + // Database may store older versions of the output structs. |
| 359 | + // It's not possible to automatically handle all possible changes to the structs. |
| 360 | + // For now we just parse in some way. Later when we start changing output structs, |
| 361 | + // we may need to reconsider and use more careful parsing. |
| 362 | + data, err := json.Marshal(raw) |
| 363 | + if err != nil { |
| 364 | + return res, err |
| 365 | + } |
| 366 | + dec := json.NewDecoder(bytes.NewReader(data)) |
| 367 | + dec.DisallowUnknownFields() |
| 368 | + if err := dec.Decode(&res); err != nil { |
| 369 | + return res, fmt.Errorf("failed to unmarshal %T: %w", res, err) |
| 370 | + } |
| 371 | + return res, nil |
| 372 | +} |
| 373 | + |
291 | 374 | func apiAITrajectoryLog(ctx context.Context, req *dashapi.AITrajectoryReq) (any, error) { |
292 | 375 | err := aidb.StoreTrajectorySpan(ctx, req.JobID, req.Span) |
293 | 376 | return nil, err |
|
0 commit comments