Skip to content

feat : add datasource sql and csv for webapp #772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 112 additions & 2 deletions apps/candlestick-patterns/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Run this app with `python app.py` and
# visit http://127.0.0.1:8050/ in your web browser.

import time
import dash
import dash_table
import dash_core_components as dcc
Expand All @@ -16,6 +17,8 @@
import plotly.graph_objects as go

import os
import getpass
import glob
import numpy as np
import pandas as pd
import json
Expand All @@ -34,6 +37,8 @@
from vectorbt.portfolio.enums import Direction, DirectionConflictMode
from vectorbt.portfolio.base import Portfolio

from dal_stock_sql.dal import MarketDataRepository

USE_CACHING = os.environ.get(
"USE_CACHING",
"True",
Expand Down Expand Up @@ -463,6 +468,10 @@
"Reset",
id="reset_button"
),
html.Button(
"CleanCache",
id="clean_button"
),
html.Details(
open=True,
children=[
Expand Down Expand Up @@ -509,7 +518,20 @@
value=default_interval,
),
]
)
),
dbc.Col(
children=[
html.Label("Source:"),
dcc.Dropdown(
id="datasource",
options=[{"value": 'yahoo', "label": 'yahoo'},
{"value": 'csv', "label": 'csv'},
{"value": 'csv_all', "label": 'csv_all'},
{"value": 'sql', "label": 'sql'}],
value='csv',
),
]
),
],
),
html.Label("Filter period:"),
Expand Down Expand Up @@ -938,8 +960,26 @@
)


data_mode = 'sql'
df_all_symbol = None
sql_dal = None


@cache.memoize()
def fetch_data(symbol, period, interval, auto_adjust, back_adjust):
"""Fetch OHLCV data from backend."""
global data_mode
if data_mode == 'csv_all':
return fetch_data_csv_all(symbol, period, interval, auto_adjust, back_adjust)
elif data_mode == 'csv':
return fetch_data_csv(symbol, period, interval, auto_adjust, back_adjust)
elif data_mode == 'sql':
return fetch_data_sql(symbol, period, interval, auto_adjust, back_adjust)
return fetch_data_yf(symbol, period, interval, auto_adjust, back_adjust)


@cache.memoize()
def fetch_data_yf(symbol, period, interval, auto_adjust, back_adjust):
"""Fetch OHLCV data from Yahoo! Finance."""
return yf.Ticker(symbol).history(
period=period,
Expand All @@ -950,6 +990,73 @@ def fetch_data(symbol, period, interval, auto_adjust, back_adjust):
)


@cache.memoize()
def fetch_data_csv(symbol, period, interval, auto_adjust, back_adjust, csvdir='~/trade/data/kaggle_allUS_daily_with_volume_yahoo/stocks'):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

csvdir is using ~ this shouldn't be hard coded

"""Fetch OHLCV data from csv containing one symbols."""
csvfile = csvdir + '/' + symbol + '.csv'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to f string

_start = time.perf_counter()
df = pd.read_csv(csvfile, parse_dates=['Date'], index_col='Date')
_duration = time.perf_counter() - _start
return df


@cache.memoize()
def fetch_data_csv_all(symbol, period, interval, auto_adjust, back_adjust, csvfile='~/trade/data/test.csv'):
"""Fetch OHLCV data from csv containing multiple symbols."""
global df_all_symbol

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a global doesn't feel appropriate there.

csvfile = '~/trade/data/test.csv'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't use ~ anywhere

if df_all_symbol is None:
_start = time.perf_counter()
df_all_symbol = pd.read_csv(csvfile, parse_dates=['date'])
df_all_symbol = df_all_symbol.rename(
columns={"date": "Date", "open": "Open", "high": "High", "low": "Low", "close": "Close",
"volume": "Volume"})
df_all_symbol['Date'] = df_all_symbol['Date'].map(
lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total nitpick but don't say 't' be more descriptive about this variable

df_all_symbol = df_all_symbol.set_index(pd.DatetimeIndex(df_all_symbol['Date']))
_duration = time.perf_counter() - _start
res = df_all_symbol[df_all_symbol['symbol'] == symbol]
res = res.drop(['symbol'], axis=1)
return res


@cache.memoize()
def fetch_data_sql(symbol, period, interval, auto_adjust, back_adjust):
"""Fetch OHLCV data from SQL."""
global sql_dal

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove global

if interval == '60m':
dbtableRead = 'stock_market.ohlcv_1_hour'
elif interval == '15m':
dbtableRead = 'stock_market.ohlcv_15_minute'
elif interval == '1d':

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems counter intuitive. You have args "period" and "interval" but you're checking interval=='xyx' shouldn't it be some f-string combination of interval and period?

dbtableRead = 'stock_market.ohlcv_1_day'
else:
raise(ValueError(f"error unknown interval {interval}"))
if sql_dal is None:
user = getpass.getuser()
password = os.getenv('pg_password')
sql_dal = MarketDataRepository.Instance()
sql_dal.init(user=user, password=password)
res = sql_dal.get_one_symbol(dbtableRead, symbol)
res = res.rename(
columns={"date": "Date", "open": "Open", "high": "High", "low": "Low", "close": "Close",
"volume": "Volume"})
return res


@app.callback(
[Output('clean_button', 'children')],
[Input('clean_button', 'n_clicks')],
prevent_initial_call=True
)
def clean_cache(_):
"""Clean all data in cache."""
files = glob.glob('data/*')
for f in files:
os.remove(f)
return


@app.callback(
[Output('data_signal', 'children'),
Output('index_signal', 'children')],
Expand All @@ -962,8 +1069,11 @@ def update_data(symbol, period, interval, yf_options):
"""Store data into a hidden DIV to avoid repeatedly calling Yahoo's API."""
auto_adjust = 'auto_adjust' in yf_options
back_adjust = 'back_adjust' in yf_options
_start = time.perf_counter()
df = fetch_data(symbol, period, interval, auto_adjust, back_adjust)
return df.to_json(date_format='iso', orient='split'), df.index.tolist()
_duration = time.perf_counter() - _start
res = df.to_json(date_format='iso', orient='split'), df.index.tolist()
return res


@app.callback(
Expand Down
126 changes: 126 additions & 0 deletions apps/candlestick-patterns/dal_stock_sql/dal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import datetime as dt
import psycopg2
from psycopg2 import pool
import pandas as pd
import logging


logger = logging.getLogger(__name__)


class MarketDataRepository:
def __init__(self):
"""Initialize basic structure but don't connect yet"""
self.connection_pool = None
self.ver_twsapi = None

def init(self, database='postgres', user='admin', password='admin', host='127.0.0.1', port='5432',
ver_twsapi=10, maxconn=2):
self.ver_twsapi = ver_twsapi
# Establishing the connection
logger.info('connecting to host %s and DB %s' % (host, database))
self.connection_pool = psycopg2.pool.ThreadedConnectionPool(
1, maxconn, # minconn, maxconn
user=user,
password=password,
host=host,
port=port,
database=database
)

def get_connection(self):
try:
conn = self.connection_pool.getconn()
except Exception as e:
raise e
return conn

def release_connection(self, connection):
self.connection_pool.putconn(connection)

def end(self):
self.connection_pool.closeall()

def get_all_symbol(self, dbtable: str, date_start: dt = None, date_end: dt = None):
if None not in [date_start, date_end]:
sql = '''SELECT * FROM %s WHERE date >= to_timestamp(%s) AND date <= to_timestamp(%s);''' % \
(dbtable, date_start.strftime('%s'), date_end.strftime('%s'))
elif date_start == date_end:
sql = '''SELECT * FROM %s ;''' % (dbtable)
print(sql)
elif date_end is None:
sql = '''SELECT * FROM %s WHERE date >= to_timestamp(%s);''' % \
(dbtable, date_start.strftime('%s'))
else:
sql = '''SELECT * FROM %s WHERE date <= to_timestamp(%s);''' % \
(dbtable, date_end.strftime('%s'))
# pandas only supports SQLAlchemy connectable (engine/connection) or database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 objects are not tested. Please consider using SQLAlchemy.
conn = self.get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(sql)
result = cursor.fetchall()
columns = [desc[0] for desc in cursor.description]
cursor.close()
df = pd.DataFrame(result, columns=columns)
# Specify the types directly at DataFrame creation
df = df.astype({
'date': 'datetime64[ns, UTC]',
'open': 'float',
'high': 'float',
'low': 'float',
'close': 'float',
'volume': 'int'
})
except Exception as e:
raise e
finally:
self.release_connection(conn)
# remove TZ otherwise pandas DateTimeIndex lookup in pandas do not work
#df['date'] = df['date'].map(lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime())
df = df.set_index(pd.DatetimeIndex(df['date']))
return df

def get_one_symbol(self, dbtable: str, symbol: str, date_start: dt = None, date_end: dt = None):
if ':' not in symbol:
logger.error('Incorrect symbol %s' % symbol)
return
# raise ValueError
if None not in [date_start, date_end]:
sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date >= to_timestamp(%s) AND date <= to_timestamp(%s);''' % \
(dbtable, symbol, date_start.strftime('%s'), date_end.strftime('%s'))
elif date_start == date_end:
sql = '''SELECT * FROM %s WHERE symbol = '%s';''' % (dbtable, symbol)
print(sql)
elif date_end is None:
sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date >= to_timestamp(%s);''' % \
(dbtable, symbol, date_start.strftime('%s'))
else:
sql = '''SELECT * FROM %s WHERE symbol = '%s' AND date <= to_timestamp(%s);''' % \
(dbtable, symbol, date_end.strftime('%s'))
# pandas only supports SQLAlchemy connectable (engine/connection) or database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 objects are not tested. Please consider using SQLAlchemy.
conn = self.get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(sql)
result = cursor.fetchall()
columns = [desc[0] for desc in cursor.description]
cursor.close()
df = pd.DataFrame(result, columns=columns)
# Specify the types directly at DataFrame creation
df = df.astype({
'date': 'datetime64[ns, UTC]',
'open': 'float',
'high': 'float',
'low': 'float',
'close': 'float',
'volume': 'int'
})
except Exception as e:
raise e
finally:
self.release_connection(conn)
# remove TZ otherwise pandas DateTimeIndex lookup in pandas do not work
#df['date'] = df['date'].map(lambda t: pd.to_datetime(t.replace(tzinfo=None)).to_pydatetime())
df = df.set_index(pd.DatetimeIndex(df['date']))
return df