Skip to content

Commit 0198d3f

Browse files
authored
remove unsupported save,load,read,write from api docs for knn estimat… (#646)
* remove unsupported save,load,read,write from api docs for knn estimator, model classes Signed-off-by: Erik Ordentlich <[email protected]> * fix class names in error messages Signed-off-by: Erik Ordentlich <[email protected]> * typo Signed-off-by: Erik Ordentlich <[email protected]> --------- Signed-off-by: Erik Ordentlich <[email protected]>
1 parent c59795a commit 0198d3f

File tree

3 files changed

+86
-3
lines changed

3 files changed

+86
-3
lines changed

python/src/spark_rapids_ml/knn.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,12 +361,27 @@ def _get_cuml_fit_func(self, dataset: DataFrame) -> Callable[ # type: ignore
361361
pass
362362

363363
def write(self) -> MLWriter:
364+
"""Unsupported."""
364365
raise NotImplementedError(
365366
"NearestNeighbors does not support saving/loading, just re-create the estimator."
366367
)
367368

368369
@classmethod
369370
def read(cls) -> MLReader:
371+
"""Unsupported."""
372+
raise NotImplementedError(
373+
"NearestNeighbors does not support saving/loading, just re-create the estimator."
374+
)
375+
376+
def save(self, path: str) -> None:
377+
"""Unsupported."""
378+
raise NotImplementedError(
379+
"NearestNeighbors does not support saving/loading, just re-create the estimator."
380+
)
381+
382+
@classmethod
383+
def load(cls, path: str) -> MLReader:
384+
"""Unsupported."""
370385
raise NotImplementedError(
371386
"NearestNeighbors does not support saving/loading, just re-create the estimator."
372387
)
@@ -442,14 +457,29 @@ def _nearest_neighbors_join(
442457
return knnjoin_df
443458

444459
def write(self) -> MLWriter:
460+
"""Unsupported."""
445461
raise NotImplementedError(
446462
f"{self.__class__} does not support saving/loading, just re-fit the estimator to re-create a model."
447463
)
448464

449465
@classmethod
450466
def read(cls) -> MLReader:
467+
"""Unsupported."""
451468
raise NotImplementedError(
452-
f"{cls} does not support loading/loading, just re-fit the estimator to re-create a model."
469+
f"{cls} does not support saving/loading, just re-fit the estimator to re-create a model."
470+
)
471+
472+
def save(self, path: str) -> None:
473+
"""Unsupported."""
474+
raise NotImplementedError(
475+
f"{self.__class__} does not support saving/loading, just re-create the estimator."
476+
)
477+
478+
@classmethod
479+
def load(cls, path: str) -> MLReader:
480+
"""Unsupported."""
481+
raise NotImplementedError(
482+
f"{cls} does not support saving/loading, just re-create the estimator."
453483
)
454484

455485

@@ -1040,13 +1070,29 @@ def _get_cuml_fit_func(self, dataset: DataFrame) -> Callable[ # type: ignore
10401070
"""
10411071
pass
10421072

1073+
# for the following 4 methods leave doc string as below so that they are filtered out from api docs
10431074
def write(self) -> MLWriter:
1075+
"""Unsupported."""
10441076
raise NotImplementedError(
10451077
"ApproximateNearestNeighbors does not support saving/loading, just re-create the estimator."
10461078
)
10471079

10481080
@classmethod
10491081
def read(cls) -> MLReader:
1082+
"""Unsupported."""
1083+
raise NotImplementedError(
1084+
"ApproximateNearestNeighbors does not support saving/loading, just re-create the estimator."
1085+
)
1086+
1087+
@classmethod
1088+
def load(cls, path: str) -> MLReader:
1089+
"""Unsupported."""
1090+
raise NotImplementedError(
1091+
"ApproximateNearestNeighbors does not support saving/loading, just re-create the estimator."
1092+
)
1093+
1094+
def save(self, path: str) -> None:
1095+
"""Unsupported."""
10501096
raise NotImplementedError(
10511097
"ApproximateNearestNeighbors does not support saving/loading, just re-create the estimator."
10521098
)

python/src/spark_rapids_ml/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,17 @@ def _unsupported_methods_attributes(clazz: Any) -> Set[str]:
5656
_unsupported_methods: List[str] = sum(
5757
[_method_names_from_param(k) for k in _unsupported_params], []
5858
)
59-
return set(_unsupported_params + _unsupported_methods)
59+
methods_and_functions = inspect.getmembers(
60+
clazz,
61+
predicate=lambda member: inspect.isfunction(member)
62+
or inspect.ismethod(member),
63+
)
64+
_other_unsupported = [
65+
entry[0]
66+
for entry in methods_and_functions
67+
if entry and (entry[1].__doc__) == "Unsupported."
68+
]
69+
return set(_unsupported_params + _unsupported_methods + _other_unsupported)
6070
else:
6171
return set()
6272

python/tests/test_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,35 @@ class A:
6161
def _param_mapping(cls) -> Dict[str, Optional[str]]:
6262
return {"param1": "param2", "param3": None, "param4": ""}
6363

64+
@classmethod
65+
def unsupported_method(cls) -> None:
66+
"""Unsupported."""
67+
pass
68+
69+
def unsupported_function(self) -> None:
70+
"""Unsupported."""
71+
pass
72+
73+
@classmethod
74+
def supported_method(cls) -> None:
75+
"""supported"""
76+
pass
77+
78+
def supported_function(self) -> None:
79+
"""supported"""
80+
pass
81+
6482
assert _unsupported_methods_attributes(A) == set(
65-
["param3", "getParam3", "setParam3", "param4", "getParam4", "setParam4"]
83+
[
84+
"param3",
85+
"getParam3",
86+
"setParam3",
87+
"param4",
88+
"getParam4",
89+
"setParam4",
90+
"unsupported_method",
91+
"unsupported_function",
92+
]
6693
)
6794

6895

0 commit comments

Comments
 (0)