Skip to content

Commit e99f095

Browse files
authored
simple text data sources (#19)
Signed-off-by: Max Pumperla <[email protected]>
1 parent e802f1b commit e99f095

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,6 @@ airflow/airflow.db
107107

108108
# scraped folders
109109
docs.ray.io/
110+
111+
# book and other source folders
112+
data/

app/index.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def path_to_uri(path, scheme="https://", domain="docs.ray.io"):
8787
return scheme + domain + path.split(domain)[-1]
8888

8989

90-
def parse_file(record):
90+
def parse_html_file(record):
9191
html_content = load_html_file(record["path"])
9292
if not html_content:
9393
return []
@@ -100,6 +100,17 @@ def parse_file(record):
100100
]
101101

102102

103+
def parse_text_file(record):
104+
with open(record["path"]) as f:
105+
text = f.read()
106+
return [
107+
{
108+
"source": str(record["path"]),
109+
"text": text,
110+
}
111+
]
112+
113+
103114
class EmbedChunks:
104115
def __init__(self, model_name):
105116
self.embedding_model = HuggingFaceEmbeddings(
@@ -139,6 +150,7 @@ def __call__(self, batch):
139150
@app.command()
140151
def create_index(
141152
docs_path: Annotated[str, typer.Option(help="location of data")] = DOCS_PATH,
153+
extension_type: Annotated[str, typer.Option(help="type of data")] = "html",
142154
embedding_model: Annotated[str, typer.Option(help="embedder")] = EMBEDDING_MODEL,
143155
chunk_size: Annotated[int, typer.Option(help="chunk size")] = CHUNK_SIZE,
144156
chunk_overlap: Annotated[int, typer.Option(help="chunk overlap")] = CHUNK_OVERLAP,
@@ -148,11 +160,17 @@ def create_index(
148160

149161
# Dataset
150162
ds = ray.data.from_items(
151-
[{"path": path} for path in Path(docs_path).rglob("*.html") if not path.is_dir()]
163+
[
164+
{"path": path}
165+
for path in Path(docs_path).rglob(f"*.{extension_type}")
166+
if not path.is_dir()
167+
]
152168
)
153169

154170
# Sections
155-
sections_ds = ds.flat_map(parse_file)
171+
parser = parse_html_file if extension_type == "html" else parse_text_file
172+
sections_ds = ds.flat_map(parser)
173+
# TODO: do we really need to take_all()? Bring the splitter to the cluster
156174
sections = sections_ds.take_all()
157175

158176
# Chunking

dashboard/pages/1_✨_Generation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from langchain.text_splitter import RecursiveCharacterTextSplitter
1010
from pgvector.psycopg import register_vector
1111

12-
from app.index import parse_file
12+
from app.index import parse_html_file
1313
from app.query import generate_response
1414

1515

@@ -38,7 +38,7 @@ def get_ds(docs_path):
3838
docs_page_url = st.text_input("Docs page URL", "https://docs.ray.io/en/master/train/faq.html")
3939
docs_page_path = docs_path_str + docs_page_url.split("docs.ray.io/en/master/")[-1]
4040
with st.expander("View sections"):
41-
sections = parse_file({"path": docs_page_path})
41+
sections = parse_html_file({"path": docs_page_path})
4242
st.write(sections)
4343

4444
# Chunks

0 commit comments

Comments
 (0)