-
-
Notifications
You must be signed in to change notification settings - Fork 704
feat : add datasource sql and csv for webapp #772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
# Run this app with `python app.py` and | ||
# visit http://127.0.0.1:8050/ in your web browser. | ||
|
||
import time | ||
import dash | ||
import dash_table | ||
import dash_core_components as dcc | ||
|
@@ -16,6 +17,8 @@ | |
import plotly.graph_objects as go | ||
|
||
import os | ||
import getpass | ||
import glob | ||
import numpy as np | ||
import pandas as pd | ||
import json | ||
|
@@ -34,6 +37,8 @@ | |
from vectorbt.portfolio.enums import Direction, DirectionConflictMode | ||
from vectorbt.portfolio.base import Portfolio | ||
|
||
from dal_stock_sql.dal import MarketDataRepository | ||
|
||
USE_CACHING = os.environ.get( | ||
"USE_CACHING", | ||
"True", | ||
|
@@ -463,6 +468,10 @@ | |
"Reset", | ||
id="reset_button" | ||
), | ||
html.Button( | ||
"CleanCache", | ||
id="clean_button" | ||
), | ||
html.Details( | ||
open=True, | ||
children=[ | ||
|
@@ -509,7 +518,20 @@ | |
value=default_interval, | ||
), | ||
] | ||
) | ||
), | ||
dbc.Col( | ||
children=[ | ||
html.Label("Source:"), | ||
dcc.Dropdown( | ||
id="datasource", | ||
options=[{"value": 'yahoo', "label": 'yahoo'}, | ||
{"value": 'csv', "label": 'csv'}, | ||
{"value": 'csv_all', "label": 'csv_all'}, | ||
{"value": 'sql', "label": 'sql'}], | ||
value='csv', | ||
), | ||
] | ||
), | ||
], | ||
), | ||
html.Label("Filter period:"), | ||
|
@@ -938,8 +960,26 @@ | |
) | ||
|
||
|
||
data_mode = 'sql' | ||
df_all_symbol = None | ||
sql_dal = None | ||
|
||
|
||
@cache.memoize() | ||
def fetch_data(symbol, period, interval, auto_adjust, back_adjust): | ||
"""Fetch OHLCV data from backend.""" | ||
global data_mode | ||
if data_mode == 'csv_all': | ||
return fetch_data_csv_all(symbol, period, interval, auto_adjust, back_adjust) | ||
elif data_mode == 'csv': | ||
return fetch_data_csv(symbol, period, interval, auto_adjust, back_adjust) | ||
elif data_mode == 'sql': | ||
return fetch_data_sql(symbol, period, interval, auto_adjust, back_adjust) | ||
return fetch_data_yf(symbol, period, interval, auto_adjust, back_adjust) | ||
|
||
|
||
@cache.memoize() | ||
def fetch_data_yf(symbol, period, interval, auto_adjust, back_adjust): | ||
"""Fetch OHLCV data from Yahoo! Finance.""" | ||
return yf.Ticker(symbol).history( | ||
period=period, | ||
|
@@ -950,6 +990,73 @@ def fetch_data(symbol, period, interval, auto_adjust, back_adjust): | |
) | ||
|
||
|
||
@cache.memoize() | ||
def fetch_data_csv(symbol, period, interval, auto_adjust, back_adjust, csvdir='~/trade/data/kaggle_allUS_daily_with_volume_yahoo/stocks'): | ||
"""Fetch OHLCV data from csv containing one symbols.""" | ||
csvfile = csvdir + '/' + symbol + '.csv' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change to f string |
||
_start = time.perf_counter() | ||
df = pd.read_csv(csvfile, parse_dates=['Date'], index_col='Date') | ||
_duration = time.perf_counter() - _start | ||
return df | ||
|
||
|
||
@cache.memoize() | ||
def fetch_data_csv_all(symbol, period, interval, auto_adjust, back_adjust, csvfile='~/trade/data/test.csv'): | ||
"""Fetch OHLCV data from csv containing multiple symbols.""" | ||
global df_all_symbol | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a global doesn't feel appropriate there. |
||
csvfile = '~/trade/data/test.csv' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we shouldn't use |
||
if df_all_symbol is None: | ||
_start = time.perf_counter() | ||
df_all_symbol = pd.read_csv(csvfile, parse_dates=['date']) | ||
df_all_symbol = df_all_symbol.rename( | ||
columns={"date": "Date", "open": "Open", "high": "High", "low": "Low", "close": "Close", | ||
"volume": "Volume"}) | ||
df_all_symbol['Date'] = df_all_symbol['Date'].map( | ||
lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. total nitpick but don't say 't' be more descriptive about this variable |
||
df_all_symbol = df_all_symbol.set_index(pd.DatetimeIndex(df_all_symbol['Date'])) | ||
_duration = time.perf_counter() - _start | ||
res = df_all_symbol[df_all_symbol['symbol'] == symbol] | ||
res = res.drop(['symbol'], axis=1) | ||
return res | ||
|
||
|
||
@cache.memoize() | ||
def fetch_data_sql(symbol, period, interval, auto_adjust, back_adjust): | ||
"""Fetch OHLCV data from SQL.""" | ||
global sql_dal | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove global |
||
if interval == '60m': | ||
dbtableRead = 'stock_market.ohlcv_1_hour' | ||
elif interval == '15m': | ||
dbtableRead = 'stock_market.ohlcv_15_minute' | ||
elif interval == '1d': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems counter intuitive. You have args "period" and "interval" but you're checking |
||
dbtableRead = 'stock_market.ohlcv_1_day' | ||
else: | ||
raise(ValueError(f"error unknown interval {interval}")) | ||
if sql_dal is None: | ||
user = getpass.getuser() | ||
password = os.getenv('pg_password') | ||
sql_dal = MarketDataRepository.Instance() | ||
sql_dal.init(user=user, password=password) | ||
res = sql_dal.get_one_symbol(dbtableRead, symbol) | ||
res = res.rename( | ||
columns={"date": "Date", "open": "Open", "high": "High", "low": "Low", "close": "Close", | ||
"volume": "Volume"}) | ||
return res | ||
|
||
|
||
@app.callback( | ||
[Output('clean_button', 'children')], | ||
[Input('clean_button', 'n_clicks')], | ||
prevent_initial_call=True | ||
) | ||
def clean_cache(_): | ||
"""Clean all data in cache.""" | ||
files = glob.glob('data/*') | ||
for f in files: | ||
os.remove(f) | ||
return | ||
|
||
|
||
@app.callback( | ||
[Output('data_signal', 'children'), | ||
Output('index_signal', 'children')], | ||
|
@@ -962,8 +1069,11 @@ def update_data(symbol, period, interval, yf_options): | |
"""Store data into a hidden DIV to avoid repeatedly calling Yahoo's API.""" | ||
auto_adjust = 'auto_adjust' in yf_options | ||
back_adjust = 'back_adjust' in yf_options | ||
_start = time.perf_counter() | ||
df = fetch_data(symbol, period, interval, auto_adjust, back_adjust) | ||
return df.to_json(date_format='iso', orient='split'), df.index.tolist() | ||
_duration = time.perf_counter() - _start | ||
res = df.to_json(date_format='iso', orient='split'), df.index.tolist() | ||
return res | ||
|
||
|
||
@app.callback( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import datetime as dt | ||
import psycopg2 | ||
from psycopg2 import pool | ||
import pandas as pd | ||
import logging | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class MarketDataRepository: | ||
def __init__(self): | ||
"""Initialize basic structure but don't connect yet""" | ||
self.connection_pool = None | ||
self.ver_twsapi = None | ||
|
||
def init(self, database='postgres', user='admin', password='admin', host='127.0.0.1', port='5432', | ||
ver_twsapi=10, maxconn=2): | ||
self.ver_twsapi = ver_twsapi | ||
# Establishing the connection | ||
logger.info('connecting to host %s and DB %s' % (host, database)) | ||
self.connection_pool = psycopg2.pool.ThreadedConnectionPool( | ||
1, maxconn, # minconn, maxconn | ||
user=user, | ||
password=password, | ||
host=host, | ||
port=port, | ||
database=database | ||
) | ||
|
||
def get_connection(self): | ||
try: | ||
conn = self.connection_pool.getconn() | ||
except Exception as e: | ||
raise e | ||
return conn | ||
|
||
def release_connection(self, connection): | ||
self.connection_pool.putconn(connection) | ||
|
||
def end(self): | ||
self.connection_pool.closeall() | ||
|
||
def get_all_symbol(self, dbtable: str, date_start: dt = None, date_end: dt = None): | ||
if None not in [date_start, date_end]: | ||
sql = '''SELECT * FROM %s WHERE date >= to_timestamp(%s) AND date <= to_timestamp(%s);''' % \ | ||
(dbtable, date_start.strftime('%s'), date_end.strftime('%s')) | ||
elif date_start == date_end: | ||
sql = '''SELECT * FROM %s ;''' % (dbtable) | ||
print(sql) | ||
elif date_end is None: | ||
sql = '''SELECT * FROM %s WHERE date >= to_timestamp(%s);''' % \ | ||
(dbtable, date_start.strftime('%s')) | ||
else: | ||
sql = '''SELECT * FROM %s WHERE date <= to_timestamp(%s);''' % \ | ||
(dbtable, date_end.strftime('%s')) | ||
# pandas only supports SQLAlchemy connectable (engine/connection) or database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 objects are not tested. Please consider using SQLAlchemy. | ||
conn = self.get_connection() | ||
try: | ||
with conn.cursor() as cursor: | ||
cursor.execute(sql) | ||
result = cursor.fetchall() | ||
columns = [desc[0] for desc in cursor.description] | ||
cursor.close() | ||
df = pd.DataFrame(result, columns=columns) | ||
# Specify the types directly at DataFrame creation | ||
df = df.astype({ | ||
'date': 'datetime64[ns, UTC]', | ||
'open': 'float', | ||
'high': 'float', | ||
'low': 'float', | ||
'close': 'float', | ||
'volume': 'int' | ||
}) | ||
except Exception as e: | ||
raise e | ||
finally: | ||
self.release_connection(conn) | ||
# remove TZ otherwise pandas DateTimeIndex lookup in pandas do not work | ||
#df['date'] = df['date'].map(lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime()) | ||
df = df.set_index(pd.DatetimeIndex(df['date'])) | ||
return df | ||
|
||
def get_one_symbol(self, dbtable: str, symbol: str, date_start: dt = None, date_end: dt = None): | ||
if ':' not in symbol: | ||
logger.error('Incorrect symbol %s' % symbol) | ||
return | ||
# raise ValueError | ||
if None not in [date_start, date_end]: | ||
sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date >= to_timestamp(%s) AND date <= to_timestamp(%s);''' % \ | ||
(dbtable, symbol, date_start.strftime('%s'), date_end.strftime('%s')) | ||
elif date_start == date_end: | ||
sql = '''SELECT * FROM %s WHERE symbol = '%s';''' % (dbtable, symbol) | ||
print(sql) | ||
elif date_end is None: | ||
sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date >= to_timestamp(%s);''' % \ | ||
(dbtable, symbol, date_start.strftime('%s')) | ||
else: | ||
sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date <= to_timestamp(%s);''' % \ | ||
(dbtable, symbol, date_end.strftime('%s')) | ||
# pandas only supports SQLAlchemy connectable (engine/connection) or database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 objects are not tested. Please consider using SQLAlchemy. | ||
conn = self.get_connection() | ||
try: | ||
with conn.cursor() as cursor: | ||
cursor.execute(sql) | ||
result = cursor.fetchall() | ||
columns = [desc[0] for desc in cursor.description] | ||
cursor.close() | ||
df = pd.DataFrame(result, columns=columns) | ||
# Specify the types directly at DataFrame creation | ||
df = df.astype({ | ||
'date': 'datetime64[ns, UTC]', | ||
'open': 'float', | ||
'high': 'float', | ||
'low': 'float', | ||
'close': 'float', | ||
'volume': 'int' | ||
}) | ||
except Exception as e: | ||
raise e | ||
finally: | ||
self.release_connection(conn) | ||
# remove TZ otherwise pandas DateTimeIndex lookup in pandas do not work | ||
#df['date'] = df['date'].map(lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime()) | ||
df = df.set_index(pd.DatetimeIndex(df['date'])) | ||
return df |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
csvdir is using
~
this shouldn't be hard coded