Skip to content

Commit d26a8a8

Browse files
committed
Update imports to be modules not classes
This aligns with the Google style guide: https://google.github.io/styleguide/pyguide.html#22-imports Also remove `__all__` from `__init__.py` files. This only applies when using `import *`, which is not common or recommended, so it's not worth us protecting against within the library.
1 parent 24fc44c commit d26a8a8

File tree

3 files changed

+42
-68
lines changed

3 files changed

+42
-68
lines changed

src/metrax/__init__.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from metrax.base import (
16-
Average,
17-
)
18-
from metrax.classification_metrics import (
19-
AUCPR,
20-
AUCROC,
21-
Precision,
22-
Recall,
23-
)
24-
from metrax.nlp_metrics import (
25-
Perplexity,
26-
WER
27-
)
28-
from metrax.ranking_metrics import (
29-
AveragePrecisionAtK,
30-
)
31-
from metrax.regression_metrics import (
32-
MSE,
33-
RMSE,
34-
RSQUARED,
35-
)
15+
from metrax import base
16+
from metrax import classification_metrics
17+
from metrax import nlp_metrics
18+
from metrax import ranking_metrics
19+
from metrax import regression_metrics
3620

37-
__all__ = [
38-
"AUCPR",
39-
"AUCROC",
40-
"Average",
41-
"AveragePrecisionAtK",
42-
"MSE",
43-
"Perplexity",
44-
"Precision",
45-
"Recall",
46-
"RMSE",
47-
"RSQUARED",
48-
"WER",
49-
]
21+
Average = base.Average
22+
AUCPR = classification_metrics.AUCPR
23+
AUCROC = classification_metrics.AUCROC
24+
Precision = classification_metrics.Precision
25+
Recall = classification_metrics.Recall
26+
Perplexity = nlp_metrics.Perplexity
27+
WER = nlp_metrics.WER
28+
AveragePrecisionAtK = ranking_metrics.AveragePrecisionAtK
29+
MSE = regression_metrics.MSE
30+
RMSE = regression_metrics.RMSE
31+
RSQUARED = regression_metrics.RSQUARED

src/metrax/nnx/__init__.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from metrax.nnx.nnx_metrics import (
16-
AUCPR,
17-
AUCROC,
18-
Average,
19-
AveragePrecisionAtK,
20-
MSE,
21-
Perplexity,
22-
Precision,
23-
RMSE,
24-
RSQUARED,
25-
Recall,
26-
WER,
27-
)
15+
from metrax.nnx import nnx_metrics
2816

29-
__all__ = [
30-
"AUCPR",
31-
"AUCROC",
32-
"Average",
33-
"AveragePrecisionAtK",
34-
"MSE",
35-
"Perplexity",
36-
"Precision",
37-
"Recall",
38-
"RMSE",
39-
"RSQUARED",
40-
"WER",
41-
]
17+
AUCPR = nnx_metrics.AUCPR
18+
AUCROC = nnx_metrics.AUCROC
19+
Average = nnx_metrics.Average
20+
AveragePrecisionAtK = nnx_metrics.AveragePrecisionAtK
21+
MSE = nnx_metrics.MSE
22+
Perplexity = nnx_metrics.Perplexity
23+
Precision = nnx_metrics.Precision
24+
RMSE = nnx_metrics.RMSE
25+
RSQUARED = nnx_metrics.RSQUARED
26+
Recall = nnx_metrics.Recall
27+
WER = nnx_metrics.WER

src/metrax/nnx/nnx_metrics_test.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414

1515
"""Tests for metrax NNX metrics."""
1616

17+
import dataclasses
1718
import importlib
1819
import inspect
1920
import pkgutil
21+
2022
from absl.testing import absltest
2123
from absl.testing import parameterized
22-
from clu import metrics as clu_metrics
24+
from flax import nnx
2325
import metrax
2426
import metrax.nnx
2527

@@ -28,13 +30,17 @@ class NnxMetricsTest(parameterized.TestCase):
2830

2931
def test_nnx_metrics_exists(self):
3032
"""Tests that every metrax CLU metric has an NNX counterpart."""
31-
metrax_metrics = metrax.nnx.__all__
32-
for _, module_name, _ in pkgutil.iter_modules(metrax.__path__):
33-
full_module_name = f"{metrax.__name__}.{module_name}"
34-
module = importlib.import_module(full_module_name)
35-
for name, obj in inspect.getmembers(module):
36-
if inspect.isclass(obj) and issubclass(obj, clu_metrics.Metric):
37-
self.assertIn(name, metrax_metrics)
33+
metrax_metric_keys = [
34+
key for key, metric in inspect.getmembers(metrax)
35+
if dataclasses.is_dataclass(metric)
36+
]
37+
metrax_nnx_metric_keys = [
38+
key for key, metric in inspect.getmembers(metrax.nnx)
39+
if inspect.isclass(metric) and issubclass(metric, nnx.Metric)
40+
]
41+
self.assertGreater(len(metrax_metric_keys), 0)
42+
self.assertSameElements(metrax_metric_keys, metrax_nnx_metric_keys)
43+
self.assertSameElements(metrax_metric_keys, metrax_nnx_metric_keys)
3844

3945

4046
if __name__ == '__main__':

0 commit comments

Comments
 (0)