Source code for jdaviz.configs.default.plugins.viewers

from echo import delay_callback, CallbackProperty

import warnings

import numpy as np
from numpy.linalg import norm

from glue.config import data_translator
from glue.core import BaseData
from glue.core.edit_subset_mode import NewMode, ReplaceMode
from glue.core.exceptions import IncompatibleAttribute
from glue.core.roi import CircularAnnulusROI, CircularROI, EllipticalROI, RectangularROI
from glue.core.subset import Subset
from glue.core.subset_group import GroupedSubset
from glue.viewers.histogram.state import HistogramViewerState
from glue.viewers.scatter.state import ScatterViewerState
from glue.viewers.scatter.state import ScatterLayerState as BqplotScatterLayerState
from glue.utils import avoid_circular

from glue_astronomy.spectral_coordinates import SpectralCoordinates
from glue_jupyter.bqplot.profile import BqplotProfileView
from glue_jupyter.bqplot.histogram import BqplotHistogramView
from glue_jupyter.bqplot.image import BqplotImageView
from glue_jupyter.bqplot.scatter import BqplotScatterView
from glue_jupyter.table import TableViewer

from astropy.utils import deprecated
from astropy import units as u
from astropy.nddata import (
    NDDataArray, StdDevUncertainty, VarianceUncertainty, InverseVariance
)
from specutils import Spectrum

from traitlets import Bool, Unicode

from jdaviz.components.toolbar_nested import NestedJupyterToolbar
from jdaviz.configs.default.plugins.data_menu import DataMenu
from jdaviz.core.astrowidgets_api import AstrowidgetsImageViewerMixin
from jdaviz.core.custom_units_and_equivs import _eqv_sb_per_pixel_to_per_angle
from jdaviz.core.events import (SnackbarMessage,
                                NewViewerMessage,
                                ViewerRemovedMessage,
                                ViewerVisibleLayersChangedMessage,
                                RestoreToolbarMessage,
                                TableSelectRowClickMessage)
from jdaviz.core.freezable_state import FreezableProfileViewerState
from jdaviz.core.marks import (LineUncertainties, ScatterMask,
                               OffscreenLinesMarks, TableSelectionMark)
from jdaviz.core.registries import viewer_registry
from jdaviz.core.template_mixin import WithCache, TemplateMixin, show_widget
from jdaviz.core.tools import _get_skycoords_from_table, _get_pixel_coords_from_table
from jdaviz.core.user_api import ViewerUserApi
from jdaviz.core.unit_conversion_utils import (check_if_unit_is_per_solid_angle,
                                               flux_conversion_general,
                                               all_flux_unit_conversion_equivs)
from jdaviz.utils import (ColorCycler, get_subset_type, _wcs_only_label,
                          layer_is_image_data, layer_is_not_dq, layer_is_3d)

uncertainty_str_to_cls_mapping = {
    "std": StdDevUncertainty,
    "var": VarianceUncertainty,
    "ivar": InverseVariance
}


__all__ = ['JdavizViewerMixin', 'JdavizProfileView']

viewer_registry.add("g-profile-viewer", label="Profile 1D", cls=BqplotProfileView)
viewer_registry.add("g-image-viewer", label="Image 2D", cls=BqplotImageView)
viewer_registry.add("g-table-viewer", label="Table", cls=TableViewer)


[docs] class JdavizViewerMixin(WithCache): toolbar = None tools_nested = [] _prev_limits = None _native_mark_classnames = ('Lines', 'LinesGL', 'FRBImage', 'Contour') def __init__(self, *args, **kwargs): # NOTE: anything here most likely won't be called by viewers because of inheritance order super().__init__(*args, **kwargs) # Allow each viewer to cycle through colors for each new addition to the viewer: self.color_cycler = ColorCycler() # Separate color cycler for scatter layers (catalogs) that has brighter colors, # starting with neon green self.scatter_color_cycler = ColorCycler() self.scatter_color_cycler.default_color_palette = [ '#00FF00', # neon green '#FF00FF', # magenta '#00FFFF', # cyan '#FF0000', # red '#FFFF00', # yellow '#FF8800', # orange '#8800FF', # purple '#0088FF', # blue ] self._data_menu = DataMenu(viewer=self, app=self.jdaviz_app) @staticmethod def _is_circular_edit(old_roi, new_roi, rtol_move=1e-3, rtol_resize=0.25): """ Classify the change from *old_roi* to *new_roi* for circular ROIs. Checks for moves, i.e. the radius is unchanged within a specified tolerance, and resizes, i.e. the radius is changed within a specified tolerance relative to the old radius. Parameters ---------- old_roi : CircularROI The existing circular ROI. new_roi : CircularROI The newly proposed circular ROI. rtol_move : float, optional Relative tolerance for move-related comparisons. rtol_resize : float, optional Relative tolerance for resize-related comparisons. Returns ------- edit_type : 'move', 'resize', or None """ if old_roi.radius <= 0: return None # Move: radius unchanged. if abs(old_roi.radius - new_roi.radius) < rtol_move * max(old_roi.radius, 1): return 'move' # Resize: radius changed but center stayed very close dist = norm([new_roi.xc - old_roi.xc, new_roi.yc - old_roi.yc]) if dist < (rtol_resize * max(old_roi.radius, 1)): return 'resize' return None @staticmethod def _is_annulus_edit(old_roi, new_roi, rtol_move=1e-3, rtol_resize=0.25): """ Classify the change from *old_roi* to *new_roi* for annular ROIs. Checks for moves, i.e. the radii are unchanged within a specified tolerance, and resizes, i.e. the radii change but the new center stays within an area defined by the old radius and a specified tolerance. Parameters ---------- old_roi : CircularAnnulusROI The existing annular ROI. new_roi : CircularAnnulusROI The newly proposed annular ROI. rtol_move : float, optional Relative tolerance for move-related comparisons. rtol_resize : float, optional Relative tolerance for resize-related comparisons. Returns ------- edit_type : 'move', 'resize', or None """ if old_roi.outer_radius == 0 or old_roi.outer_radius <= old_roi.inner_radius: return None # Move: both radii unchanged. if (abs(old_roi.inner_radius - new_roi.inner_radius) < rtol_move * max(old_roi.inner_radius, 1) # noqa and abs(old_roi.outer_radius - new_roi.outer_radius) < rtol_resize * max(old_roi.outer_radius, 1)): # noqa return 'move' # Resize: radii changed but center stayed very close dist = norm([new_roi.xc - old_roi.xc, new_roi.yc - old_roi.yc]) if dist < (rtol_resize * max(old_roi.outer_radius, 1)): return 'resize' return None @staticmethod def _is_elliptical_edit(old_roi, new_roi, rtol_move=1e-3, rtol_resize=0.25): """ Classify the change from *old_roi* to *new_roi* for elliptical ROIs. Checks for moves, i.e. the radii are unchanged within a specified tolerance, and resizes, i.e. either radius changes but the new center stays within an area defined by the old radius and a specified tolerance. Parameters ---------- old_roi : EllipticalROI The existing elliptical ROI. new_roi : EllipticalROI The newly proposed elliptical ROI. rtol_move : float, optional Relative tolerance for move-related comparisons. rtol_resize : float, optional Relative tolerance for resize-related comparisons. Returns ------- edit_type : 'move', 'resize', or None """ size = max(old_roi.radius_x, old_roi.radius_y) if size <= 0 or min(old_roi.radius_x, old_roi.radius_y) <= 0: return None # Move: both radii unchanged. if (abs(old_roi.radius_x - new_roi.radius_x) < rtol_move * max(old_roi.radius_x, 1) and abs(old_roi.radius_y - new_roi.radius_y) < rtol_move * max(old_roi.radius_y, 1)): # noqa return 'move' # Resize: radii changed but center stayed very close dist = norm([new_roi.xc - old_roi.xc, new_roi.yc - old_roi.yc]) if dist < rtol_resize * max(size, 1): return 'resize' return None @staticmethod def _is_rectangular_edit(old_roi, new_roi, rtol_move=1e-3, rtol_resize=0.25): """ Classify the change from *old_roi* to *new_roi* for rectangular ROIs. Checks for moves, i.e. the width and height are unchanged within a specified tolerance, and resizes, i.e. when dimensions change but the new center stays within an area defined by the old dimensions and a specified tolerance. Parameters ---------- old_roi : RectangularROI The existing rectangular ROI. new_roi : RectangularROI The newly proposed rectangular ROI. rtol_move : float, optional Relative tolerance for move-related comparisons. rtol_resize : float, optional Relative tolerance for resize-related comparisons. Returns ------- edit_type : 'move', 'resize', or None """ old_w = old_roi.xmax - old_roi.xmin old_h = old_roi.ymax - old_roi.ymin new_w = new_roi.xmax - new_roi.xmin new_h = new_roi.ymax - new_roi.ymin size = max(old_w, old_h) / 2 if size <= 0 or old_w <= 0 or old_h <= 0: return None # Move: same width and height. if (abs(old_w - new_w) < rtol_move * max(old_w, 1) and abs(old_h - new_h) < rtol_move * max(old_h, 1)): return 'move' # Resize: dimensions changed but center stayed very close old_cx = (old_roi.xmin + old_roi.xmax) / 2 old_cy = (old_roi.ymin + old_roi.ymax) / 2 new_cx = (new_roi.xmin + new_roi.xmax) / 2 new_cy = (new_roi.ymin + new_roi.ymax) / 2 dist = norm([new_cx - old_cx, new_cy - old_cy]) if dist < rtol_resize * max(size, 1): return 'resize' return None def _is_roi_edit(self, old_roi, new_roi): """ Classify the change from *old_roi* to *new_roi*. Parameters ---------- old_roi : ROI The existing subset's ROI. new_roi : ROI The newly drawn ROI. Returns ------- edit_type : 'move', 'resize', or None """ if not isinstance(new_roi, type(old_roi)) or not isinstance(old_roi, type(new_roi)): return None handlers = [(CircularAnnulusROI, self._is_annulus_edit), (CircularROI, self._is_circular_edit), (EllipticalROI, self._is_elliptical_edit), (RectangularROI, self._is_rectangular_edit)] for roi_type, handler in handlers: if isinstance(old_roi, roi_type): return handler(old_roi, new_roi) return None @staticmethod def _is_range_edit(old_range, new_range, rtol=1e-6): """ Classify the change from *old_range* to *new_range* for 1-D ranges. Checks for moves, i.e. the width is unchanged within a specified tolerance, and resizes, i.e. when the width changes but one endpoint is fixed. Parameters ---------- old_range : RangeSubsetState The existing 1-D subset state. new_range : ROI The new ROI with ``min`` and ``max`` attributes. rtol : float, optional Relative tolerance for the width comparison. Returns ------- edit_type : 'move', 'resize', or None """ if not (hasattr(new_range, 'min') and hasattr(new_range, 'max')): return None old_w = old_range.hi - old_range.lo new_w = new_range.max - new_range.min if old_w <= 0 or new_w <= 0: return None # Move: same width. if abs(new_w - old_w) < (rtol * max(abs(old_w), 1)): return 'move' # Resize: width changed, one endpoint fixed. if old_range.lo == new_range.min or old_range.hi == new_range.max: return 'resize' return None
[docs] def apply_roi(self, roi, use_current=False): """ Apply an ROI to the viewer. Detects whether the user is resizing or moving an existing subset. When detected and mode is NewMode, temporarily switches to ReplaceMode so the subset is modified in-place rather than duplicated. """ from glue.core.subset import RoiSubsetState, RangeSubsetState edit_subset_mode = self.session.edit_subset_mode needs_override = False edit_type = None if (not getattr(self.jdaviz_app, '_importing_regions', False) and len(edit_subset_mode.edit_subset) > 0 and edit_subset_mode.mode is NewMode): existing_state = edit_subset_mode.edit_subset[0].subset_state if isinstance(existing_state, RoiSubsetState): old_roi = existing_state.roi edit_type = self._is_roi_edit(old_roi, roi) elif isinstance(existing_state, RangeSubsetState): edit_type = self._is_range_edit(existing_state, roi) needs_override = edit_type is not None if needs_override and edit_type == 'move': # Warn when the new subset has identical dimensions to the # existing one, it will be treated as a move in-place rather # than a new subset. warnings.warn( 'The new subset has the same dimensions as the ' 'existing subset and will be treated as a move ' 'operation. To create a new subset with the same ' 'dimensions, use import_region instead.', stacklevel=2) original_mode = edit_subset_mode.mode if needs_override: edit_subset_mode._mode = ReplaceMode try: super().apply_roi(roi, use_current=use_current) finally: if needs_override: edit_subset_mode._mode = original_mode
@property def user_api(self): # default exposed user APIs. Can override this method in any particular viewer. if not (isinstance(self, TableViewer) and self.__class__.__name__ == 'MosvizTableViewer'): # TODO: eventually remove data_labels_loaded # and data_labels_visible once deprecation period passes expose = ['data_labels_loaded', 'data_labels_visible', 'data_menu'] else: expose = [] if self.jdaviz_app.config == 'deconfigged': expose += ['clone_viewer'] if isinstance(self, BqplotImageView): if isinstance(self, AstrowidgetsImageViewerMixin): expose += ['save', 'center_on', 'offset_by', 'zoom_level', 'zoom', 'colormap_options', 'set_colormap', 'stretch_options', 'stretch', 'autocut_options', 'cuts', 'marker', 'add_markers', 'remove_markers', 'reset_markers', 'blink_once', 'reset_limits', 'get_viewport_region'] else: # cubeviz image viewers don't inherit from AstrowidgetsImageViewerMixin yet, # but also shouldn't expose set_limits because of equal aspect ratio concerns expose += [] elif isinstance(self, TableViewer): expose += [] else: expose += ['set_limits', 'reset_limits', 'set_tick_format'] return ViewerUserApi(self, expose=expose) @property def data_menu(self): return self._data_menu.user_api def _deprecated_data_menu(self): # temporary method to allow for opening new data-menu from old button. This should # be removed anytime after the old button is removed (likely in 4.3) self.data_menu.open_menu() @property @deprecated(since="4.1", alternative="viewer.data_menu.data_labels_loaded") def data_labels_loaded(self): """ List of data labels loaded in this viewer. Returns ------- data_labels : list list of strings """ return self.data_menu.data_labels_loaded @property @deprecated(since="4.1", alternative="viewer.data_menu.data_labels_visible") def data_labels_visible(self): """ List of data labels visible in this viewer. Returns ------- data_labels : list list of strings """ return self.data_menu.data_labels_visible def _get_clone_viewer_reference(self): return self.jdaviz_helper._get_clone_viewer_reference(self.reference) def _clone_viewer_outside_app(self): new_viewer = type(self)(session=self.session) # TODO: this is needed for jdaviz only, without it it should also # work for glue-jupyter new_viewer._reference_id = self.reference_id + "_TODO_IS_THIS_OK?" d = self.state.as_dict() new_viewer.state.update_from_dict(d) for layer in self.layers: layer_state = layer.state new_layer = type(layer)(view=new_viewer, viewer_state=new_viewer.state, layer_state=layer_state) new_layer.update() return new_viewer
[docs] def clone_viewer(self): name = self.jdaviz_helper._get_clone_viewer_reference(self.reference) self.jdaviz_app._on_new_viewer(NewViewerMessage(self.__class__, data=None, sender=self.jdaviz_app), vid=name, name=name) new_viewer = self.jdaviz_app.get_viewer(name) visible_layers = self.data_menu.data_labels_visible for layer in self.data_menu.data_labels_loaded[::-1]: visible = layer in visible_layers new_viewer.data_menu.add_data(layer) if hasattr(new_viewer.data_menu, 'set_layer_visibility'): new_viewer.data_menu.set_layer_visibility(layer, visible) # TODO: don't revert color when adding same data to a new viewer # allow viewers to set attributes (not in state) on cloned viewers for attr in getattr(self, '_clone_attrs', []): if hasattr(self, attr): setattr(new_viewer, attr, getattr(self, attr)) new_viewer.state.update_from_dict(self.state.as_dict()) for this_layer_state, new_layer_state in zip(self.state.layers, new_viewer.state.layers): for k, v in this_layer_state.as_dict().items(): if k in ('layer',): continue setattr(new_layer_state, k, v) return JdavizViewerWindow(new_viewer, app=self.jdaviz_app).user_api
[docs] def reset_limits(self): """ Reset viewer axes limits. """ self.state.reset_limits()
[docs] def set_limits(self, x_min=None, x_max=None, y_min=None, y_max=None): """ Set viewer axes limits. Parameters ---------- x_min : float or None, optional lower-limit of x-axis (in current axes units) x_max: float or None, optional upper-limit of x-axis (in current axes units) y_min : float or None, optional lower-limit of y-axis (in current axes units) y_max: float or None, optional upper-limit of y-axis (in current axes units) """ for val in (x_min, x_max, y_min, y_max): if val is not None and not isinstance(val, (float, int, np.float32)): raise TypeError('All arguments must be None, int, or float, ' f'but got: {type(val)}') with delay_callback(self.state, 'x_min', 'x_max', 'y_min', 'y_max'): if x_min is not None: self.state.x_min = x_min if x_max is not None: self.state.x_max = x_max # NOTE: for some reason, setting ymax first avoids an issue # where back-to-back calls of get_limits and set_limits # give different results for y limits. if y_max is not None: self.state.y_max = y_max if y_min is not None: self.state.y_min = y_min
[docs] def get_limits(self): """Return current viewer axes limits. Returns ------- x_min, x_max, y_min, y_max : float Lower/upper X/Y limits, respectively. """ return self.state.x_min, self.state.x_max, self.state.y_min, self.state.y_max
[docs] def set_tick_format(self, fmt, axis): """ Manually set the tick format of one of the axes. Parameters ---------- fmt : str Format of tick marks. For example, ``'0.1e'`` to set scientific notation or ``'0.2f'`` to turn it off. axis : {x, y} The viewer axis. """ if axis not in ('x', 'y'): raise ValueError("axis must be 'x' or 'y'") # Examples of values for fmt are '0.1e' or '0.2f' axis = {'x': 0, 'y': 1}[axis] self.figure.axes[axis].tick_format = fmt
@property def native_marks(self): """ Return all marks that are Lines/LinesGL objects (and not subclasses) """ return [m for m in self.figure.marks if m.__class__.__name__ in self._native_mark_classnames] @property def custom_marks(self): """ Return all marks that are not Lines/LinesGL objects (but can be subclasses) """ return [m for m in self.figure.marks if m.__class__.__name__ not in self._native_mark_classnames] def _subscribe_to_layers_update(self): # subscribe to new layers self._expected_subset_layers = [] self._layers_with_defaults_applied = [] self.state.add_callback('layers', self._on_layers_update) def _get_layer(self, label): for layer in self.state.layers: if layer.layer.label == label: return layer def _apply_layer_defaults(self, layer_state): # Hide subsets on orientation layers to avoid rendering artifacts. if (layer_state.layer.label != layer_state.layer.data.label and layer_state.layer.data.meta.get(_wcs_only_label, False)): layer_state.visible = False if hasattr(layer_state, 'as_steps'): if layer_state.layer.label != layer_state.layer.data.label: # then this is a subset, so default based on the parent data layer layer_state.as_steps = self._get_layer(layer_state.layer.data.label).as_steps else: # default to not plotting with as_steps (despite glue defaulting to True) layer_state.as_steps = False # whenever as_steps changes, we need to redraw the uncertainties (if enabled) layer_state.add_callback('as_steps', self._show_uncertainty_changed) # set default size for scatter layers (e.g., catalog markers) # based on 1% of the average viewer dimension if isinstance(layer_state, BqplotScatterLayerState): # set the marker default size to 1% of the size of the viewer (average # x and y dimensions to account for viewer not being square), and # fall back on a 10 pt marker size if we can't get the viewer # dimensions for some reason (e.g., no image layers) marker_size = 10 # default fallback default for marker size in points d for layer in self.state.layers: if (hasattr(layer, 'layer') and hasattr(layer.layer, 'data') and hasattr(layer.layer.data, 'shape') and len(layer.layer.data.shape) >= 2): shape = layer.layer.data.shape avg_dimension = (shape[-2] + shape[-1]) / 2 marker_size = avg_dimension * 0.01 # set a lower limit on marker size, which is the glue default # of three pixels. marker_size = max(marker_size, 3) break layer_state.size = marker_size # use echo-validator to ensure visible sets & updates properly in plot options & data menu if (hasattr(layer_state, 'visible') and get_subset_type(layer_state.layer) != 'spatial'): layer_state.layer.visible = CallbackProperty() layer_state.add_callback('layer', self._expected_subset_layer_default, validator=True) def _expected_subset_layer_default(self, layer_state): if self.__class__.__name__ == 'RampvizImageView': # Do not override default for subsets as for some reason # this isn't getting called when they're first added, but rather when # the next state change is made (for example: manually changing the visibility) return # default visibility based on the visibility of the "parent" data layer if self.__class__.__name__ == 'RampvizProfileView': # Rampviz doesn't show subset profiles by default: layer_state.visible = False elif (self.__class__.__name__ == 'CubevizImageView' and get_subset_type(layer_state.layer) != 'spatial'): # set visibility of spectral subsets to false in Cubeviz image-viewers layer_state.visible = False else: layer_state.visible = self._get_layer(layer_state.layer.data.label).visible def _update_layer_icons(self): # update visible_layers (TODO: move this somewhere that can update on color change, etc) def _get_layer_color(layer): if isinstance(layer, BqplotScatterLayerState): # then this could be a scatter layer in an image viewer, # so we'll ignore the color_mode return layer.color if getattr(self.state, 'color_mode', None) == 'Colormaps': for subset in self.jdaviz_app.data_collection.subset_groups: if subset.label == layer.layer.label: # then we still want to show the color for a subset return layer.color # then this is a data-layer in colormap mode, so we'll ignore the color return '' return getattr(layer, 'color', '') def _get_layer_linewidth(layer): linewidth = getattr(layer, 'linewidth', 0) return min(linewidth, 6) def _get_layer_info(layer): if 'Trace' in layer.layer.data.meta: return "mdi-chart-line-stacked", None for subset in self.jdaviz_app.data_collection.subset_groups: if subset.label == layer.layer.label: subset_type = get_subset_type(subset) if subset_type == 'spatial': return "mdi-chart-scatter-plot", subset_type else: return "mdi-chart-bell-curve", subset_type return "", None visible_layers = {} for layer in self.state.layers[::-1]: layer_is_wcs_only = ( hasattr(layer.layer, 'meta') and layer.layer.meta.get(_wcs_only_label, False) ) if layer.visible and not layer_is_wcs_only: prefix_icon, subset_type = _get_layer_info(layer) if ( subset_type == 'spatial' and self.__class__.__name__ in ('CubevizProfileView', 'RampvizProfileView', 'Spectrum1DViewer') ): # do not show spatial subsets in profile viewer continue visible_layers[layer.layer.label] = {'color': _get_layer_color(layer), 'linewidth': _get_layer_linewidth(layer), 'prefix_icon': prefix_icon} self._data_menu.visible_layers = visible_layers @avoid_circular def _on_layers_update(self, layers=None): if self.__class__.__name__ == 'MosvizTableViewer': # MosvizTableViewer uses this as a mixin, but we do not need any of this layer # logic there return viewer_item = self.jdaviz_app._viewer_item_by_id(self.reference_id) if viewer_item is None: return selected_data_items = viewer_item.get('selected_data_items', {}) # update selected_data_items for data_id, visibility in selected_data_items.items(): label = next((x['name'] for x in self.jdaviz_app.state.data_items if x['id'] == data_id), None) visibilities = [] for layer in self.state.layers: if layer.layer.data.label == label: visibilities.append(layer.visible) if np.all(visibilities): selected_data_items[data_id] = 'visible' elif np.any(visibilities): selected_data_items[data_id] = 'mixed' else: selected_data_items[data_id] = 'hidden' self._update_layer_icons() # we'll make a deepcopy so that we can remove entries from the self._expected_subset_layers # to avoid recursion, but also handle multiple layers for the same subset expected_subset_layers = self._expected_subset_layers[:] for layer in self.state.layers: layer_info = {'data_label': layer.layer.data.label, 'layer_label': layer.layer.label} if layer_info not in self._layers_with_defaults_applied: self._layers_with_defaults_applied.append(layer_info) self._apply_layer_defaults(layer) if layer.layer.label in expected_subset_layers: if layer.layer.label in self._expected_subset_layers: self._expected_subset_layers.remove(layer.layer.label) self._expected_subset_layer_default(layer) self.hub.broadcast(ViewerVisibleLayersChangedMessage( viewer_reference=self.reference, visible_layers=selected_data_items, sender=self)) def _on_subset_create(self, msg): from jdaviz.configs.mosviz.plugins.viewers import MosvizTableViewer if isinstance(self, MosvizTableViewer): # MosvizTableViewer uses this as a mixin, but we do not need any of this layer # logic there return # NOTE: the subscription to this method is handled in ConfigHelper # we don't have access to the actual subset yet to tell if its spectral or spatial, so # we'll store the name of this new subset and change the default linewidth when the # layers are added if not hasattr(self, '_expected_subset_layers'): return if msg.subset.label not in self._expected_subset_layers and msg.subset.label: self._expected_subset_layers.append(msg.subset.label) def _on_subset_delete(self, msg): """ This is needed to remove the "ghost" subset left over when the subset tool is active, and the active subset is deleted. https://github.com/spacetelescope/jdaviz/issues/2499 is open to revert/update this if it ends up being addressed upstream in https://github.com/glue-viz/glue-jupyter/issues/401. """ from jdaviz.configs.mosviz.plugins.viewers import MosvizTableViewer if isinstance(self, MosvizTableViewer): # MosvizTableViewer uses this as a mixin, but we do not need any of this layer # logic there return subset_tools = ['bqplot:truecircle', 'bqplot:rectangle', 'bqplot:ellipse', 'bqplot:circannulus', 'bqplot:xrange'] if not len(self.session.edit_subset_mode.edit_subset): if self.toolbar.active_tool_id in subset_tools: if (hasattr(self.toolbar, "default_tool_priority") and len(self.toolbar.default_tool_priority)): self.toolbar.active_tool_id = self.toolbar.default_tool_priority[0] else: self.toolbar.active_tool = None @property def active_image_layer(self): """Active image layer in the viewer, if available.""" # Find visible layers visible_layers = [layer for layer in self.state.layers if (layer.visible and layer_is_image_data(layer.layer) and layer_is_not_dq(layer.layer) and (getattr(layer, 'bitmap_visible', False) or getattr(layer, 'contour_visible', False)))] if len(visible_layers) == 0: return None z_order = [layer.zorder for layer in visible_layers] active_index = np.argmax(z_order) return visible_layers[active_index] @property def active_cube_layer(self): """Active cube layer in the viewer, if available.""" # Find visible layers visible_layers = [layer for layer in self.state.layers if (layer.visible and layer_is_3d(layer.layer) and layer_is_not_dq(layer.layer) and (layer.bitmap_visible or layer.contour_visible))] if len(visible_layers) == 0: return None return visible_layers[-1]
[docs] def initialize_toolbar(self, default_tool_priority=[]): # NOTE: this overrides glue_jupyter.IPyWidgetView self.toolbar = NestedJupyterToolbar(self, self.tools_nested, default_tool_priority)
@property def tools(self): # NOTE: this overrides the default list of tools for the BasicJupyterToolbar by # returning a flattened version of self.tools_nested return list(self.toolbar.tools.keys()) @property def jdaviz_app(self): """The Jdaviz application tied to the viewer.""" return self.session.jdaviz_app @property def jdaviz_helper(self): """The Jdaviz configuration helper tied to the viewer.""" return self.jdaviz_app._jdaviz_helper @property def hub(self): return self.session.hub @property def reference_id(self): return self._reference_id @property def reference(self): return self.jdaviz_app._viewer_item_by_id(self.reference_id).get('reference') @property def _ref_or_id(self): reference = getattr(self, 'reference', None) if reference is not None: return reference return self.reference_id
[docs] def set_plot_axes(self): # individual viewers can override to set custom axes labels/ticks/styling return
class JdavizViewerWindow(TemplateMixin): """ wraps a glue viewer in a single container that also includes the toolbar and data-menu, while redirecting user_api calls to the underlying viewer. """ template_file = __file__, "../../../viewer_window.vue" id = Unicode().tag(sync=True) name = Unicode().tag(sync=True) reference = Unicode().tag(sync=True) config = Unicode().tag(sync=True) figure_widget = Unicode().tag(sync=True) toolbar_widget = Unicode().tag(sync=True) data_menu_widget = Unicode().tag(sync=True) tool_override_mode = Unicode("").tag(sync=True) viewer_destroyed = Bool(False).tag(sync=True) def __init__(self, viewer, *args, reference="", name="", **kwargs): super().__init__(*args, **kwargs) self.glue_viewer = viewer self.config = self._app.config vid = viewer._reference_id self.id = vid self.name = name or vid self.reference = reference or name or vid self.figure_widget = "IPY_MODEL_" + viewer.figure_widget.model_id self.toolbar_widget = "IPY_MODEL_" + viewer.toolbar.model_id if viewer.toolbar else '' self.data_menu_widget = 'IPY_MODEL_' + viewer._data_menu.model_id if hasattr(viewer, '_data_menu') else '' # noqa # Link tool_override_mode from toolbar (only for NestedJupyterToolbar) if viewer.toolbar and hasattr(viewer.toolbar, 'tool_override_mode'): self.tool_override_mode = viewer.toolbar.tool_override_mode viewer.toolbar.observe(self._on_toolbar_override_change, names=['tool_override_mode']) self.hub.subscribe(self, ViewerRemovedMessage, self._on_viewer_removed) def _on_toolbar_override_change(self, change): self.tool_override_mode = change['new'] @property def user_api(self): # expose show methods at this level and redirect others to the underlying viewers from jdaviz.core.user_api import ViewerWindowUserApi return ViewerWindowUserApi(self, expose=['show']) def _on_viewer_removed(self, msg): if msg.viewer_id == self.id: self.viewer_destroyed = True def show(self, loc="inline", title=None, height=None): # pragma: no cover """Display the viewer window UI. Parameters ---------- loc : str The display location determines where to present the viewer UI. Supported locations: "inline": Display the viewer inline in a notebook. "sidecar": Display the viewer in a separate JupyterLab window from the notebook, the location of which is decided by the 'anchor.' right is the default Other anchors: * ``sidecar:right`` (The default, opens a tab to the right of display) * ``sidecar:tab-before`` (Full-width tab before the current notebook) * ``sidecar:tab-after`` (Full-width tab after the current notebook) * ``sidecar:split-right`` (Split-tab in the same window right of the notebook) * ``sidecar:split-left`` (Split-tab in the same window left of the notebook) * ``sidecar:split-top`` (Split-tab in the same window above the notebook) * ``sidecar:split-bottom`` (Split-tab in the same window below the notebook) See `jupyterlab-sidecar <https://github.com/jupyter-widgets/jupyterlab-sidecar>`_ for the most up-to-date options. "popout": Display the viewer in a detached display. By default, a new window will open. Browser popup permissions required. Other anchors: * ``popout:window`` (The default, opens Jdaviz in a new, detached popout) * ``popout:tab`` (Opens Jdaviz in a new, detached tab in your browser) title : str, optional The title of the sidecar tab. Defaults to the name of the viewer. NOTE: Only applicable to a "sidecar" display. height : int, optional The height of the viewer display, in pixels. Only applicable if loc is "inline". Notes ----- If "sidecar" is requested in the "classic" Jupyter notebook, the viewer will appear inline, as only JupyterLab has a mechanism to have multiple tabs. """ title = title if title is not None else self.name show_widget(self, loc=loc, title=title, height=height)
[docs] @viewer_registry("jdaviz-profile-viewer", label="Profile 1D") class JdavizProfileView(JdavizViewerMixin, BqplotProfileView): # categories: zoom resets, zoom, pan, subset, select tools, shortcuts tools_nested = [ ['jdaviz:homezoom', 'jdaviz:prevzoom'], ['jdaviz:boxzoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'], ['jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'], ['bqplot:xrange'], ['jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] ] default_class = NDDataArray _state_cls = FreezableProfileViewerState _default_profile_subset_type = None def __init__(self, *args, **kwargs): default_tool_priority = kwargs.pop('default_tool_priority', []) super().__init__(*args, **kwargs) self._subscribe_to_layers_update() self.initialize_toolbar(default_tool_priority=default_tool_priority) self._offscreen_lines_marks = OffscreenLinesMarks(self) self.figure.marks = self.figure.marks + self._offscreen_lines_marks.marks self.state.add_callback('show_uncertainty', self._show_uncertainty_changed) self.display_mask = False # Change collapse function to sum default_collapse_function = kwargs.pop('default_collapse_function', 'sum') self.state.function = default_collapse_function def _expected_subset_layer_default(self, layer_state): super()._expected_subset_layer_default(layer_state) layer_state.linewidth = 3
[docs] def data(self, cls=None): # Grab the user's chosen statistic for collapsing data statistic = getattr(self.state, 'function', None) data = [] for layer_state in self.state.layers: if hasattr(layer_state, 'layer'): lyr = layer_state.layer # For raw data, just include the data itself if isinstance(lyr, BaseData): _class = cls or self.default_class if _class is not None: cache_key = (lyr.label, statistic) if cache_key in self.jdaviz_app._get_object_cache: layer_data = self.jdaviz_app._get_object_cache[cache_key] else: # If spectrum, collapse via the defined statistic if _class == Spectrum: layer_data = lyr.get_object(cls=_class, statistic=statistic) else: layer_data = lyr.get_object(cls=_class) self.jdaviz_app._get_object_cache[cache_key] = layer_data data.append(layer_data) # For subsets, make sure to apply the subset mask to the layer data first elif isinstance(lyr, Subset): layer_data = lyr if _class is not None: handler, _ = data_translator.get_handler_for(_class) try: layer_data = handler.to_object(layer_data, statistic=statistic) except IncompatibleAttribute: continue data.append(layer_data) return data
[docs] def get_scales(self): fig = self.figure # Deselect any pan/zoom or subsetting tools so they don't interfere # with the scale retrieval if self.toolbar.active_tool is not None: self.toolbar.active_tool = None return {'x': fig.interaction.x_scale, 'y': fig.interaction.y_scale}
def _show_uncertainty_changed(self, msg=None): # this is subscribed in init to watch for changes to the state # object since uncertainty handling is in jdaviz instead of glue/glue-jupyter if self.state.show_uncertainty: self._plot_uncertainties() else: self._clean_error()
[docs] def show_mask(self): self.display_mask = True self._plot_mask()
[docs] def clean(self): # Remove extra traces, in case they exist. self.display_mask = False self._clean_mask() # this will automatically call _clean_error via _show_uncertainty_changed self.state.show_uncertainty = False
def _clean_mask(self): fig = self.figure fig.marks = [x for x in fig.marks if not isinstance(x, ScatterMask)] def _clean_error(self): fig = self.figure fig.marks = [x for x in fig.marks if not isinstance(x, LineUncertainties)]
[docs] def add_data(self, data, color=None, alpha=None, **layer_state): """ Overrides the base class to add markers for plotting uncertainties and data quality flags. Parameters ---------- spectrum : :class:`glue.core.data.Data` Data object with the spectrum. color : obj Color value for plotting. alpha : float Alpha value for plotting. Returns ------- result : bool `True` if successful, `False` otherwise. """ # If this is the first loaded data, set things up for unit conversion. if len(self.layers) == 0: reset_plot_axes = True else: # Check if the new data flux unit is actually compatible since flux not linked. try: if (self.state.y_display_unit not in ['None', None, 'DN'] and hasattr(data.get_component('flux').data, 'units')): psc = data.meta.get('_pixel_scale_factor', None) cube_wave = data.get_component('spectral') eqv = all_flux_unit_conversion_equivs(pixar_sr=psc, cube_wave=cube_wave) flux_conversion_general([1, 1], data.get_component('flux').data.units, self.state.y_display_unit, equivalencies=eqv) except Exception as e: # Raising exception here introduces a dirty state that messes up next load_data # but not raising exception also causes weird behavior unless we remove the data # completely. self.session.hub.broadcast(SnackbarMessage( f"Failed to load {data.label}, so removed it: {repr(e)}", sender=self, color='error', traceback=e)) self.jdaviz_app.data_collection.remove(data) return False reset_plot_axes = False # The base class handles the plotting of the main # trace representing the profile itself. result = super().add_data(data, color, alpha, **layer_state) if reset_plot_axes: x_units = data.get_component(self.state.x_att.label).units y_axis_component = ( 'flux' if 'flux' in [comp.label for comp in self.state.layers[0].layer.components] else 'data' ) y_units = data.get_component(y_axis_component).units with delay_callback(self.state, "x_display_unit", "y_display_unit"): self.state.x_display_unit = x_units if len(x_units) else None self.state.y_display_unit = y_units if len(y_units) else None self.set_plot_axes() self._plot_uncertainties() self._plot_mask() # Set default linewidth on any created spectral subset layers # NOTE: this logic will need updating if we add support for multiple cubes as this assumes # that new data entries (from model fitting or gaussian smooth, etc) will only be spectra # and all subsets affected will be spectral for layer in self.state.layers: if (isinstance(layer.layer, GroupedSubset) and get_subset_type(layer.layer) == self._default_profile_subset_type and layer.layer.data.label == data.label): layer.linewidth = 3 return result
def _plot_mask(self): if not self.display_mask: return # Remove existing mask marks self._clean_mask() # Loop through all active data in the viewer for index, layer_state in enumerate(self.state.layers): lyr = layer_state.layer comps = [str(component) for component in lyr.components] # Skip subsets if hasattr(lyr, "subset_state"): continue # Ignore data that does not have a mask component if "mask" in comps: mask = np.array(lyr['mask'].data) data_obj = lyr.data.get_object(cls=self.default_class) if self.default_class == Spectrum: data_x = data_obj.spectral_axis.value data_y = data_obj.flux.value else: data_x = np.arange(data_obj.shape[-1]) data_y = data_obj.data.value # For plotting markers only for the masked data # points, erase un-masked data from trace. y = np.where(np.asarray(mask) == 0, np.nan, data_y) # A subclass of the bqplot Scatter object, ScatterMask places # 'X' marks where there is masked data in the viewer. color = layer_state.color alpha_shade = layer_state.alpha / 3 mask_line_mark = ScatterMask(scales=self.scales, marker='cross', x=data_x, y=y, stroke_width=0.5, colors=[color], default_size=25, default_opacities=[alpha_shade] ) # Add mask marks to viewer self.figure.marks = list(self.figure.marks) + [mask_line_mark] def _plot_uncertainties(self): if not self.state.show_uncertainty: return # Remove existing error bars self._clean_error() # Loop through all active data in the viewer for index, layer_state in enumerate(self.state.layers): lyr = layer_state.layer # Skip subsets if hasattr(lyr, "subset_state"): continue comps = [str(component) for component in lyr.components] # Ignore data that does not have an uncertainty component if "uncertainty" in comps: # noqa error = np.array(lyr['uncertainty'].data) # ensure that the uncertainties are represented as stddev: uncertainty_type_str = lyr.meta.get('uncertainty_type', 'stddev') uncert_cls = uncertainty_str_to_cls_mapping[uncertainty_type_str] error = uncert_cls(error).represent_as(StdDevUncertainty).array if 'spectral_axis_index' in lyr.data.meta: spectral_axis_index = lyr.data.meta['spectral_axis_index'] else: # We have to make an assumption in this case. # TODO: Should we have other handling for non-spectral data (e.g. light curves?) spectral_axis_index = -1 data_obj = lyr.data.get_object(cls=self.default_class, statistic=None) lyr_coords = lyr.data.coords if isinstance(lyr_coords, SpectralCoordinates): spectral_wcs = lyr_coords data_x = spectral_wcs.pixel_to_world_values( np.arange(lyr.data.shape[spectral_axis_index]) ) if isinstance(data_x, tuple): data_x = data_x[0] else: if hasattr(lyr_coords, 'spectral_wcs'): spectral_wcs = lyr_coords.spectral_wcs elif hasattr(lyr_coords, 'spectral'): spectral_wcs = lyr_coords.spectral elif hasattr(lyr_coords, "world_n_dim") and lyr_coords.world_n_dim == 1: # 1D GWCS in this case, just use the coords spectral_wcs = lyr_coords data_x = spectral_wcs.pixel_to_world( np.arange(lyr.data.shape[spectral_axis_index]) ) data_y = data_obj.data # The shaded band around the spectrum trace is bounded by # two lines, above and below the spectrum trace itself. data_x_list = np.ndarray.tolist(data_x) x = [data_x_list, data_x_list] y = [np.ndarray.tolist(data_y - error), np.ndarray.tolist(data_y + error)] if layer_state.as_steps: for i in (0, 1): a = np.insert(x[i], 0, 2*x[i][0] - x[i][1]) b = np.append(x[i], 2*x[i][-1] - x[i][-2]) edges = (a + b) / 2 x[i] = np.concatenate((edges[:1], np.repeat(edges[1:-1], 2), edges[-1:])) y[i] = np.repeat(y[i], 2) x, y = np.asarray(x), np.asarray(y) # A subclass of the bqplot Lines object, LineUncertainties keeps # track of uncertainties plotted in the viewer. LineUncertainties # appear with two lines and shaded area in between. color = layer_state.color alpha_shade = layer_state.alpha / 3 error_line_mark = LineUncertainties(viewer=self, x=[x], y=[y], scales=self.scales, stroke_width=1, colors=[color, color], fill_colors=[color, color], opacities=[0.0, 0.0], fill_opacities=[alpha_shade, alpha_shade], fill='between', close_path=False ) # Add error lines to viewer self.figure.marks = list(self.figure.marks) + [error_line_mark]
[docs] def set_plot_axes(self): # Set x and y axes labels for the spectrum viewer y_display_unit = self.state.y_display_unit y_unit = ( u.Unit(y_display_unit) if y_display_unit and y_display_unit != 'None' else u.dimensionless_unscaled ) # Get local units. locally_defined_flux_units = [ u.Jy, u.mJy, u.uJy, u.MJy, u.W / (u.m**2 * u.Hz), u.eV / (u.s * u.m**2 * u.Hz), u.erg / (u.s * u.cm**2), u.erg / (u.s * u.cm**2 * u.Angstrom), u.erg / (u.s * u.cm**2 * u.Hz), u.ph / (u.s * u.cm**2 * u.Angstrom), u.ph / (u.s * u.cm**2 * u.Hz), u.bol, u.AB, u.ST ] # get square angle from 'sb' display unit sb_unit = self.jdaviz_app._get_display_unit(axis='sb') if sb_unit is not None: solid_angle_unit = check_if_unit_is_per_solid_angle(sb_unit, return_unit=True) else: solid_angle_unit = None # if solid angle is present in denominator, check physical type of numerator # if numerator is a flux type the display unit is a 'surface brightness', otherwise # default to the catchall 'flux density' label flux_unit_type = None for un in locally_defined_flux_units: locally_defined_sb_unit = un / solid_angle_unit if solid_angle_unit is not None else None # noqa # create an equivalency for each flux unit for flux <> flux/pix2. # for similar reasons to the 'untranslatable units' issue, custom # equivs. can't be combined, so a workaround is creating an eqiv # for each flux that may need an additional equiv. angle_to_pixel_equiv = _eqv_sb_per_pixel_to_per_angle(un) if (locally_defined_sb_unit is not None and y_unit.is_equivalent(locally_defined_sb_unit, angle_to_pixel_equiv)): flux_unit_type = "Surface Brightness" elif y_unit.is_equivalent(un): flux_unit_type = 'Flux' elif y_unit.is_equivalent(u.electron / u.s) or y_unit.physical_type == 'dimensionless': # noqa # electron / s or 'dimensionless_unscaled' should be labeled counts flux_unit_type = "Counts" elif y_unit.is_equivalent(u.W): flux_unit_type = "Luminosity" if flux_unit_type is not None: # if we determined a label, stop checking break else: # default to Flux Density for flux density or uncaught types flux_unit_type = "Flux density" # Set x axes labels for the spectrum viewer x_disp_unit = self.state.x_display_unit x_unit = u.Unit(x_disp_unit) if x_disp_unit else u.dimensionless_unscaled if x_unit.is_equivalent(u.m): spectral_axis_unit_type = "Wavelength" elif x_unit.is_equivalent(u.Hz): spectral_axis_unit_type = "Frequency" elif x_unit.is_equivalent(u.pixel): spectral_axis_unit_type = "Pixel" elif x_unit.is_equivalent(u.dimensionless_unscaled): # case for rampviz spectral_axis_unit_type = "Group" else: spectral_axis_unit_type = str(x_unit.physical_type).title() with self.figure.hold_sync(): self.figure.axes[0].label = f"{spectral_axis_unit_type}" + ( f" [{self.state.x_display_unit}]" if self.state.x_display_unit not in ["None", None] else "" ) self.figure.axes[1].label = f"{flux_unit_type}" + ( f"[{self.state.y_display_unit}]" if self.state.y_display_unit not in ["None", None] else "" ) # Make it so axis labels are not covering tick numbers. self.figure.fig_margin["left"] = 95 self.figure.fig_margin["bottom"] = 60 self.figure.send_state('fig_margin') # Force update self.figure.axes[0].label_offset = "40" self.figure.axes[1].label_offset = "-70" # NOTE: with tick_style changed below, the default responsive ticks in bqplot result # in overlapping tick labels. For now we'll hardcode at 8, but this could be removed # (default to None) if/when bqplot auto ticks react to styling options. self.figure.axes[1].num_ticks = 8 # Set Y-axis to scientific notation self.figure.axes[1].tick_format = '0.1e' for i in (0, 1): self.figure.axes[i].tick_style = {'font-size': 15, 'font-weight': 600}
@viewer_registry("scatter-viewer", label="scatter") class ScatterViewer(JdavizViewerMixin, BqplotScatterView): def __init__(self, session, *args, **kwargs): super().__init__(session, *args, **kwargs) # make y axis scientific notation for readibility self.figure.axes[1].tick_format = '0.1e' # offset y axis label so it doesn't cover y tick labels self.figure.axes[1].label_offset = "-50" # make axis labels a little smaller so that they're the same size, but the # y axis label fits self.figure.axes[0].axis_label_style = {'font-size': '12px'} # x self.figure.axes[1].axis_label_style = {'font-size': '12px'} # y # categories: zoom resets, zoom, pan, subset, select tools, shortcuts tools_nested = [ ['jdaviz:homezoom', 'jdaviz:prevzoom'], ['jdaviz:boxzoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'], ['jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'], ['bqplot:xrange', 'bqplot:yrange', 'bqplot:rectangle'], [], ['jdaviz:viewer_clone', 'jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] ] _state_cls = ScatterViewerState _native_mark_classnames = ('Scatter', 'ScatterGL') @viewer_registry("histogram-viewer", label="histogram") class HistogramViewer(JdavizViewerMixin, BqplotHistogramView): # categories: zoom resets, zoom, pan, subset, select tools, shortcuts tools_nested = [ ['jdaviz:homezoom', 'jdaviz:prevzoom'], ['jdaviz:boxzoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'], ['jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'], ['bqplot:xrange'], [], ['jdaviz:viewer_clone', 'jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] ] _state_cls = HistogramViewerState _native_mark_classnames = ('Bars', 'BarsGL') @viewer_registry("table-viewer", label="table") class JdavizTableViewer(JdavizViewerMixin, TableViewer): # categories: zoom resets, zoom, pan, subset, select tools, shortcuts tools_nested = [ ['jdaviz:table_highlight_selected'], ['jdaviz:table_zoom_to_selected'], ['jdaviz:table_subset'], ['jdaviz:viewer_clone'] ] def __init__(self, session, *args, **kwargs): super().__init__(session, *args, **kwargs) # enable scrolling: # https://github.com/glue-viz/glue-jupyter/pull/287 self.widget_table.scrollable = True # hide checkboxes by default (shown when TableSubset tool is activated) self.widget_table.selection_enabled = False self.data_menu._obj.dataset.add_filter('is_catalog') self.widget_table.observe(lambda _: self.toolbar._update_tool_visibilities(), names=['checked']) # Also update selection highlight marks when checked rows change self.widget_table.observe(self._on_checked_changed, names=['checked']) self.widget_table.observe(self._on_selection_enabled_changed, names=['selection_enabled']) # Subscribe to RestoreToolbarMessage to clean up checkbox state # when toolbar is restored (e.g., by clicking X on custom toolbar) self.hub.subscribe(self, RestoreToolbarMessage, handler=self._on_restore_toolbar) # Subscribe to TableSelectRowClickMessage to handle clicks from image viewers self.hub.subscribe(self, TableSelectRowClickMessage, handler=self._on_table_select_row_click) # Subscribe to ViewerRemovedMessage to clean up toolbar overrides # if this table viewer is removed while tools are active self.hub.subscribe(self, ViewerRemovedMessage, handler=self._on_viewer_removed) def _on_table_select_row_click(self, msg): """Handle click from image viewer to select/toggle closest table row.""" # Only respond if this message is for this table viewer if msg.table_viewer_id != self.reference_id: return if not len(self.layers): return # Get pixel coordinates from the message (these are in reference data frame) click_x, click_y = msg.x, msg.y try: layer = self.layers[0].layer # Get sky coordinates for WCS-accurate comparison. # Click coordinates are in the viewer's reference frame, so catalog # coordinates must also be converted to that frame for proper matching. xs, ys = None, None skycoords = _get_skycoords_from_table(layer) if skycoords is not None: # Convert sky coordinates to pixels in the viewer's reference frame for viewer in self.jdaviz_app.get_viewers_of_cls('ImvizImageView'): if viewer.state.reference_data is None: continue if viewer.state.reference_data.coords is None: continue pixel_result = viewer.state.reference_data.coords.world_to_pixel(skycoords) xs, ys = pixel_result[0], pixel_result[1] break else: # Fall back to pixel coordinates only if no sky coordinates available pixel_coords = _get_pixel_coords_from_table(layer) if pixel_coords is not None: xs, ys = pixel_coords if xs is None or ys is None: return # Find nearest point and toggle its selection distsq = (xs - click_x)**2 + (ys - click_y)**2 ind = int(np.argmin(distsq)) current_checked = list(self.widget_table.checked) if ind in current_checked: current_checked.remove(ind) else: current_checked.append(ind) self.widget_table.checked = current_checked except Exception: # nosec # pragma: no cover pass def _on_checked_changed(self, change): """Update highlight marks in image viewers when checked rows change.""" self._update_selection_marks() def _on_selection_enabled_changed(self, change): """Show/hide selection marks when selection is enabled/disabled.""" if not change['new']: # Selection disabled, clear all marks self._clear_selection_marks() else: # Selection enabled, update marks self._update_selection_marks() def _get_selection_mark(self, viewer): """Get or create a selection highlight mark for the given viewer.""" matches = [mark for mark in viewer.figure.marks if isinstance(mark, TableSelectionMark)] if len(matches): return matches[0] mark = TableSelectionMark(viewer) viewer.figure.marks = viewer.figure.marks + [mark] return mark def _update_selection_marks(self): """Update selection highlight marks in all image viewers.""" if not self.widget_table.selection_enabled: return checked_rows = self.widget_table.checked if not len(checked_rows) or not len(self.layers): self._clear_selection_marks() return layer = self.layers[0].layer # Get sky coordinates for WCS-accurate placement across different images. # Pixel coordinates from the catalog are in the catalog's original image frame, # which may differ from the viewer's reference data frame. skycoords = _get_skycoords_from_table(layer, checked_rows) pixel_coords = None if skycoords is None: # Fall back to pixel coordinates only if no sky coordinates available pixel_coords = _get_pixel_coords_from_table(layer, checked_rows) if pixel_coords is None: self._clear_selection_marks() return # Update marks in all image viewers for viewer in self.jdaviz_app.get_viewers_of_cls('ImvizImageView'): try: if skycoords is not None: # Convert sky coordinates to pixels for this viewer's reference frame coords = viewer.state.reference_data.coords.world_to_pixel(skycoords) xs, ys = coords[0], coords[1] else: # Use pixel coordinates directly (last resort when no sky coords) xs, ys = pixel_coords[0], pixel_coords[1] mark = self._get_selection_mark(viewer) mark.update_xy(xs, ys) mark.visible = True except Exception: # nosec # pragma: no cover pass def _clear_selection_marks(self): """Clear selection highlight marks from all image viewers.""" for viewer in self.jdaviz_app.get_viewers_of_cls('ImvizImageView'): for mark in viewer.figure.marks: if isinstance(mark, TableSelectionMark): mark.visible = False def _on_restore_toolbar(self, msg={}): """Clean up checkbox state when toolbar is restored.""" # Clear selection marks self._clear_selection_marks() # Hide checkboxes (they should always be hidden when default toolbar is shown) self.widget_table.selection_enabled = False def _on_viewer_removed(self, msg): """Clean up selection marks if this table viewer is removed.""" if msg.viewer_id != self.reference_id: return # Clear selection marks in image viewers when this table viewer is removed # (toolbar cleanup is handled generically by NestedJupyterToolbar) self._clear_selection_marks()