Skip to content
This repository was archived by the owner on Apr 30, 2026. It is now read-only.

Commit afadfd5

Browse files
authored
Merge pull request #215 from bbrowning/knowledge-contexts-fixes
Incorporate knowledge generation context selection improvements
2 parents 25018dc + 415d1a5 commit afadfd5

2 files changed

Lines changed: 95 additions & 26 deletions

File tree

src/instructlab/sdg/datamixing.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -262,41 +262,48 @@ def _add_extra_contexts_to_samples(ds: Dataset, p, num_doc_in_context=4):
262262
`keep_context_separate` equal to True. When this finishes, the `context`
263263
column is removed from the dataset and all context moved to the user
264264
messages.
265+
266+
This is inspired by the concepts of Retrieval Augmented FineTuning (RAFT)
267+
from https://arxiv.org/abs/2403.10131
265268
"""
266-
all_context = ds["context"]
267-
all_context = [
268-
" ".join(e.split(" ")[: random.randint(100, 500)]) for e in all_context
269-
]
270-
ds = ds.add_column("row_idx", range(ds.num_rows))
269+
all_context = list(set(ds["context"]))
271270

272271
def __pick_documents(rec, p):
273-
# Loop until we find enough other documents to add to the context
274-
# for this document. Exit the loop early if we have fewer total
275-
# documents than the number of documents we want in our context
276-
# so that we don't end up looping forever. This handles edge
277-
# cases where the number of generated instructions is very low,
278-
# like in CI or user's testing small sizes.
279-
while True:
280-
selected_docs = random.choices(range(ds.num_rows), k=num_doc_in_context)
281-
if ds.num_rows <= num_doc_in_context:
282-
break
283-
if rec["row_idx"] not in selected_docs:
284-
break
285-
if random.uniform(0, 1) < p:
286-
docs = [
287-
all_context[idx] for idx in selected_docs[: num_doc_in_context - 1]
288-
] + [rec["context"]]
289-
# rec['indicator'] ='golden'
272+
answer_document = [rec["context"]]
273+
selected_docs = [e for e in all_context if e != answer_document]
274+
if len(selected_docs) > 0:
275+
if len(selected_docs) < num_doc_in_context:
276+
logger.debug(
277+
f"Number of unique documents is {len(selected_docs)} which is less than {num_doc_in_context}. Using all the documents in the expanded context."
278+
)
279+
if random.uniform(0, 1) < p:
280+
# golden/answer + distractor documents
281+
docs = (
282+
random.sample(selected_docs, k=num_doc_in_context)
283+
if len(selected_docs) >= num_doc_in_context
284+
else selected_docs + [answer_document]
285+
)
286+
else:
287+
# distractor documents
288+
docs = (
289+
random.sample(selected_docs, k=num_doc_in_context)
290+
if len(selected_docs) >= num_doc_in_context
291+
else selected_docs
292+
)
290293
else:
291-
docs = [all_context[idx] for idx in selected_docs]
292-
# rec['indicator'] = 'distractor'
294+
logger.warning(
295+
"Only 1 unique document found. Disabling expanded context injection, which may lead to poorer knowledge retention results."
296+
)
297+
docs = [answer_document]
293298
random.shuffle(docs)
294299
docs = "\n".join(([f"Document:\n{e}\n\n" for idx, e in enumerate(docs)]))
295-
user_idx, user_msg = [
300+
user_idx_msgs = [
296301
(idx, rec_msg)
297302
for idx, rec_msg in enumerate(rec["messages"])
298303
if rec_msg["role"] == "user"
299-
][0]
304+
]
305+
assert len(user_idx_msgs) > 0, "No user role found in dataset"
306+
user_idx, user_msg = user_idx_msgs[0]
300307
user_inst = user_msg["content"]
301308
rec["messages"][user_idx]["content"] = f"{docs}\n\n{user_inst}"
302309
rec["messages"] = rec["messages"]

tests/test_datamixing.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""
2+
Unit tests for the top-level datamixing module.
3+
"""
4+
5+
# Third Party
6+
from datasets import Dataset
7+
8+
# First Party
9+
from instructlab.sdg.datamixing import _add_extra_contexts_to_samples
10+
11+
12+
def _fake_context(msg_id):
13+
return {
14+
"context": f"context {msg_id}",
15+
"id": msg_id,
16+
"messages": [{"role": "user", "content": f"user content {msg_id}"}],
17+
"metadata": '{"dataset": []}',
18+
}
19+
20+
21+
def test_add_extra_contexts_to_samples_with_one_sample():
22+
"""
23+
Test _add_extra_contexts_to_samples doesn't error out when
24+
given only one sample
25+
"""
26+
samples = Dataset.from_list([_fake_context("abc123")])
27+
dataset = _add_extra_contexts_to_samples(samples, p=0.4)
28+
assert len(dataset) == 1
29+
30+
31+
def test_add_extra_contexts_to_samples_with_two_samples():
32+
"""
33+
Test _add_extra_contexts_to_samples doesn't error out when
34+
given only two samples
35+
"""
36+
samples = Dataset.from_list(
37+
[
38+
_fake_context("abc123"),
39+
_fake_context("bcd234"),
40+
]
41+
)
42+
dataset = _add_extra_contexts_to_samples(samples, p=0.4)
43+
assert len(dataset) == 2
44+
45+
46+
def test_add_extra_contexts_to_samples_with_six_samples():
47+
"""
48+
Test _add_extra_contexts_to_samples doesn't error out when
49+
given more samples
50+
"""
51+
samples = Dataset.from_list(
52+
[
53+
_fake_context("s1"),
54+
_fake_context("s2"),
55+
_fake_context("s3"),
56+
_fake_context("s4"),
57+
_fake_context("s5"),
58+
_fake_context("s6"),
59+
]
60+
)
61+
dataset = _add_extra_contexts_to_samples(samples, p=0.4)
62+
assert len(dataset) == 6

0 commit comments

Comments
 (0)