Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,43 @@ def get_api_type(self):

def reset_api_type(self):
self.defaults.resetApiType()

def set_api_timeout(self, timeout):
timeout_float = float(timeout)
if timeout_float <= 0:
raise ValueError(
f"API timeout must be greater than 0, got: {timeout_float}"
)
self.defaults.setApiTimeout(timeout_float)

def get_api_timeout(self):
return getOption(self.defaults.getApiTimeout())

def reset_api_timeout(self):
self.defaults.resetApiTimeout()

def set_connection_timeout(self, timeout):
timeout_float = float(timeout)
if timeout_float <= 0:
raise ValueError(
f"Connection timeout must be greater than 0, got: {timeout_float}"
)
self.defaults.setConnectionTimeout(timeout_float)

def get_connection_timeout(self):
return getOption(self.defaults.getConnectionTimeout())

def reset_connection_timeout(self):
self.defaults.resetConnectionTimeout()

def set_timeout(self, timeout):
timeout_float = float(timeout)
if timeout_float <= 0:
raise ValueError(f"Timeout must be greater than 0, got: {timeout_float}")
self.defaults.setTimeout(timeout_float)

def get_timeout(self):
return getOption(self.defaults.getTimeout())

def reset_timeout(self):
self.defaults.resetTimeout()
Original file line number Diff line number Diff line change
Expand Up @@ -579,18 +579,22 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform
case l => l
}

val baseTransformer = new SimpleHTTPTransformer()
.setInputCol(dynamicParamColName)
.setOutputCol(getOutputCol)
.setInputParser(getInternalInputParser(schema))
.setOutputParser(getInternalOutputParser(schema))
.setHandler(handlingFunc _)
.setConcurrency(getConcurrency)
.setConcurrentTimeout(get(concurrentTimeout))
.setApiTimeout(getApiTimeout)
.setConnectionTimeout(getConnectionTimeout)
.setErrorCol(getErrorCol)
val transformer = get(timeout).map(baseTransformer.setTimeout).getOrElse(baseTransformer)

val stages = Array(
Lambda(_.withColumn(dynamicParamColName, struct(dynamicParamCols: _*))),
new SimpleHTTPTransformer()
.setInputCol(dynamicParamColName)
.setOutputCol(getOutputCol)
.setInputParser(getInternalInputParser(schema))
.setOutputParser(getInternalOutputParser(schema))
.setHandler(handlingFunc _)
.setConcurrency(getConcurrency)
.setConcurrentTimeout(get(concurrentTimeout))
.setTimeout(getTimeout)
.setErrorCol(getErrorCol),
transformer,
new DropColumns().setCol(dynamicParamColName)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ case object OpenAITopPKey extends GlobalKey[Either[Double, String]]
case object OpenAIVerbosityKey extends GlobalKey[Either[String, String]]
case object OpenAIReasoningEffortKey extends GlobalKey[Either[String, String]]
case object OpenAIApiTypeKey extends GlobalKey[String]
case object OpenAIApiTimeoutKey extends GlobalKey[Double]
case object OpenAIConnectionTimeoutKey extends GlobalKey[Double]
case object OpenAITimeoutKey extends GlobalKey[Double]

// scalastyle:off number.of.methods
trait HasOpenAITextParams extends HasOpenAISharedParams {
Expand Down Expand Up @@ -412,7 +415,7 @@ trait HasTextOutput {

abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String)
with HasOpenAISharedParams with OpenAIFabricSetting {
setDefault(timeout -> 360.0)
setDefault(apiTimeout -> 600.0)

private def usingDefaultOpenAIEndpoint(): Boolean = {
getUrl == FabricClient.MLWorkloadEndpointML + "/cognitive/openai/"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,45 @@ object OpenAIDefaults {
GlobalParams.resetGlobalParam(OpenAIApiTypeKey)
}

def setApiTimeout(v: Double): Unit = {
require(v > 0, s"API timeout must be greater than 0, got: $v")
GlobalParams.setGlobalParam(OpenAIApiTimeoutKey, v)
}

def getApiTimeout: Option[Double] = {
GlobalParams.getGlobalParam(OpenAIApiTimeoutKey)
}

def resetApiTimeout(): Unit = {
GlobalParams.resetGlobalParam(OpenAIApiTimeoutKey)
}

def setConnectionTimeout(v: Double): Unit = {
require(v > 0, s"Connection timeout must be greater than 0, got: $v")
GlobalParams.setGlobalParam(OpenAIConnectionTimeoutKey, v)
}

def getConnectionTimeout: Option[Double] = {
GlobalParams.getGlobalParam(OpenAIConnectionTimeoutKey)
}

def resetConnectionTimeout(): Unit = {
GlobalParams.resetGlobalParam(OpenAIConnectionTimeoutKey)
}

def setTimeout(v: Double): Unit = {
require(v > 0, s"Timeout must be greater than 0, got: $v")
GlobalParams.setGlobalParam(OpenAITimeoutKey, v)
}

def getTimeout: Option[Double] = {
GlobalParams.getGlobalParam(OpenAITimeoutKey)
}

def resetTimeout(): Unit = {
GlobalParams.resetGlobalParam(OpenAITimeoutKey)
}

private def extractLeft[T](optEither: Option[Either[T, String]]): Option[T] = {
optEither match {
case Some(Left(v)) => Some(v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
with HasReturnUsage {
logClass(FeatureNames.AiServices.OpenAI)

GlobalParams.registerParam(apiTimeout, OpenAIApiTimeoutKey)
GlobalParams.registerParam(connectionTimeout, OpenAIConnectionTimeoutKey)
GlobalParams.registerParam(timeout, OpenAITimeoutKey)

def this() = this(Identifiable.randomUID("OpenAIEmbedding"))

def urlPath: String = ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer

logClass(FeatureNames.AiServices.OpenAI)

GlobalParams.registerParam(apiTimeout, OpenAIApiTimeoutKey)
GlobalParams.registerParam(connectionTimeout, OpenAIConnectionTimeoutKey)
GlobalParams.registerParam(timeout, OpenAITimeoutKey)

def this() = this(Identifiable.randomUID("OpenAIPrompt"))

override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
Expand Down Expand Up @@ -176,7 +180,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
systemPrompt -> defaultSystemPrompt,
apiType -> "chat_completions",
columnTypes -> Map.empty,
timeout -> 360.0
apiTimeout -> 600.0
)

override def setCustomServiceName(v: String): this.type = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def test_setters_and_getters(self):
defaults.set_api_version("2024-05-01-preview")
defaults.set_model("grok-3-mini")
defaults.set_embedding_deployment_name("text-embedding-ada-002")
defaults.set_api_timeout(600.0)
defaults.set_connection_timeout(5.0)
defaults.set_timeout(120.0)

self.assertEqual(defaults.get_deployment_name(), "Bing Bong")
self.assertEqual(defaults.get_subscription_key(), "SubKey")
Expand All @@ -42,6 +45,9 @@ def test_setters_and_getters(self):
self.assertEqual(
defaults.get_embedding_deployment_name(), "text-embedding-ada-002"
)
self.assertEqual(defaults.get_api_timeout(), 600.0)
self.assertEqual(defaults.get_connection_timeout(), 5.0)
self.assertEqual(defaults.get_timeout(), 120.0)

def test_resetters(self):
defaults = OpenAIDefaults()
Expand All @@ -55,6 +61,9 @@ def test_resetters(self):
defaults.set_api_version("2024-05-01-preview")
defaults.set_model("grok-3-mini")
defaults.set_embedding_deployment_name("text-embedding-ada-002")
defaults.set_api_timeout(600.0)
defaults.set_connection_timeout(5.0)
defaults.set_timeout(120.0)

self.assertEqual(defaults.get_deployment_name(), "Bing Bong")
self.assertEqual(defaults.get_subscription_key(), "SubKey")
Expand All @@ -67,6 +76,9 @@ def test_resetters(self):
self.assertEqual(
defaults.get_embedding_deployment_name(), "text-embedding-ada-002"
)
self.assertEqual(defaults.get_api_timeout(), 600.0)
self.assertEqual(defaults.get_connection_timeout(), 5.0)
self.assertEqual(defaults.get_timeout(), 120.0)

defaults.reset_deployment_name()
defaults.reset_subscription_key()
Expand All @@ -77,6 +89,9 @@ def test_resetters(self):
defaults.reset_api_version()
defaults.reset_model()
defaults.reset_embedding_deployment_name()
defaults.reset_api_timeout()
defaults.reset_connection_timeout()
defaults.reset_timeout()

self.assertEqual(defaults.get_deployment_name(), None)
self.assertEqual(defaults.get_subscription_key(), None)
Expand All @@ -87,6 +102,9 @@ def test_resetters(self):
self.assertEqual(defaults.get_api_version(), None)
self.assertEqual(defaults.get_model(), None)
self.assertEqual(defaults.get_embedding_deployment_name(), None)
self.assertEqual(defaults.get_api_timeout(), None)
self.assertEqual(defaults.get_connection_timeout(), None)
self.assertEqual(defaults.get_timeout(), None)

def test_two_defaults(self):
defaults = OpenAIDefaults()
Expand Down Expand Up @@ -168,6 +186,28 @@ def test_parameter_validation(self):
with self.assertRaises(ValueError):
defaults.set_top_p(1.1)

# Test valid timeout values
defaults.set_api_timeout(1.0)
defaults.set_api_timeout(600.0)
defaults.set_connection_timeout(1.0)
defaults.set_connection_timeout(5.0)
defaults.set_timeout(60.0)
defaults.set_timeout(120.0)

# Test invalid timeout values (must be > 0)
with self.assertRaises(ValueError):
defaults.set_api_timeout(0.0)
with self.assertRaises(ValueError):
defaults.set_api_timeout(-1.0)
with self.assertRaises(ValueError):
defaults.set_connection_timeout(0.0)
with self.assertRaises(ValueError):
defaults.set_connection_timeout(-1.0)
with self.assertRaises(ValueError):
defaults.set_timeout(0.0)
with self.assertRaises(ValueError):
defaults.set_timeout(-1.0)


class TestResponseFormatJsonSchema(unittest.TestCase):
def setUp(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
OpenAIDefaults.setVerbosity("medium")
OpenAIDefaults.setReasoningEffort("medium")
OpenAIDefaults.setApiType("responses")
OpenAIDefaults.setApiTimeout(600.0)
OpenAIDefaults.setConnectionTimeout(5.0)
OpenAIDefaults.setTimeout(120.0)

assert(OpenAIDefaults.getDeploymentName.contains(deploymentName))
assert(OpenAIDefaults.getSubscriptionKey.contains(openAIAPIKey))
Expand All @@ -101,6 +104,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
assert(OpenAIDefaults.getVerbosity.contains("medium"))
assert(OpenAIDefaults.getReasoningEffort.contains("medium"))
assert(OpenAIDefaults.getApiType.contains("responses"))
assert(OpenAIDefaults.getApiTimeout.contains(600.0))
assert(OpenAIDefaults.getConnectionTimeout.contains(5.0))
assert(OpenAIDefaults.getTimeout.contains(120.0))
}

test("Test Resetters") {
Expand All @@ -116,6 +122,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
OpenAIDefaults.setVerbosity("medium")
OpenAIDefaults.setReasoningEffort("medium")
OpenAIDefaults.setApiType("responses")
OpenAIDefaults.setApiTimeout(600.0)
OpenAIDefaults.setConnectionTimeout(5.0)
OpenAIDefaults.setTimeout(120.0)

OpenAIDefaults.resetDeploymentName()
OpenAIDefaults.resetSubscriptionKey()
Expand All @@ -129,6 +138,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
OpenAIDefaults.resetVerbosity()
OpenAIDefaults.resetReasoningEffort()
OpenAIDefaults.resetApiType()
OpenAIDefaults.resetApiTimeout()
OpenAIDefaults.resetConnectionTimeout()
OpenAIDefaults.resetTimeout()

assert(OpenAIDefaults.getDeploymentName.isEmpty)
assert(OpenAIDefaults.getSubscriptionKey.isEmpty)
Expand All @@ -142,6 +154,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
assert(OpenAIDefaults.getVerbosity.isEmpty)
assert(OpenAIDefaults.getReasoningEffort.isEmpty)
assert(OpenAIDefaults.getApiType.isEmpty)
assert(OpenAIDefaults.getApiTimeout.isEmpty)
assert(OpenAIDefaults.getConnectionTimeout.isEmpty)
assert(OpenAIDefaults.getTimeout.isEmpty)
}

test("Test Parameter Validation") {
Expand Down Expand Up @@ -183,5 +198,33 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {
// Test reasoning effort values
OpenAIDefaults.setReasoningEffort("low")
OpenAIDefaults.setReasoningEffort("anything")

// Test valid timeout values
OpenAIDefaults.setApiTimeout(1.0)
OpenAIDefaults.setApiTimeout(600.0)
OpenAIDefaults.setConnectionTimeout(1.0)
OpenAIDefaults.setConnectionTimeout(5.0)
OpenAIDefaults.setTimeout(60.0)
OpenAIDefaults.setTimeout(120.0)

// Test invalid timeout values (must be > 0)
assertThrows[IllegalArgumentException] {
OpenAIDefaults.setApiTimeout(0.0)
}
assertThrows[IllegalArgumentException] {
OpenAIDefaults.setApiTimeout(-1.0)
}
assertThrows[IllegalArgumentException] {
OpenAIDefaults.setConnectionTimeout(0.0)
}
assertThrows[IllegalArgumentException] {
OpenAIDefaults.setConnectionTimeout(-1.0)
}
assertThrows[IllegalArgumentException] {
OpenAIDefaults.setTimeout(0.0)
}
assertThrows[IllegalArgumentException] {
OpenAIDefaults.setTimeout(-1.0)
}
}
}
Loading
Loading