Skip to content

Commit c698d95

Browse files
authored
Update imports to be modules not classes (#46)
This aligns with the Google style guide: https://google.github.io/styleguide/pyguide.html#22-imports Also remove `__all__` from `__init__.py` files. This only applies when using `import *`, which is not common or recommended, so it's not worth us protecting against within the library.
1 parent 24fc44c commit c698d95

File tree

4 files changed

+45
-71
lines changed

4 files changed

+45
-71
lines changed

CONTRIBUTING.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ Please make sure that your PR passes all tests by running `pytest ./src/` on you
2828
local machine. Also, you can run only tests that are affected by your code
2929
changes, but you will need to select them manually.
3030

31+
Metrax uses [ruff](https://github.com/astral-sh/ruff) for linting. Before
32+
sending a PR please run `ruff check` to catch any issues.
33+
3134
## Community Guidelines
3235

3336
This project follows [Google's Open Source Community
34-
Guidelines](https://opensource.google.com/conduct/).
37+
Guidelines](https://opensource.google.com/conduct/).

src/metrax/__init__.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from metrax.base import (
16-
Average,
17-
)
18-
from metrax.classification_metrics import (
19-
AUCPR,
20-
AUCROC,
21-
Precision,
22-
Recall,
23-
)
24-
from metrax.nlp_metrics import (
25-
Perplexity,
26-
WER
27-
)
28-
from metrax.ranking_metrics import (
29-
AveragePrecisionAtK,
30-
)
31-
from metrax.regression_metrics import (
32-
MSE,
33-
RMSE,
34-
RSQUARED,
35-
)
15+
from metrax import base
16+
from metrax import classification_metrics
17+
from metrax import nlp_metrics
18+
from metrax import ranking_metrics
19+
from metrax import regression_metrics
3620

37-
__all__ = [
38-
"AUCPR",
39-
"AUCROC",
40-
"Average",
41-
"AveragePrecisionAtK",
42-
"MSE",
43-
"Perplexity",
44-
"Precision",
45-
"Recall",
46-
"RMSE",
47-
"RSQUARED",
48-
"WER",
49-
]
21+
Average = base.Average
22+
AUCPR = classification_metrics.AUCPR
23+
AUCROC = classification_metrics.AUCROC
24+
Precision = classification_metrics.Precision
25+
Recall = classification_metrics.Recall
26+
Perplexity = nlp_metrics.Perplexity
27+
WER = nlp_metrics.WER
28+
AveragePrecisionAtK = ranking_metrics.AveragePrecisionAtK
29+
MSE = regression_metrics.MSE
30+
RMSE = regression_metrics.RMSE
31+
RSQUARED = regression_metrics.RSQUARED

src/metrax/nnx/__init__.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from metrax.nnx.nnx_metrics import (
16-
AUCPR,
17-
AUCROC,
18-
Average,
19-
AveragePrecisionAtK,
20-
MSE,
21-
Perplexity,
22-
Precision,
23-
RMSE,
24-
RSQUARED,
25-
Recall,
26-
WER,
27-
)
15+
from metrax.nnx import nnx_metrics
2816

29-
__all__ = [
30-
"AUCPR",
31-
"AUCROC",
32-
"Average",
33-
"AveragePrecisionAtK",
34-
"MSE",
35-
"Perplexity",
36-
"Precision",
37-
"Recall",
38-
"RMSE",
39-
"RSQUARED",
40-
"WER",
41-
]
17+
AUCPR = nnx_metrics.AUCPR
18+
AUCROC = nnx_metrics.AUCROC
19+
Average = nnx_metrics.Average
20+
AveragePrecisionAtK = nnx_metrics.AveragePrecisionAtK
21+
MSE = nnx_metrics.MSE
22+
Perplexity = nnx_metrics.Perplexity
23+
Precision = nnx_metrics.Precision
24+
RMSE = nnx_metrics.RMSE
25+
RSQUARED = nnx_metrics.RSQUARED
26+
Recall = nnx_metrics.Recall
27+
WER = nnx_metrics.WER

src/metrax/nnx/nnx_metrics_test.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414

1515
"""Tests for metrax NNX metrics."""
1616

17-
import importlib
17+
import dataclasses
1818
import inspect
19-
import pkgutil
19+
2020
from absl.testing import absltest
2121
from absl.testing import parameterized
22-
from clu import metrics as clu_metrics
22+
from flax import nnx
2323
import metrax
2424
import metrax.nnx
2525

@@ -28,13 +28,16 @@ class NnxMetricsTest(parameterized.TestCase):
2828

2929
def test_nnx_metrics_exists(self):
3030
"""Tests that every metrax CLU metric has an NNX counterpart."""
31-
metrax_metrics = metrax.nnx.__all__
32-
for _, module_name, _ in pkgutil.iter_modules(metrax.__path__):
33-
full_module_name = f"{metrax.__name__}.{module_name}"
34-
module = importlib.import_module(full_module_name)
35-
for name, obj in inspect.getmembers(module):
36-
if inspect.isclass(obj) and issubclass(obj, clu_metrics.Metric):
37-
self.assertIn(name, metrax_metrics)
31+
metrax_metric_keys = [
32+
key for key, metric in inspect.getmembers(metrax)
33+
if dataclasses.is_dataclass(metric)
34+
]
35+
metrax_nnx_metric_keys = [
36+
key for key, metric in inspect.getmembers(metrax.nnx)
37+
if inspect.isclass(metric) and issubclass(metric, nnx.Metric)
38+
]
39+
self.assertGreater(len(metrax_metric_keys), 0)
40+
self.assertSameElements(metrax_metric_keys, metrax_nnx_metric_keys)
3841

3942

4043
if __name__ == '__main__':

0 commit comments

Comments
 (0)