Skip to content

Commit bfa8c46

Browse files
levscautWendong Li
andauthored
feat: Add OpenAI API Type to OpenAIDefaults (#2453)
Co-authored-by: Wendong Li <[email protected]>
1 parent 95c8b61 commit bfa8c46

File tree

5 files changed

+41
-2
lines changed

5 files changed

+41
-2
lines changed

cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,12 @@ def get_reasoning_effort(self):
128128

129129
def reset_reasoning_effort(self):
130130
self.defaults.resetReasoningEffort()
131+
132+
def set_api_type(self, api_type):
133+
self.defaults.setApiType(api_type)
134+
135+
def get_api_type(self):
136+
return getOption(self.defaults.getApiType())
137+
138+
def reset_api_type(self):
139+
self.defaults.resetApiType()

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ case object OpenAISeedKey extends GlobalKey[Either[Int, String]]
107107
case object OpenAITopPKey extends GlobalKey[Either[Double, String]]
108108
case object OpenAIVerbosityKey extends GlobalKey[Either[String, String]]
109109
case object OpenAIReasoningEffortKey extends GlobalKey[Either[String, String]]
110+
case object OpenAIApiTypeKey extends GlobalKey[String]
110111

111112
// scalastyle:off number.of.methods
112113
trait HasOpenAITextParams extends HasOpenAISharedParams {

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,23 @@ object OpenAIDefaults {
144144
GlobalParams.resetGlobalParam(OpenAIReasoningEffortKey)
145145
}
146146

147+
def setApiType(v: String): Unit = {
148+
val options = Seq("responses", "chat_completions")
149+
require(
150+
options.contains(v),
151+
s"ApiType must be in ${options.mkString(", ")}, got: $v"
152+
)
153+
GlobalParams.setGlobalParam(OpenAIApiTypeKey, v)
154+
}
155+
156+
def getApiType: Option[String] = {
157+
GlobalParams.getGlobalParam(OpenAIApiTypeKey)
158+
}
159+
160+
def resetApiType(): Unit = {
161+
GlobalParams.resetGlobalParam(OpenAIApiTypeKey)
162+
}
163+
147164
private def extractLeft[T](optEither: Option[Either[T, String]]): Option[T] = {
148165
optEither match {
149166
case Some(Left(v)) => Some(v)

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import com.microsoft.azure.synapse.ml.core.spark.Functions
88
import com.microsoft.azure.synapse.ml.io.binary.BinaryFileReader
99
import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL}
1010
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
11-
import com.microsoft.azure.synapse.ml.param.{HasGlobalParams, StringStringMapParam}
11+
import com.microsoft.azure.synapse.ml.param.{GlobalParams, HasGlobalParams, StringStringMapParam}
1212
import com.microsoft.azure.synapse.ml.services._
1313
import com.microsoft.azure.synapse.ml.services.aifoundry.{AIFoundryChatCompletion, HasAIFoundryTextParamsExtended}
1414
import HasReturnUsage.{UsageFieldMapping, UsageMappings}
@@ -125,15 +125,22 @@ class OpenAIPrompt(override val uid: String) extends Transformer
125125
this, "apiType", "The OpenAI API type to use: 'chat_completions' or 'responses'",
126126
isValid = ParamValidators.inArray(Array("chat_completions", "responses")))
127127

128+
GlobalParams.registerParam(apiType, OpenAIApiTypeKey)
129+
128130
def getApiType: String = $(apiType)
129131

130132
def setApiType(value: String): this.type = set(apiType, value)
131133

132134
val columnTypes = new StringStringMapParam(
133135
this, "columnTypes", "A map from column names to their types. Supported types are 'text' and 'path'.")
134136
private def validateColumnType(value: String) = {
135-
require(value.equalsIgnoreCase("text") || value.equalsIgnoreCase("path"),
137+
if (value.equalsIgnoreCase("path") || value.equalsIgnoreCase("text")) {
138+
logWarning(s"Column type '$value' is deprecated. Please use lowercase 'path' or 'text' instead.")
139+
}
140+
require(value == "text" || value == "path",
136141
s"Unsupported column type: $value. Supported types are 'text' and 'path'.")
142+
require(value != "responses" || this.getApiType == "responses",
143+
s"Column type 'path' is only supported when apiType is set to 'responses'.")
137144
}
138145

139146
def getColumnTypes: Map[String, String] = $(columnTypes)

cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
8787
OpenAIDefaults.setEmbeddingDeploymentName("text-embedding-ada-002")
8888
OpenAIDefaults.setVerbosity("medium")
8989
OpenAIDefaults.setReasoningEffort("medium")
90+
OpenAIDefaults.setApiType("responses")
9091

9192
assert(OpenAIDefaults.getDeploymentName.contains(deploymentName))
9293
assert(OpenAIDefaults.getSubscriptionKey.contains(openAIAPIKey))
@@ -99,6 +100,7 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
99100
assert(OpenAIDefaults.getEmbeddingDeploymentName.contains("text-embedding-ada-002"))
100101
assert(OpenAIDefaults.getVerbosity.contains("medium"))
101102
assert(OpenAIDefaults.getReasoningEffort.contains("medium"))
103+
assert(OpenAIDefaults.getApiType.contains("responses"))
102104
}
103105

104106
test("Test Resetters") {
@@ -113,6 +115,7 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
113115
OpenAIDefaults.setEmbeddingDeploymentName("text-embedding-ada-002")
114116
OpenAIDefaults.setVerbosity("medium")
115117
OpenAIDefaults.setReasoningEffort("medium")
118+
OpenAIDefaults.setApiType("responses")
116119

117120
OpenAIDefaults.resetDeploymentName()
118121
OpenAIDefaults.resetSubscriptionKey()
@@ -125,6 +128,7 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
125128
OpenAIDefaults.resetEmbeddingDeploymentName()
126129
OpenAIDefaults.resetVerbosity()
127130
OpenAIDefaults.resetReasoningEffort()
131+
OpenAIDefaults.resetApiType()
128132

129133
assert(OpenAIDefaults.getDeploymentName.isEmpty)
130134
assert(OpenAIDefaults.getSubscriptionKey.isEmpty)
@@ -137,6 +141,7 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
137141
assert(OpenAIDefaults.getEmbeddingDeploymentName.isEmpty)
138142
assert(OpenAIDefaults.getVerbosity.isEmpty)
139143
assert(OpenAIDefaults.getReasoningEffort.isEmpty)
144+
assert(OpenAIDefaults.getApiType.isEmpty)
140145
}
141146

142147
test("Test Parameter Validation") {

0 commit comments

Comments
 (0)