-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathprocess.rb
More file actions
executable file
·95 lines (82 loc) · 3.96 KB
/
process.rb
File metadata and controls
executable file
·95 lines (82 loc) · 3.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# frozen_string_literal: true
require 'remote_file/reader'
module PredictionResults
class Process
SUBJECT_ACTION_API_BATCH_SIZE = ENV.fetch('SUBJECT_ACTION_API_BATCH_SIZE', '10').to_i
attr_accessor :results_url, :subject_set_id, :probability_threshold,
:over_threshold_subject_ids, :under_threshold_subject_ids,
:random_spice_subject_ids, :randomisation_factor, :prediction_data
def initialize(results_url:, subject_set_id:, probability_threshold: 0.8, randomisation_factor: 0.2)
@results_url = results_url
@subject_set_id = subject_set_id
@probability_threshold = probability_threshold
@randomisation_factor = randomisation_factor
@over_threshold_subject_ids = []
@under_threshold_subject_ids = []
@random_spice_subject_ids = []
@prediction_data = nil
end
def run
# paritions the data by specified probability threshold
RemoteFile::Reader.stream_to_tempfile(results_url) do |results_file|
# read the prediciton json data from the tempfile
prediction_data_results = JSON.parse(results_file.read)
@prediction_data = prediction_data_results['data']
partition_results
# TODO: ensure the resulting sets are mutually exclusive to avoid
# running more jobs / API calls than needed
# e.g. remove under threshold ids may conflict with the add random ones
# do a set diff operation to ensure we don't add the same subject ids
move_over_threshold_subjects_to_active_set
remove_under_threshold_subjects_from_active_set
add_random_spice_subjects_to_active_set
end
schedule_subjects_retirement_check
end
def partition_results
prediction_data.each do |subject_id, prediction_samples|
# data schema format is published in the file
# and https://github.com/zooniverse/bajor/blob/main/azure/batch/scripts/predict_on_catalog.py
# the hash is keyed by the sample_num
# we will use the the first sample for the prediction results
prediction_results = prediction_samples['0']
# and we want the probability from the first entry in the prediction results array
probability = prediction_results[0]
@over_threshold_subject_ids << subject_id if probability >= probability_threshold
@under_threshold_subject_ids << subject_id if probability < probability_threshold
end
# now add some 'spice' to the results by adding some random under threshold subject ids
# but don't skew the prediction results by adding too many under threshold images
# ensure we only use apply the randomisation factor to the count of over threshold subject ids
# i.e. 20% of the number of over threshold subject ids
num_random_subject_ids_to_sample = (over_threshold_subject_ids.count * randomisation_factor).to_i
@random_spice_subject_ids = under_threshold_subject_ids.sample(num_random_subject_ids_to_sample)
# ensure the random subject ids aren't in the under_threshold_subject_ids list
@under_threshold_subject_ids = under_threshold_subject_ids - random_spice_subject_ids
end
def move_over_threshold_subjects_to_active_set
AddSubjectToSubjectSetJob.perform_bulk(
api_batch_bulk_job_args(over_threshold_subject_ids)
)
end
def remove_under_threshold_subjects_from_active_set
RemoveSubjectFromSubjectSetJob.perform_bulk(
api_batch_bulk_job_args(under_threshold_subject_ids)
)
end
def add_random_spice_subjects_to_active_set
AddSubjectToSubjectSetJob.perform_bulk(
api_batch_bulk_job_args(random_spice_subject_ids)
)
end
def api_batch_bulk_job_args(subject_ids)
subject_ids
.each_slice(SUBJECT_ACTION_API_BATCH_SIZE)
.map { |batch_subject_ids| [batch_subject_ids, subject_set_id] }
end
private
def schedule_subjects_retirement_check
SubjectsRetirementWorker.perform_async(subject_set_id)
end
end
end