Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chaco/abstract_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class AbstractDataSource(HasTraits):
# Abstract methods
#------------------------------------------------------------------------

def get_data(self):
def get_data(self, lod=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should update the docstring below to describe the new optional argument

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

"""get_data() -> data_array

Returns a data array of the dimensions of the data source. This data
Expand Down
77 changes: 58 additions & 19 deletions chaco/image_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from numpy import fmax, fmin, swapaxes

# Enthought library imports
from traits.api import Bool, Int, Property, ReadOnly, Tuple
from traits.api import Any, Bool, Int, Property, ReadOnly, Tuple, Unicode

# Local relative imports
from .base import DimensionTrait, ImageTrait
from .abstract_data_source import AbstractDataSource
from .base import DimensionTrait, ImageTrait


class ImageData(AbstractDataSource):
"""
Expand Down Expand Up @@ -60,6 +61,15 @@ class ImageData(AbstractDataSource):
#: A read-only attribute that exposes the underlying array.
raw_value = Property(ImageTrait)

#: Flag that data source support retrieving data with specified
#: level of details (LOD)
support_downsampling = Bool(False)

#: An entry point to the LOD data which maps LOD to corresponding data
lod_data_entry = Any

#: Key pattern for lod data stored in the **lod_data_entry**
lod_key_pattern = Unicode

#------------------------------------------------------------------------
# Private traits
Expand Down Expand Up @@ -99,41 +109,64 @@ def fromfile(cls, filename):
(filename, fmt))
return imgdata

def get_width(self):
def get_width(self, lod=None):
""" Returns the shape of the x-axis.
"""
data = self.get_data(lod, transpose_inplace=False)
if self.transposed:
return self._data.shape[0]
return data.shape[0]
else:
return self._data.shape[1]
return data.shape[1]

def get_height(self):
def get_height(self, lod=None):
""" Returns the shape of the y-axis.
"""
data = self.get_data(lod, transpose_inplace=False)
if self.transposed:
return self._data.shape[1]
return data.shape[1]
else:
return self._data.shape[0]
return data.shape[0]

def get_array_bounds(self):
def get_array_bounds(self, lod=None):
""" Always returns ((0, width), (0, height)) for x-bounds and y-bounds.
"""
data = self.get_data(lod, transpose_inplace=False)
if self.transposed:
b = ((0,self._data.shape[0]), (0,self._data.shape[1]))
b = ((0, data.shape[0]), (0, data.shape[1]))
else:
b = ((0,self._data.shape[1]), (0,self._data.shape[0]))
b = ((0, data.shape[1]), (0, data.shape[0]))
return b

#------------------------------------------------------------------------
# Datasource interface
#------------------------------------------------------------------------

def get_data(self):
def get_data(self, lod=None, transpose_inplace=True):
""" Returns the data for this data source.

Implements AbstractDataSource.

Parameters
----------
lod : int
Level of detail for data to retrieve. If None, use the in-memory
`self._data`
transpose_inplace : bool
Whether to transpose the data before returning it when the raw data
stored is transposed.

Returns
-------
data : array-like
Requested image data
"""
return self.data
if lod is None:
data = self._data
else:
data = self.get_lod_data(lod)
if self.transposed and transpose_inplace:
data = swapaxes(data, 0, 1)
return data

def is_masked(self):
"""is_masked() -> False
Expand All @@ -160,13 +193,15 @@ def get_bounds(self):
self._bounds_cache_valid = True
return self._cached_bounds

def get_size(self):
def get_size(self, lod=None):
"""get_size() -> int

Implements AbstractDataSource.
"""
if self._data is not None and self._data.shape[0] != 0:
return self._data.shape[0] * self._data.shape[1]
image = self.get_data(lod)

if image is not None and image.shape[0] != 0:
return image.shape[0] * image.shape[1]
else:
return 0

Expand All @@ -180,6 +215,13 @@ def set_data(self, data):
"""
self._set_data(data)

def get_lod_data(self, lod):
if not self.lod_key_pattern:
key = str(lod)
else:
key = self.lod_key_pattern.format(lod)
return self.lod_data_entry[key]

#------------------------------------------------------------------------
# Private methods
#------------------------------------------------------------------------
Expand All @@ -198,7 +240,6 @@ def _set_data(self, newdata):
def _get_raw_value(self):
return self._data


#------------------------------------------------------------------------
# Event handlers
#------------------------------------------------------------------------
Expand All @@ -210,6 +251,4 @@ def _metadata_items_changed(self, event):
self.metadata_changed = True




# EOF
111 changes: 87 additions & 24 deletions chaco/image_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
from math import ceil, floor, pi
from contextlib import contextmanager

import six
import six.moves as sm

import numpy as np

# Enthought library imports.
from traits.api import (Bool, Either, Enum, Instance, List, Range, Trait,
Tuple, Property, cached_property)
Tuple, Property, cached_property, on_trait_change)
from traits_futures.api import CallFuture, TraitsExecutor
from kiva.agg import GraphicsContextArray
from traitsui.api import Handler

# Local relative imports
from .base_2d_plot import Base2DPlot
Expand All @@ -43,7 +44,7 @@
KIVA_DEPTH_MAP = {3: "rgb24", 4: "rgba32"}


class ImagePlot(Base2DPlot):
class ImagePlot(Base2DPlot, Handler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't inherit from Handler. This creates a problem with the life-cycle of the executor - we may need Enable to send a new type of event to indicate that the plot is being closed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about overriding the cleanup method of Component and adding codes that stop the executor there?

""" A plot based on an image.
"""
#------------------------------------------------------------------------
Expand All @@ -63,6 +64,9 @@ class ImagePlot(Base2DPlot):
#: Bool indicating whether y-axis is flipped.
y_axis_is_flipped = Property(depends_on=['orientation', 'origin'])

#: Does the plot use downsampling?
use_downsampling = Bool(False)

#------------------------------------------------------------------------
# Private traits
#------------------------------------------------------------------------
Expand All @@ -81,6 +85,12 @@ class ImagePlot(Base2DPlot):
# The name "principal diagonal" is borrowed from linear algebra.
_origin_on_principal_diagonal = Property(depends_on='origin')

#: The Traits executor for the background jobs.
_traits_executor = Instance(TraitsExecutor, ())

#: Submitted job. Only keeping track of the last submitted one.
_future = Instance(CallFuture)

#------------------------------------------------------------------------
# Properties
#------------------------------------------------------------------------
Expand Down Expand Up @@ -111,6 +121,29 @@ def _value_data_changed_fired(self):
self._image_cache_valid = False
self.request_redraw()

@on_trait_change("index_mapper:updated, bounds[]")
def _update_lod_cache_image(self):
if not self.use_downsampling:
return
lod = self._calculate_necessary_lod()
self._future = self._traits_executor.submit_call(
self._compute_cached_image, lod=lod
)

@on_trait_change("_future:done", dispatch='ui')
def _handle_lod_cached_image(self):
self._cached_image, self._cached_dest_rect = self._future.result
self._image_cache_valid = True
self.request_redraw()

#------------------------------------------------------------------------
# Hander interface
#------------------------------------------------------------------------

def closed(self, info, is_ok):
self._traits_executor.stop()
super(ImagePlot, self).closed(info, is_ok)

#------------------------------------------------------------------------
# Base2DPlot interface
#------------------------------------------------------------------------
Expand All @@ -121,7 +154,9 @@ def _render(self, gc):
Implements the Base2DPlot interface.
"""
if not self._image_cache_valid:
self._compute_cached_image()
self._cached_image, self._cached_dest_rect = \
self._compute_cached_image()
self._image_cache_valid = True

scale_x = -1 if self.x_axis_is_flipped else 1
scale_y = 1 if self.y_axis_is_flipped else -1
Expand Down Expand Up @@ -234,42 +269,55 @@ def _calc_virtual_screen_bbox(self):
y_min += 0.5
return [x_min, y_min, virtual_x_size, virtual_y_size]

def _compute_cached_image(self, data=None, mapper=None):
""" Computes the correct screen coordinates and renders an image into
`self._cached_image`.
def _compute_cached_image(self, mapper=None, lod=None):
""" Computes the correct screen coordinates and image cache

Parameters
----------
data : array
Image data. If None, image is derived from the `value` attribute.
mapper : function
Allows subclasses to transform the displayed values for the visible
region. This may be used to adapt grayscale images to RGB(A)
images.
lod : int
Level of detail for cached image. If None, use the in-memory part
`self.value._data`.

Returns
-------
cache_image : `kiva.agg.GraphicsContextArray`
Computed cache image.
cache_dest_rect : 4-tuple
(x, y, width, height) rectangle describing the pixels bounds where
the image will be rendered in the plot
"""
if data is None:
data = self.value.data
# Not to transpose the full matrix ahead in case it is too large
data = self.value.get_data(lod=lod, transpose_inplace=False)

virtual_rect = self._calc_virtual_screen_bbox()
index_bounds, screen_rect = self._calc_zoom_coords(virtual_rect)
index_bounds, screen_rect = self._calc_zoom_coords(virtual_rect,
lod=lod)
col_min, col_max, row_min, row_max = index_bounds

view_rect = self.position + self.bounds
sub_array_size = (col_max - col_min, row_max - row_min)
screen_rect = trim_screen_rect(screen_rect, view_rect, sub_array_size)

data = data[row_min:row_max, col_min:col_max]
if self.value.transposed:
# Swap after slicing to avoid transposing the whole matrix
data = data[col_min:col_max, row_min:row_max]
data = data.swapaxes(0, 1)
else:
data = data[row_min:row_max, col_min:col_max]

if mapper is not None:
data = mapper(data)

if len(data.shape) != 3:
raise RuntimeError("`ImagePlot` requires color images.")

# Update cached image and rectangle.
self._cached_image = self._kiva_array_from_numpy_array(data)
self._cached_dest_rect = screen_rect
self._image_cache_valid = True
cached_image = self._kiva_array_from_numpy_array(data)
cached_dest_rect = screen_rect
return cached_image, cached_dest_rect

def _kiva_array_from_numpy_array(self, data):
if data.shape[2] not in KIVA_DEPTH_MAP:
Expand All @@ -281,7 +329,7 @@ def _kiva_array_from_numpy_array(self, data):
data = np.ascontiguousarray(data)
return GraphicsContextArray(data, pix_format=kiva_depth)

def _calc_zoom_coords(self, image_rect):
def _calc_zoom_coords(self, image_rect, lod=None):
""" Calculates the coordinates of a zoomed sub-image.

Because of floating point limitations, it is not advisable to request a
Expand All @@ -307,12 +355,12 @@ def _calc_zoom_coords(self, image_rect):
if 0 in (image_width, image_height) or 0 in self.bounds:
return (None, None)

array_bounds = self._array_bounds_from_screen_rect(image_rect)
array_bounds = self._array_bounds_from_screen_rect(image_rect, lod=lod)
col_min, col_max, row_min, row_max = array_bounds
# Convert array indices back into screen coordinates after its been
# clipped to fit within the bounds.
array_width = self.value.get_width()
array_height = self.value.get_height()
array_width = self.value.get_width(lod=lod)
array_height = self.value.get_height(lod=lod)
x_min = float(col_min) / array_width * image_width + ix
x_max = float(col_max) / array_width * image_width + ix
y_min = float(row_min) / array_height * image_height + iy
Expand All @@ -333,7 +381,7 @@ def _calc_zoom_coords(self, image_rect):
screen_rect = [x_min, y_min, x_max - x_min, y_max - y_min]
return index_bounds, screen_rect

def _array_bounds_from_screen_rect(self, image_rect):
def _array_bounds_from_screen_rect(self, image_rect, lod=None):
""" Transform virtual-image rectangle into array indices.

The virtual-image rectangle is in screen coordinates and can be outside
Expand All @@ -357,8 +405,8 @@ def _array_bounds_from_screen_rect(self, image_rect):
x_max = x_min + plot_width
y_max = y_min + plot_height

array_width = self.value.get_width()
array_height = self.value.get_height()
array_width = self.value.get_width(lod=lod)
array_height = self.value.get_height(lod=lod)
# Convert screen coordinates to array indexes
col_min = floor(float(x_min) / image_width * array_width)
col_max = ceil(float(x_max) / image_width * array_width)
Expand All @@ -372,3 +420,18 @@ def _array_bounds_from_screen_rect(self, image_rect):
row_max = min(row_max, array_height)
Copy link

@siddhantwahal siddhantwahal Dec 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An edge case that isn't handled here occurs when col_min is greater than array_width or col_max is less than 0. These only occur when the image disappears from the plot, e.g., panning the image to the right as much as possible and then zooming in leads to negative x_max and col_max.

While x_min and col_min are also negative in this example, col_min is explicitly clipped to 0, whereas col_max isn't, leading to col_max - col_min < 0.

When these bounds are used to compute the LOD in _calculate_necessary_lod:

chaco/chaco/image_plot.py

Lines 432 to 438 in e41cfd3

for lod in range(len(self.value.lod_data_entry))[::-1]:
index_bounds, screen_rect = self._calc_zoom_coords(virtual_rect, lod=lod)
array_width = index_bounds[1] - index_bounds[0]
array_height = index_bounds[3] - index_bounds[2]
if (array_width >= screen_rect[2]) and (array_height >= screen_rect[3]):
break
return lod

no LOD will satisfy col_max - col_min >= screen_rect[2], and the necessary LOD is set to 0. This causes the highest resolution image to be unnecessarily cached, potentially crashing the program if the image is too large to fit in memory (actually, only the slice [row_min:row_max, :-abs(col_max)] is cached, but that can still be large).

The fix is simple:

-        col_min = max(col_min, 0)
-        col_max = min(col_max, array_width)
-        row_min = max(row_min, 0)
-        row_max = min(row_max, array_height)
+        col_min, col_max = np.clip([col_min, col_max], 0, array_width)
+        row_min, row_max = np.clip([row_min, row_max], 0, array_height)


return col_min, col_max, row_min, row_max

def _calculate_necessary_lod(self):
""" Computes the necessary lod so that array has more pixels than
the screen rectangle.
"""
virtual_rect = self._calc_virtual_screen_bbox()
# NOTE: LOD numbers are assumed to be continuous integers
# starting from 0
for lod in range(len(self.value.lod_data_entry))[::-1]:
index_bounds, screen_rect = self._calc_zoom_coords(virtual_rect, lod=lod)
array_width = index_bounds[1] - index_bounds[0]
array_height = index_bounds[3] - index_bounds[2]
if (array_width >= screen_rect[2]) and (array_height >= screen_rect[3]):
break
return lod
1 change: 1 addition & 0 deletions ci/edmtool.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"cython",
# Needed to install enable from source
"swig",
"traits_futures"
}

extra_dependencies = {
Expand Down
Loading