@@ -1467,7 +1467,9 @@ def saveImpl(self, path: str) -> None:
14671467
14681468 spark = _get_spark_session ()
14691469
1470- def write_sparse_array (array : scipy .sparse .spmatrix , df_dir : str ) -> None :
1470+ def write_sparse_array (
1471+ array : scipy .sparse .spmatrix , df_dir : str , mode : str
1472+ ) -> None :
14711473 indptr_schema = StructType ([StructField ("indptr" , IntegerType (), False )])
14721474 indptr_df = spark .createDataFrame (
14731475 pd .DataFrame (array .indptr ), schema = indptr_schema
@@ -1491,10 +1493,12 @@ def write_sparse_array(array: scipy.sparse.spmatrix, df_dir: str) -> None:
14911493 schema = indices_data_schema ,
14921494 )
14931495
1494- indptr_df .write .parquet (os .path .join (df_dir , "indptr.parquet" ))
1495- indices_data_df .write .parquet (os .path .join (df_dir , "indices_data.parquet" ))
1496+ indptr_df .write .parquet (os .path .join (df_dir , "indptr.parquet" ), mode = mode )
1497+ indices_data_df .write .parquet (
1498+ os .path .join (df_dir , "indices_data.parquet" ), mode = mode
1499+ )
14961500
1497- def write_dense_array (array : np .ndarray , df_path : str ) -> None :
1501+ def write_dense_array (array : np .ndarray , df_path : str , mode : str ) -> None :
14981502 assert (
14991503 spark .conf .get ("spark.sql.execution.arrow.pyspark.enabled" ) == "true"
15001504 ), "spark.sql.execution.arrow.pyspark.enabled must be set to true to persist array attributes"
@@ -1514,7 +1518,8 @@ def write_dense_array(array: np.ndarray, df_path: str) -> None:
15141518 ),
15151519 schema = schema ,
15161520 )
1517- data_df .write .parquet (df_path )
1521+
1522+ data_df .write .parquet (df_path , mode = mode )
15181523
15191524 DefaultParamsWriter .saveMetadata (
15201525 self .instance ,
@@ -1527,6 +1532,9 @@ def write_dense_array(array: np.ndarray, df_path: str) -> None:
15271532 },
15281533 )
15291534
1535+ # adhere to the overwrite() -> shouldOverWrite flag from the MLWriter
1536+ write_mode = "overwrite" if self .shouldOverwrite else "errorifexists"
1537+
15301538 # get a copy, since we're going to modify the array attributes
15311539 model_attributes = self .instance ._get_model_attributes ()
15321540 assert model_attributes is not None
@@ -1538,12 +1546,12 @@ def write_dense_array(array: np.ndarray, df_path: str) -> None:
15381546 array = model_attributes [key ]
15391547 if isinstance (array , scipy .sparse .csr_matrix ):
15401548 df_dir = os .path .join (data_path , f"{ key } csr" )
1541- write_sparse_array (array , df_dir )
1549+ write_sparse_array (array , df_dir , write_mode )
15421550 model_attributes [key ] = df_dir
15431551 model_attributes [key + "shape" ] = array .shape
15441552 else :
15451553 df_path = os .path .join (data_path , f"{ key } .parquet" )
1546- write_dense_array (array , df_path )
1554+ write_dense_array (array , df_path , write_mode )
15471555 model_attributes [key ] = df_path
15481556
15491557 metadata_file_path = os .path .join (data_path , "metadata.json" )
0 commit comments