@@ -87,7 +87,7 @@ def path_to_uri(path, scheme="https://", domain="docs.ray.io"):
87
87
return scheme + domain + path .split (domain )[- 1 ]
88
88
89
89
90
- def parse_file (record ):
90
+ def parse_html_file (record ):
91
91
html_content = load_html_file (record ["path" ])
92
92
if not html_content :
93
93
return []
@@ -100,6 +100,17 @@ def parse_file(record):
100
100
]
101
101
102
102
103
+ def parse_text_file (record ):
104
+ with open (record ["path" ]) as f :
105
+ text = f .read ()
106
+ return [
107
+ {
108
+ "source" : str (record ["path" ]),
109
+ "text" : text ,
110
+ }
111
+ ]
112
+
113
+
103
114
class EmbedChunks :
104
115
def __init__ (self , model_name ):
105
116
self .embedding_model = HuggingFaceEmbeddings (
@@ -139,6 +150,7 @@ def __call__(self, batch):
139
150
@app .command ()
140
151
def create_index (
141
152
docs_path : Annotated [str , typer .Option (help = "location of data" )] = DOCS_PATH ,
153
+ extension_type : Annotated [str , typer .Option (help = "type of data" )] = "html" ,
142
154
embedding_model : Annotated [str , typer .Option (help = "embedder" )] = EMBEDDING_MODEL ,
143
155
chunk_size : Annotated [int , typer .Option (help = "chunk size" )] = CHUNK_SIZE ,
144
156
chunk_overlap : Annotated [int , typer .Option (help = "chunk overlap" )] = CHUNK_OVERLAP ,
@@ -148,11 +160,17 @@ def create_index(
148
160
149
161
# Dataset
150
162
ds = ray .data .from_items (
151
- [{"path" : path } for path in Path (docs_path ).rglob ("*.html" ) if not path .is_dir ()]
163
+ [
164
+ {"path" : path }
165
+ for path in Path (docs_path ).rglob (f"*.{ extension_type } " )
166
+ if not path .is_dir ()
167
+ ]
152
168
)
153
169
154
170
# Sections
155
- sections_ds = ds .flat_map (parse_file )
171
+ parser = parse_html_file if extension_type == "html" else parse_text_file
172
+ sections_ds = ds .flat_map (parser )
173
+ # TODO: do we really need to take_all()? Bring the splitter to the cluster
156
174
sections = sections_ds .take_all ()
157
175
158
176
# Chunking
0 commit comments