11#!/usr/bin/env python
22
3- # Copyright (c) 2022, 2023 Oracle and/or its affiliates.
3+ # Copyright (c) 2022, 2025 Oracle and/or its affiliates.
44# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55
66import os
77import shutil
88from unittest .mock import patch
99
10- import numpy as np
1110import pytest
1211import sklearn
1312import xgboost
14- from sklearn import datasets , linear_model
13+ from sklearn import linear_model
14+ from sklearn .datasets import make_regression
1515
1616from ads .feature_engineering .schema import Schema
1717from ads .model .framework .sklearn_model import SklearnModel
2222
2323class TestMetadataMixin :
2424 def setup_method (cls ):
25- # Load the diabetes dataset
26- diabetes_X , diabetes_y = datasets .load_diabetes (return_X_y = True )
27-
28- # Use only one feature
29- diabetes_X = diabetes_X [:, np .newaxis , 2 ]
25+ X , y = make_regression (
26+ n_samples = 442 , n_features = 1 , n_informative = 1 , noise = 10.0 , random_state = 42
27+ )
3028
3129 # Split the data into training/testing sets
32- cls .diabetes_X_train = diabetes_X [:- 20 ]
33- cls .diabetes_X_test = diabetes_X [- 20 :]
30+ cls .X_train = X [:- 20 ]
31+ cls .X_test = X [- 20 :]
3432
3533 # Split the targets into training/testing sets
36- cls .diabetes_y_train = diabetes_y [:- 20 ]
37- cls .diabetes_y_test = diabetes_y [- 20 :]
34+ cls .y_train = y [:- 20 ]
35+ cls .y_test = y [- 20 :]
3836
3937 # Create linear regression object
4038 regr = linear_model .LinearRegression ()
@@ -43,8 +41,8 @@ def setup_method(cls):
4341
4442 xgb_regr = XGBRegressor ()
4543 # Train the model using the training sets
46- cls .rgr = regr .fit (cls .diabetes_X_train , cls .diabetes_y_train )
47- cls .xgb_rgr = xgb_regr .fit (cls .diabetes_X_train , cls .diabetes_y_train )
44+ cls .rgr = regr .fit (cls .X_train , cls .y_train )
45+ cls .xgb_rgr = xgb_regr .fit (cls .X_train , cls .y_train )
4846
4947 def test_metadata_generic_model (self ):
5048 model = GenericModel (self .rgr , artifact_dir = "~/test_generic" )
@@ -132,8 +130,8 @@ def test_metadata_sklearn_model(self, mock_get_service_packs):
132130 )
133131 model .populate_metadata (
134132 use_case_type = "other" ,
135- X_sample = self .diabetes_X_test ,
136- y_sample = self .diabetes_y_test ,
133+ X_sample = self .X_test ,
134+ y_sample = self .y_test ,
137135 )
138136
139137 assert model .metadata_custom .get ("ModelSerializationFormat" ).value == "joblib"
@@ -185,8 +183,8 @@ def test_metadata_xgboost_model(self, mock_get_service_packs):
185183 )
186184 model .populate_metadata (
187185 use_case_type = "binary_classification" ,
188- X_sample = self .diabetes_X_test ,
189- y_sample = self .diabetes_y_test ,
186+ X_sample = self .X_test ,
187+ y_sample = self .y_test ,
190188 )
191189 assert (
192190 model .metadata_custom .get ("CondaEnvironment" ).value
0 commit comments