Skip to content

8 visualization should be moved to separate modules #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ jobs:
- name: Make sure virtualevn>20 is installed, which will yield newer pip and possibility to pin pip version.
run: pip install "virtualenv>20"
- name: Install Tox
run: pip install tox
run: pip install tox
- name: Run pre-commit in Tox
run: tox -e pre-commit
Empty file added aiida_uppasd/tools/__init__.py
Empty file.
97 changes: 49 additions & 48 deletions aiida_uppasd/tools/core_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def uppasd_cli():
@click.option('-plot_name', default='None')
@click.option('-width', default=100)
@click.option('-height', default=20)
@click.argument('pk')
@click.argument('node_pk')
def visualization_observations(
iter_slice: int,
y_axis: list,
plot_style: str,
plot_name: str,
width: int,
height: int,
pk: int,
node_pk: int,
): # pylint: disable=too-many-arguments, too-many-branches, too-many-statements
"""
Visualize a given observable
Expand All @@ -51,11 +51,11 @@ def visualization_observations(
:type width: int
:param height: height of the figure
:type height: int
:param pk: pk number of the calculation that one wishes to visualize
:type pk: int
:param node_pk: node_pk number of the calculation that one wishes to visualize
:type node_pk: int
"""
auto_name = locals()
cal_node = orm.load_node(pk)
auto_name = {}
cal_node = orm.load_node(node_pk)
if iter_slice != -1:

for name in y_axis:
Expand All @@ -66,7 +66,7 @@ def visualization_observations(
iter_list = cal_node.get_array('iterations')[:int(iter_slice)].astype(int)
for name in y_axis:
_name = str(name)
plotext.plot(iter_list, eval(_name), label=_name)
plotext.plot(iter_list, _name, label=_name)
plotext.plotsize(width, height)
if plot_name != 'None':
plotext.title(f'{plot_name}')
Expand All @@ -77,7 +77,7 @@ def visualization_observations(
iter_list = cal_node.get_array('iterations')[:int(iter_slice)].astype(int)
for name in y_axis:
_name = str(name)
plotext.scatter(iter_list, eval(_name), label=_name)
plotext.scatter(iter_list, _name, label=_name)
plotext.plotsize(width, height)
if plot_name != 'None':
plotext.title(f'{plot_name}')
Expand All @@ -96,7 +96,7 @@ def visualization_observations(
iter_list = cal_node.get_array('iterations').astype(int)
for name in y_axis:
_name = str(name)
plotext.plot(iter_list, eval(_name), label=_name)
plotext.plot(iter_list, _name, label=_name)
plotext.plotsize(width, height)
if plot_name != 'None':
plotext.title(f'{plot_name}')
Expand All @@ -107,7 +107,7 @@ def visualization_observations(
iter_list = cal_node.get_array('iterations')[:int(iter_slice)].astype(int)
for name in y_axis:
_name = str(name)
plt.scatter(iter_list, eval(_name), label=_name)
plt.scatter(iter_list, _name, label=_name)
plotext.plotsize(width, height)
if plot_name != 'None':
plotext.title(f'{plot_name}')
Expand All @@ -119,15 +119,15 @@ def visualization_observations(


def output_node_query(
cal_node_pk: typing.Union[int, str],
cal_node_node_pk: typing.Union[int, str],
output_array_name: str,
attribute_name: str,
) -> np.ndarray:
"""
Get the array output of a given calculation node

:param cal_node_pk: pk number for the node that is being queried
:type cal_node_pk: typing.Union[int, str]
:param cal_node_node_pk: node_pk number for the node that is being queried
:type cal_node_node_pk: typing.Union[int, str]
:param output_array_name: name of the array that we are looking for
:type output_array_name: str
:param attribute_name: specific entry in the array that one is looking for
Expand All @@ -138,7 +138,7 @@ def output_node_query(
query_builder = orm.QueryBuilder()
query_builder.append(
orm.CalcJobNode,
filters={'id': str(cal_node_pk)},
filters={'id': str(cal_node_node_pk)},
tag='cal_node',
)
query_builder.append(
Expand Down Expand Up @@ -178,7 +178,7 @@ def trajectory_parser(
return mom_states


def get_arrow_next(array_name: str):
def get_arrow_next(array_name: str, rotation, mom_array_from_result, coord_r):
"""
Get the coordinates of an array in cartesian and in rotated coordinates.

Expand All @@ -187,7 +187,7 @@ def get_arrow_next(array_name: str):
:return: coordinates in cartesian and rotated coordinates
:rtype: typing.Union[np.array,np.array,np.array,np.array, np.array, np.array]
"""
rot_mom_array = r.apply(mom_array_from_result[array_name])
rot_mom_array = rotation.apply(mom_array_from_result[array_name])
return (
coord_r[:, 0],
coord_r[:, 1],
Expand All @@ -198,14 +198,24 @@ def get_arrow_next(array_name: str):
)


def animate(data):
def animate(
quivers,
axes,
array_name,
rotation,
mom_array_from_result,
coord_r,
arrow_ratio_arr,
length_arr,
colors_arr,
normalize_flag_arr,
): #pylint: disable=too-many-arguments
"""
Animate the magnetic moments
"""
global quivers
quivers.remove()
quivers = ax.quiver(
*get_arrow_next(data),
quivers = axes.quiver(
*get_arrow_next(array_name, rotation, mom_array_from_result, coord_r),
arrow_length_ratio=arrow_ratio_arr,
length=length_arr,
colors=colors_arr,
Expand All @@ -222,14 +232,13 @@ def animate(data):
@click.option('-normalize_flag', default=True)
@click.option('-height', default=20)
@click.option('-width', default=20)
@click.option('-color_bar_axis', default='x')
@click.option('-path_animation', default='./motion.gif')
@click.option('-interval_time', default=200)
@click.option('-dpi_setting', default=100)
@click.option('-path_frame', default='./motion.png')
@click.option('-frame_number', default=0)
@click.option('-animation_flag', default=False)
@click.argument('pk')
@click.argument('node_pk')
def visualization_motion(
rotation_axis: str,
rotation_matrix: list,
Expand All @@ -239,14 +248,13 @@ def visualization_motion(
normalize_flag: bool,
height: int,
width: int,
color_bar_axis: str,
path_animation: str,
interval_time: int,
dpi_setting: int,
path_frame: str,
frame_number: int,
animation_flag: bool,
pk: int,
node_pk: int,
): # pylint: disable=too-many-arguments, too-many-locals
"""
Visualize the magnetic moments of the calculation node
Expand All @@ -267,8 +275,6 @@ def visualization_motion(
:type height: int
:param width: width of the plot
:type width: int
:param color_bar_axis: which axis determines the color bar
:type color_bar_axis: str
:param path_animation: path to store the animation
:type path_animation: str
:param interval_time: how often one performs the animation
Expand All @@ -281,18 +287,16 @@ def visualization_motion(
:type frame_number: int
:param animation_flag: whether or not to animate the figure
:type animation_flag: bool
:param pk: pk number of the calculation that one wishes to animate the moments
:type pk: int
:param node_pk: node_pk number of the calculation that one wishes to animate the moments
:type node_pk: int
"""
global coord_r, mom_array_from_result, r, quivers, axis_to_colorbar, ax
global arrow_ratio_arr, length_arr, colors_arr, normalize_flag_arr
r = Rotation.from_euler(rotation_axis, rotation_matrix, degrees=True)
mom_states_x = output_node_query(pk, 'trajectories_moments', 'moments_x')
mom_states_y = output_node_query(pk, 'trajectories_moments', 'moments_y')
mom_states_z = output_node_query(pk, 'trajectories_moments', 'moments_z')
coord = output_node_query(pk, 'coord', 'coord')[:, 1:4]
rotation = Rotation.from_euler(rotation_axis, rotation_matrix, degrees=True)
mom_states_x = output_node_query(node_pk, 'trajectories_moments', 'moments_x')
mom_states_y = output_node_query(node_pk, 'trajectories_moments', 'moments_y')
mom_states_z = output_node_query(node_pk, 'trajectories_moments', 'moments_z')
coord = output_node_query(node_pk, 'coord', 'coord')[:, 1:4]

coord_r = r.apply(coord)
coord_r = rotation.apply(coord)
atoms_total = len(coord)
arrow_ratio_arr = arrow_head_ratio
length_arr = length_ratio
Expand All @@ -305,19 +309,12 @@ def visualization_motion(
atoms_total,
)

if color_bar_axis == 'x':
axis_to_colorbar = 0
elif color_bar_axis == 'y':
axis_to_colorbar = 1
else:
axis_to_colorbar = 2

fig = plt.figure(figsize=(height, width))
ax = fig.gca(projection='3d')
axes = fig.gca(projection='3d')

if not animation_flag:
quivers = ax.quiver(
*get_arrow_next(frame_number),
quivers = axes.quiver(
*get_arrow_next(frame_number, rotation, mom_array_from_result, coord_r),
arrow_length_ratio=arrow_ratio_arr,
length=length_arr,
colors=colors_arr,
Expand All @@ -326,8 +323,8 @@ def visualization_motion(
fig.savefig(path_frame)

if animation_flag:
quivers = ax.quiver(
*get_arrow_next(0),
quivers = axes.quiver(
*get_arrow_next(0, rotation, mom_array_from_result, coord_r),
arrow_length_ratio=arrow_ratio_arr,
length=length_arr,
colors=colors_arr,
Expand All @@ -336,6 +333,10 @@ def visualization_motion(
ani = FuncAnimation(
fig,
animate,
fargs=(
quivers, axes, rotation, mom_array_from_result, coord_r, arrow_ratio_arr, length_arr, colors_arr,
normalize_flag_arr
),
frames=list(range(len(mom_array_from_result))),
interval=interval_time,
)
Expand Down
Empty file.
Loading