Skip to content

Commit bcf251c

Browse files
levscautWendong Li
andauthored
chore: Replace placeholder with file name in prompt template (#2456)
* remove placeholder * style issue --------- Co-authored-by: Wendong Li <[email protected]>
1 parent b979064 commit bcf251c

File tree

2 files changed

+56
-45
lines changed

2 files changed

+56
-45
lines changed

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer
195195
private val imageExtensions = Set("jpg", "jpeg", "png", "gif", "webp")
196196
private val audioExtensions = Set("mp3", "wav")
197197

198-
private def attachmentPlaceholder(columnName: String): String =
199-
s"[Content for column '$columnName' will be provided later as an attachment.]"
200-
201-
private[openai] def applyPathPlaceholders(template: String, pathColumns: Seq[String]): String = {
202-
pathColumns.foldLeft(template) { (current, columnName) =>
203-
current.replace(s"{$columnName}", attachmentPlaceholder(columnName))
204-
}
198+
private def extractFilename = udf { (path: String) =>
199+
Option(path).map(_.trim).filter(_.nonEmpty).map(p => new HPath(p).getName).orNull
205200
}
206201

207202
private def addRAIErrors[T <: OpenAIServicesBase with HasRAIContentFilter](
@@ -313,32 +308,51 @@ class OpenAIPrompt(override val uid: String) extends Transformer
313308
results.select(results.columns.filter(_ != getErrorCol).map(col) :+ col(getErrorCol): _*)
314309
}
315310

311+
private def processPathColumns(df: DataFrame): (DataFrame, Seq[String], Map[String, String], String) = {
312+
val columnTypeMap = if (isSet(columnTypes)) getColumnTypes else Map.empty[String, String]
313+
314+
columnTypeMap.foreach { case (colName, colType) =>
315+
require(colType == "text" || colType == "path",
316+
s"Unsupported column type '$colType' for column '$colName'. Supported types are 'text' and 'path'.")
317+
}
318+
319+
val pathColumnNames = columnTypeMap.collect {
320+
case (colName, colType) if colType == "path" => colName
321+
}.toSeq
322+
323+
pathColumnNames.foreach { colName =>
324+
require(
325+
df.columns.contains(colName),
326+
s"Column '$colName' specified in columnTypes was not found in the DataFrame. " +
327+
s"Available columns: ${df.columns.mkString(", ")}"
328+
)
329+
}
330+
331+
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._
332+
val (dfWithFilenames, filenameColMapping) = pathColumnNames.foldLeft((df, Map.empty[String, String])) {
333+
case ((currentDf, mapping), colName) =>
334+
val filenameCol = currentDf.withDerivativeCol(s"${colName}_filename")
335+
(currentDf.withColumn(filenameCol, extractFilename(F.col(colName))), mapping + (colName -> filenameCol))
336+
}
337+
338+
val templateWithFilenameRefs = filenameColMapping.foldLeft(getPromptTemplate) {
339+
case (template, (colName, filenameCol)) =>
340+
template.replace(s"{$colName}", s"{$filenameCol}")
341+
}
342+
343+
(dfWithFilenames, pathColumnNames, filenameColMapping, templateWithFilenameRefs)
344+
}
345+
316346
override def transform(dataset: Dataset[_]): DataFrame = {
317347
transferGlobalParamsToParamMap()
318348
logTransform[DataFrame]({
319349
val df = dataset.toDF
320350
val service = getOpenAIChatService
321-
val columnTypeMap = if (isSet(columnTypes)) getColumnTypes else Map.empty[String, String]
322351

323-
columnTypeMap.foreach { case (colName, colType) =>
324-
val normalized = colType.toLowerCase
325-
require(normalized == "text" || normalized == "path",
326-
s"Unsupported column type '$colType' for column '$colName'. Supported types are 'text' and 'path'.")
327-
}
328-
329-
val pathColumnNames = columnTypeMap.collect {
330-
case (colName, colType) if colType.equalsIgnoreCase("path") => colName
331-
}.toSeq
352+
val (dfWithFilenames, pathColumnNames, filenameColMapping, templateWithFilenameRefs) =
353+
processPathColumns(df)
332354

333-
val promptTemplateWithPlaceholders = applyPathPlaceholders(getPromptTemplate, pathColumnNames)
334-
val promptCol = Functions.template(promptTemplateWithPlaceholders)
335-
336-
pathColumnNames.foreach { colName =>
337-
require(
338-
df.columns.contains(colName),
339-
s"Column '$colName' specified in columnTypes was not found in the DataFrame."
340-
)
341-
}
355+
val promptCol = Functions.template(templateWithFilenameRefs)
342356

343357
val attachmentsColumn: Column =
344358
if (pathColumnNames.nonEmpty) {
@@ -364,9 +378,14 @@ class OpenAIPrompt(override val uid: String) extends Transformer
364378
}
365379

366380
val (dfTemplated, inputColName, serviceConfigured) =
367-
configureService(service, df, promptCol, createMessagesUDF, attachmentsColumn)
381+
configureService(service, dfWithFilenames, promptCol, createMessagesUDF, attachmentsColumn)
368382
val result = generateText(serviceConfigured, dfTemplated)
369-
if (getDropPrompt) result.drop(inputColName) else result
383+
384+
val resultCleaned = filenameColMapping.values.foldLeft(result) { (df, colName) =>
385+
if (df.columns.contains(colName)) df.drop(colName) else df
386+
}
387+
388+
if (getDropPrompt) resultCleaned.drop(inputColName) else resultCleaned
370389
}, dataset.columns.length)
371390
}
372391

cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,6 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
8686
}
8787
}
8888

89-
test("applyPathPlaceholders replaces path columns with attachment notice") {
90-
val prompt = new OpenAIPrompt()
91-
val template = "Describe {text} with reference to {filePath}"
92-
val updated = prompt.applyPathPlaceholders(template, Seq("filePath"))
93-
assert(updated.contains("Content for column 'filePath' will be provided later as an attachment."))
94-
assert(!updated.contains("{filePath}"))
95-
assert(updated.contains("{text}"))
96-
}
97-
9889
test("RAI Usage") {
9990
val result = prompt
10091
.setDeploymentName(deploymentName)
@@ -312,14 +303,15 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
312303

313304
val keywordsForEachQuestions = List("knn", "sorry")
314305

315-
promptResponses.transform(urlDF)
316-
.select("outParsed")
317-
.where(col("outParsed").isNotNull)
318-
.collect()
319-
.zip(keywordsForEachQuestions)
320-
.foreach { case (row, keyword) =>
321-
assert(row.getString(0).toLowerCase.contains(keyword))
322-
}
306+
promptResponses
307+
.transform(urlDF)
308+
.select("outParsed")
309+
.where(col("outParsed").isNotNull)
310+
.collect()
311+
.zip(keywordsForEachQuestions)
312+
.foreach { case (row, keyword) =>
313+
assert(row.getString(0).toLowerCase.contains(keyword))
314+
}
323315
}
324316

325317
ignore("Custom EndPoint") {

0 commit comments

Comments
 (0)