Skip to content

Commit ce9de75

Browse files
authored
[Cursor] Improve Reasoning Tokens Documentation and Implementation (#99)
* Refactor plan_exec_llm to use centralized LLM query function - Update plan_exec_llm to use query_llm from llm_api - Remove redundant LLM client creation and token tracking logic - Add support for multiple LLM providers and models via CLI arguments - Simplify token usage tracking by leveraging existing infrastructure - Remove hardcoded OpenAI-specific code to improve provider flexibility * [Cursor] Improve Reasoning Tokens Documentation and Implementation This commit improves the handling and documentation of reasoning tokens across the codebase: - Added comprehensive docstrings explaining reasoning tokens - Enhanced query_llm function documentation for provider-specific behaviors - Fixed token tracking for o1 model and non-o1 models - Improved test coverage and documentation - Added CHANGELOG.md to track changes Key technical details: - Reasoning tokens are o1-specific (OpenAI's most advanced model) - All other models have reasoning_tokens=None - Token tracking behavior varies by provider (OpenAI, Anthropic, Gemini) Testing: - All 21 tests passing - Added specific test cases for reasoning tokens - Improved test documentation and coverage * update token check logic
1 parent 3d0bad9 commit ce9de75

File tree

7 files changed

+134
-108
lines changed

7 files changed

+134
-108
lines changed

.cursorrules

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ If needed, you can further use the `web_scraper.py` file to scrape the web page
127127
- When using seaborn styles in matplotlib, use 'seaborn-v0_8' instead of 'seaborn' as the style name due to recent seaborn version changes
128128
- Use `gpt-4o` as the model name for OpenAI. It is the latest GPT model and has vision capabilities as well. `o1` is the most advanced and expensive model from OpenAI. Use it when you need to do reasoning, planning, or get blocked.
129129
- Use `claude-3-5-sonnet-20241022` as the model name for Claude. It is the latest Claude model and has vision capabilities as well.
130+
- When running Python scripts that import from other local modules, use `PYTHONPATH=.` to ensure Python can find the modules. For example: `PYTHONPATH=. python tools/plan_exec_llm.py` instead of just `python tools/plan_exec_llm.py`. This is especially important when using relative imports.
130131

131132
# Multi-Agent Scratchpad
132133

CHANGELOG.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Changelog
2+
3+
## [Unreleased]
4+
5+
### Added
6+
- Comprehensive documentation for reasoning tokens across the codebase
7+
- Detailed test cases for token tracking with different providers
8+
- Clear docstrings explaining provider-specific token tracking behavior
9+
10+
### Changed
11+
- Updated `query_llm` function to properly handle reasoning tokens for o1 model
12+
- Improved test coverage for token tracking across all providers
13+
- Enhanced documentation in test files to clarify token tracking behavior
14+
15+
### Fixed
16+
- Proper handling of reasoning tokens for non-o1 models (explicitly set to None)
17+
- Token tracking tests to verify correct behavior for all providers

tests/test_llm_api.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
from unittest.mock import patch, MagicMock, mock_open
33
from tools.llm_api import create_llm_client, query_llm, load_environment
4-
from tools.token_tracker import TokenUsage, APIResponse
4+
from tools.token_tracker import TokenUsage, APIResponse, get_token_tracker
55
import os
66
import google.generativeai as genai
77
import io
@@ -202,18 +202,43 @@ def test_query_azure(self, mock_create_client):
202202
)
203203

204204
@patch('tools.llm_api.create_llm_client')
205-
def test_query_deepseek(self, mock_create_client):
205+
@patch('tools.llm_api.get_token_tracker')
206+
def test_query_deepseek(self, mock_get_tracker, mock_create_client):
207+
"""Test querying DeepSeek API with token tracking.
208+
209+
DeepSeek uses OpenAI-compatible API but like most models does not support
210+
reasoning tokens (only OpenAI's o1 model has this feature).
211+
"""
206212
mock_create_client.return_value = self.mock_openai_client
213+
mock_tracker = MagicMock()
214+
mock_get_tracker.return_value = mock_tracker
215+
216+
# Set up mock response with usage data
217+
self.mock_openai_response.usage = MagicMock()
218+
self.mock_openai_response.usage.prompt_tokens = 10
219+
self.mock_openai_response.usage.completion_tokens = 5
220+
self.mock_openai_response.usage.total_tokens = 15
221+
207222
response = query_llm("Test prompt", provider="deepseek", model="deepseek-chat")
208223
self.assertEqual(response, "Test OpenAI response")
209224
self.mock_openai_client.chat.completions.create.assert_called_once_with(
210225
model="deepseek-chat",
211226
messages=[{"role": "user", "content": [{"type": "text", "text": "Test prompt"}]}],
212227
temperature=0.7
213228
)
229+
# Verify token usage tracking for OpenAI-style providers
230+
self.assertTrue(mock_tracker.track_request.called)
231+
api_response = mock_tracker.track_request.call_args[0][0]
232+
# Verify reasoning_tokens is None since this is not the o1 model
233+
self.assertIsNone(api_response.token_usage.reasoning_tokens)
214234

215235
@patch('tools.llm_api.create_llm_client')
216236
def test_query_anthropic(self, mock_create_client):
237+
"""Test querying Anthropic API.
238+
239+
Note: Anthropic's API has its own token tracking system that differs from OpenAI's.
240+
It does not support reasoning tokens (which is an OpenAI o1-specific feature).
241+
"""
217242
mock_create_client.return_value = self.mock_anthropic_client
218243
response = query_llm("Test prompt", provider="anthropic", model="claude-3-5-sonnet-20241022")
219244
self.assertEqual(response, "Test Anthropic response")
@@ -222,6 +247,7 @@ def test_query_anthropic(self, mock_create_client):
222247
max_tokens=1000,
223248
messages=[{"role": "user", "content": [{"type": "text", "text": "Test prompt"}]}]
224249
)
250+
# Note: Token tracking is not yet implemented for Anthropic
225251

226252
@patch('tools.llm_api.create_llm_client')
227253
def test_query_gemini(self, mock_create_client):
@@ -243,8 +269,26 @@ def test_query_with_custom_model(self, mock_create_client):
243269
)
244270

245271
@patch('tools.llm_api.create_llm_client')
246-
def test_query_o1_model(self, mock_create_client):
272+
@patch('tools.llm_api.get_token_tracker')
273+
def test_query_o1_model(self, mock_get_tracker, mock_create_client):
274+
"""Test querying OpenAI's o1 model.
275+
276+
The o1 model is special in that it:
277+
1. Uses a different response format
278+
2. Has a reasoning_effort parameter
279+
3. Is the only model that provides reasoning_tokens in its response
280+
"""
247281
mock_create_client.return_value = self.mock_openai_client
282+
mock_tracker = MagicMock()
283+
mock_get_tracker.return_value = mock_tracker
284+
285+
# Set up mock response with usage data including reasoning tokens
286+
self.mock_openai_response.usage = MagicMock()
287+
self.mock_openai_response.usage.prompt_tokens = 10
288+
self.mock_openai_response.usage.completion_tokens = 5
289+
self.mock_openai_response.usage.total_tokens = 15
290+
self.mock_openai_response.usage.reasoning_tokens = 3 # o1 model provides this
291+
248292
response = query_llm("Test prompt", provider="openai", model="o1")
249293
self.assertEqual(response, "Test OpenAI response")
250294
self.mock_openai_client.chat.completions.create.assert_called_once_with(
@@ -253,6 +297,11 @@ def test_query_o1_model(self, mock_create_client):
253297
response_format={"type": "text"},
254298
reasoning_effort="low"
255299
)
300+
301+
# Verify token usage tracking includes reasoning tokens for o1 model
302+
self.assertTrue(mock_tracker.track_request.called)
303+
api_response = mock_tracker.track_request.call_args[0][0]
304+
self.assertEqual(api_response.token_usage.reasoning_tokens, 3)
256305

257306
@patch('tools.llm_api.create_llm_client')
258307
def test_query_with_existing_client(self, mock_create_client):

tests/test_plan_exec_llm.py

Lines changed: 30 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
# Add the parent directory to the Python path so we can import the module
1010
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11-
from tools.plan_exec_llm import load_environment, read_plan_status, read_file_content, create_llm_client, query_llm
12-
from tools.plan_exec_llm import TokenUsage
11+
from tools.plan_exec_llm import load_environment, read_plan_status, read_file_content, query_llm_with_plan
12+
from tools.token_tracker import TokenUsage
1313

1414
class TestPlanExecLLM(unittest.TestCase):
1515
def setUp(self):
@@ -18,9 +18,13 @@ def setUp(self):
1818
self.original_env = dict(os.environ)
1919
# Set test environment variables
2020
os.environ['OPENAI_API_KEY'] = 'test_key'
21+
os.environ['DEEPSEEK_API_KEY'] = 'test_deepseek_key'
22+
os.environ['ANTHROPIC_API_KEY'] = 'test_anthropic_key'
2123

2224
self.test_env_content = """
2325
OPENAI_API_KEY=test_key
26+
DEEPSEEK_API_KEY=test_deepseek_key
27+
ANTHROPIC_API_KEY=test_anthropic_key
2428
"""
2529
self.test_plan_content = """
2630
# Multi-Agent Scratchpad
@@ -66,55 +70,35 @@ def test_read_file_content(self):
6670
content = read_file_content('nonexistent_file.txt')
6771
self.assertIsNone(content)
6872

69-
@patch('tools.plan_exec_llm.OpenAI')
70-
def test_create_llm_client(self, mock_openai):
71-
"""Test LLM client creation"""
72-
mock_client = MagicMock()
73-
mock_openai.return_value = mock_client
74-
75-
client = create_llm_client()
76-
self.assertEqual(client, mock_client)
77-
mock_openai.assert_called_once_with(api_key='test_key')
78-
79-
@patch('tools.plan_exec_llm.create_llm_client')
80-
def test_query_llm(self, mock_create_client):
81-
"""Test LLM querying"""
82-
# Mock the OpenAI response
83-
mock_response = MagicMock()
84-
mock_response.choices = [MagicMock()]
85-
mock_response.choices[0].message = MagicMock()
86-
mock_response.choices[0].message.content = "Test response"
87-
mock_response.usage = MagicMock()
88-
mock_response.usage.prompt_tokens = 10
89-
mock_response.usage.completion_tokens = 5
90-
mock_response.usage.total_tokens = 15
91-
mock_response.usage.completion_tokens_details = MagicMock()
92-
mock_response.usage.completion_tokens_details.reasoning_tokens = None
93-
94-
mock_client = MagicMock()
95-
mock_client.chat.completions.create.return_value = mock_response
96-
mock_create_client.return_value = mock_client
73+
@patch('tools.llm_api.query_llm')
74+
def test_query_llm_with_plan(self, mock_query_llm):
75+
"""Test LLM querying with plan context"""
76+
# Mock the LLM response
77+
mock_query_llm.return_value = "Test response"
9778

9879
# Test with various combinations of parameters
99-
response = query_llm("Test plan", "Test prompt", "Test file content")
100-
self.assertEqual(response, "Test response")
80+
with patch('tools.plan_exec_llm.query_llm') as mock_plan_query_llm:
81+
mock_plan_query_llm.return_value = "Test response"
82+
response = query_llm_with_plan("Test plan", "Test prompt", "Test file content", provider="openai", model="gpt-4o")
83+
self.assertEqual(response, "Test response")
84+
mock_plan_query_llm.assert_called_with(unittest.mock.ANY, model="gpt-4o", provider="openai")
10185

102-
response = query_llm("Test plan", "Test prompt")
103-
self.assertEqual(response, "Test response")
86+
# Test with DeepSeek
87+
response = query_llm_with_plan("Test plan", "Test prompt", provider="deepseek")
88+
self.assertEqual(response, "Test response")
89+
mock_plan_query_llm.assert_called_with(unittest.mock.ANY, model=None, provider="deepseek")
10490

105-
response = query_llm("Test plan")
106-
self.assertEqual(response, "Test response")
91+
# Test with Anthropic
92+
response = query_llm_with_plan("Test plan", provider="anthropic")
93+
self.assertEqual(response, "Test response")
94+
mock_plan_query_llm.assert_called_with(unittest.mock.ANY, model=None, provider="anthropic")
10795

108-
# Verify the OpenAI client was called with correct parameters
109-
mock_client.chat.completions.create.assert_called_with(
110-
model="o1",
111-
messages=[
112-
{"role": "system", "content": ""},
113-
{"role": "user", "content": unittest.mock.ANY}
114-
],
115-
response_format={"type": "text"},
116-
reasoning_effort="low"
117-
)
96+
# Verify the prompt format
97+
calls = mock_plan_query_llm.call_args_list
98+
for call in calls:
99+
prompt = call[0][0]
100+
self.assertIn("Multi-Agent Scratchpad", prompt)
101+
self.assertIn("Test plan", prompt)
118102

119103
if __name__ == '__main__':
120104
unittest.main()

tools/llm_api.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,24 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat
121121
Args:
122122
prompt (str): The text prompt to send
123123
client: The LLM client instance
124-
model (str, optional): The model to use
124+
model (str, optional): The model to use. Special handling for OpenAI's o1 model:
125+
- Uses different response format
126+
- Has reasoning_effort parameter
127+
- Is the only model that provides reasoning_tokens in its response
125128
provider (str): The API provider to use
126129
image_path (str, optional): Path to an image file to attach
127130
128131
Returns:
129132
Optional[str]: The LLM's response or None if there was an error
133+
134+
Note:
135+
Token tracking behavior varies by provider:
136+
- OpenAI-style APIs (OpenAI, Azure, DeepSeek, Local): Full token tracking
137+
- Anthropic: Has its own token tracking system (input/output tokens)
138+
- Gemini: Token tracking not yet implemented
139+
140+
Reasoning tokens are only available when using OpenAI's o1 model.
141+
For all other models, reasoning_tokens will be None.
130142
"""
131143
if client is None:
132144
client = create_llm_client(provider)
@@ -187,7 +199,7 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat
187199
prompt_tokens=response.usage.prompt_tokens,
188200
completion_tokens=response.usage.completion_tokens,
189201
total_tokens=response.usage.total_tokens,
190-
reasoning_tokens=response.usage.completion_tokens_details.reasoning_tokens if hasattr(response.usage, 'completion_tokens_details') else None
202+
reasoning_tokens=response.usage.reasoning_tokens if model.lower().startswith("o") else None # Only checks if model starts with "o", e.g., o1, o1-preview, o1-mini, o3, etc. Can update this logic to specific models in the future.
191203
)
192204

193205
# Calculate cost

tools/plan_exec_llm.py

Lines changed: 10 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import argparse
44
import os
55
from pathlib import Path
6-
from openai import OpenAI
76
from dotenv import load_dotenv
87
import sys
98
import time
10-
from .token_tracker import TokenUsage, APIResponse, get_token_tracker
9+
from tools.token_tracker import TokenUsage, APIResponse, get_token_tracker
10+
from tools.llm_api import query_llm, create_llm_client
1111

1212
STATUS_FILE = '.cursorrules'
1313

@@ -52,17 +52,8 @@ def read_file_content(file_path):
5252
print(f"Error reading {file_path}: {e}", file=sys.stderr)
5353
return None
5454

55-
def create_llm_client():
56-
"""Create OpenAI client"""
57-
api_key = os.getenv('OPENAI_API_KEY')
58-
if not api_key:
59-
raise ValueError("OPENAI_API_KEY not found in environment variables")
60-
return OpenAI(api_key=api_key)
61-
62-
def query_llm(plan_content, user_prompt=None, file_content=None):
55+
def query_llm_with_plan(plan_content, user_prompt=None, file_content=None, provider="openai", model=None):
6356
"""Query the LLM with combined prompts"""
64-
client = create_llm_client()
65-
6657
# Combine prompts
6758
system_prompt = """"""
6859

@@ -93,54 +84,16 @@ def query_llm(plan_content, user_prompt=None, file_content=None):
9384
We will do the actual changes in the .cursorrules file.
9485
"""
9586

96-
try:
97-
start_time = time.time()
98-
response = client.chat.completions.create(
99-
model="o1",
100-
messages=[
101-
{"role": "system", "content": system_prompt},
102-
{"role": "user", "content": combined_prompt}
103-
],
104-
response_format={"type": "text"},
105-
reasoning_effort="low"
106-
)
107-
thinking_time = time.time() - start_time
108-
109-
# Track token usage
110-
token_usage = TokenUsage(
111-
prompt_tokens=response.usage.prompt_tokens,
112-
completion_tokens=response.usage.completion_tokens,
113-
total_tokens=response.usage.total_tokens,
114-
reasoning_tokens=response.usage.completion_tokens_details.reasoning_tokens if hasattr(response.usage, 'completion_tokens_details') else None
115-
)
116-
117-
# Calculate cost
118-
cost = get_token_tracker().calculate_openai_cost(
119-
token_usage.prompt_tokens,
120-
token_usage.completion_tokens,
121-
"o1"
122-
)
123-
124-
# Track the request
125-
api_response = APIResponse(
126-
content=response.choices[0].message.content,
127-
token_usage=token_usage,
128-
cost=cost,
129-
thinking_time=thinking_time,
130-
provider="openai",
131-
model="o1"
132-
)
133-
get_token_tracker().track_request(api_response)
134-
135-
return response.choices[0].message.content
136-
except Exception as e:
137-
print(f"Error querying LLM: {e}", file=sys.stderr)
138-
return None
87+
# Use the imported query_llm function
88+
response = query_llm(combined_prompt, model=model, provider=provider)
89+
return response
13990

14091
def main():
141-
parser = argparse.ArgumentParser(description='Query OpenAI o1 model with project plan context')
92+
parser = argparse.ArgumentParser(description='Query LLM with project plan context')
14293
parser.add_argument('--prompt', type=str, help='Additional prompt to send to the LLM', required=False)
14394
parser.add_argument('--file', type=str, help='Path to a file whose content should be included in the prompt', required=False)
95+
parser.add_argument('--provider', choices=['openai','anthropic','gemini','local','deepseek','azure'], default='openai', help='The API provider to use')
96+
parser.add_argument('--model', type=str, help='The model to use (default depends on provider)')
14497
args = parser.parse_args()
14598

14699
# Load environment variables
@@ -157,7 +110,7 @@ def main():
157110
sys.exit(1)
158111

159112
# Query LLM and output response
160-
response = query_llm(plan_content, args.prompt, file_content)
113+
response = query_llm_with_plan(plan_content, args.prompt, file_content, provider=args.provider, model=args.model)
161114
if response:
162115
print('Following is the instruction on how to revise the Multi-Agent Scratchpad section in .cursorrules:')
163116
print('========================================================')

tools/token_tracker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@
1414

1515
@dataclass
1616
class TokenUsage:
17+
"""Token usage information for an LLM API request.
18+
19+
Attributes:
20+
prompt_tokens: Number of tokens in the input prompt
21+
completion_tokens: Number of tokens in the model's response
22+
total_tokens: Total number of tokens used (prompt + completion)
23+
reasoning_tokens: Number of tokens used for reasoning (only available for OpenAI's o1 model)
24+
This is a special field that's only populated when using OpenAI's o1 model.
25+
For all other models (including other OpenAI models), this will be None.
26+
"""
1727
prompt_tokens: int
1828
completion_tokens: int
1929
total_tokens: int

0 commit comments

Comments
 (0)