Skip to content

Support log scale for the return curve. #886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions rqalpha/mod/rqalpha_mod_sys_analyser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
# 是否在收益图中展示买卖点
'open_close_points': False,
# 是否在收益图中展示周度指标和收益曲线
'weekly_indicators': False
'weekly_indicators': False,
# 是否对收益图使用对数坐标
'log_scale': False
},
}

Expand Down Expand Up @@ -95,6 +97,11 @@ def load_mod():
is_flag=True, default=None,
help=_("[sys_analyser] show weekly indicators and return curve on plot")
))
inject_run_param(click.Option(
("--plot-log-scale", cli_prefix + "plot_config__log_scale"),
is_flag=True, default=None,
help=_("[sys_analyser] show return curve at log scale")
))


@cli.command(help=_("[sys_analyser] Plot from strategy output file"))
Expand All @@ -103,13 +110,14 @@ def load_mod():
@click.option('--plot-save', 'plot_save_file', default=None, type=click.Path(), help=_("save plot result to file"))
@click.option('--plot-open-close-points', is_flag=True, help=_("show open close points on plot"))
@click.option('--plot-weekly-indicators', is_flag=True, help=_("show weekly indicators and return curve on plot"))
def plot(result_pickle_file_path, show, plot_save_file, plot_open_close_points, plot_weekly_indicators):
@click.option('--plot-log-scale', is_flag=True, help=_("show return curve at log scale"))
def plot(result_pickle_file_path, show, plot_save_file, plot_open_close_points, plot_weekly_indicators, plot_log_scale):
import pandas as pd
from .plot import plot_result

result_dict = pd.read_pickle(result_pickle_file_path)
print(plot_open_close_points, plot_weekly_indicators)
plot_result(result_dict, show, plot_save_file, plot_weekly_indicators, plot_open_close_points)
plot_result(result_dict, show, plot_save_file, plot_weekly_indicators, plot_open_close_points, plot_log_scale)


@cli.command(help=_("[sys_analyser] Generate report from strategy output file"))
Expand Down
3 changes: 2 additions & 1 deletion rqalpha/mod/rqalpha_mod_sys_analyser/mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,8 @@ def tear_down(self, code, exception=None):
_plot_template_cls = PLOT_TEMPLATE.get(self._mod_config.plot, DefaultPlot)
plot_result(
result_dict, self._mod_config.plot, self._mod_config.plot_save_file,
plot_config.weekly_indicators, plot_config.open_close_points, _plot_template_cls, self._mod_config.strategy_name
plot_config.weekly_indicators, plot_config.open_close_points, _plot_template_cls, self._mod_config.strategy_name,
plot_config.log_scale
)

return result_dict
30 changes: 21 additions & 9 deletions rqalpha/mod/rqalpha_mod_sys_analyser/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,42 +83,54 @@ def __init__(
self,
returns,
lines: List[Tuple[pd.Series, LineInfo]],
spots_on_returns: List[Tuple[Sequence[int], SpotInfo]]
spots_on_returns: List[Tuple[Sequence[int], SpotInfo]],
log_scale: bool
):
self._returns = returns
self._lines = lines
self._spots_on_returns = spots_on_returns
self._log_scale = log_scale

@classmethod
def _plot_line(cls, ax, returns, info: LineInfo):
if returns is not None:
ax.plot(returns, label=info.label, alpha=info.alpha, linewidth=info.linewidth, color=info.color)

def _plot_spots_on_returns(self, ax, positions: Sequence[int], info: SpotInfo):
return_or_net_values = self._returns[positions] if not self._log_scale else 1 + self._returns[positions]
ax.plot(
self._returns.index[positions], self._returns[positions],
self._returns.index[positions], return_or_net_values,
info.marker, color=info.color, markersize=info.markersize, alpha=info.alpha, label=info.label
)

def plot(self, ax: Axes):
ax.get_xaxis().set_minor_locator(ticker.AutoMinorLocator())
ax.get_yaxis().set_minor_locator(ticker.AutoMinorLocator())
ax.grid(visible=True, which='minor', linewidth=.2)
ax.grid(visible=True, which='major', linewidth=1)
ax.patch.set_alpha(0.6)

# plot lines
for returns, info in self._lines:
self._plot_line(ax, returns, info)
return_or_net_values = returns if not self._log_scale else 1 + returns
self._plot_line(ax, return_or_net_values, info)
# plot MaxDD/MaxDDD
for positions, info in self._spots_on_returns:
self._plot_spots_on_returns(ax, positions, info)

# place legend
pyplot.legend(loc="best").get_frame().set_alpha(0.5)

# manipulate axis
ax.set_yticks(ax.get_yticks()) # make matplotlib happy
ax.set_yticklabels(['{:3.2f}%'.format(x * 100) for x in ax.get_yticks()])
ax.get_xaxis().set_minor_locator(ticker.AutoMinorLocator())
ax.get_yaxis().set_minor_locator(ticker.AutoMinorLocator())
if self._log_scale:
ax.set_yscale('log')
ax.yaxis.set_major_locator(ticker.AutoLocator())
formatter = ticker.FuncFormatter(lambda x, _: '{:3.2f}%'.format((x - 1) * 100))
ax.yaxis.set_major_formatter(formatter)
ax.yaxis.set_minor_formatter(formatter)
else:
ax.set_yticks(ax.get_yticks()) # make matplotlib happy
ax.set_yticklabels(['{:3.2f}%'.format(x * 100) for x in ax.get_yticks()])


class UserPlot(SubPlot):
Expand Down Expand Up @@ -184,7 +196,7 @@ def _plot(title: str, sub_plots: List[SubPlot], strategy_name):

def plot_result(
result_dict, show=True, save=None, weekly_indicators: bool = False, open_close_points: bool = False,
plot_template_cls=DefaultPlot, strategy_name=None
plot_template_cls=DefaultPlot, strategy_name=None, log_scale: bool = False
):
summary = result_dict["summary"]
portfolio = result_dict["portfolio"]
Expand Down Expand Up @@ -234,7 +246,7 @@ def plot_result(
"max_dd_ddd": "MaxDD {}\nMaxDDD {}".format(max_dd.repr, max_ddd.repr),
"excess_max_dd_ddd": ex_max_dd_ddd,
}), plot_template, strategy_name), ReturnPlot(
portfolio.unit_net_value - 1, return_lines, spots_on_returns
portfolio.unit_net_value - 1, return_lines, spots_on_returns, log_scale
)]
if "plots" in result_dict:
sub_plots.append(UserPlot(result_dict["plots"]))
Expand Down