Skip to content

Commit c85a02b

Browse files
committed
Update on pattern and trend agent to use precomputed images
1 parent 8fa3a1b commit c85a02b

File tree

6 files changed

+318
-85
lines changed

6 files changed

+318
-85
lines changed

pattern_agent.py

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def invoke_tool_with_retry(tool_fn, tool_args, retries=3, wait_sec=4):
2222
def create_pattern_agent(tool_llm, graph_llm, toolkit):
2323
"""
2424
Create a pattern recognition agent node for candlestick pattern analysis.
25-
The agent uses an LLM and a chart generation tool to identify classic trading patterns.
25+
The agent uses precomputed images from state or falls back to tool generation.
2626
"""
2727
def pattern_agent_node(state):
2828
# --- Tool and pattern definitions ---
@@ -49,25 +49,9 @@ def pattern_agent_node(state):
4949
16. Symmetrical Triangle: Highs and lows converge toward the apex, usually followed by a breakout.
5050
"""
5151

52-
# --- Step 1: System prompt setup ---
53-
prompt = ChatPromptTemplate.from_messages(
54-
[
55-
(
56-
"system",
57-
"You are a trading pattern recognition assistant tasked with identifying classical high-frequency trading patterns. "
58-
"You have access to tool: generate_kline_image"
59-
"Use it by providing appropriate arguments like `kline_data`\n\n"
60-
"Once the chart is generated, compare it to classical pattern descriptions and determine if any known pattern is present."
61-
),
62-
MessagesPlaceholder(variable_name="messages"),
63-
]
64-
).partial(
65-
kline_data=json.dumps(state["kline_data"], indent=2)
66-
)
67-
68-
chain = prompt | tool_llm.bind_tools(tools)
69-
messages = state.get("messages", [])
70-
52+
# --- Check for precomputed image in state ---
53+
pattern_image_b64 = state.get("pattern_image")
54+
7155
# --- Retry wrapper for LLM invocation ---
7256
def invoke_with_retry(call_fn, *args, retries=3, wait_sec=8):
7357
for attempt in range(retries):
@@ -81,30 +65,54 @@ def invoke_with_retry(call_fn, *args, retries=3, wait_sec=8):
8165
time.sleep(wait_sec)
8266
raise RuntimeError("Max retries exceeded")
8367

84-
# --- Step 2: First LLM call to determine tool usage ---
85-
ai_response = invoke_with_retry(chain.invoke, messages)
86-
messages.append(ai_response)
68+
messages = state.get("messages", [])
69+
70+
# --- If no precomputed image, fall back to tool generation ---
71+
if not pattern_image_b64:
72+
print("No precomputed pattern image found in state, generating with tool...")
73+
74+
# --- System prompt setup for tool generation ---
75+
prompt = ChatPromptTemplate.from_messages(
76+
[
77+
(
78+
"system",
79+
"You are a trading pattern recognition assistant tasked with identifying classical high-frequency trading patterns. "
80+
"You have access to tool: generate_kline_image. "
81+
"Use it by providing appropriate arguments like `kline_data`\n\n"
82+
"Once the chart is generated, compare it to classical pattern descriptions and determine if any known pattern is present."
83+
),
84+
MessagesPlaceholder(variable_name="messages"),
85+
]
86+
).partial(
87+
kline_data=json.dumps(state["kline_data"], indent=2)
88+
)
8789

88-
pattern_image_b64 = None
90+
chain = prompt | tool_llm.bind_tools(tools)
91+
92+
# --- Step 1: First LLM call to determine tool usage ---
93+
ai_response = invoke_with_retry(chain.invoke, messages)
94+
messages.append(ai_response)
8995

90-
# --- Step 3: Handle tool call (generate_kline_image) ---
91-
if hasattr(ai_response, "tool_calls"):
92-
for call in ai_response.tool_calls:
93-
tool_name = call["name"]
94-
tool_args = call["args"]
95-
# Always provide kline_data
96-
tool_args["kline_data"] = copy.deepcopy(state["kline_data"])
97-
tool_fn = next(t for t in tools if t.name == tool_name)
98-
tool_result = invoke_tool_with_retry(tool_fn, tool_args)
99-
pattern_image_b64 = tool_result.get("pattern_image")
100-
messages.append(
101-
ToolMessage(
102-
tool_call_id=call["id"],
103-
content=json.dumps(tool_result)
96+
# --- Step 2: Handle tool call (generate_kline_image) ---
97+
if hasattr(ai_response, "tool_calls"):
98+
for call in ai_response.tool_calls:
99+
tool_name = call["name"]
100+
tool_args = call["args"]
101+
# Always provide kline_data
102+
tool_args["kline_data"] = copy.deepcopy(state["kline_data"])
103+
tool_fn = next(t for t in tools if t.name == tool_name)
104+
tool_result = invoke_tool_with_retry(tool_fn, tool_args)
105+
pattern_image_b64 = tool_result.get("pattern_image")
106+
messages.append(
107+
ToolMessage(
108+
tool_call_id=call["id"],
109+
content=json.dumps(tool_result)
110+
)
104111
)
105-
)
112+
else:
113+
print("Using precomputed pattern image from state")
106114

107-
# --- Step 4: Second call with image (Vision LLM expects image_url + context) ---
115+
# --- Step 3: Vision analysis with image (precomputed or generated) ---
108116
if pattern_image_b64:
109117
image_prompt = [
110118
{

static_util.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import matplotlib
2+
matplotlib.use('Agg')
3+
import talib
4+
import pandas as pd
5+
import matplotlib.pyplot as plt
6+
import talib
7+
import numpy as np
8+
from langchain_core.tools import tool
9+
from typing import Annotated
10+
import mplfinance as mpf
11+
import base64
12+
import io
13+
import mplfinance as mpf
14+
import color_style as color
15+
16+
17+
def generate_kline_image(kline_data) -> dict:
18+
"""
19+
Generate a candlestick (K-line) chart from OHLCV data, save it locally, and return a base64-encoded image.
20+
21+
Args:
22+
kline_data (dict): Dictionary with keys including 'Datetime', 'Open', 'High', 'Low', 'Close'.
23+
filename (str): Name of the file to save the image locally (default: 'kline_chart.png').
24+
25+
Returns:
26+
dict: Dictionary containing base64-encoded image string and local file path.
27+
"""
28+
29+
df = pd.DataFrame(kline_data)
30+
# take recent 40
31+
df = df.tail(40)
32+
33+
df.to_csv("record.csv", index=False, date_format="%Y-%m-%d %H:%M:%S")
34+
try:
35+
# df.index = pd.to_datetime(df["Datetime"])
36+
df.index = pd.to_datetime(df["Datetime"], format="%Y-%m-%d %H:%M:%S")
37+
38+
except ValueError:
39+
print("ValueError at graph_util.py\n")
40+
41+
42+
43+
# Save image locally
44+
fig, axlist = mpf.plot(
45+
df[["Open", "High", "Low", "Close"]],
46+
type="candle",
47+
style=color.my_color_style,
48+
figsize=(12, 6),
49+
returnfig=True,
50+
block=False,
51+
52+
)
53+
axlist[0].set_ylabel('Price', fontweight='normal')
54+
axlist[0].set_xlabel('Datetime', fontweight='normal')
55+
56+
fig.savefig(
57+
fname="kline_chart.png",
58+
dpi=600,
59+
bbox_inches="tight",
60+
pad_inches=0.1,
61+
)
62+
plt.close(fig)
63+
# ---------- Encode to base64 -----------------
64+
buf = io.BytesIO()
65+
fig.savefig(buf, format="png", dpi=600, bbox_inches="tight", pad_inches=0.1)
66+
plt.close(fig) # release memory
67+
68+
buf.seek(0)
69+
img_b64 = base64.b64encode(buf.read()).decode("utf-8")
70+
71+
return {
72+
"pattern_image": img_b64,
73+
"pattern_image_description": "Candlestick chart saved locally and returned as base64 string."
74+
}
75+
76+
from graph_util import *
77+
def generate_trend_image(kline_data) -> dict:
78+
"""
79+
Generate a candlestick chart with trendlines from OHLCV data,
80+
save it locally as 'trend_graph.png', and return a base64-encoded image.
81+
82+
Returns:
83+
dict: base64 image and description
84+
"""
85+
data = pd.DataFrame(kline_data)
86+
candles = data.iloc[-50:].copy()
87+
88+
candles["Datetime"] = pd.to_datetime(candles["Datetime"])
89+
candles.set_index("Datetime", inplace=True)
90+
91+
# Trendline fit functions assumed to be defined outside this scope
92+
support_coefs_c, resist_coefs_c = fit_trendlines_single(candles['Close'])
93+
support_coefs, resist_coefs = fit_trendlines_high_low(candles['High'], candles['Low'], candles['Close'])
94+
95+
# Trendline values
96+
support_line_c = support_coefs_c[0] * np.arange(len(candles)) + support_coefs_c[1]
97+
resist_line_c = resist_coefs_c[0] * np.arange(len(candles)) + resist_coefs_c[1]
98+
support_line = support_coefs[0] * np.arange(len(candles)) + support_coefs[1]
99+
resist_line = resist_coefs[0] * np.arange(len(candles)) + resist_coefs[1]
100+
101+
# Convert to time-anchored coordinates
102+
s_seq = get_line_points(candles, support_line)
103+
r_seq = get_line_points(candles, resist_line)
104+
s_seq2 = get_line_points(candles, support_line_c)
105+
r_seq2 = get_line_points(candles, resist_line_c)
106+
107+
s_segments = split_line_into_segments(s_seq)
108+
r_segments = split_line_into_segments(r_seq)
109+
s2_segments = split_line_into_segments(s_seq2)
110+
r2_segments = split_line_into_segments(r_seq2)
111+
112+
all_segments = s_segments + r_segments + s2_segments + r2_segments
113+
colors = ['white'] * len(s_segments) + ['white'] * len(r_segments) + ['blue'] * len(s2_segments) + ['red'] * len(r2_segments)
114+
115+
# Create addplot lines for close-based support/resistance
116+
apds = [
117+
mpf.make_addplot(support_line_c, color='blue', width=1, label="Close Support"),
118+
mpf.make_addplot(resist_line_c, color='red', width=1, label="Close Resistance")
119+
]
120+
121+
# Generate figure with legend and save locally
122+
fig, axlist = mpf.plot(
123+
candles,
124+
type='candle',
125+
style=color.my_color_style,
126+
addplot=apds,
127+
alines=dict(alines=all_segments, colors=colors, linewidths=1),
128+
returnfig=True,
129+
figsize=(12, 6),
130+
block=False,
131+
)
132+
133+
axlist[0].set_ylabel('Price', fontweight='normal')
134+
axlist[0].set_xlabel('Datetime', fontweight='normal')
135+
136+
#save fig locally
137+
fig.savefig(
138+
"trend_graph.png",
139+
format="png",
140+
dpi=600,
141+
bbox_inches="tight",
142+
pad_inches=0.1
143+
)
144+
plt.close(fig)
145+
146+
# Add legend manually
147+
axlist[0].legend(loc='upper left')
148+
149+
# Save to base64
150+
buf = io.BytesIO()
151+
fig.savefig(buf, format="png")
152+
buf.seek(0)
153+
img_b64 = base64.b64encode(buf.read()).decode("utf-8")
154+
plt.close(fig)
155+
156+
return {
157+
"trend_image": img_b64,
158+
"trend_image_description": "Trend-enhanced candlestick chart with support/resistance lines."
159+
}

templates/demo_new.html

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,10 +1613,14 @@ <h4 class="panel-title">
16131613
body: JSON.stringify(requestData)
16141614
})
16151615
.then(response => response.json())
1616-
.then(data => {
1617-
if (data.redirect) {
1618-
// Redirect to output page with results
1619-
window.location.href = data.redirect;
1616+
.then(data => {
1617+
if (data.redirect) {
1618+
// Store full results (with images) in sessionStorage before redirect
1619+
if (data.full_results) {
1620+
sessionStorage.setItem('analysisResults', JSON.stringify(data.full_results));
1621+
}
1622+
// Redirect to output page with results
1623+
window.location.href = data.redirect;
16201624
} else if (data.error) {
16211625
// Show error message
16221626
alert('Analysis failed: ' + data.error);

templates/output.html

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,38 @@ <h4 class="panel-title">
662662
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
663663
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
664664
<script>
665+
// Check for full results in sessionStorage and update page
666+
document.addEventListener('DOMContentLoaded', function() {
667+
// Check if we have full results with images in sessionStorage
668+
const fullResults = sessionStorage.getItem('analysisResults');
669+
if (fullResults) {
670+
try {
671+
const results = JSON.parse(fullResults);
672+
673+
// Update pattern chart if available
674+
if (results.pattern_chart) {
675+
const patternImg = document.querySelector('img[alt="Pattern Analysis Chart"]');
676+
if (patternImg) {
677+
patternImg.src = `data:image/png;base64,${results.pattern_chart}`;
678+
}
679+
}
680+
681+
// Update trend chart if available
682+
if (results.trend_chart) {
683+
const trendImg = document.querySelector('img[alt="Trend Analysis Chart"]');
684+
if (trendImg) {
685+
trendImg.src = `data:image/png;base64,${results.trend_chart}`;
686+
}
687+
}
688+
689+
// Clear sessionStorage after using the results
690+
sessionStorage.removeItem('analysisResults');
691+
} catch (e) {
692+
console.error('Error parsing sessionStorage results:', e);
693+
}
694+
}
695+
});
696+
665697
// Convert markdown to HTML and beautify content when page loads
666698
document.addEventListener('DOMContentLoaded', function() {
667699
// Find all elements that might contain markdown

0 commit comments

Comments
 (0)