Skip to content

[BUG] ONNXModel fails in local Spark 3.5.6 environment  #2417

@bgreenwell

Description

@bgreenwell

SynapseML version

1.0.13

System information

System Information:

  • Language version: Python 3.8+
  • Spark Version: 3.5.6
  • Spark Platform: Local Spark session (macOS)
  • OS: macOS (Darwin 24.6.0)

Component(s) Affected:

  • ONNX Integration

Language(s) Affected:

  • Python

Describe the problem

SynapseML's ONNXModel fails to execute ONNX models in local Spark environments with native library path errors, while
the same code and models work successfully in Databricks environments. This suggests a configuration or environment
setup issue specific to local Spark deployments.

Expected Behavior:
The ONNXModel.transform() method should work consistently across both local Spark environments and Databricks, given
proper configuration.

Actual Behavior:

  • ✅ Works in Databricks: Same ONNX models and SynapseML code execute successfully
  • ❌ Fails locally: Transformation fails with ONNX Runtime native library loading errors

Code to reproduce issue

  """
  Minimal Reproducible Example for SynapseML ONNX Issue

  Requirements:
  pip install pyspark synapseml scikit-learn onnx onnxruntime skl2onnx
  """

  import tempfile
  import os
  from sklearn.ensemble import RandomForestClassifier
  from sklearn.datasets import make_classification
  import numpy as np

  def create_simple_onnx_model():
      """Create a simple ONNX model for testing."""
      from skl2onnx import to_onnx
      from skl2onnx.common.data_types import FloatTensorType

      print("Creating simple scikit-learn model...")

      # Create simple dataset
      X, y = make_classification(
          n_samples=100, n_features=3, n_informative=2,
          n_redundant=1, n_classes=2, random_state=42
      )

      # Train simple model
      model = RandomForestClassifier(n_estimators=3, max_depth=3, random_state=42)
      model.fit(X, y)

      # Convert to ONNX
      initial_types = [('input', FloatTensorType([None, 3]))]
      onnx_model = to_onnx(model, X.astype(np.float32), initial_types=initial_types)

      # Save to temporary file
      temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.onnx')
      with open(temp_file.name, 'wb') as f:
          f.write(onnx_model.SerializeToString())

      print(f"ONNX model saved to: {temp_file.name}")
      return temp_file.name, X, y

  def test_synapseml_onnx():
      """Test ONNX model with SynapseML (fails locally, works in Databricks)."""
      from pyspark.sql import SparkSession
      from synapse.ml.onnx import ONNXModel
      from pyspark.ml.linalg import VectorUDT, Vectors
      from pyspark.sql.functions import udf

      print("\n=== Testing with SynapseML ONNXModel ===")

      # Create ONNX model
      onnx_path, X, y = create_simple_onnx_model()

      # Create Spark session (local configuration)
      print("Creating local Spark session...")
      spark = (SparkSession.builder
               .appName("SynapseML-ONNX-Bug-Repro")
               .config("spark.jars.packages", "com.microsoft.azure:synapseml_2.12:1.0.13")
               .config("spark.sql.adaptive.enabled", "false")
               .config("spark.driver.memory", "2g")
               .getOrCreate())

      spark.sparkContext.setLogLevel("WARN")

      try:
          # Convert test data to Spark DataFrame with Vector column
          def to_vector(row):
              return Vectors.dense(row)

          vector_udf = udf(to_vector, VectorUDT())

          # Create DataFrame
          test_data = X[:10].tolist()
          df = spark.createDataFrame(
              [(i, row) for i, row in enumerate(test_data)],
              ['id', 'features_list']
          )
          df = df.withColumn('features', vector_udf(df['features_list'])).select('id', 'features')

          # Configure SynapseML ONNX model
          onnx_ml = (ONNXModel()
                    .setModelLocation(onnx_path)
                    .setDeviceType("CPU")
                    .setFeedDict({"input": "features"})
                    .setFetchDict({"output_label": "prediction", "output_probability": "probability"}))

          # THIS IS WHERE IT FAILS IN LOCAL ENVIRONMENTS
          print("Attempting ONNX transformation...")
          result = onnx_ml.transform(df)

          print("✅ SynapseML ONNX transformation successful!")
          result.show(5)
          return True

      except Exception as e:
          print(f"❌ SynapseML ONNX transformation failed: {e}")
          print(f"Error type: {type(e).__name__}")
          return False

      finally:
          spark.stop()
          if os.path.exists(onnx_path):
              os.unlink(onnx_path)

  if __name__ == "__main__":
      test_synapseml_onnx()

Other info / logs

Additional Context:

  1. Reference Issue: This relates to the successful Databricks implementation documented in
    Support for EBM/ONNX Model (Possible Bug?) [BUG] [HELP] #1902
  2. Local vs. Managed Environment: The core difference appears to be how ONNX Runtime native libraries are
    configured/discovered in local Spark vs. Databricks managed environments
  3. Configuration Attempts: Various local configuration attempts have been made:
    - Explicit java.library.path settings
    - Different ONNX Runtime Maven coordinates
    - Memory and execution tuning
    - Manual native library path specification
  4. Request for Documentation: It would be helpful to have documentation on the specific configuration requirements for
    local Spark environments to match Databricks functionality.

Suggested Solution:

Could SynapseML provide:

  1. Detailed local environment setup guide for ONNX integration
  2. Specific Maven/configuration requirements for local Spark deployments
  3. Native library path configuration examples for different platforms (macOS, Linux, Windows)

This would help bridge the gap between managed Databricks environments and local development setups.

What component(s) does this bug affect?

  • area/cognitive: Cognitive project
  • area/core: Core project
  • area/deep-learning: DeepLearning project
  • area/lightgbm: Lightgbm project
  • area/opencv: Opencv project
  • area/vw: VW project
  • area/website: Website
  • area/build: Project build system
  • area/notebooks: Samples under notebooks folder
  • area/docker: Docker usage
  • area/models: models related issue

What language(s) does this bug affect?

  • language/scala: Scala source code
  • language/python: Pyspark APIs
  • language/r: R APIs
  • language/csharp: .NET APIs
  • language/new: Proposals for new client languages

What integration(s) does this bug affect?

  • integrations/synapse: Azure Synapse integrations
  • integrations/azureml: Azure ML integrations
  • integrations/databricks: Databricks integrations

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions