Skip to content

Commit 1113863

Browse files
authored
[SPARKNLP-1378] HTMLReader Default Headers Error (#14770)
1 parent c137dde commit 1113863

14 files changed

Lines changed: 173 additions & 44 deletions

python/test/partition/partition_test.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,11 @@ def runTest(self):
111111
html_df = Partition(content_type = "text/html").partition(self.html_directory)
112112
html_file_df = Partition().partition(f"{self.html_directory}/fake-html.html")
113113

114-
self.assertTrue(html_df.select("html").count() > 0)
115-
self.assertTrue(html_file_df.select("html").count() > 0)
114+
html_rows = html_df.select("html").collect()
115+
html_file_rows = html_file_df.select("html").collect()
116+
117+
self.assertTrue(len(html_rows) > 0)
118+
self.assertTrue(len(html_file_rows) > 0)
116119

117120

118121
@pytest.mark.slow
@@ -122,8 +125,11 @@ def runTest(self):
122125
url_df = Partition().partition("https://www.wikipedia.org", headers={"User-Agent": "Mozilla/5.0"})
123126
urls_df = Partition().partition_urls(["https://www.wikipedia.org", "https://example.com/"])
124127

125-
self.assertTrue(url_df.select("html").count() > 0)
126-
self.assertTrue(urls_df.select("html").count() > 0)
128+
url_rows = url_df.select("html").collect()
129+
urls_rows = urls_df.select("html").collect()
130+
131+
self.assertTrue(len(url_rows) > 0)
132+
self.assertTrue(len(urls_rows) > 0)
127133

128134

129135
@pytest.mark.fast

python/test/partition/partition_transformer_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def runTest(self):
4949
resultDf = pipelineModel.transform(self.testDataSet)
5050
resultDf.show(truncate=False)
5151

52-
self.assertTrue(resultDf.select("partition").count() > 0)
52+
rows = resultDf.select("partition").collect()
53+
self.assertTrue(len(rows) > 0)
5354

5455

5556
@pytest.mark.slow
@@ -80,7 +81,8 @@ def runTest(self):
8081

8182
resultDf = pipelineModel.transform(self.testDataSet)
8283

83-
self.assertTrue(resultDf.select("partition").count() > 0)
84+
rows = resultDf.select("partition").collect()
85+
self.assertTrue(len(rows) > 0)
8486

8587

8688
@pytest.mark.fast
@@ -108,4 +110,5 @@ def runTest(self):
108110

109111
resultDf = pipelineModel.transform(self.emptyDataSet)
110112

111-
self.assertTrue(resultDf.select("partition").count() >= 0)
113+
rows = resultDf.select("partition").collect()
114+
self.assertTrue(len(rows) >= 0)

python/test/reader/reader2doc_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def runTest(self):
4141

4242
result_df = model.transform(self.empty_df)
4343

44-
self.assertTrue(result_df.select("document").count() > 0)
44+
rows = result_df.select("document").collect()
45+
self.assertTrue(len(rows) > 0)
4546

4647

4748
@pytest.mark.fast
@@ -67,7 +68,8 @@ def runTest(self):
6768

6869
result_df = model.transform(self.empty_df)
6970

70-
self.assertTrue(result_df.select("document").count() > 0)
71+
rows = result_df.select("document").collect()
72+
self.assertTrue(len(rows) > 0)
7173

7274

7375
@pytest.mark.fast
@@ -89,7 +91,8 @@ def runTest(self):
8991

9092
result_df = model.transform(self.empty_df)
9193

92-
self.assertTrue(result_df.select("document").count() > 0)
94+
rows = result_df.select("document").collect()
95+
self.assertTrue(len(rows) > 0)
9396

9497

9598
@pytest.mark.fast
@@ -157,7 +160,8 @@ def runTest(self):
157160

158161
result_df = model.transform(self.html_df)
159162

160-
self.assertTrue(result_df.select("document").count() > 0)
163+
rows = result_df.select("document").collect()
164+
self.assertTrue(len(rows) > 0)
161165

162166

163167
@pytest.mark.fast

python/test/reader/reader2image_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def runTest(self):
4141

4242
result_df = model.transform(self.empty_df)
4343

44-
self.assertTrue(result_df.select("image").count() > 0)
44+
rows = result_df.select("image").collect()
45+
self.assertTrue(len(rows) > 0)
4546

4647

4748
@pytest.mark.slow
@@ -73,4 +74,4 @@ def runTest(self):
7374
result_df.select("image.origin", "answer.result").show(truncate=False)
7475

7576
# Assertion
76-
self.assertTrue(result_df.count() > 0)
77+
self.assertTrue(result_df.count() > 0)

python/test/reader/reader2table_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def runTest(self):
4141

4242
result_df = model.transform(self.empty_df)
4343

44-
self.assertTrue(result_df.select("document").count() > 0)
44+
rows = result_df.select("document").collect()
45+
self.assertTrue(len(rows) > 0)
4546

4647
@pytest.mark.fast
4748
class Reader2TableMixedFilesTest(unittest.TestCase):
@@ -90,4 +91,5 @@ def runTest(self):
9091

9192
result_df = model.transform(self.html_df)
9293

93-
self.assertTrue(result_df.select("document").count() > 0)
94+
rows = result_df.select("document").collect()
95+
self.assertTrue(len(rows) > 0)

python/test/reader/readerassembler_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,14 @@ def runTest(self):
3333
reader_assembler = ReaderAssembler() \
3434
.setContentType("text/html") \
3535
.setContentPath(f"file:///{os.getcwd()}/../src/test/resources/reader/html/table-image.html") \
36-
.setOutputCol("document")
36+
.setOutputCol("document") \
37+
.setOutputAsDocument(False) \
38+
.setExplodeDocs(False)
3739

3840
pipeline = Pipeline(stages=[reader_assembler])
3941
model = pipeline.fit(self.empty_df)
4042

41-
result_df = model.transform(self.empty_df)
43+
rows = model.transform(self.empty_df).collect()
4244

43-
self.assertTrue(result_df.count() > 0)
45+
self.assertTrue(len(rows) > 0)
46+
self.assertTrue(any(row.document_image for row in rows))

src/main/scala/com/johnsnowlabs/partition/HasHTMLReaderProperties.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,26 @@ trait HasHTMLReaderProperties extends ParamsAndFeaturesWritable {
3737
setHeaders(headers.asScala.toMap)
3838
}
3939

40+
protected def getHeadersAsJava: java.util.Map[String, String] = {
41+
val headersCopy = new java.util.HashMap[String, String]()
42+
val rawHeaders = getOrDefault(headers.asInstanceOf[Param[Any]])
43+
rawHeaders match {
44+
case null =>
45+
case javaHeaders: java.util.Map[_, _] =>
46+
javaHeaders.asScala.foreach { case (key, value) =>
47+
if (key != null && value != null) headersCopy.put(key.toString, value.toString)
48+
}
49+
case scalaHeaders: scala.collection.Map[_, _] =>
50+
scalaHeaders.foreach { case (key, value) =>
51+
if (key != null && value != null) headersCopy.put(key.toString, value.toString)
52+
}
53+
case other =>
54+
throw new IllegalArgumentException(
55+
s"headers must be a Map[String, String], but got ${other.getClass.getName}")
56+
}
57+
headersCopy
58+
}
59+
4060
val includeTitleTag = new Param[Boolean](
4161
this,
4262
"includeTitleTag",

src/main/scala/com/johnsnowlabs/partition/PartitionTransformer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,9 @@ class PartitionTransformer(override val uid: String)
171171
partitionInstance.setOutputColumn(inputColum)
172172

173173
val partitionDf = if (isStringContent($(contentType))) {
174-
val partitionUDF = udf((text: String) =>
175-
partitionInstance.partitionStringContent(text, $(this.headers).asJava))
174+
val requestHeaders = getHeadersAsJava
175+
val partitionUDF =
176+
udf((text: String) => partitionInstance.partitionStringContent(text, requestHeaders))
176177
val schemaFieldOpt = dataset.schema.find(_.name == inputColum)
177178

178179
schemaFieldOpt match {

src/main/scala/com/johnsnowlabs/reader/HTMLReader.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,11 @@ class HTMLReader(
276276

277277
private def extractElements(root: Node): Array[HTMLElement] = {
278278
var sentenceIndex = 0
279+
var paragraphIndex = 0
279280
val elements = ArrayBuffer[HTMLElement]()
280281
val trackingNodes = mutable.Map[Node, NodeMetadata]()
281282
var pageNumber = 1
283+
val paragraphSpacingY = 25
282284

283285
// Track parent-child hierarchy
284286
var currentParentId: Option[String] = None
@@ -428,6 +430,10 @@ class HTMLReader(
428430
case "a" =>
429431
pageMetadata("sentence") = sentenceIndex.toString
430432
sentenceIndex += 1
433+
pageMetadata("paragraph_index") = paragraphIndex.toString
434+
pageMetadata("paragraph_y") = (paragraphIndex * paragraphSpacingY).toString
435+
pageMetadata("page_y") = (paragraphIndex * paragraphSpacingY).toString
436+
paragraphIndex += 1
431437
val href = element.attr("href").trim
432438
val linkText = element.text().trim
433439
if (href.nonEmpty && linkText.nonEmpty && !visitedNode) {
@@ -443,6 +449,10 @@ class HTMLReader(
443449
case "table" =>
444450
pageMetadata("sentence") = sentenceIndex.toString
445451
sentenceIndex += 1
452+
pageMetadata("paragraph_index") = paragraphIndex.toString
453+
pageMetadata("paragraph_y") = (paragraphIndex * paragraphSpacingY).toString
454+
pageMetadata("page_y") = (paragraphIndex * paragraphSpacingY).toString
455+
paragraphIndex += 1
446456
val tableContent = outputFormat match {
447457
case "plain-text" => extractNestedTableContent(element).trim
448458
case "html-table" =>
@@ -474,6 +484,10 @@ class HTMLReader(
474484
case "li" =>
475485
pageMetadata("sentence") = sentenceIndex.toString
476486
sentenceIndex += 1
487+
pageMetadata("paragraph_index") = paragraphIndex.toString
488+
pageMetadata("paragraph_y") = (paragraphIndex * paragraphSpacingY).toString
489+
pageMetadata("page_y") = (paragraphIndex * paragraphSpacingY).toString
490+
paragraphIndex += 1
477491
val itemText = element.text().trim
478492
if (itemText.nonEmpty && !visitedNode) {
479493
trackingNodes(element).visited = true
@@ -493,6 +507,10 @@ class HTMLReader(
493507
if (codeText.nonEmpty && !visitedNode) {
494508
pageMetadata("sentence") = sentenceIndex.toString
495509
sentenceIndex += 1
510+
pageMetadata("paragraph_index") = paragraphIndex.toString
511+
pageMetadata("paragraph_y") = (paragraphIndex * paragraphSpacingY).toString
512+
pageMetadata("page_y") = (paragraphIndex * paragraphSpacingY).toString
513+
paragraphIndex += 1
496514
trackingNodes(element).visited = true
497515
pageMetadata("element_id") = newUUID()
498516
currentParentId.foreach(pid => pageMetadata("parent_id") = pid)
@@ -519,6 +537,9 @@ class HTMLReader(
519537
sentenceIndex += 1
520538
trackingNodes(element).visited = true
521539
pageMetadata("element_id") = newUUID()
540+
pageMetadata("paragraph_index") = paragraphIndex.toString
541+
pageMetadata("paragraph_y") = (paragraphIndex * paragraphSpacingY).toString
542+
pageMetadata("page_y") = (paragraphIndex * paragraphSpacingY).toString
522543
currentParentId.foreach(pid => pageMetadata("parent_id") = pid)
523544
elements += HTMLElement(
524545
ElementType.NARRATIVE_TEXT,
@@ -534,11 +555,15 @@ class HTMLReader(
534555
trackingNodes(element).visited = true
535556
val titleId = newUUID()
536557
pageMetadata("element_id") = titleId
558+
pageMetadata("paragraph_index") = paragraphIndex.toString
559+
pageMetadata("paragraph_y") = (paragraphIndex * paragraphSpacingY).toString
560+
pageMetadata("page_y") = (paragraphIndex * paragraphSpacingY).toString
537561
elements += HTMLElement(
538562
ElementType.TITLE,
539563
content = titleText,
540564
metadata = pageMetadata)
541565
currentParentId = Some(titleId)
566+
paragraphIndex += 1
542567
}
543568

544569
case ElementType.UNCATEGORIZED_TEXT =>
@@ -548,11 +573,15 @@ class HTMLReader(
548573
sentenceIndex += 1
549574
trackingNodes(element).visited = true
550575
pageMetadata("element_id") = newUUID()
576+
pageMetadata("paragraph_index") = paragraphIndex.toString
577+
pageMetadata("paragraph_y") = (paragraphIndex * paragraphSpacingY).toString
578+
pageMetadata("page_y") = (paragraphIndex * paragraphSpacingY).toString
551579
currentParentId.foreach(pid => pageMetadata("parent_id") = pid)
552580
elements += HTMLElement(
553581
ElementType.UNCATEGORIZED_TEXT,
554582
content = text,
555583
metadata = pageMetadata)
584+
paragraphIndex += 1
556585
}
557586
}
558587
}
@@ -565,6 +594,10 @@ class HTMLReader(
565594
sentenceIndex += 1
566595
val titleId = newUUID()
567596
pageMetadata("element_id") = titleId
597+
pageMetadata("paragraph_index") = paragraphIndex.toString
598+
pageMetadata("paragraph_y") = (paragraphIndex * paragraphSpacingY).toString
599+
pageMetadata("page_y") = (paragraphIndex * paragraphSpacingY).toString
600+
paragraphIndex += 1
568601
elements += HTMLElement(
569602
ElementType.TITLE,
570603
content = titleText,
@@ -585,6 +618,10 @@ class HTMLReader(
585618
if (divText.nonEmpty) {
586619
pageMetadata("sentence") = sentenceIndex.toString
587620
sentenceIndex += 1
621+
pageMetadata("paragraph_index") = paragraphIndex.toString
622+
pageMetadata("paragraph_y") = (paragraphIndex * paragraphSpacingY).toString
623+
pageMetadata("page_y") = (paragraphIndex * paragraphSpacingY).toString
624+
paragraphIndex += 1
588625
trackingNodes(element).visited = true
589626
pageMetadata("element_id") = newUUID()
590627
currentParentId.foreach(pid => pageMetadata("parent_id") = pid)

src/main/scala/com/johnsnowlabs/reader/HasReaderContent.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,9 @@ trait HasReaderContent extends HasReaderProperties with HasTagsReaderProperties
123123
if (isDelimitedContentType($(contentType)) || isDelimitedExtension(ext)) {
124124
partitionDelimitedContent(partition, contentPath, $(contentType), ext)
125125
} else {
126+
val requestHeaders = getHeadersAsJava
126127
val partitionUDF =
127-
udf((text: String) => partition.partitionStringContent(text, $(this.headers).asJava))
128+
udf((text: String) => partition.partitionStringContent(text, requestHeaders))
128129
datasetWithTextFile(dataset.sparkSession, contentPath)
129130
.withColumn(partition.getOutputColumn, partitionUDF(col("content")))
130131
}
@@ -172,8 +173,9 @@ trait HasReaderContent extends HasReaderProperties with HasTagsReaderProperties
172173
partition: Partition,
173174
dataset: Dataset[_],
174175
inputCol: String): DataFrame = {
176+
val requestHeaders = getHeadersAsJava
175177
val partitionUDF =
176-
udf((text: String) => partition.partitionStringContent(text, $(this.headers).asJava))
178+
udf((text: String) => partition.partitionStringContent(text, requestHeaders))
177179

178180
dataset
179181
.withColumn(partition.getOutputColumn, partitionUDF(col(inputCol)))

0 commit comments

Comments
 (0)