Skip to content

Commit 01559c4

Browse files
Merge branch 'master' into rana/updated-openai-timeouts
2 parents 48617a9 + bcf251c commit 01559c4

File tree

2 files changed

+56
-45
lines changed

2 files changed

+56
-45
lines changed

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer
199199
private val imageExtensions = Set("jpg", "jpeg", "png", "gif", "webp")
200200
private val audioExtensions = Set("mp3", "wav")
201201

202-
private def attachmentPlaceholder(columnName: String): String =
203-
s"[Content for column '$columnName' will be provided later as an attachment.]"
204-
205-
private[openai] def applyPathPlaceholders(template: String, pathColumns: Seq[String]): String = {
206-
pathColumns.foldLeft(template) { (current, columnName) =>
207-
current.replace(s"{$columnName}", attachmentPlaceholder(columnName))
208-
}
202+
private def extractFilename = udf { (path: String) =>
203+
Option(path).map(_.trim).filter(_.nonEmpty).map(p => new HPath(p).getName).orNull
209204
}
210205

211206
private def addRAIErrors[T <: OpenAIServicesBase with HasRAIContentFilter](
@@ -317,32 +312,51 @@ class OpenAIPrompt(override val uid: String) extends Transformer
317312
results.select(results.columns.filter(_ != getErrorCol).map(col) :+ col(getErrorCol): _*)
318313
}
319314

315+
private def processPathColumns(df: DataFrame): (DataFrame, Seq[String], Map[String, String], String) = {
316+
val columnTypeMap = if (isSet(columnTypes)) getColumnTypes else Map.empty[String, String]
317+
318+
columnTypeMap.foreach { case (colName, colType) =>
319+
require(colType == "text" || colType == "path",
320+
s"Unsupported column type '$colType' for column '$colName'. Supported types are 'text' and 'path'.")
321+
}
322+
323+
val pathColumnNames = columnTypeMap.collect {
324+
case (colName, colType) if colType == "path" => colName
325+
}.toSeq
326+
327+
pathColumnNames.foreach { colName =>
328+
require(
329+
df.columns.contains(colName),
330+
s"Column '$colName' specified in columnTypes was not found in the DataFrame. " +
331+
s"Available columns: ${df.columns.mkString(", ")}"
332+
)
333+
}
334+
335+
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._
336+
val (dfWithFilenames, filenameColMapping) = pathColumnNames.foldLeft((df, Map.empty[String, String])) {
337+
case ((currentDf, mapping), colName) =>
338+
val filenameCol = currentDf.withDerivativeCol(s"${colName}_filename")
339+
(currentDf.withColumn(filenameCol, extractFilename(F.col(colName))), mapping + (colName -> filenameCol))
340+
}
341+
342+
val templateWithFilenameRefs = filenameColMapping.foldLeft(getPromptTemplate) {
343+
case (template, (colName, filenameCol)) =>
344+
template.replace(s"{$colName}", s"{$filenameCol}")
345+
}
346+
347+
(dfWithFilenames, pathColumnNames, filenameColMapping, templateWithFilenameRefs)
348+
}
349+
320350
override def transform(dataset: Dataset[_]): DataFrame = {
321351
transferGlobalParamsToParamMap()
322352
logTransform[DataFrame]({
323353
val df = dataset.toDF
324354
val service = getOpenAIChatService
325-
val columnTypeMap = if (isSet(columnTypes)) getColumnTypes else Map.empty[String, String]
326355

327-
columnTypeMap.foreach { case (colName, colType) =>
328-
val normalized = colType.toLowerCase
329-
require(normalized == "text" || normalized == "path",
330-
s"Unsupported column type '$colType' for column '$colName'. Supported types are 'text' and 'path'.")
331-
}
332-
333-
val pathColumnNames = columnTypeMap.collect {
334-
case (colName, colType) if colType.equalsIgnoreCase("path") => colName
335-
}.toSeq
356+
val (dfWithFilenames, pathColumnNames, filenameColMapping, templateWithFilenameRefs) =
357+
processPathColumns(df)
336358

337-
val promptTemplateWithPlaceholders = applyPathPlaceholders(getPromptTemplate, pathColumnNames)
338-
val promptCol = Functions.template(promptTemplateWithPlaceholders)
339-
340-
pathColumnNames.foreach { colName =>
341-
require(
342-
df.columns.contains(colName),
343-
s"Column '$colName' specified in columnTypes was not found in the DataFrame."
344-
)
345-
}
359+
val promptCol = Functions.template(templateWithFilenameRefs)
346360

347361
val attachmentsColumn: Column =
348362
if (pathColumnNames.nonEmpty) {
@@ -368,9 +382,14 @@ class OpenAIPrompt(override val uid: String) extends Transformer
368382
}
369383

370384
val (dfTemplated, inputColName, serviceConfigured) =
371-
configureService(service, df, promptCol, createMessagesUDF, attachmentsColumn)
385+
configureService(service, dfWithFilenames, promptCol, createMessagesUDF, attachmentsColumn)
372386
val result = generateText(serviceConfigured, dfTemplated)
373-
if (getDropPrompt) result.drop(inputColName) else result
387+
388+
val resultCleaned = filenameColMapping.values.foldLeft(result) { (df, colName) =>
389+
if (df.columns.contains(colName)) df.drop(colName) else df
390+
}
391+
392+
if (getDropPrompt) resultCleaned.drop(inputColName) else resultCleaned
374393
}, dataset.columns.length)
375394
}
376395

cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,6 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
8686
}
8787
}
8888

89-
test("applyPathPlaceholders replaces path columns with attachment notice") {
90-
val prompt = new OpenAIPrompt()
91-
val template = "Describe {text} with reference to {filePath}"
92-
val updated = prompt.applyPathPlaceholders(template, Seq("filePath"))
93-
assert(updated.contains("Content for column 'filePath' will be provided later as an attachment."))
94-
assert(!updated.contains("{filePath}"))
95-
assert(updated.contains("{text}"))
96-
}
97-
9889
test("RAI Usage") {
9990
val result = prompt
10091
.setDeploymentName(deploymentName)
@@ -312,14 +303,15 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
312303

313304
val keywordsForEachQuestions = List("knn", "sorry")
314305

315-
promptResponses.transform(urlDF)
316-
.select("outParsed")
317-
.where(col("outParsed").isNotNull)
318-
.collect()
319-
.zip(keywordsForEachQuestions)
320-
.foreach { case (row, keyword) =>
321-
assert(row.getString(0).toLowerCase.contains(keyword))
322-
}
306+
promptResponses
307+
.transform(urlDF)
308+
.select("outParsed")
309+
.where(col("outParsed").isNotNull)
310+
.collect()
311+
.zip(keywordsForEachQuestions)
312+
.foreach { case (row, keyword) =>
313+
assert(row.getString(0).toLowerCase.contains(keyword))
314+
}
323315
}
324316

325317
ignore("Custom EndPoint") {

0 commit comments

Comments
 (0)