@@ -262,41 +262,48 @@ def _add_extra_contexts_to_samples(ds: Dataset, p, num_doc_in_context=4):
262262 `keep_context_separate` equal to True. When this finishes, the `context`
263263 column is removed from the dataset and all context moved to the user
264264 messages.
265+
266+ This is inspired by the concepts of Retrieval Augmented FineTuning (RAFT)
267+ from https://arxiv.org/abs/2403.10131
265268 """
266- all_context = ds ["context" ]
267- all_context = [
268- " " .join (e .split (" " )[: random .randint (100 , 500 )]) for e in all_context
269- ]
270- ds = ds .add_column ("row_idx" , range (ds .num_rows ))
269+ all_context = list (set (ds ["context" ]))
271270
272271 def __pick_documents (rec , p ):
273- # Loop until we find enough other documents to add to the context
274- # for this document. Exit the loop early if we have fewer total
275- # documents than the number of documents we want in our context
276- # so that we don't end up looping forever. This handles edge
277- # cases where the number of generated instructions is very low,
278- # like in CI or user's testing small sizes.
279- while True :
280- selected_docs = random .choices (range (ds .num_rows ), k = num_doc_in_context )
281- if ds .num_rows <= num_doc_in_context :
282- break
283- if rec ["row_idx" ] not in selected_docs :
284- break
285- if random .uniform (0 , 1 ) < p :
286- docs = [
287- all_context [idx ] for idx in selected_docs [: num_doc_in_context - 1 ]
288- ] + [rec ["context" ]]
289- # rec['indicator'] ='golden'
272+ answer_document = [rec ["context" ]]
273+ selected_docs = [e for e in all_context if e != answer_document ]
274+ if len (selected_docs ) > 0 :
275+ if len (selected_docs ) < num_doc_in_context :
276+ logger .debug (
277+ f"Number of unique documents is { len (selected_docs )} which is less than { num_doc_in_context } . Using all the documents in the expanded context."
278+ )
279+ if random .uniform (0 , 1 ) < p :
280+ # golden/answer + distractor documents
281+ docs = (
282+ random .sample (selected_docs , k = num_doc_in_context )
283+ if len (selected_docs ) >= num_doc_in_context
284+ else selected_docs + [answer_document ]
285+ )
286+ else :
287+ # distractor documents
288+ docs = (
289+ random .sample (selected_docs , k = num_doc_in_context )
290+ if len (selected_docs ) >= num_doc_in_context
291+ else selected_docs
292+ )
290293 else :
291- docs = [all_context [idx ] for idx in selected_docs ]
292- # rec['indicator'] = 'distractor'
294+ logger .warning (
295+ "Only 1 unique document found. Disabling expanded context injection, which may lead to poorer knowledge retention results."
296+ )
297+ docs = [answer_document ]
293298 random .shuffle (docs )
294299 docs = "\n " .join (([f"Document:\n { e } \n \n " for idx , e in enumerate (docs )]))
295- user_idx , user_msg = [
300+ user_idx_msgs = [
296301 (idx , rec_msg )
297302 for idx , rec_msg in enumerate (rec ["messages" ])
298303 if rec_msg ["role" ] == "user"
299- ][0 ]
304+ ]
305+ assert len (user_idx_msgs ) > 0 , "No user role found in dataset"
306+ user_idx , user_msg = user_idx_msgs [0 ]
300307 user_inst = user_msg ["content" ]
301308 rec ["messages" ][user_idx ]["content" ] = f"{ docs } \n \n { user_inst } "
302309 rec ["messages" ] = rec ["messages" ]
0 commit comments