11"""Mapper for orchestrating variable transformations."""
22
3+ from dataclasses import dataclass
34from typing import Dict , Any , List , cast
45
56from mmirage .core .process .variables import BaseVar , InputVar , OutputVar
67from mmirage .core .process .base import AutoProcessor , BaseProcessor , BaseProcessorConfig
78
9+
10+ @dataclass
11+ class TokenCounts :
12+ """Cumulative token counts from LLM processors."""
13+
14+ input_tokens : int
15+ output_tokens : int
16+
817import logging
918
1019from mmirage .core .process .variables import VariableEnvironment
@@ -104,14 +113,14 @@ def rewrite_batch(
104113
105114 return batch_environment
106115
107- def get_token_counts (self ) -> Dict [ str , int ] :
116+ def get_token_counts (self ) -> TokenCounts :
108117 """Return cumulative token counts aggregated across all LLM processors.
109118
110119 Sums ``input_tokens`` and ``output_tokens`` from every processor that
111120 exposes a ``get_token_counts()`` method (i.e., ``LLMProcessor``).
112121
113122 Returns:
114- Dict with ``input_tokens`` and ``output_tokens`` keys .
123+ TokenCounts with ``input_tokens`` and ``output_tokens`` fields .
115124 """
116125 total_input = 0
117126 total_output = 0
@@ -120,7 +129,7 @@ def get_token_counts(self) -> Dict[str, int]:
120129 counts = proc .get_token_counts ()
121130 total_input += counts .get ("input_tokens" , 0 )
122131 total_output += counts .get ("output_tokens" , 0 )
123- return { " input_tokens" : total_input , " output_tokens" : total_output }
132+ return TokenCounts ( input_tokens = total_input , output_tokens = total_output )
124133
125134 def get_load_time (self ) -> float :
126135 """Return total model-loading time (seconds) summed across all LLM processors."""
0 commit comments