diff --git a/apps/candlestick-patterns/app.py b/apps/candlestick-patterns/app.py index ad208326..6dae0685 100644 --- a/apps/candlestick-patterns/app.py +++ b/apps/candlestick-patterns/app.py @@ -6,6 +6,7 @@ # Run this app with `python app.py` and # visit http://127.0.0.1:8050/ in your web browser. +import time import dash import dash_table import dash_core_components as dcc @@ -16,6 +17,8 @@ import plotly.graph_objects as go import os +import getpass +import glob import numpy as np import pandas as pd import json @@ -34,6 +37,8 @@ from vectorbt.portfolio.enums import Direction, DirectionConflictMode from vectorbt.portfolio.base import Portfolio +from dal_stock_sql.dal import MarketDataRepository + USE_CACHING = os.environ.get( "USE_CACHING", "True", @@ -463,6 +468,10 @@ "Reset", id="reset_button" ), + html.Button( + "CleanCache", + id="clean_button" + ), html.Details( open=True, children=[ @@ -509,7 +518,20 @@ value=default_interval, ), ] - ) + ), + dbc.Col( + children=[ + html.Label("Source:"), + dcc.Dropdown( + id="datasource", + options=[{"value": 'yahoo', "label": 'yahoo'}, + {"value": 'csv', "label": 'csv'}, + {"value": 'csv_all', "label": 'csv_all'}, + {"value": 'sql', "label": 'sql'}], + value='csv', + ), + ] + ), ], ), html.Label("Filter period:"), @@ -938,8 +960,26 @@ ) +data_mode = 'sql' +df_all_symbol = None +sql_dal = None + + @cache.memoize() def fetch_data(symbol, period, interval, auto_adjust, back_adjust): + """Fetch OHLCV data from backend.""" + global data_mode + if data_mode == 'csv_all': + return fetch_data_csv_all(symbol, period, interval, auto_adjust, back_adjust) + elif data_mode == 'csv': + return fetch_data_csv(symbol, period, interval, auto_adjust, back_adjust) + elif data_mode == 'sql': + return fetch_data_sql(symbol, period, interval, auto_adjust, back_adjust) + return fetch_data_yf(symbol, period, interval, auto_adjust, back_adjust) + + +@cache.memoize() +def fetch_data_yf(symbol, period, interval, auto_adjust, back_adjust): """Fetch OHLCV data from Yahoo! Finance.""" return yf.Ticker(symbol).history( period=period, @@ -950,6 +990,73 @@ def fetch_data(symbol, period, interval, auto_adjust, back_adjust): ) +@cache.memoize() +def fetch_data_csv(symbol, period, interval, auto_adjust, back_adjust, csvdir='~/trade/data/kaggle_allUS_daily_with_volume_yahoo/stocks'): + """Fetch OHLCV data from csv containing one symbols.""" + csvfile = csvdir + '/' + symbol + '.csv' + _start = time.perf_counter() + df = pd.read_csv(csvfile, parse_dates=['Date'], index_col='Date') + _duration = time.perf_counter() - _start + return df + + +@cache.memoize() +def fetch_data_csv_all(symbol, period, interval, auto_adjust, back_adjust, csvfile='~/trade/data/test.csv'): + """Fetch OHLCV data from csv containing multiple symbols.""" + global df_all_symbol + csvfile = '~/trade/data/test.csv' + if df_all_symbol is None: + _start = time.perf_counter() + df_all_symbol = pd.read_csv(csvfile, parse_dates=['date']) + df_all_symbol = df_all_symbol.rename( + columns={"date": "Date", "open": "Open", "high": "High", "low": "Low", "close": "Close", + "volume": "Volume"}) + df_all_symbol['Date'] = df_all_symbol['Date'].map( + lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime()) + df_all_symbol = df_all_symbol.set_index(pd.DatetimeIndex(df_all_symbol['Date'])) + _duration = time.perf_counter() - _start + res = df_all_symbol[df_all_symbol['symbol'] == symbol] + res = res.drop(['symbol'], axis=1) + return res + + +@cache.memoize() +def fetch_data_sql(symbol, period, interval, auto_adjust, back_adjust): + """Fetch OHLCV data from SQL.""" + global sql_dal + if interval == '60m': + dbtableRead = 'stock_market.ohlcv_1_hour' + elif interval == '15m': + dbtableRead = 'stock_market.ohlcv_15_minute' + elif interval == '1d': + dbtableRead = 'stock_market.ohlcv_1_day' + else: + raise(ValueError(f"error unknown interval {interval}")) + if sql_dal is None: + user = getpass.getuser() + password = os.getenv('pg_password') + sql_dal = MarketDataRepository() + sql_dal.init(user=user, password=password) + res = sql_dal.get_one_symbol(dbtableRead, symbol) + res = res.rename( + columns={"date": "Date", "open": "Open", "high": "High", "low": "Low", "close": "Close", + "volume": "Volume"}) + return res + + +@app.callback( + [Output('clean_button', 'children')], + [Input('clean_button', 'n_clicks')], + prevent_initial_call=True +) +def clean_cache(_): + """Clean all data in cache.""" + files = glob.glob('data/*') + for f in files: + os.remove(f) + return + + @app.callback( [Output('data_signal', 'children'), Output('index_signal', 'children')], @@ -962,8 +1069,11 @@ def update_data(symbol, period, interval, yf_options): """Store data into a hidden DIV to avoid repeatedly calling Yahoo's API.""" auto_adjust = 'auto_adjust' in yf_options back_adjust = 'back_adjust' in yf_options + _start = time.perf_counter() df = fetch_data(symbol, period, interval, auto_adjust, back_adjust) - return df.to_json(date_format='iso', orient='split'), df.index.tolist() + _duration = time.perf_counter() - _start + res = df.to_json(date_format='iso', orient='split'), df.index.tolist() + return res @app.callback( diff --git a/apps/candlestick-patterns/dal_stock_sql/dal.py b/apps/candlestick-patterns/dal_stock_sql/dal.py new file mode 100644 index 00000000..fe675b03 --- /dev/null +++ b/apps/candlestick-patterns/dal_stock_sql/dal.py @@ -0,0 +1,126 @@ +import datetime as dt +import psycopg2 +from psycopg2 import pool +import pandas as pd +import logging + + +logger = logging.getLogger(__name__) + + +class MarketDataRepository: + def __init__(self): + """Initialize basic structure but don't connect yet""" + self.connection_pool = None + self.ver_twsapi = None + + def init(self, database='postgres', user='admin', password='admin', host='127.0.0.1', port='5432', + ver_twsapi=10, maxconn=2): + self.ver_twsapi = ver_twsapi + # Establishing the connection + logger.info('connecting to host %s and DB %s' % (host, database)) + self.connection_pool = psycopg2.pool.ThreadedConnectionPool( + 1, maxconn, # minconn, maxconn + user=user, + password=password, + host=host, + port=port, + database=database + ) + + def get_connection(self): + try: + conn = self.connection_pool.getconn() + except Exception as e: + raise e + return conn + + def release_connection(self, connection): + self.connection_pool.putconn(connection) + + def end(self): + self.connection_pool.closeall() + + def get_all_symbol(self, dbtable: str, date_start: dt = None, date_end: dt = None): + if None not in [date_start, date_end]: + sql = '''SELECT * FROM %s WHERE date >= to_timestamp(%s) AND date <= to_timestamp(%s);''' % \ + (dbtable, date_start.strftime('%s'), date_end.strftime('%s')) + elif date_start == date_end: + sql = '''SELECT * FROM %s ;''' % (dbtable) + print(sql) + elif date_end is None: + sql = '''SELECT * FROM %s WHERE date >= to_timestamp(%s);''' % \ + (dbtable, date_start.strftime('%s')) + else: + sql = '''SELECT * FROM %s WHERE date <= to_timestamp(%s);''' % \ + (dbtable, date_end.strftime('%s')) + # pandas only supports SQLAlchemy connectable (engine/connection) or database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 objects are not tested. Please consider using SQLAlchemy. + conn = self.get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(sql) + result = cursor.fetchall() + columns = [desc[0] for desc in cursor.description] + cursor.close() + df = pd.DataFrame(result, columns=columns) + # Specify the types directly at DataFrame creation + df = df.astype({ + 'date': 'datetime64[ns, UTC]', + 'open': 'float', + 'high': 'float', + 'low': 'float', + 'close': 'float', + 'volume': 'int' + }) + except Exception as e: + raise e + finally: + self.release_connection(conn) + # remove TZ otherwise pandas DateTimeIndex lookup in pandas do not work + #df['date'] = df['date'].map(lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime()) + df = df.set_index(pd.DatetimeIndex(df['date'])) + return df + + def get_one_symbol(self, dbtable: str, symbol: str, date_start: dt = None, date_end: dt = None): + if ':' not in symbol: + logger.error('Incorrect symbol %s' % symbol) + return + # raise ValueError + if None not in [date_start, date_end]: + sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date >= to_timestamp(%s) AND date <= to_timestamp(%s);''' % \ + (dbtable, symbol, date_start.strftime('%s'), date_end.strftime('%s')) + elif date_start == date_end: + sql = '''SELECT * FROM %s WHERE symbol = '%s';''' % (dbtable, symbol) + print(sql) + elif date_end is None: + sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date >= to_timestamp(%s);''' % \ + (dbtable, symbol, date_start.strftime('%s')) + else: + sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date <= to_timestamp(%s);''' % \ + (dbtable, symbol, date_end.strftime('%s')) + # pandas only supports SQLAlchemy connectable (engine/connection) or database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 objects are not tested. Please consider using SQLAlchemy. + conn = self.get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(sql) + result = cursor.fetchall() + columns = [desc[0] for desc in cursor.description] + cursor.close() + df = pd.DataFrame(result, columns=columns) + # Specify the types directly at DataFrame creation + df = df.astype({ + 'date': 'datetime64[ns, UTC]', + 'open': 'float', + 'high': 'float', + 'low': 'float', + 'close': 'float', + 'volume': 'int' + }) + except Exception as e: + raise e + finally: + self.release_connection(conn) + # remove TZ otherwise pandas DateTimeIndex lookup in pandas do not work + #df['date'] = df['date'].map(lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime()) + df = df.set_index(pd.DatetimeIndex(df['date'])) + return df