Skip to content

Enabling 3D plots? #582

Open
Open
@krassowski

Description

@krassowski

Could 3D plots be supported by plotnine, or if not, could the internals be slightly adjusted to enable easy extension to 3D plots? Currently with a simple hack it is possible to get a simple 3d plotnine plot:

from plotnine import scale_color_manual, scale_shape_manual, theme_minimal
from plotnine.data import mtcars

mtcars['transmission'] = mtcars['am'].replace({0: 'automatic', 1: 'manual'})
(
    ggplot_3d(mtcars)
    + aes(x='hp', y='disp', z='mpg', shape='transmission', color='transmission')
    + geom_point_3d()
    + theme_minimal()
    + scale_shape_manual(values={'automatic': 'o', 'manual': '*'})
    + scale_color_manual(values={'automatic': 'red', 'manual': 'blue'})
)

image

The hack is:

from plotnine import ggplot, geom_point
from plotnine.utils import to_rgba, SIZE_FACTOR


class ggplot_3d(ggplot):
    def _create_figure(self):
        figure = plt.figure()
        axs = [plt.axes(projection='3d')]
        figure._themeable = {}
        self.figure = figure
        self.axs = axs
        return figure, axs

    def _draw_labels(self):
        ax = self.axs[0]
        ax.set_xlabel(self.layout.xlabel(self.labels))
        ax.set_ylabel(self.layout.ylabel(self.labels))
        ax.set_zlabel(self.labels['z'])


class geom_point_3d(geom_point):
    REQUIRED_AES = {'x', 'y', 'z'}

    
    @staticmethod
    def draw_unit(data, panel_params, coord, ax, **params):
        size = ((data['size']+data['stroke'])**2)*np.pi
        stroke = data['stroke'] * SIZE_FACTOR
        color = to_rgba(data['color'], data['alpha'])

        if all(c is None for c in data['fill']):
            fill = color
        else:
            fill = to_rgba(data['fill'], data['alpha'])

        ax.scatter3D(
            data['x'],
            data['y'],
            data['z'],
            s=size,
            facecolor=fill,
            edgecolor=color,
            marker=data.loc[0, 'shape'],
        )

    @staticmethod
    def draw_group(data, panel_params, coord, ax, **params):
        data = coord.transform(data, panel_params)
        units = 'shape'
        for _, udata in data.groupby(units, dropna=False):
            udata.reset_index(inplace=True, drop=True)
            geom_point_3d.draw_unit(udata, panel_params, coord, ax, **params)

However, because it is bypassing self.facet.make_axes() call in _create_figure() it does not support faceting.

While I would prefer plotnine to just support 3D plots, I fully respect if this is not aligned with the vision of the project; in that case the following places could be reworked to simplify extending:

  • to allow extending geom_point easier, the geom_point.draw_groups() method could be a classmethod rather than a static method, and instead of using hard-coded geom_point it would use cls; presumably this would be also need for other geoms which can have 3d equivalents (line, area, etc).
  • to allow to add zlabel special handling in addition to xlabel and ylabel a new common method Layout.label(self, labels, axis) could replace Layout.xlabel(self, labels) and Layout.ylabel(self, labels) OR the hard-coded use of self.layout = Layout() in ggplot._build could be replaced with either a new method ggplot._init_layout() or a new class attribute layout_class = Layout could be added (and then the call would be replaced with self.layout = self.layout_class());
  • to allow for faceting support, the fig.add_subplot(gs[i - 1]) call in _create_subplots should allow to pass projection='3d' keyword argument; this could be conditional on layout having a projection attribute.

Of course there some more changes for full integration (e.g. fliping of coordinates would need to know around which axis), but I believe that very little changes would be needed to gain quite a lot in terms of usability.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions