@@ -92,6 +92,8 @@ def _assert_same_shape(a: jnp.ndarray, b: jnp.ndarray):
9292
9393
9494M = TypeVar ("M" , bound = "Metric" )
95+ R = TypeVar ("R" , jnp .ndarray , dict [str , jnp .ndarray ])
96+ V = TypeVar ("V" , clu .values .Value , dict [str , clu .values .Value ])
9597
9698
9799class Metric :
@@ -160,7 +162,7 @@ def merge(self: M, other: M) -> M:
160162 def _reduce_merge (self : M , other : M ) -> M :
161163 return self .merge (other )
162164
163- def compute (self ) -> jnp . ndarray :
165+ def compute (self ) -> R :
164166 """Computes final metrics from intermediate values."""
165167 raise NotImplementedError ("Must override compute()" )
166168
@@ -169,9 +171,13 @@ def empty(cls: type[M]) -> M:
169171 """Returns an empty instance (i.e. `.merge(Metric.empty())` is a no-op)."""
170172 raise NotImplementedError ("Must override empty()" )
171173
172- def compute_value (self ) -> clu .values .Value :
173- """Wraps compute() and returns a values.Value."""
174- return clu .values .Scalar (self .compute ())
174+ def compute_value (self ) -> V :
175+ """Wraps compute() and returns a values.Value or dict of values.Value."""
176+ result = self .compute ()
177+ if isinstance (result , dict ):
178+ return {k : clu .values .Scalar (v ) for k , v in result .items ()}
179+ else :
180+ return clu .values .Scalar (result )
175181
176182 def reduce (self : M ) -> M :
177183 """Reduces the metric along it first axis by calling `_reduce_merge()`.
@@ -623,22 +629,32 @@ def reduce(self: C) -> C:
623629 })
624630
625631 def compute (self ) -> dict [str , jnp .ndarray ]:
626- """Returns a dictionary mapping metric field name to `Metric.compute()`."""
627- _check_reduction_counter_ndim (self ._reduction_counter )
628- return {
629- metric_name : metric .compute ()
630- for metric_name , metric in vars (self ).items ()
631- if metric_name != "_reduction_counter"
632- }
632+ """Returns a dictionary mapping metrics to their computed values."""
633+ metric_results = {}
634+ for metric_name , metric in vars (self ).items ():
635+ if metric_name != "_reduction_counter" :
636+ metric_result = metric .compute ()
637+ if isinstance (metric_result , dict ):
638+ metric_results .update (
639+ {f"{ metric_name } /{ k } " : v for k , v in metric_result .items ()}
640+ )
641+ else :
642+ metric_results [metric_name ] = metric_result
643+ return metric_results
633644
634645 def compute_values (self ) -> dict [str , clu .values .Value ]:
635- """Computes metrics and returns them as clu.values.Value."""
636- _check_reduction_counter_ndim (self ._reduction_counter )
637- return {
638- metric_name : metric .compute_value ()
639- for metric_name , metric in vars (self ).items ()
640- if metric_name != "_reduction_counter"
641- }
646+ """Computes metrics and returns them as clu_values.Value."""
647+ metric_results = {}
648+ for metric_name , metric in vars (self ).items ():
649+ if metric_name != "_reduction_counter" :
650+ metric_result = metric .compute_value ()
651+ if isinstance (metric_result , dict ):
652+ metric_results .update (
653+ {f"{ metric_name } /{ k } " : v for k , v in metric_result .items ()}
654+ )
655+ else :
656+ metric_results [metric_name ] = metric_result
657+ return metric_results
642658
643659 def unreplicate (self : C ) -> C :
644660 """Short-hand for `flax.jax_utils.unreplicate(self)`.
0 commit comments