from echo import delay_callback, CallbackProperty
import warnings
import numpy as np
from numpy.linalg import norm
from glue.config import data_translator
from glue.core import BaseData
from glue.core.edit_subset_mode import NewMode, ReplaceMode
from glue.core.exceptions import IncompatibleAttribute
from glue.core.roi import CircularAnnulusROI, CircularROI, EllipticalROI, RectangularROI
from glue.core.subset import Subset
from glue.core.subset_group import GroupedSubset
from glue.viewers.histogram.state import HistogramViewerState
from glue.viewers.scatter.state import ScatterViewerState
from glue.viewers.scatter.state import ScatterLayerState as BqplotScatterLayerState
from glue.utils import avoid_circular
from glue_astronomy.spectral_coordinates import SpectralCoordinates
from glue_jupyter.bqplot.profile import BqplotProfileView
from glue_jupyter.bqplot.histogram import BqplotHistogramView
from glue_jupyter.bqplot.image import BqplotImageView
from glue_jupyter.bqplot.scatter import BqplotScatterView
from glue_jupyter.table import TableViewer
from astropy.utils import deprecated
from astropy import units as u
from astropy.nddata import (
NDDataArray, StdDevUncertainty, VarianceUncertainty, InverseVariance
)
from specutils import Spectrum
from traitlets import Bool, Unicode
from jdaviz.components.toolbar_nested import NestedJupyterToolbar
from jdaviz.configs.default.plugins.data_menu import DataMenu
from jdaviz.core.astrowidgets_api import AstrowidgetsImageViewerMixin
from jdaviz.core.custom_units_and_equivs import _eqv_sb_per_pixel_to_per_angle
from jdaviz.core.events import (SnackbarMessage,
NewViewerMessage,
ViewerRemovedMessage,
ViewerVisibleLayersChangedMessage,
RestoreToolbarMessage,
TableSelectRowClickMessage)
from jdaviz.core.freezable_state import FreezableProfileViewerState
from jdaviz.core.marks import (LineUncertainties, ScatterMask,
OffscreenLinesMarks, TableSelectionMark)
from jdaviz.core.registries import viewer_registry
from jdaviz.core.template_mixin import WithCache, TemplateMixin, show_widget
from jdaviz.core.tools import _get_skycoords_from_table, _get_pixel_coords_from_table
from jdaviz.core.user_api import ViewerUserApi
from jdaviz.core.unit_conversion_utils import (check_if_unit_is_per_solid_angle,
flux_conversion_general,
all_flux_unit_conversion_equivs)
from jdaviz.utils import (ColorCycler, get_subset_type, _wcs_only_label,
layer_is_image_data, layer_is_not_dq, layer_is_3d)
uncertainty_str_to_cls_mapping = {
"std": StdDevUncertainty,
"var": VarianceUncertainty,
"ivar": InverseVariance
}
__all__ = ['JdavizViewerMixin', 'JdavizProfileView']
viewer_registry.add("g-profile-viewer", label="Profile 1D", cls=BqplotProfileView)
viewer_registry.add("g-image-viewer", label="Image 2D", cls=BqplotImageView)
viewer_registry.add("g-table-viewer", label="Table", cls=TableViewer)
[docs]
class JdavizViewerMixin(WithCache):
toolbar = None
tools_nested = []
_prev_limits = None
_native_mark_classnames = ('Lines', 'LinesGL', 'FRBImage', 'Contour')
def __init__(self, *args, **kwargs):
# NOTE: anything here most likely won't be called by viewers because of inheritance order
super().__init__(*args, **kwargs)
# Allow each viewer to cycle through colors for each new addition to the viewer:
self.color_cycler = ColorCycler()
# Separate color cycler for scatter layers (catalogs) that has brighter colors,
# starting with neon green
self.scatter_color_cycler = ColorCycler()
self.scatter_color_cycler.default_color_palette = [
'#00FF00', # neon green
'#FF00FF', # magenta
'#00FFFF', # cyan
'#FF0000', # red
'#FFFF00', # yellow
'#FF8800', # orange
'#8800FF', # purple
'#0088FF', # blue
]
self._data_menu = DataMenu(viewer=self, app=self.jdaviz_app)
@staticmethod
def _is_circular_edit(old_roi, new_roi, rtol_move=1e-3, rtol_resize=0.25):
"""
Classify the change from *old_roi* to *new_roi* for circular ROIs.
Checks for moves, i.e. the radius is unchanged within a specified tolerance, and
resizes, i.e. the radius is changed within a specified tolerance relative to
the old radius.
Parameters
----------
old_roi : CircularROI
The existing circular ROI.
new_roi : CircularROI
The newly proposed circular ROI.
rtol_move : float, optional
Relative tolerance for move-related comparisons.
rtol_resize : float, optional
Relative tolerance for resize-related comparisons.
Returns
-------
edit_type : 'move', 'resize', or None
"""
if old_roi.radius <= 0:
return None
# Move: radius unchanged.
if abs(old_roi.radius - new_roi.radius) < rtol_move * max(old_roi.radius, 1):
return 'move'
# Resize: radius changed but center stayed very close
dist = norm([new_roi.xc - old_roi.xc, new_roi.yc - old_roi.yc])
if dist < (rtol_resize * max(old_roi.radius, 1)):
return 'resize'
return None
@staticmethod
def _is_annulus_edit(old_roi, new_roi, rtol_move=1e-3, rtol_resize=0.25):
"""
Classify the change from *old_roi* to *new_roi* for annular ROIs.
Checks for moves, i.e. the radii are unchanged within a specified tolerance, and
resizes, i.e. the radii change but the new center stays within an area defined
by the old radius and a specified tolerance.
Parameters
----------
old_roi : CircularAnnulusROI
The existing annular ROI.
new_roi : CircularAnnulusROI
The newly proposed annular ROI.
rtol_move : float, optional
Relative tolerance for move-related comparisons.
rtol_resize : float, optional
Relative tolerance for resize-related comparisons.
Returns
-------
edit_type : 'move', 'resize', or None
"""
if old_roi.outer_radius == 0 or old_roi.outer_radius <= old_roi.inner_radius:
return None
# Move: both radii unchanged.
if (abs(old_roi.inner_radius - new_roi.inner_radius) < rtol_move * max(old_roi.inner_radius, 1) # noqa
and abs(old_roi.outer_radius - new_roi.outer_radius) < rtol_resize * max(old_roi.outer_radius, 1)): # noqa
return 'move'
# Resize: radii changed but center stayed very close
dist = norm([new_roi.xc - old_roi.xc, new_roi.yc - old_roi.yc])
if dist < (rtol_resize * max(old_roi.outer_radius, 1)):
return 'resize'
return None
@staticmethod
def _is_elliptical_edit(old_roi, new_roi, rtol_move=1e-3, rtol_resize=0.25):
"""
Classify the change from *old_roi* to *new_roi* for elliptical ROIs.
Checks for moves, i.e. the radii are unchanged within a specified tolerance, and
resizes, i.e. either radius changes but the new center stays within an area defined
by the old radius and a specified tolerance.
Parameters
----------
old_roi : EllipticalROI
The existing elliptical ROI.
new_roi : EllipticalROI
The newly proposed elliptical ROI.
rtol_move : float, optional
Relative tolerance for move-related comparisons.
rtol_resize : float, optional
Relative tolerance for resize-related comparisons.
Returns
-------
edit_type : 'move', 'resize', or None
"""
size = max(old_roi.radius_x, old_roi.radius_y)
if size <= 0 or min(old_roi.radius_x, old_roi.radius_y) <= 0:
return None
# Move: both radii unchanged.
if (abs(old_roi.radius_x - new_roi.radius_x) < rtol_move * max(old_roi.radius_x, 1)
and abs(old_roi.radius_y - new_roi.radius_y) < rtol_move * max(old_roi.radius_y, 1)): # noqa
return 'move'
# Resize: radii changed but center stayed very close
dist = norm([new_roi.xc - old_roi.xc, new_roi.yc - old_roi.yc])
if dist < rtol_resize * max(size, 1):
return 'resize'
return None
@staticmethod
def _is_rectangular_edit(old_roi, new_roi, rtol_move=1e-3, rtol_resize=0.25):
"""
Classify the change from *old_roi* to *new_roi* for rectangular ROIs.
Checks for moves, i.e. the width and height are unchanged within
a specified tolerance, and resizes, i.e. when dimensions change
but the new center stays within an area defined by the old dimensions
and a specified tolerance.
Parameters
----------
old_roi : RectangularROI
The existing rectangular ROI.
new_roi : RectangularROI
The newly proposed rectangular ROI.
rtol_move : float, optional
Relative tolerance for move-related comparisons.
rtol_resize : float, optional
Relative tolerance for resize-related comparisons.
Returns
-------
edit_type : 'move', 'resize', or None
"""
old_w = old_roi.xmax - old_roi.xmin
old_h = old_roi.ymax - old_roi.ymin
new_w = new_roi.xmax - new_roi.xmin
new_h = new_roi.ymax - new_roi.ymin
size = max(old_w, old_h) / 2
if size <= 0 or old_w <= 0 or old_h <= 0:
return None
# Move: same width and height.
if (abs(old_w - new_w) < rtol_move * max(old_w, 1) and
abs(old_h - new_h) < rtol_move * max(old_h, 1)):
return 'move'
# Resize: dimensions changed but center stayed very close
old_cx = (old_roi.xmin + old_roi.xmax) / 2
old_cy = (old_roi.ymin + old_roi.ymax) / 2
new_cx = (new_roi.xmin + new_roi.xmax) / 2
new_cy = (new_roi.ymin + new_roi.ymax) / 2
dist = norm([new_cx - old_cx, new_cy - old_cy])
if dist < rtol_resize * max(size, 1):
return 'resize'
return None
def _is_roi_edit(self, old_roi, new_roi):
"""
Classify the change from *old_roi* to *new_roi*.
Parameters
----------
old_roi : ROI
The existing subset's ROI.
new_roi : ROI
The newly drawn ROI.
Returns
-------
edit_type : 'move', 'resize', or None
"""
if not isinstance(new_roi, type(old_roi)) or not isinstance(old_roi, type(new_roi)):
return None
handlers = [(CircularAnnulusROI, self._is_annulus_edit),
(CircularROI, self._is_circular_edit),
(EllipticalROI, self._is_elliptical_edit),
(RectangularROI, self._is_rectangular_edit)]
for roi_type, handler in handlers:
if isinstance(old_roi, roi_type):
return handler(old_roi, new_roi)
return None
@staticmethod
def _is_range_edit(old_range, new_range, rtol=1e-6):
"""
Classify the change from *old_range* to *new_range* for 1-D ranges.
Checks for moves, i.e. the width is unchanged within a specified tolerance, and
resizes, i.e. when the width changes but one endpoint is fixed.
Parameters
----------
old_range : RangeSubsetState
The existing 1-D subset state.
new_range : ROI
The new ROI with ``min`` and ``max`` attributes.
rtol : float, optional
Relative tolerance for the width comparison.
Returns
-------
edit_type : 'move', 'resize', or None
"""
if not (hasattr(new_range, 'min') and hasattr(new_range, 'max')):
return None
old_w = old_range.hi - old_range.lo
new_w = new_range.max - new_range.min
if old_w <= 0 or new_w <= 0:
return None
# Move: same width.
if abs(new_w - old_w) < (rtol * max(abs(old_w), 1)):
return 'move'
# Resize: width changed, one endpoint fixed.
if old_range.lo == new_range.min or old_range.hi == new_range.max:
return 'resize'
return None
[docs]
def apply_roi(self, roi, use_current=False):
"""
Apply an ROI to the viewer.
Detects whether the user is resizing or moving an existing
subset. When detected and mode is NewMode, temporarily
switches to ReplaceMode so the subset is modified in-place
rather than duplicated.
"""
from glue.core.subset import RoiSubsetState, RangeSubsetState
edit_subset_mode = self.session.edit_subset_mode
needs_override = False
edit_type = None
if (not getattr(self.jdaviz_app, '_importing_regions', False)
and len(edit_subset_mode.edit_subset) > 0
and edit_subset_mode.mode is NewMode):
existing_state = edit_subset_mode.edit_subset[0].subset_state
if isinstance(existing_state, RoiSubsetState):
old_roi = existing_state.roi
edit_type = self._is_roi_edit(old_roi, roi)
elif isinstance(existing_state, RangeSubsetState):
edit_type = self._is_range_edit(existing_state, roi)
needs_override = edit_type is not None
if needs_override and edit_type == 'move':
# Warn when the new subset has identical dimensions to the
# existing one, it will be treated as a move in-place rather
# than a new subset.
warnings.warn(
'The new subset has the same dimensions as the '
'existing subset and will be treated as a move '
'operation. To create a new subset with the same '
'dimensions, use import_region instead.',
stacklevel=2)
original_mode = edit_subset_mode.mode
if needs_override:
edit_subset_mode._mode = ReplaceMode
try:
super().apply_roi(roi, use_current=use_current)
finally:
if needs_override:
edit_subset_mode._mode = original_mode
@property
def user_api(self):
# default exposed user APIs. Can override this method in any particular viewer.
if not (isinstance(self, TableViewer) and self.__class__.__name__ == 'MosvizTableViewer'):
# TODO: eventually remove data_labels_loaded
# and data_labels_visible once deprecation period passes
expose = ['data_labels_loaded', 'data_labels_visible', 'data_menu']
else:
expose = []
if self.jdaviz_app.config == 'deconfigged':
expose += ['clone_viewer']
if isinstance(self, BqplotImageView):
if isinstance(self, AstrowidgetsImageViewerMixin):
expose += ['save',
'center_on', 'offset_by', 'zoom_level', 'zoom',
'colormap_options', 'set_colormap',
'stretch_options', 'stretch',
'autocut_options', 'cuts',
'marker', 'add_markers', 'remove_markers', 'reset_markers',
'blink_once', 'reset_limits', 'get_viewport_region']
else:
# cubeviz image viewers don't inherit from AstrowidgetsImageViewerMixin yet,
# but also shouldn't expose set_limits because of equal aspect ratio concerns
expose += []
elif isinstance(self, TableViewer):
expose += []
else:
expose += ['set_limits', 'reset_limits', 'set_tick_format']
return ViewerUserApi(self, expose=expose)
@property
def data_menu(self):
return self._data_menu.user_api
def _deprecated_data_menu(self):
# temporary method to allow for opening new data-menu from old button. This should
# be removed anytime after the old button is removed (likely in 4.3)
self.data_menu.open_menu()
@property
@deprecated(since="4.1", alternative="viewer.data_menu.data_labels_loaded")
def data_labels_loaded(self):
"""
List of data labels loaded in this viewer.
Returns
-------
data_labels : list
list of strings
"""
return self.data_menu.data_labels_loaded
@property
@deprecated(since="4.1", alternative="viewer.data_menu.data_labels_visible")
def data_labels_visible(self):
"""
List of data labels visible in this viewer.
Returns
-------
data_labels : list
list of strings
"""
return self.data_menu.data_labels_visible
def _get_clone_viewer_reference(self):
return self.jdaviz_helper._get_clone_viewer_reference(self.reference)
def _clone_viewer_outside_app(self):
new_viewer = type(self)(session=self.session)
# TODO: this is needed for jdaviz only, without it it should also
# work for glue-jupyter
new_viewer._reference_id = self.reference_id + "_TODO_IS_THIS_OK?"
d = self.state.as_dict()
new_viewer.state.update_from_dict(d)
for layer in self.layers:
layer_state = layer.state
new_layer = type(layer)(view=new_viewer, viewer_state=new_viewer.state,
layer_state=layer_state)
new_layer.update()
return new_viewer
[docs]
def clone_viewer(self):
name = self.jdaviz_helper._get_clone_viewer_reference(self.reference)
self.jdaviz_app._on_new_viewer(NewViewerMessage(self.__class__,
data=None,
sender=self.jdaviz_app),
vid=name, name=name)
new_viewer = self.jdaviz_app.get_viewer(name)
visible_layers = self.data_menu.data_labels_visible
for layer in self.data_menu.data_labels_loaded[::-1]:
visible = layer in visible_layers
new_viewer.data_menu.add_data(layer)
if hasattr(new_viewer.data_menu, 'set_layer_visibility'):
new_viewer.data_menu.set_layer_visibility(layer, visible)
# TODO: don't revert color when adding same data to a new viewer
# allow viewers to set attributes (not in state) on cloned viewers
for attr in getattr(self, '_clone_attrs', []):
if hasattr(self, attr):
setattr(new_viewer, attr, getattr(self, attr))
new_viewer.state.update_from_dict(self.state.as_dict())
for this_layer_state, new_layer_state in zip(self.state.layers, new_viewer.state.layers):
for k, v in this_layer_state.as_dict().items():
if k in ('layer',):
continue
setattr(new_layer_state, k, v)
return JdavizViewerWindow(new_viewer, app=self.jdaviz_app).user_api
[docs]
def reset_limits(self):
"""
Reset viewer axes limits.
"""
self.state.reset_limits()
[docs]
def set_limits(self, x_min=None, x_max=None, y_min=None, y_max=None):
"""
Set viewer axes limits.
Parameters
----------
x_min : float or None, optional
lower-limit of x-axis (in current axes units)
x_max: float or None, optional
upper-limit of x-axis (in current axes units)
y_min : float or None, optional
lower-limit of y-axis (in current axes units)
y_max: float or None, optional
upper-limit of y-axis (in current axes units)
"""
for val in (x_min, x_max, y_min, y_max):
if val is not None and not isinstance(val, (float, int, np.float32)):
raise TypeError('All arguments must be None, int, or float, '
f'but got: {type(val)}')
with delay_callback(self.state, 'x_min', 'x_max', 'y_min', 'y_max'):
if x_min is not None:
self.state.x_min = x_min
if x_max is not None:
self.state.x_max = x_max
# NOTE: for some reason, setting ymax first avoids an issue
# where back-to-back calls of get_limits and set_limits
# give different results for y limits.
if y_max is not None:
self.state.y_max = y_max
if y_min is not None:
self.state.y_min = y_min
[docs]
def get_limits(self):
"""Return current viewer axes limits.
Returns
-------
x_min, x_max, y_min, y_max : float
Lower/upper X/Y limits, respectively.
"""
return self.state.x_min, self.state.x_max, self.state.y_min, self.state.y_max
@property
def native_marks(self):
"""
Return all marks that are Lines/LinesGL objects (and not subclasses)
"""
return [m for m in self.figure.marks
if m.__class__.__name__ in self._native_mark_classnames]
@property
def custom_marks(self):
"""
Return all marks that are not Lines/LinesGL objects (but can be subclasses)
"""
return [m for m in self.figure.marks
if m.__class__.__name__ not in self._native_mark_classnames]
def _subscribe_to_layers_update(self):
# subscribe to new layers
self._expected_subset_layers = []
self._layers_with_defaults_applied = []
self.state.add_callback('layers', self._on_layers_update)
def _get_layer(self, label):
for layer in self.state.layers:
if layer.layer.label == label:
return layer
def _apply_layer_defaults(self, layer_state):
# Hide subsets on orientation layers to avoid rendering artifacts.
if (layer_state.layer.label != layer_state.layer.data.label
and layer_state.layer.data.meta.get(_wcs_only_label, False)):
layer_state.visible = False
if hasattr(layer_state, 'as_steps'):
if layer_state.layer.label != layer_state.layer.data.label:
# then this is a subset, so default based on the parent data layer
layer_state.as_steps = self._get_layer(layer_state.layer.data.label).as_steps
else:
# default to not plotting with as_steps (despite glue defaulting to True)
layer_state.as_steps = False
# whenever as_steps changes, we need to redraw the uncertainties (if enabled)
layer_state.add_callback('as_steps', self._show_uncertainty_changed)
# set default size for scatter layers (e.g., catalog markers)
# based on 1% of the average viewer dimension
if isinstance(layer_state, BqplotScatterLayerState):
# set the marker default size to 1% of the size of the viewer (average
# x and y dimensions to account for viewer not being square), and
# fall back on a 10 pt marker size if we can't get the viewer
# dimensions for some reason (e.g., no image layers)
marker_size = 10 # default fallback default for marker size in points d
for layer in self.state.layers:
if (hasattr(layer, 'layer') and
hasattr(layer.layer, 'data') and
hasattr(layer.layer.data, 'shape') and
len(layer.layer.data.shape) >= 2):
shape = layer.layer.data.shape
avg_dimension = (shape[-2] + shape[-1]) / 2
marker_size = avg_dimension * 0.01
# set a lower limit on marker size, which is the glue default
# of three pixels.
marker_size = max(marker_size, 3)
break
layer_state.size = marker_size
# use echo-validator to ensure visible sets & updates properly in plot options & data menu
if (hasattr(layer_state, 'visible') and get_subset_type(layer_state.layer) != 'spatial'):
layer_state.layer.visible = CallbackProperty()
layer_state.add_callback('layer',
self._expected_subset_layer_default,
validator=True)
def _expected_subset_layer_default(self, layer_state):
if self.__class__.__name__ == 'RampvizImageView':
# Do not override default for subsets as for some reason
# this isn't getting called when they're first added, but rather when
# the next state change is made (for example: manually changing the visibility)
return
# default visibility based on the visibility of the "parent" data layer
if self.__class__.__name__ == 'RampvizProfileView':
# Rampviz doesn't show subset profiles by default:
layer_state.visible = False
elif (self.__class__.__name__ == 'CubevizImageView' and
get_subset_type(layer_state.layer) != 'spatial'):
# set visibility of spectral subsets to false in Cubeviz image-viewers
layer_state.visible = False
else:
layer_state.visible = self._get_layer(layer_state.layer.data.label).visible
def _update_layer_icons(self):
# update visible_layers (TODO: move this somewhere that can update on color change, etc)
def _get_layer_color(layer):
if isinstance(layer, BqplotScatterLayerState):
# then this could be a scatter layer in an image viewer,
# so we'll ignore the color_mode
return layer.color
if getattr(self.state, 'color_mode', None) == 'Colormaps':
for subset in self.jdaviz_app.data_collection.subset_groups:
if subset.label == layer.layer.label:
# then we still want to show the color for a subset
return layer.color
# then this is a data-layer in colormap mode, so we'll ignore the color
return ''
return getattr(layer, 'color', '')
def _get_layer_linewidth(layer):
linewidth = getattr(layer, 'linewidth', 0)
return min(linewidth, 6)
def _get_layer_info(layer):
if 'Trace' in layer.layer.data.meta:
return "mdi-chart-line-stacked", None
for subset in self.jdaviz_app.data_collection.subset_groups:
if subset.label == layer.layer.label:
subset_type = get_subset_type(subset)
if subset_type == 'spatial':
return "mdi-chart-scatter-plot", subset_type
else:
return "mdi-chart-bell-curve", subset_type
return "", None
visible_layers = {}
for layer in self.state.layers[::-1]:
layer_is_wcs_only = (
hasattr(layer.layer, 'meta') and
layer.layer.meta.get(_wcs_only_label, False)
)
if layer.visible and not layer_is_wcs_only:
prefix_icon, subset_type = _get_layer_info(layer)
if (
subset_type == 'spatial' and
self.__class__.__name__ in ('CubevizProfileView',
'RampvizProfileView',
'Spectrum1DViewer')
):
# do not show spatial subsets in profile viewer
continue
visible_layers[layer.layer.label] = {'color': _get_layer_color(layer),
'linewidth': _get_layer_linewidth(layer),
'prefix_icon': prefix_icon}
self._data_menu.visible_layers = visible_layers
@avoid_circular
def _on_layers_update(self, layers=None):
if self.__class__.__name__ == 'MosvizTableViewer':
# MosvizTableViewer uses this as a mixin, but we do not need any of this layer
# logic there
return
viewer_item = self.jdaviz_app._viewer_item_by_id(self.reference_id)
if viewer_item is None:
return
selected_data_items = viewer_item.get('selected_data_items', {})
# update selected_data_items
for data_id, visibility in selected_data_items.items():
label = next((x['name'] for x in self.jdaviz_app.state.data_items
if x['id'] == data_id), None)
visibilities = []
for layer in self.state.layers:
if layer.layer.data.label == label:
visibilities.append(layer.visible)
if np.all(visibilities):
selected_data_items[data_id] = 'visible'
elif np.any(visibilities):
selected_data_items[data_id] = 'mixed'
else:
selected_data_items[data_id] = 'hidden'
self._update_layer_icons()
# we'll make a deepcopy so that we can remove entries from the self._expected_subset_layers
# to avoid recursion, but also handle multiple layers for the same subset
expected_subset_layers = self._expected_subset_layers[:]
for layer in self.state.layers:
layer_info = {'data_label': layer.layer.data.label,
'layer_label': layer.layer.label}
if layer_info not in self._layers_with_defaults_applied:
self._layers_with_defaults_applied.append(layer_info)
self._apply_layer_defaults(layer)
if layer.layer.label in expected_subset_layers:
if layer.layer.label in self._expected_subset_layers:
self._expected_subset_layers.remove(layer.layer.label)
self._expected_subset_layer_default(layer)
self.hub.broadcast(ViewerVisibleLayersChangedMessage(
viewer_reference=self.reference, visible_layers=selected_data_items, sender=self))
def _on_subset_create(self, msg):
from jdaviz.configs.mosviz.plugins.viewers import MosvizTableViewer
if isinstance(self, MosvizTableViewer):
# MosvizTableViewer uses this as a mixin, but we do not need any of this layer
# logic there
return
# NOTE: the subscription to this method is handled in ConfigHelper
# we don't have access to the actual subset yet to tell if its spectral or spatial, so
# we'll store the name of this new subset and change the default linewidth when the
# layers are added
if not hasattr(self, '_expected_subset_layers'):
return
if msg.subset.label not in self._expected_subset_layers and msg.subset.label:
self._expected_subset_layers.append(msg.subset.label)
def _on_subset_delete(self, msg):
"""
This is needed to remove the "ghost" subset left over when the subset tool is active,
and the active subset is deleted. https://github.com/spacetelescope/jdaviz/issues/2499
is open to revert/update this if it ends up being addressed upstream in
https://github.com/glue-viz/glue-jupyter/issues/401.
"""
from jdaviz.configs.mosviz.plugins.viewers import MosvizTableViewer
if isinstance(self, MosvizTableViewer):
# MosvizTableViewer uses this as a mixin, but we do not need any of this layer
# logic there
return
subset_tools = ['bqplot:truecircle', 'bqplot:rectangle', 'bqplot:ellipse',
'bqplot:circannulus', 'bqplot:xrange']
if not len(self.session.edit_subset_mode.edit_subset):
if self.toolbar.active_tool_id in subset_tools:
if (hasattr(self.toolbar, "default_tool_priority") and
len(self.toolbar.default_tool_priority)):
self.toolbar.active_tool_id = self.toolbar.default_tool_priority[0]
else:
self.toolbar.active_tool = None
@property
def active_image_layer(self):
"""Active image layer in the viewer, if available."""
# Find visible layers
visible_layers = [layer for layer in self.state.layers
if (layer.visible and
layer_is_image_data(layer.layer) and
layer_is_not_dq(layer.layer) and
(getattr(layer, 'bitmap_visible', False) or
getattr(layer, 'contour_visible', False)))]
if len(visible_layers) == 0:
return None
z_order = [layer.zorder for layer in visible_layers]
active_index = np.argmax(z_order)
return visible_layers[active_index]
@property
def active_cube_layer(self):
"""Active cube layer in the viewer, if available."""
# Find visible layers
visible_layers = [layer for layer in self.state.layers
if (layer.visible and
layer_is_3d(layer.layer) and
layer_is_not_dq(layer.layer) and
(layer.bitmap_visible or layer.contour_visible))]
if len(visible_layers) == 0:
return None
return visible_layers[-1]
@property
def tools(self):
# NOTE: this overrides the default list of tools for the BasicJupyterToolbar by
# returning a flattened version of self.tools_nested
return list(self.toolbar.tools.keys())
@property
def jdaviz_app(self):
"""The Jdaviz application tied to the viewer."""
return self.session.jdaviz_app
@property
def jdaviz_helper(self):
"""The Jdaviz configuration helper tied to the viewer."""
return self.jdaviz_app._jdaviz_helper
@property
def hub(self):
return self.session.hub
@property
def reference_id(self):
return self._reference_id
@property
def reference(self):
return self.jdaviz_app._viewer_item_by_id(self.reference_id).get('reference')
@property
def _ref_or_id(self):
reference = getattr(self, 'reference', None)
if reference is not None:
return reference
return self.reference_id
[docs]
def set_plot_axes(self):
# individual viewers can override to set custom axes labels/ticks/styling
return
class JdavizViewerWindow(TemplateMixin):
"""
wraps a glue viewer in a single container that also includes
the toolbar and data-menu, while redirecting user_api calls
to the underlying viewer.
"""
template_file = __file__, "../../../viewer_window.vue"
id = Unicode().tag(sync=True)
name = Unicode().tag(sync=True)
reference = Unicode().tag(sync=True)
config = Unicode().tag(sync=True)
figure_widget = Unicode().tag(sync=True)
toolbar_widget = Unicode().tag(sync=True)
data_menu_widget = Unicode().tag(sync=True)
tool_override_mode = Unicode("").tag(sync=True)
viewer_destroyed = Bool(False).tag(sync=True)
def __init__(self, viewer, *args, reference="", name="", **kwargs):
super().__init__(*args, **kwargs)
self.glue_viewer = viewer
self.config = self._app.config
vid = viewer._reference_id
self.id = vid
self.name = name or vid
self.reference = reference or name or vid
self.figure_widget = "IPY_MODEL_" + viewer.figure_widget.model_id
self.toolbar_widget = "IPY_MODEL_" + viewer.toolbar.model_id if viewer.toolbar else ''
self.data_menu_widget = 'IPY_MODEL_' + viewer._data_menu.model_id if hasattr(viewer, '_data_menu') else '' # noqa
# Link tool_override_mode from toolbar (only for NestedJupyterToolbar)
if viewer.toolbar and hasattr(viewer.toolbar, 'tool_override_mode'):
self.tool_override_mode = viewer.toolbar.tool_override_mode
viewer.toolbar.observe(self._on_toolbar_override_change, names=['tool_override_mode'])
self.hub.subscribe(self, ViewerRemovedMessage, self._on_viewer_removed)
def _on_toolbar_override_change(self, change):
self.tool_override_mode = change['new']
@property
def user_api(self):
# expose show methods at this level and redirect others to the underlying viewers
from jdaviz.core.user_api import ViewerWindowUserApi
return ViewerWindowUserApi(self, expose=['show'])
def _on_viewer_removed(self, msg):
if msg.viewer_id == self.id:
self.viewer_destroyed = True
def show(self, loc="inline", title=None, height=None): # pragma: no cover
"""Display the viewer window UI.
Parameters
----------
loc : str
The display location determines where to present the viewer UI.
Supported locations:
"inline": Display the viewer inline in a notebook.
"sidecar": Display the viewer in a separate JupyterLab window from the
notebook, the location of which is decided by the 'anchor.' right is the default
Other anchors:
* ``sidecar:right`` (The default, opens a tab to the right of display)
* ``sidecar:tab-before`` (Full-width tab before the current notebook)
* ``sidecar:tab-after`` (Full-width tab after the current notebook)
* ``sidecar:split-right`` (Split-tab in the same window right of the notebook)
* ``sidecar:split-left`` (Split-tab in the same window left of the notebook)
* ``sidecar:split-top`` (Split-tab in the same window above the notebook)
* ``sidecar:split-bottom`` (Split-tab in the same window below the notebook)
See `jupyterlab-sidecar <https://github.com/jupyter-widgets/jupyterlab-sidecar>`_
for the most up-to-date options.
"popout": Display the viewer in a detached display. By default, a new
window will open. Browser popup permissions required.
Other anchors:
* ``popout:window`` (The default, opens Jdaviz in a new, detached popout)
* ``popout:tab`` (Opens Jdaviz in a new, detached tab in your browser)
title : str, optional
The title of the sidecar tab. Defaults to the name of the viewer.
NOTE: Only applicable to a "sidecar" display.
height : int, optional
The height of the viewer display, in pixels. Only applicable if loc is "inline".
Notes
-----
If "sidecar" is requested in the "classic" Jupyter notebook, the viewer will appear inline,
as only JupyterLab has a mechanism to have multiple tabs.
"""
title = title if title is not None else self.name
show_widget(self, loc=loc, title=title, height=height)
[docs]
@viewer_registry("jdaviz-profile-viewer", label="Profile 1D")
class JdavizProfileView(JdavizViewerMixin, BqplotProfileView):
# categories: zoom resets, zoom, pan, subset, select tools, shortcuts
tools_nested = [
['jdaviz:homezoom', 'jdaviz:prevzoom'],
['jdaviz:boxzoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'],
['jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'],
['bqplot:xrange'],
['jdaviz:sidebar_plot', 'jdaviz:sidebar_export']
]
default_class = NDDataArray
_state_cls = FreezableProfileViewerState
_default_profile_subset_type = None
def __init__(self, *args, **kwargs):
default_tool_priority = kwargs.pop('default_tool_priority', [])
super().__init__(*args, **kwargs)
self._subscribe_to_layers_update()
self.initialize_toolbar(default_tool_priority=default_tool_priority)
self._offscreen_lines_marks = OffscreenLinesMarks(self)
self.figure.marks = self.figure.marks + self._offscreen_lines_marks.marks
self.state.add_callback('show_uncertainty', self._show_uncertainty_changed)
self.display_mask = False
# Change collapse function to sum
default_collapse_function = kwargs.pop('default_collapse_function', 'sum')
self.state.function = default_collapse_function
def _expected_subset_layer_default(self, layer_state):
super()._expected_subset_layer_default(layer_state)
layer_state.linewidth = 3
[docs]
def data(self, cls=None):
# Grab the user's chosen statistic for collapsing data
statistic = getattr(self.state, 'function', None)
data = []
for layer_state in self.state.layers:
if hasattr(layer_state, 'layer'):
lyr = layer_state.layer
# For raw data, just include the data itself
if isinstance(lyr, BaseData):
_class = cls or self.default_class
if _class is not None:
cache_key = (lyr.label, statistic)
if cache_key in self.jdaviz_app._get_object_cache:
layer_data = self.jdaviz_app._get_object_cache[cache_key]
else:
# If spectrum, collapse via the defined statistic
if _class == Spectrum:
layer_data = lyr.get_object(cls=_class, statistic=statistic)
else:
layer_data = lyr.get_object(cls=_class)
self.jdaviz_app._get_object_cache[cache_key] = layer_data
data.append(layer_data)
# For subsets, make sure to apply the subset mask to the layer data first
elif isinstance(lyr, Subset):
layer_data = lyr
if _class is not None:
handler, _ = data_translator.get_handler_for(_class)
try:
layer_data = handler.to_object(layer_data, statistic=statistic)
except IncompatibleAttribute:
continue
data.append(layer_data)
return data
[docs]
def get_scales(self):
fig = self.figure
# Deselect any pan/zoom or subsetting tools so they don't interfere
# with the scale retrieval
if self.toolbar.active_tool is not None:
self.toolbar.active_tool = None
return {'x': fig.interaction.x_scale, 'y': fig.interaction.y_scale}
def _show_uncertainty_changed(self, msg=None):
# this is subscribed in init to watch for changes to the state
# object since uncertainty handling is in jdaviz instead of glue/glue-jupyter
if self.state.show_uncertainty:
self._plot_uncertainties()
else:
self._clean_error()
[docs]
def show_mask(self):
self.display_mask = True
self._plot_mask()
[docs]
def clean(self):
# Remove extra traces, in case they exist.
self.display_mask = False
self._clean_mask()
# this will automatically call _clean_error via _show_uncertainty_changed
self.state.show_uncertainty = False
def _clean_mask(self):
fig = self.figure
fig.marks = [x for x in fig.marks if not isinstance(x, ScatterMask)]
def _clean_error(self):
fig = self.figure
fig.marks = [x for x in fig.marks if not isinstance(x, LineUncertainties)]
[docs]
def add_data(self, data, color=None, alpha=None, **layer_state):
"""
Overrides the base class to add markers for plotting
uncertainties and data quality flags.
Parameters
----------
spectrum : :class:`glue.core.data.Data`
Data object with the spectrum.
color : obj
Color value for plotting.
alpha : float
Alpha value for plotting.
Returns
-------
result : bool
`True` if successful, `False` otherwise.
"""
# If this is the first loaded data, set things up for unit conversion.
if len(self.layers) == 0:
reset_plot_axes = True
else:
# Check if the new data flux unit is actually compatible since flux not linked.
try:
if (self.state.y_display_unit not in ['None', None, 'DN'] and
hasattr(data.get_component('flux').data, 'units')):
psc = data.meta.get('_pixel_scale_factor', None)
cube_wave = data.get_component('spectral')
eqv = all_flux_unit_conversion_equivs(pixar_sr=psc, cube_wave=cube_wave)
flux_conversion_general([1, 1],
data.get_component('flux').data.units,
self.state.y_display_unit,
equivalencies=eqv)
except Exception as e:
# Raising exception here introduces a dirty state that messes up next load_data
# but not raising exception also causes weird behavior unless we remove the data
# completely.
self.session.hub.broadcast(SnackbarMessage(
f"Failed to load {data.label}, so removed it: {repr(e)}",
sender=self, color='error', traceback=e))
self.jdaviz_app.data_collection.remove(data)
return False
reset_plot_axes = False
# The base class handles the plotting of the main
# trace representing the profile itself.
result = super().add_data(data, color, alpha, **layer_state)
if reset_plot_axes:
x_units = data.get_component(self.state.x_att.label).units
y_axis_component = (
'flux' if 'flux' in [comp.label for comp in self.state.layers[0].layer.components]
else 'data'
)
y_units = data.get_component(y_axis_component).units
with delay_callback(self.state, "x_display_unit", "y_display_unit"):
self.state.x_display_unit = x_units if len(x_units) else None
self.state.y_display_unit = y_units if len(y_units) else None
self.set_plot_axes()
self._plot_uncertainties()
self._plot_mask()
# Set default linewidth on any created spectral subset layers
# NOTE: this logic will need updating if we add support for multiple cubes as this assumes
# that new data entries (from model fitting or gaussian smooth, etc) will only be spectra
# and all subsets affected will be spectral
for layer in self.state.layers:
if (isinstance(layer.layer, GroupedSubset)
and get_subset_type(layer.layer) == self._default_profile_subset_type
and layer.layer.data.label == data.label):
layer.linewidth = 3
return result
def _plot_mask(self):
if not self.display_mask:
return
# Remove existing mask marks
self._clean_mask()
# Loop through all active data in the viewer
for index, layer_state in enumerate(self.state.layers):
lyr = layer_state.layer
comps = [str(component) for component in lyr.components]
# Skip subsets
if hasattr(lyr, "subset_state"):
continue
# Ignore data that does not have a mask component
if "mask" in comps:
mask = np.array(lyr['mask'].data)
data_obj = lyr.data.get_object(cls=self.default_class)
if self.default_class == Spectrum:
data_x = data_obj.spectral_axis.value
data_y = data_obj.flux.value
else:
data_x = np.arange(data_obj.shape[-1])
data_y = data_obj.data.value
# For plotting markers only for the masked data
# points, erase un-masked data from trace.
y = np.where(np.asarray(mask) == 0, np.nan, data_y)
# A subclass of the bqplot Scatter object, ScatterMask places
# 'X' marks where there is masked data in the viewer.
color = layer_state.color
alpha_shade = layer_state.alpha / 3
mask_line_mark = ScatterMask(scales=self.scales,
marker='cross',
x=data_x,
y=y,
stroke_width=0.5,
colors=[color],
default_size=25,
default_opacities=[alpha_shade]
)
# Add mask marks to viewer
self.figure.marks = list(self.figure.marks) + [mask_line_mark]
def _plot_uncertainties(self):
if not self.state.show_uncertainty:
return
# Remove existing error bars
self._clean_error()
# Loop through all active data in the viewer
for index, layer_state in enumerate(self.state.layers):
lyr = layer_state.layer
# Skip subsets
if hasattr(lyr, "subset_state"):
continue
comps = [str(component) for component in lyr.components]
# Ignore data that does not have an uncertainty component
if "uncertainty" in comps: # noqa
error = np.array(lyr['uncertainty'].data)
# ensure that the uncertainties are represented as stddev:
uncertainty_type_str = lyr.meta.get('uncertainty_type', 'stddev')
uncert_cls = uncertainty_str_to_cls_mapping[uncertainty_type_str]
error = uncert_cls(error).represent_as(StdDevUncertainty).array
if 'spectral_axis_index' in lyr.data.meta:
spectral_axis_index = lyr.data.meta['spectral_axis_index']
else:
# We have to make an assumption in this case.
# TODO: Should we have other handling for non-spectral data (e.g. light curves?)
spectral_axis_index = -1
data_obj = lyr.data.get_object(cls=self.default_class, statistic=None)
lyr_coords = lyr.data.coords
if isinstance(lyr_coords, SpectralCoordinates):
spectral_wcs = lyr_coords
data_x = spectral_wcs.pixel_to_world_values(
np.arange(lyr.data.shape[spectral_axis_index])
)
if isinstance(data_x, tuple):
data_x = data_x[0]
else:
if hasattr(lyr_coords, 'spectral_wcs'):
spectral_wcs = lyr_coords.spectral_wcs
elif hasattr(lyr_coords, 'spectral'):
spectral_wcs = lyr_coords.spectral
elif hasattr(lyr_coords, "world_n_dim") and lyr_coords.world_n_dim == 1:
# 1D GWCS in this case, just use the coords
spectral_wcs = lyr_coords
data_x = spectral_wcs.pixel_to_world(
np.arange(lyr.data.shape[spectral_axis_index])
)
data_y = data_obj.data
# The shaded band around the spectrum trace is bounded by
# two lines, above and below the spectrum trace itself.
data_x_list = np.ndarray.tolist(data_x)
x = [data_x_list, data_x_list]
y = [np.ndarray.tolist(data_y - error),
np.ndarray.tolist(data_y + error)]
if layer_state.as_steps:
for i in (0, 1):
a = np.insert(x[i], 0, 2*x[i][0] - x[i][1])
b = np.append(x[i], 2*x[i][-1] - x[i][-2])
edges = (a + b) / 2
x[i] = np.concatenate((edges[:1], np.repeat(edges[1:-1], 2), edges[-1:]))
y[i] = np.repeat(y[i], 2)
x, y = np.asarray(x), np.asarray(y)
# A subclass of the bqplot Lines object, LineUncertainties keeps
# track of uncertainties plotted in the viewer. LineUncertainties
# appear with two lines and shaded area in between.
color = layer_state.color
alpha_shade = layer_state.alpha / 3
error_line_mark = LineUncertainties(viewer=self,
x=[x],
y=[y],
scales=self.scales,
stroke_width=1,
colors=[color, color],
fill_colors=[color, color],
opacities=[0.0, 0.0],
fill_opacities=[alpha_shade,
alpha_shade],
fill='between',
close_path=False
)
# Add error lines to viewer
self.figure.marks = list(self.figure.marks) + [error_line_mark]
[docs]
def set_plot_axes(self):
# Set x and y axes labels for the spectrum viewer
y_display_unit = self.state.y_display_unit
y_unit = (
u.Unit(y_display_unit) if y_display_unit and y_display_unit != 'None'
else u.dimensionless_unscaled
)
# Get local units.
locally_defined_flux_units = [
u.Jy, u.mJy, u.uJy, u.MJy,
u.W / (u.m**2 * u.Hz),
u.eV / (u.s * u.m**2 * u.Hz),
u.erg / (u.s * u.cm**2),
u.erg / (u.s * u.cm**2 * u.Angstrom),
u.erg / (u.s * u.cm**2 * u.Hz),
u.ph / (u.s * u.cm**2 * u.Angstrom),
u.ph / (u.s * u.cm**2 * u.Hz),
u.bol, u.AB, u.ST
]
# get square angle from 'sb' display unit
sb_unit = self.jdaviz_app._get_display_unit(axis='sb')
if sb_unit is not None:
solid_angle_unit = check_if_unit_is_per_solid_angle(sb_unit, return_unit=True)
else:
solid_angle_unit = None
# if solid angle is present in denominator, check physical type of numerator
# if numerator is a flux type the display unit is a 'surface brightness', otherwise
# default to the catchall 'flux density' label
flux_unit_type = None
for un in locally_defined_flux_units:
locally_defined_sb_unit = un / solid_angle_unit if solid_angle_unit is not None else None # noqa
# create an equivalency for each flux unit for flux <> flux/pix2.
# for similar reasons to the 'untranslatable units' issue, custom
# equivs. can't be combined, so a workaround is creating an eqiv
# for each flux that may need an additional equiv.
angle_to_pixel_equiv = _eqv_sb_per_pixel_to_per_angle(un)
if (locally_defined_sb_unit is not None
and y_unit.is_equivalent(locally_defined_sb_unit, angle_to_pixel_equiv)):
flux_unit_type = "Surface Brightness"
elif y_unit.is_equivalent(un):
flux_unit_type = 'Flux'
elif y_unit.is_equivalent(u.electron / u.s) or y_unit.physical_type == 'dimensionless': # noqa
# electron / s or 'dimensionless_unscaled' should be labeled counts
flux_unit_type = "Counts"
elif y_unit.is_equivalent(u.W):
flux_unit_type = "Luminosity"
if flux_unit_type is not None:
# if we determined a label, stop checking
break
else:
# default to Flux Density for flux density or uncaught types
flux_unit_type = "Flux density"
# Set x axes labels for the spectrum viewer
x_disp_unit = self.state.x_display_unit
x_unit = u.Unit(x_disp_unit) if x_disp_unit else u.dimensionless_unscaled
if x_unit.is_equivalent(u.m):
spectral_axis_unit_type = "Wavelength"
elif x_unit.is_equivalent(u.Hz):
spectral_axis_unit_type = "Frequency"
elif x_unit.is_equivalent(u.pixel):
spectral_axis_unit_type = "Pixel"
elif x_unit.is_equivalent(u.dimensionless_unscaled):
# case for rampviz
spectral_axis_unit_type = "Group"
else:
spectral_axis_unit_type = str(x_unit.physical_type).title()
with self.figure.hold_sync():
self.figure.axes[0].label = f"{spectral_axis_unit_type}" + (
f" [{self.state.x_display_unit}]"
if self.state.x_display_unit not in ["None", None] else ""
)
self.figure.axes[1].label = f"{flux_unit_type}" + (
f"[{self.state.y_display_unit}]"
if self.state.y_display_unit not in ["None", None] else ""
)
# Make it so axis labels are not covering tick numbers.
self.figure.fig_margin["left"] = 95
self.figure.fig_margin["bottom"] = 60
self.figure.send_state('fig_margin') # Force update
self.figure.axes[0].label_offset = "40"
self.figure.axes[1].label_offset = "-70"
# NOTE: with tick_style changed below, the default responsive ticks in bqplot result
# in overlapping tick labels. For now we'll hardcode at 8, but this could be removed
# (default to None) if/when bqplot auto ticks react to styling options.
self.figure.axes[1].num_ticks = 8
# Set Y-axis to scientific notation
self.figure.axes[1].tick_format = '0.1e'
for i in (0, 1):
self.figure.axes[i].tick_style = {'font-size': 15, 'font-weight': 600}
@viewer_registry("scatter-viewer", label="scatter")
class ScatterViewer(JdavizViewerMixin, BqplotScatterView):
def __init__(self, session, *args, **kwargs):
super().__init__(session, *args, **kwargs)
# make y axis scientific notation for readibility
self.figure.axes[1].tick_format = '0.1e'
# offset y axis label so it doesn't cover y tick labels
self.figure.axes[1].label_offset = "-50"
# make axis labels a little smaller so that they're the same size, but the
# y axis label fits
self.figure.axes[0].axis_label_style = {'font-size': '12px'} # x
self.figure.axes[1].axis_label_style = {'font-size': '12px'} # y
# categories: zoom resets, zoom, pan, subset, select tools, shortcuts
tools_nested = [
['jdaviz:homezoom', 'jdaviz:prevzoom'],
['jdaviz:boxzoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'],
['jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'],
['bqplot:xrange', 'bqplot:yrange', 'bqplot:rectangle'],
[],
['jdaviz:viewer_clone', 'jdaviz:sidebar_plot', 'jdaviz:sidebar_export']
]
_state_cls = ScatterViewerState
_native_mark_classnames = ('Scatter', 'ScatterGL')
@viewer_registry("histogram-viewer", label="histogram")
class HistogramViewer(JdavizViewerMixin, BqplotHistogramView):
# categories: zoom resets, zoom, pan, subset, select tools, shortcuts
tools_nested = [
['jdaviz:homezoom', 'jdaviz:prevzoom'],
['jdaviz:boxzoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'],
['jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'],
['bqplot:xrange'],
[],
['jdaviz:viewer_clone', 'jdaviz:sidebar_plot', 'jdaviz:sidebar_export']
]
_state_cls = HistogramViewerState
_native_mark_classnames = ('Bars', 'BarsGL')
@viewer_registry("table-viewer", label="table")
class JdavizTableViewer(JdavizViewerMixin, TableViewer):
# categories: zoom resets, zoom, pan, subset, select tools, shortcuts
tools_nested = [
['jdaviz:table_highlight_selected'],
['jdaviz:table_zoom_to_selected'],
['jdaviz:table_subset'],
['jdaviz:viewer_clone']
]
def __init__(self, session, *args, **kwargs):
super().__init__(session, *args, **kwargs)
# enable scrolling: # https://github.com/glue-viz/glue-jupyter/pull/287
self.widget_table.scrollable = True
# hide checkboxes by default (shown when TableSubset tool is activated)
self.widget_table.selection_enabled = False
self.data_menu._obj.dataset.add_filter('is_catalog')
self.widget_table.observe(lambda _: self.toolbar._update_tool_visibilities(),
names=['checked'])
# Also update selection highlight marks when checked rows change
self.widget_table.observe(self._on_checked_changed, names=['checked'])
self.widget_table.observe(self._on_selection_enabled_changed, names=['selection_enabled'])
# Subscribe to RestoreToolbarMessage to clean up checkbox state
# when toolbar is restored (e.g., by clicking X on custom toolbar)
self.hub.subscribe(self, RestoreToolbarMessage,
handler=self._on_restore_toolbar)
# Subscribe to TableSelectRowClickMessage to handle clicks from image viewers
self.hub.subscribe(self, TableSelectRowClickMessage,
handler=self._on_table_select_row_click)
# Subscribe to ViewerRemovedMessage to clean up toolbar overrides
# if this table viewer is removed while tools are active
self.hub.subscribe(self, ViewerRemovedMessage,
handler=self._on_viewer_removed)
def _on_table_select_row_click(self, msg):
"""Handle click from image viewer to select/toggle closest table row."""
# Only respond if this message is for this table viewer
if msg.table_viewer_id != self.reference_id:
return
if not len(self.layers):
return
# Get pixel coordinates from the message (these are in reference data frame)
click_x, click_y = msg.x, msg.y
try:
layer = self.layers[0].layer
# Get sky coordinates for WCS-accurate comparison.
# Click coordinates are in the viewer's reference frame, so catalog
# coordinates must also be converted to that frame for proper matching.
xs, ys = None, None
skycoords = _get_skycoords_from_table(layer)
if skycoords is not None:
# Convert sky coordinates to pixels in the viewer's reference frame
for viewer in self.jdaviz_app.get_viewers_of_cls('ImvizImageView'):
if viewer.state.reference_data is None:
continue
if viewer.state.reference_data.coords is None:
continue
pixel_result = viewer.state.reference_data.coords.world_to_pixel(skycoords)
xs, ys = pixel_result[0], pixel_result[1]
break
else:
# Fall back to pixel coordinates only if no sky coordinates available
pixel_coords = _get_pixel_coords_from_table(layer)
if pixel_coords is not None:
xs, ys = pixel_coords
if xs is None or ys is None:
return
# Find nearest point and toggle its selection
distsq = (xs - click_x)**2 + (ys - click_y)**2
ind = int(np.argmin(distsq))
current_checked = list(self.widget_table.checked)
if ind in current_checked:
current_checked.remove(ind)
else:
current_checked.append(ind)
self.widget_table.checked = current_checked
except Exception: # nosec # pragma: no cover
pass
def _on_checked_changed(self, change):
"""Update highlight marks in image viewers when checked rows change."""
self._update_selection_marks()
def _on_selection_enabled_changed(self, change):
"""Show/hide selection marks when selection is enabled/disabled."""
if not change['new']:
# Selection disabled, clear all marks
self._clear_selection_marks()
else:
# Selection enabled, update marks
self._update_selection_marks()
def _get_selection_mark(self, viewer):
"""Get or create a selection highlight mark for the given viewer."""
matches = [mark for mark in viewer.figure.marks if isinstance(mark, TableSelectionMark)]
if len(matches):
return matches[0]
mark = TableSelectionMark(viewer)
viewer.figure.marks = viewer.figure.marks + [mark]
return mark
def _update_selection_marks(self):
"""Update selection highlight marks in all image viewers."""
if not self.widget_table.selection_enabled:
return
checked_rows = self.widget_table.checked
if not len(checked_rows) or not len(self.layers):
self._clear_selection_marks()
return
layer = self.layers[0].layer
# Get sky coordinates for WCS-accurate placement across different images.
# Pixel coordinates from the catalog are in the catalog's original image frame,
# which may differ from the viewer's reference data frame.
skycoords = _get_skycoords_from_table(layer, checked_rows)
pixel_coords = None
if skycoords is None:
# Fall back to pixel coordinates only if no sky coordinates available
pixel_coords = _get_pixel_coords_from_table(layer, checked_rows)
if pixel_coords is None:
self._clear_selection_marks()
return
# Update marks in all image viewers
for viewer in self.jdaviz_app.get_viewers_of_cls('ImvizImageView'):
try:
if skycoords is not None:
# Convert sky coordinates to pixels for this viewer's reference frame
coords = viewer.state.reference_data.coords.world_to_pixel(skycoords)
xs, ys = coords[0], coords[1]
else:
# Use pixel coordinates directly (last resort when no sky coords)
xs, ys = pixel_coords[0], pixel_coords[1]
mark = self._get_selection_mark(viewer)
mark.update_xy(xs, ys)
mark.visible = True
except Exception: # nosec # pragma: no cover
pass
def _clear_selection_marks(self):
"""Clear selection highlight marks from all image viewers."""
for viewer in self.jdaviz_app.get_viewers_of_cls('ImvizImageView'):
for mark in viewer.figure.marks:
if isinstance(mark, TableSelectionMark):
mark.visible = False
def _on_restore_toolbar(self, msg={}):
"""Clean up checkbox state when toolbar is restored."""
# Clear selection marks
self._clear_selection_marks()
# Hide checkboxes (they should always be hidden when default toolbar is shown)
self.widget_table.selection_enabled = False
def _on_viewer_removed(self, msg):
"""Clean up selection marks if this table viewer is removed."""
if msg.viewer_id != self.reference_id:
return
# Clear selection marks in image viewers when this table viewer is removed
# (toolbar cleanup is handled generically by NestedJupyterToolbar)
self._clear_selection_marks()