Skip to content

Commit a63dcfe

Browse files
authored
Refactor: improve cohere calculate total counts (#12007)
### What problem does this PR solve? improve cohere calculate total counts ### Type of change - [x] Refactoring
1 parent 4dd8cdc commit a63dcfe

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

common/token_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ def total_token_count_from_response(resp):
5656
except Exception:
5757
pass
5858

59+
if hasattr(resp, "meta") and hasattr(resp.meta, "billed_units") and hasattr(resp.meta.billed_units, "input_tokens"):
60+
try:
61+
return resp.meta.billed_units.input_tokens
62+
except Exception:
63+
pass
64+
5965
if isinstance(resp, dict) and 'usage' in resp and 'total_tokens' in resp['usage']:
6066
try:
6167
return resp["usage"]["total_tokens"]

rag/llm/embedding_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def encode(self, texts: list):
639639
)
640640
try:
641641
ress.extend([d for d in res.embeddings.float])
642-
token_count += res.meta.billed_units.input_tokens
642+
token_count += total_token_count_from_response(res)
643643
except Exception as _e:
644644
log_exception(_e, res)
645645
raise Exception(f"Error: {res}")
@@ -653,7 +653,7 @@ def encode_queries(self, text):
653653
embedding_types=["float"],
654654
)
655655
try:
656-
return np.array(res.embeddings.float[0]), int(res.meta.billed_units.input_tokens)
656+
return np.array(res.embeddings.float[0]), int(total_token_count_from_response(res))
657657
except Exception as _e:
658658
log_exception(_e, res)
659659
raise Exception(f"Error: {res}")

0 commit comments

Comments
 (0)