Skip to content

Commit 625e5c1

Browse files
committed
changed postprocessing and added tests for all the tasks
1 parent cf59057 commit 625e5c1

5 files changed

Lines changed: 409 additions & 141 deletions

File tree

src/main/scala/com/johnsnowlabs/ml/ai/Florence2.scala

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.{
3333
}
3434
import org.intel.openvino.InferRequest
3535
import com.johnsnowlabs.ml.ai.util.Florence2Utils
36+
import org.json4s._
37+
import org.json4s.jackson.JsonMethods._
38+
import org.json4s.JsonDSL._
3639

3740
private[johnsnowlabs] class Florence2(
3841
val onnxWrappers: Option[DecoderWrappers],
@@ -166,16 +169,6 @@ private[johnsnowlabs] class Florence2(
166169
effectiveBatch_mult = 1
167170
}
168171

169-
// val inferRequestDecoderModel =
170-
// openvinoWrapper.get.decoderModel.getCompiledModel().create_infer_request()
171-
// val inferRequestEncoderModel =
172-
// openvinoWrapper.get.encoderModel.getCompiledModel().create_infer_request()
173-
// val inferRequestImageEncoder =
174-
// openvinoWrapper.get.imageEmbedModel.getCompiledModel().create_infer_request()
175-
// val inferRequestTextEmbeddings =
176-
// openvinoWrapper.get.textEmbeddingsModel.getCompiledModel().create_infer_request()
177-
// val inferRequestModelMerger =
178-
// openvinoWrapper.get.modelMergerModel.getCompiledModel().create_infer_request()
179172

180173
// use eosTokenId as the starting token for the decoder
181174
val decoderInputIds =
@@ -306,6 +299,26 @@ private[johnsnowlabs] class Florence2(
306299
val imageSize =
307300
imageAnnotations.headOption.map(img => (img.width, img.height)).getOrElse((1000, 1000))
308301
val postProcessed = Florence2Utils.postProcessGeneration(content, task, imageSize)
302+
// Serialize postProcessed to JSON string for raw values using json4s
303+
implicit val formats = DefaultFormats
304+
val postProcessedRaw = postProcessed match {
305+
case Florence2Utils.BBoxesResult(bboxes) =>
306+
compact(render(Extraction.decompose(bboxes)))
307+
case Florence2Utils.OCRResult(instances) =>
308+
compact(render(Extraction.decompose(instances)))
309+
case Florence2Utils.PhraseGroundingResult(instances) =>
310+
compact(render(Extraction.decompose(instances)))
311+
case Florence2Utils.PolygonsResult(instances) =>
312+
compact(render(Extraction.decompose(instances)))
313+
case Florence2Utils.MixedResult(bboxes, bboxesLabels, polygons, polygonsLabels) =>
314+
val obj = ("bboxes" -> bboxes.map(_.bbox)) ~
315+
("bboxesLabels" -> bboxesLabels) ~
316+
("polygons" -> polygons) ~
317+
("polygonsLabels" -> polygonsLabels)
318+
compact(render(obj))
319+
case Florence2Utils.PureTextResult(text) =>
320+
compact(render("text" -> text))
321+
}
309322
// If we have an image, try to generate a visualization
310323
val imageOpt = imageAnnotations.headOption.map { imgAnn =>
311324
com.johnsnowlabs.nlp.annotators.cv.util.io.ImageIOUtils
@@ -316,7 +329,9 @@ private[johnsnowlabs] class Florence2(
316329
}
317330
val newMetadata =
318331
ann.metadata ++
319-
Map("florence2_postprocessed" -> postProcessed.toString) ++
332+
Map(
333+
"florence2_postprocessed" -> postProcessed.toString,
334+
"florence2_postprocessed_raw" -> postProcessedRaw) ++
320335
imageBase64Opt.map(b64 => Map("florence2_image" -> b64)).getOrElse(Map.empty)
321336
new Annotation(
322337
annotatorType = DOCUMENT,

src/main/scala/com/johnsnowlabs/ml/ai/util/Florence2Utils.scala

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ object Florence2Utils {
7777
text: String,
7878
task: String,
7979
imageSize: (Int, Int)): Florence2Result = {
80-
val taskType = taskAnswerPostProcessingType.getOrElse(task, "pure_text")
80+
val baseTask = getBaseTaskToken(task)
81+
val taskType = taskAnswerPostProcessingType.getOrElse(baseTask, "pure_text")
8182
taskType match {
8283
case "pure_text" =>
8384
PureTextResult(text.replace("<s>", "").replace("</s>", ""))
@@ -180,28 +181,39 @@ object Florence2Utils {
180181

181182
// Parse polygons
182183
def parsePolygons(text: String, imageSize: (Int, Int)): Seq[PolygonInstance] = {
183-
val polygonStart = "<poly>"
184-
val polygonEnd = "</poly>"
185-
val polygonSep = "<sep>"
186-
val phrasePattern = new Regex(
187-
s"([^<]+(?:<loc_\\d+>|$polygonSep|$polygonStart|$polygonEnd){4,})")
188-
val boxPattern = new Regex("<loc_([0-9]+)><loc_([0-9]+)><loc_([0-9]+)><loc_([0-9]+)>")
189-
phrasePattern
190-
.findAllMatchIn(text.replace("<s>", "").replace("</s>", "").replace("<pad>", ""))
191-
.flatMap { m =>
184+
val cleanedText = text.replace("<s>", "").replace("</s>", "").replace("<pad>", "").trim
185+
if (cleanedText.startsWith("<loc_")) {
186+
// Fallback: treat as a single polygon
187+
val fallbackLocPattern = new Regex("<loc_([0-9]+)>")
188+
val locs = fallbackLocPattern.findAllMatchIn(cleanedText).map(_.group(1).toInt).toSeq
189+
if (locs.length >= 4 && locs.length % 2 == 0) {
190+
val polygon = dequantizeCoordinates(locs, imageSize)
191+
Seq(PolygonInstance(Seq(polygon), "")) // one polygon, empty label
192+
} else Seq.empty
193+
} else {
194+
// Phrase-based logic as before
195+
val polygonStart = "<poly>"
196+
val polygonEnd = "</poly>"
197+
val polygonSep = "<sep>"
198+
val phrasePattern = new Regex(
199+
s"([^<]+(?:<loc_\\d+>|$polygonSep|$polygonStart|$polygonEnd){4,})")
200+
val boxPattern = new Regex("<loc_([0-9]+)><loc_([0-9]+)><loc_([0-9]+)><loc_([0-9]+)>")
201+
val matches = phrasePattern.findAllMatchIn(cleanedText).toSeq
202+
matches.flatMap { m =>
192203
val phraseText = m.group(1)
193204
val phrase = phraseText.takeWhile(_ != '<').trim
194-
val polygons = boxPattern
195-
.findAllMatchIn(phraseText)
196-
.map { b =>
197-
val bins = (1 to 4).map(i => b.group(i).toInt)
198-
dequantizeBox(bins, imageSize)
199-
}
200-
.toSeq
201-
if (phrase.nonEmpty && polygons.nonEmpty) Some(PolygonInstance(polygons, phrase))
202-
else None
205+
if (phrase.nonEmpty) {
206+
val polygons = boxPattern
207+
.findAllMatchIn(phraseText)
208+
.map { b =>
209+
val bins = (1 to 4).map(i => b.group(i).toInt)
210+
dequantizeBox(bins, imageSize)
211+
}
212+
.toSeq
213+
if (polygons.nonEmpty) Some(PolygonInstance(polygons, phrase)) else None
214+
} else None
203215
}
204-
.toSeq
216+
}
205217
}
206218

207219
// --- Quantization helpers ---
@@ -451,7 +463,8 @@ object Florence2Utils {
451463
task: String,
452464
result: Florence2Result,
453465
textInput: Option[String] = None): Option[String] = {
454-
task match {
466+
val baseTask = getBaseTaskToken(task)
467+
baseTask match {
455468
case "<OD>" | "<DENSE_REGION_CAPTION>" | "<REGION_PROPOSAL>" =>
456469
result match {
457470
case BBoxesResult(bboxes) if bboxes.nonEmpty =>
@@ -481,8 +494,23 @@ object Florence2Utils {
481494
}
482495
case "<OPEN_VOCABULARY_DETECTION>" =>
483496
result match {
484-
case MixedResult(bboxes, bboxesLabels, _, _) if bboxes.nonEmpty =>
485-
val img = plotBBox(image, bboxes.map(_.bbox), bboxesLabels)
497+
case MixedResult(bboxes, bboxesLabels, polygons, polygonsLabels) =>
498+
if (polygons.nonEmpty) {
499+
val img = drawPolygons(image, polygons.map(Seq(_)), polygonsLabels, fillMask = true)
500+
Some(bufferedImageToBase64PNG(img))
501+
} else if (bboxes.nonEmpty) {
502+
val img = plotBBox(image, bboxes.map(_.bbox), bboxesLabels)
503+
Some(bufferedImageToBase64PNG(img))
504+
} else None
505+
case PolygonsResult(instances) if instances.nonEmpty =>
506+
val img = drawPolygons(
507+
image,
508+
instances.map(_.polygons),
509+
instances.map(_.catName),
510+
fillMask = true)
511+
Some(bufferedImageToBase64PNG(img))
512+
case BBoxesResult(bboxes) if bboxes.nonEmpty =>
513+
val img = plotBBox(image, bboxes.map(_.bbox), bboxes.map(_.catName))
486514
Some(bufferedImageToBase64PNG(img))
487515
case _ => None
488516
}
@@ -496,4 +524,8 @@ object Florence2Utils {
496524
case _ => None
497525
}
498526
}
527+
528+
def getBaseTaskToken(task: String): String = {
529+
taskAnswerPostProcessingType.keys.find(task.startsWith).getOrElse(task)
530+
}
499531
}

0 commit comments

Comments
 (0)