Source code for jdaviz.configs.default.plugins.data_quality.data_quality

from itertools import chain
import os
from traitlets import Any, Dict, Bool, List, Unicode, Float, observe

import numpy as np
from glue_jupyter.common.toolbar_vuetify import read_icon
from echo import delay_callback
from matplotlib.colors import hex2color

from jdaviz.core.registries import tray_registry
from jdaviz.core.template_mixin import (
    PluginTemplateMixin, LayerSelect, ViewerSelectMixin, _is_image_viewer as is_image_viewer
)
from jdaviz.core.tools import ICON_DIR
from jdaviz.core.user_api import PluginUserApi
from jdaviz.configs.default.plugins.data_quality.dq_utils import (
    decode_flags, generate_listed_colormap, dq_flag_map_paths, load_flag_map
)


__all__ = ['DataQuality']

telescope_names = {
    "jwst": "JWST",
    "roman": "Roman",
    "hst-stis": "HST/STIS",
    "hst-acs": "HST/ACS",
    "hst-wfc3-uvis": "HST/WFC3-UVIS",
    "hst-wfc3-ir": "HST/WFC3-IR",
    "hst-cos": "HST/COS",
}


[docs] @tray_registry('g-data-quality', label="Data Quality", category="data:analysis") class DataQuality(PluginTemplateMixin, ViewerSelectMixin): """ See the :ref:`Data Quality Plugin Documentation <imviz-data-quality>` for more details. Only the following attributes and methods are available through the :ref:`public plugin API <plugin-apis>`: * :meth:`~jdaviz.core.template_mixin.PluginTemplateMixin.show` * :meth:`~jdaviz.core.template_mixin.PluginTemplateMixin.open_in_tray` * :meth:`~jdaviz.core.template_mixin.PluginTemplateMixin.close_in_tray` * ``science_layer`` (:class:`~jdaviz.core.template_mixin.LayerSelect`) * ``dq_layer`` (:class:`~jdaviz.core.template_mixin.LayerSelect`): data quality layer corresponding to the science data in ``science_layer`` * ``dq_layer_opacity``: Opacity of the data quality layer. * ``decoded_flags``: List of decoded flags from the selected flag map. * ``flags_filter``: List of flags to display. * ``flag_map_definitions_selected``: Dictionary of the selected flag map. """ template_file = __file__, "data_quality.vue" # `layer` is the science data layer science_layer_multiselect = Bool(False).tag(sync=True) science_layer_items = List().tag(sync=True) science_layer_selected = Any().tag(sync=True) # Any needed for multiselect # `dq_layer` is the data quality layer corresponding to the # science data in `science_layer` dq_layer_multiselect = Bool(False).tag(sync=True) dq_layer_items = List().tag(sync=True) dq_layer_selected = Any().tag(sync=True) # Any needed for multiselect dq_layer_opacity = Float(0.9).tag(sync=True) # Any needed for multiselect flag_map_definitions = Dict().tag(sync=True) flag_map_selected = Any().tag(sync=True) flag_map_definitions_selected = Dict().tag(sync=True) flag_map_items = List().tag(sync=True) decoded_flags = List().tag(sync=True) flags_filter = List().tag(sync=True) icons = Dict().tag(sync=True) icon_radialtocheck = Unicode(read_icon(os.path.join(ICON_DIR, 'radialtocheck.svg'), 'svg+xml')).tag(sync=True) # noqa icon_checktoradial = Unicode(read_icon(os.path.join(ICON_DIR, 'checktoradial.svg'), 'svg+xml')).tag(sync=True) # noqa def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # description displayed under plugin title in tray self._plugin_description = 'Data Quality layer visualization options.' self.icons = {k: v for k, v in self._app.state.icons.items()} self.science_layer = LayerSelect( self, 'science_layer_items', 'science_layer_selected', 'viewer_selected', 'science_layer_multiselect', is_root=True, has_children=True ) self.dq_layer = LayerSelect( self, 'dq_layer_items', 'dq_layer_selected', 'viewer_selected', 'dq_layer_multiselect', is_root=False, is_child_of=self.science_layer.selected ) self.dq_layer.add_filter('is_dq_layer') self.load_default_flag_maps() self.init_decoding() self._set_irrelevant() self._update_available_viewers() if self.config == 'deconfigged': self.observe_traitlets_for_relevancy(traitlets_to_observe=['dq_layer_items']) def _update_available_viewers(self): if not hasattr(self, 'viewer'): return viewer_filter_names = [filt.__name__ for filt in self.viewer.filters] if 'is_image_viewer' not in viewer_filter_names: self.viewer.add_filter(is_image_viewer) # by default, select all image viewers to sync each DQ layer's # options across all viewers that they're visible in: self.viewer_multiselect = True self.viewer.select_all() @observe('dq_layer_items') def _set_irrelevant(self, *args): self._update_available_viewers() children_available = any([ len(assoc['children']) > 0 for label, assoc in getattr(self._app, '_data_associations', {}).items() ]) self.irrelevant_msg = ( '' if children_available else "No Data Quality layers available." ) @observe('science_layer_selected') def update_dq_layer(self, *args): if not hasattr(self, 'dq_layer'): return self.dq_layer.filter_is_child_of = self.science_layer_selected self.dq_layer._update_items() # listen for changes on the image opacity, and update the # data quality layer opacity on changes to the science layer opacity plot_options = self._app.get_tray_item_from_name('g-plot-options') plot_options.observe(self.update_opacity, 'image_opacity_value')
[docs] def load_default_flag_maps(self): for name in dq_flag_map_paths: self.flag_map_definitions[name] = load_flag_map(name) self.flag_map_items = self.flag_map_items + [telescope_names[name]]
@property def dq_layer_selected_flattened(self): if not hasattr(self, 'dq_layer'): return [] selected_dq = self.dq_layer.selected_obj if not len(selected_dq): return [] if isinstance(selected_dq, list) and isinstance(selected_dq[0], list): # flatten a nested list: selected_dq = list(chain.from_iterable(selected_dq)) elif not isinstance(selected_dq, list): # if it's a single layer, make it a list: selected_dq = [selected_dq] return selected_dq @property def unique_flags(self): selected_dq = self.dq_layer_selected_flattened if selected_dq is None or not len(selected_dq): return [] dq = selected_dq[0].get_image_data() return np.unique(dq[~np.isnan(dq)]) @property def validate_flag_decode_possible(self): return ( self.flag_map_selected is not None and len(self.dq_layer.selected_obj) > 0 and len(self.unique_flags) > 0 ) @observe('flag_map_selected') def update_flag_map_definitions_selected(self, event): flag_map_key = self.flag_map_selected.lower().replace('/', '-') selected = self.flag_map_definitions[flag_map_key] self.flag_map_definitions_selected = selected # clear decoded_flags with a meaningless one: self.init_decoding() self._update_cmap() @observe('dq_layer_selected') def init_decoding(self, event={}, viewers=None): if not self.validate_flag_decode_possible: return unique_flags = self.unique_flags cmap, rgba_colors = generate_listed_colormap(n_flags=len(unique_flags)) self.decoded_flags = decode_flags( flag_map=self.flag_map_definitions_selected, unique_flags=unique_flags, rgba_colors=rgba_colors ) self.send_state('decoded_flags') dq_layers = self.get_dq_layers(viewers=viewers) for dq_layer in dq_layers: dq_layer.composite._allow_bad_alpha = True # for cubeviz, also change uncert-viewer defaults to # map the out-of-bounds regions to the cmap's `bad` color: if self._app.config in ('cubeviz', 'rampviz'): viewer = self._app.get_viewer( getattr( self._app._jdaviz_helper, '_default_uncert_viewer_reference_name', 'level-2' ) ) for layer in viewer.layers: # allow bad alpha for image layers, not subsets: if not hasattr(layer, 'subset_array'): layer.composite._allow_bad_alpha = True layer.force_update() flag_bits = np.array([flag['flag'] for flag in self.decoded_flags]) dq_layer.state.stretch = 'lookup' stretch_object = dq_layer.state.stretch_object stretch_object.flags = flag_bits with delay_callback(dq_layer.state, 'alpha', 'cmap', 'v_min', 'v_max'): if len(flag_bits): dq_layer.state.v_min = min(flag_bits) dq_layer.state.v_max = max(flag_bits) dq_layer.state.alpha = self.dq_layer_opacity dq_layer.state.cmap = cmap
[docs] def get_dq_layers(self, viewers=None): if self.dq_layer_selected == '': return if viewers is None: viewers = self.viewer.selected_obj if not hasattr(viewers, '__len__'): viewers = [viewers] dq_layers = [ layer for viewer in viewers for layer in viewer.layers if layer.layer.label == self.dq_layer_selected ] return dq_layers
[docs] def get_science_layers(self, viewers=None): if viewers is None: viewers = self.viewer.selected_obj if not hasattr(viewers, '__len__'): viewers = [viewers] science_layers = [ layer for viewer in viewers for layer in viewer.layers if layer.layer.label == self.science_layer_selected ] return science_layers
@observe('dq_layer_opacity') def update_opacity(self, event={}): science_layers = self.get_science_layers() selected_dq = self.dq_layer_selected_flattened if len(selected_dq): for sci_layer, dq_layer in zip(science_layers, selected_dq): # DQ opacity is a fraction of the science layer's opacity: dq_layer.state.alpha = self.dq_layer_opacity * sci_layer.state.alpha @observe('decoded_flags', 'flags_filter') def _update_cmap(self, event={}, viewers=None): dq_layers = self.get_dq_layers(viewers=viewers) if dq_layers is None: return flag_bits = np.array([flag['flag'] for flag in self.decoded_flags]) rgb_colors = [hex2color(flag['color']) for flag in self.decoded_flags] hidden_flags = np.array([ flag['flag'] for flag in self.decoded_flags # hide the flag if the visibility toggle is False: if not flag['show'] or # hide the flag if `flags_filter` has entries but not this one: ( len(self.flags_filter) and not np.isin( list(map(int, flag['decomposed'].keys())), list(self.flags_filter) ).any() ) ]) for dq_layer in dq_layers: with delay_callback( dq_layer.state, 'v_min', 'v_max', 'alpha', 'stretch', 'cmap' ): # set correct stretch and limits: # dq_layer.state.stretch = 'lookup' stretch_object = dq_layer.state.stretch_object stretch_object.flags = flag_bits stretch_object.dq_array = dq_layer.get_image_data() stretch_object.hidden_flags = hidden_flags # update the colors of the listed colormap without # reassigning the layer.state.cmap object cmap = dq_layer.state.cmap cmap.colors = rgb_colors cmap._init() # trigger updates to cmap in viewer: dq_layer.update() if len(flag_bits): dq_layer.state.v_min = min(flag_bits) dq_layer.state.v_max = max(flag_bits) dq_layer.state.alpha = self.dq_layer_opacity
[docs] def update_visibility(self, index): self.decoded_flags[index]['show'] = not self.decoded_flags[index]['show'] self.vue_update_cmap()
[docs] def vue_update_cmap(self): self.send_state('decoded_flags') self._update_cmap()
[docs] def vue_update_visibility(self, index): self.update_visibility(index)
[docs] def update_color(self, index, color): self.decoded_flags[index]['color'] = color self.vue_update_cmap()
[docs] def vue_update_color(self, args): index, color = args self.update_color(index, color)
@observe('science_layer_selected') def mission_or_instrument_from_meta(self, event): if not hasattr(self, 'science_layer'): return layer = self.get_science_layers() if not len(layer): return # this is defined for JWST and ROMAN, should be upper case: telescope = layer[0].layer.meta.get('telescope', None) if telescope is None: # for spectral cubes in Cubeviz: telescope = layer[0].layer.meta.get('_primary_header', {}).get('TELESCOP', None) if telescope is not None: primary_header = layer[0].layer.meta.get('_primary_header', {}) if telescope == 'HST': instrument = primary_header.get('INSTRUME', None) if instrument == 'WFC3': detector = primary_header.get('DETECTOR', None) else: detector = None telescope = '-'.join( i for i in [telescope, instrument, detector] if i is not None ) flag_map_to_select = telescope_names.get(telescope.lower()) self.flag_map_selected = flag_map_to_select
[docs] def vue_hide_all_flags(self, event): for flag in self.decoded_flags: flag['show'] = False self.vue_update_cmap()
[docs] def vue_clear_flags_filter(self, event): self.flags_filter = [] self.vue_update_cmap()
[docs] def vue_show_all_flags(self, event): for flag in self.decoded_flags: flag['show'] = True self.flags_filter = [] self.vue_update_cmap()
@property def user_api(self): return PluginUserApi( self, expose=( 'science_layer', 'dq_layer', 'decoded_flags', 'flags_filter', 'dq_layer_opacity', 'flag_map_definitions_selected', ) )