Source code for mjolnir.plotting.live_view

"""This module provides several live-plotting classes that use
:mod:`qutil:qutil.plotting.live_view`.
"""
from __future__ import annotations

import abc
import queue
import threading
import time
from collections import deque
from contextlib import nullcontext
from typing import Literal, Mapping, Any, ContextManager, Sequence

import matplotlib as mpl
import numpy as np
from matplotlib import cm, colors
from matplotlib.axes import Axes
from numpy import typing as npt
from qcodes import Station, logger
from qcodes.plotting import find_scale_and_prefix
from qcodes_contrib_drivers.drivers.SwabianInstruments.Swabian_Instruments_Time_Tagger import (
    TimeTagger, CounterMeasurement
)
from qutil import functools, itertools, const, misc
from qutil.plotting import live_view
from qutil.plotting.live_view import DataSource, TimerT, LiveViewT, StyleT, ScaleT, ArtistT

GRAY_CMAP = cm.gray.with_extremes(bad=colors.to_rgba('tab:red', 0.3))


class _QcodesMixin(metaclass=abc.ABCMeta):

    def attach(self, data_source: DataSource = None, event_source: TimerT | None = None,
               **data_source_kwargs):
        if data_source is None:
            data_source = self._put_data
        super().attach(data_source, **data_source_kwargs)

    @abc.abstractmethod
    def _put_data(self: LiveViewT, data_queue: queue.Queue, stop_event: threading.Event, **kwargs):
        pass


class _CcdMixin(_QcodesMixin, metaclass=abc.ABCMeta):

    def __init__(self: LiveViewT, station: Station, exposure_time: float,
                 subtract_background: bool = True, data_rate: bool = True,
                 light_source: Literal['laser', 'white_light'] = 'laser',
                 horizontal_scale: Literal['energy', 'wavelength'] = 'wavelength',
                 update_interval_ms: int = int(1e3 / 60), autoscale: bool | None = True,
                 autoscale_interval_ms: int | None = 1000, show_fps: bool = False,
                 useblit: bool = True, fig_kw: Mapping[str, Any] | None = None,
                 log_level: int | None = None, **kwargs):

        self.station = station

        units = {self._data_axis: 'cps' if data_rate else 'cts'}

        match horizontal_scale:
            case 'energy':
                xlabel = '$E$'
                units['x'] = 'eV'
                self._xlabel_text2 = r'$\lambda$ ({{{}}}{})'.format('', 'nm')
            case 'wavelength':
                xlabel = r'$\lambda$'
                units['x'] = 'nm'
                self._xlabel_text2 = '$E$ ({{{}}}{})'.format('', 'eV')
            case _:
                raise ValueError('horizontal_scale')

        label_kwarg = {f'{self._data_axis}label': 'CCD Signal'}
        if autoscale is True:
            autoscale = self._data_axis

        super().__init__(self._put_data, update_interval_ms=update_interval_ms,
                         autoscale=autoscale, autoscale_interval_ms=autoscale_interval_ms,
                         show_fps=show_fps, useblit=useblit, blocking_queue=True, xlabel=xlabel,
                         units=units | self._units, fig_kw=fig_kw, exposure_time=exposure_time,
                         subtract_background=subtract_background, data_rate=data_rate,
                         light_source=light_source, log_level=log_level, **label_kwarg, **kwargs)

    @property
    @abc.abstractmethod
    def read_mode(self) -> str:
        pass

    @property
    @abc.abstractmethod
    def _data_axis(self) -> str:
        pass

    @property
    @abc.abstractmethod
    def _read_mode_settings_context(self) -> ContextManager:
        pass

    @functools.cached_property
    @abc.abstractmethod
    def _axes(self) -> tuple[npt.NDArray, ...]:
        pass

    @property
    def _units(self) -> dict[Literal['x', 'y', 'c'], str]:
        """Custom unit override."""
        return {}

    @property
    def ccd(self: LiveViewT):
        return self.station.detection_path.ccd

    def _add_axes(self: LiveViewT, ax=None):
        super()._add_axes(ax)

        self.axes['main'].margins(x=0)

        if self.units['x'] == 'nm':
            functions = (functools.scaled(1e+9)(_lambda2eV),
                         functools.scaled(1e-9)(_eV2lambda))
        elif self.units['x'] == 'eV':
            functions = (functools.scaled(1e+9)(_eV2lambda),
                         functools.scaled(1e-9)(_lambda2eV))
        else:
            raise RuntimeError("Something's wrong with the horizontal axis")

        if (ax_secondary := live_view.find_artist_with_label(self.fig.axes, 'secondary')) is None:
            ax_secondary = self.axes['main'].secondary_xaxis('top', functions=functions,
                                                             label='secondary')

        self.axes['secondary'] = ax_secondary

    def _add_static_artists(self: LiveViewT):
        super()._add_static_artists()
        self.static_artists['xlabel2'] = self.axes['secondary'].set_xlabel(
            self._xlabel_text2.format('')
        )

    def _update(self: LiveViewT, frame) -> set[ArtistT]:
        if self._first_frame:
            self.axes['main'].set_xlim(self.ccd.horizontal_axis.get()[[0, -1]])
            self.axes['main'].set_autoscalex_on(True)

        return super()._update(frame)

    def _put_data(self: LiveViewT,
                  data_queue: queue.Queue[tuple[npt.NDArray[np.float64 | np.int_], np.int32]],
                  stop_event: threading.Event,
                  *,
                  exposure_time: float = 0.1,
                  subtract_background: bool = True,
                  data_rate: bool = True,
                  light_source: Literal['laser', 'white_light'] = 'laser'):

        scale = 'wavelength' if self.units['x'] == 'nm' else 'energy'
        if light_source == 'laser' and not self.station.excitation_path.pump_laser.enabled():
            # Don't turn on laser if it wasn't on before.
            light_source = None

        with (
            self.station.excitation_path.active_light_source.set_to(light_source,
                                                                    allow_changes=True),
            self.station.detection_path.ccd_horizontal_axis_scale.set_to(scale),
            self.ccd.exposure_time.set_to(exposure_time),
            self.ccd.cosmic_ray_filter_mode.set_to(False),
            self.ccd.read_mode.set_to(self.read_mode),
            self._read_mode_settings_context,
            logger.filter_instrument(self.ccd, level='ERROR')
        ):

            if subtract_background:
                if not self.ccd.background_is_valid:
                    with (
                        self.station.excitation_path.active_light_source.set_to(None),
                        self.ccd.acquisition_mode.set_to('single scan')
                    ):
                        self._LOG.info('Acquiring background.')
                        background = self.ccd.background.get()
                else:
                    background = self.ccd.background.get_latest()

            acquisition_timings = self.ccd.get_acquisition_timings()
            producer = self.ccd.yield_till_abort()

            while not stop_event.is_set():
                data = next(producer)
                # XXX: data has a time axis of size 1
                data = data[0]
                if subtract_background:
                    data -= background
                if data_rate:
                    data = data / acquisition_timings.exposure_time
                data_queue.put(self._axes + (data,))

    def _on_close(self: LiveViewT, event=None):
        self.ccd.cancel_wait()
        self.ccd.abort_acquisition()
        super()._on_close(event)


[docs] class SpectrumLiveView(_CcdMixin, live_view.BatchedLiveView1D): """Show a video of CCD live spectra. Parameters ---------- station : The logical station in which the :class:`~instruments.logical_instruments.ExcitationPath` instrument lives. exposure_time : The (maximum) exposure time. subtract_background : Subtract a background image from the data before displaying. data_rate : Plot the count rate instead of the total number of counts. light_source : Use the white light or the laser. horizontal_scale : Show energy or wavelength on the bottom axis. read_mode : Use full vertical binning (default) or single track mode. single_track_settings : Applies only if `read_mode` is single track. """ DEFAULT_FIGSIZE = (885, 630) def __init__( self, station: Station, exposure_time: float = 0.1, subtract_background: bool = True, data_rate: bool = True, light_source: Literal['laser', 'white_light'] = 'laser', horizontal_scale: Literal['energy', 'wavelength'] = 'wavelength', read_mode: Literal['single track', 'full vertical binning'] = 'full vertical binning', single_track_settings: tuple[int, int] | None = None, update_interval_ms: int = int(1e3 / 60), autoscale: bool | None = 'x', autoscale_interval_ms: int | None = 1000, show_fps: bool = False, useblit: bool = True, yscale: str | ScaleT = 'linear', style: StyleT | Sequence[StyleT] = ('dark_background', 'fast'), fig_kw: Mapping[str, Any] | None = None, log_level: int | None = None, ): self._read_mode = read_mode self.single_track_settings = single_track_settings fig_kw = {'num': 'CCD Spectrum Viewer'} | (fig_kw if fig_kw is not None else {}) super().__init__(station, exposure_time, subtract_background, data_rate, light_source, horizontal_scale, update_interval_ms=update_interval_ms, autoscale=autoscale, autoscale_interval_ms=autoscale_interval_ms, show_fps=show_fps, useblit=useblit, yscale=yscale, style=style, fig_kw=fig_kw, log_level=log_level) if 'figsize' not in self.fig_kw: _set_figure_position(self.fig, self.DEFAULT_FIGSIZE, loc='bottom right') @property def read_mode(self) -> str: return self._read_mode @property def _data_axis(self) -> str: return 'y' @property def _read_mode_settings_context(self) -> ContextManager: if self.read_mode == 'full vertical binning': return nullcontext() return self.ccd.single_track_settings.set_to(self.single_track_settings) @property def _axes(self): if 'x' in self.autoscale: return self.ccd.horizontal_axis.get(), return self.ccd.horizontal_axis.cache.get(), def _add_axes(self: LiveViewT, ax=None): super()._add_axes(ax) self.axes['main'].margins(y=0.2) self.axes['main'].grid(True, axis='y')
[docs] class ImageLiveView(_CcdMixin, live_view.BatchedLiveView2D): """Show a video of CCD live images. Parameters ---------- station : The logical station in which the :class:`~instruments.logical_instruments.ExcitationPath` instrument lives. exposure_time : The (maximum) exposure time. subtract_background : Subtract a background image from the data before displaying. data_rate : Plot the count rate instead of the total number of counts. light_source : Use the white light or the laser. horizontal_scale : Show energy or wavelength on the bottom axis. read_mode : Use image, random track, or multi track mode. read_mode_settings : Read mode settings to apply for data acquisition. """ DEFAULT_FIGSIZE = (885, 630) def __init__(self, station: Station, exposure_time: float = 0.1, subtract_background: bool = True, data_rate: bool = False, light_source: Literal['laser', 'white_light'] = 'laser', horizontal_scale: Literal['energy', 'wavelength'] = 'wavelength', read_mode: Literal['image', 'random track', 'multi track'] = 'image', read_mode_settings: Any | None = None, update_interval_ms: int = int(1e3), autoscale: bool | None = True, autoscale_interval_ms: int | None = 1000, show_fps: bool = False, useblit: bool = True, style: StyleT | Sequence[StyleT] = ('dark_background', 'fast'), img_kw: Mapping[str, Any] | None = None, fig_kw: Mapping[str, Any] | None = None, log_level: int | None = None): self._read_mode = read_mode self.read_mode_settings = read_mode_settings img_kw = {'cmap': GRAY_CMAP} | (img_kw if img_kw is not None else {}) fig_kw = {'num': 'CCD Image Viewer'} | (fig_kw if fig_kw is not None else {}) super().__init__(station, exposure_time, subtract_background, data_rate, light_source, horizontal_scale, update_interval_ms=update_interval_ms, autoscale='c', autoscale_interval_ms=autoscale_interval_ms, show_fps=show_fps, useblit=useblit, img_kw=img_kw, fig_kw=fig_kw, ylabel=station.detection_path.ccd.vertical_axis.label, style=style, log_level=log_level) if 'figsize' not in self.fig_kw: _set_figure_position(self.fig, self.DEFAULT_FIGSIZE, loc='bottom right') @property def read_mode(self) -> str: return self._read_mode @property def _data_axis(self) -> str: return 'c' @property def _read_mode_settings_context(self) -> ContextManager: match self.read_mode: case 'image': if self.read_mode_settings is None: self.read_mode_settings = [1, 1, 1, 2000, 1, 256] return self.ccd.image_settings.set_to(self.read_mode_settings) case 'random track': if self.read_mode_settings is None: self.read_mode_settings = np.ravel(list(zip(np.arange(1, 257, 16), np.arange(16, 257, 16)))) return self.ccd.random_track_settings.set_to(self.read_mode_settings) case 'multi track': if self.read_mode_settings is None: self.read_mode_settings = [1, 20, 118] return self.ccd.multi_track_settings.set_to(self.read_mode_settings) case _: raise ValueError @property def _axes(self) -> tuple[npt.NDArray, ...]: if 'x' in self.autoscale: x = self.ccd.horizontal_axis.get() else: x = self.ccd.horizontal_axis.cache.get() if 'y' in self.autoscale: y = self.ccd.vertical_axis.get() else: y = self.ccd.vertical_axis.cache.get() return x, y @property def _units(self) -> dict[Literal['x', 'y', 'c'], str]: return {'y': 'px'} def _add_axes(self: LiveViewT, ax=None): super()._add_axes(ax) if self.plot_line: self.axes['line'].margins(y=0.2) self.axes['line'].grid(True, axis='y') def _update(self: LiveViewT, frame) -> set[ArtistT]: if self._first_frame: self.axes['main'].set_ylim(self._axes[1][[0, -1]]) return super()._update(frame)
[docs] class PowerLiveView(_QcodesMixin, live_view.BatchedLiveView1D): """Monitor readings from the power meter. Parameters ---------- station : The logical station in which the :class:`~instruments.logical_instruments.ExcitationPath` instrument lives. display_duration : The time in seconds to display in total. averaging_time : The averaging time per point in seconds. Should be smaller than 498 ms, otherwise the instrument times out. Defaults to `update_interval_ms`. """ DEFAULT_FIGSIZE = (885, 388) def __init__(self, station: Station, display_duration: float = 10., averaging_time: float | None = None, update_interval_ms: int = int(1e3 / 60), number_update_interval_ms: int = int(1e3 / 12), autoscale: Literal['', 'y'] | None = 'y', autoscale_interval_ms: int | None = 1000, show_fps: bool = False, useblit: bool = True, yscale: str | ScaleT = 'linear', ylim: tuple[float, float] | None = None, style: StyleT | Sequence[StyleT] = ('dark_background', 'fast'), fig_kw: Mapping[str, Any] | None = None, log_level: int | None = None): self.station = station self.display_duration = display_duration self.number_update_interval_ms = number_update_interval_ms update_interval_ms = max(1, update_interval_ms) self._title_clock = live_view.Clock(maxlen=int(self.display_duration // (update_interval_ms * 1e-3))) self._title_text = '{:.3g} {}W' xlabel = r'$\Delta t$' ylabel = 'Power' units: dict[Literal['x', 'y'], str] = {'x': 's', 'y': 'W'} fig_kw = {'num': 'Powermeter Viewer'} | (fig_kw if fig_kw is not None else {}) super().__init__(self._put_data, update_interval_ms=update_interval_ms, autoscale=autoscale, autoscale_interval_ms=autoscale_interval_ms, show_fps=show_fps, useblit=useblit, blocking_queue=True, xlabel=xlabel, ylabel=ylabel, units=units, yscale=yscale, xlim=(-self.display_duration, 0), ylim=ylim, style=style, fig_kw=fig_kw, log_level=log_level, averaging_time=averaging_time) if 'figsize' not in self.fig_kw: _set_figure_position(self.fig, self.DEFAULT_FIGSIZE, loc='top right') def _add_axes(self, ax=None): super()._add_axes(ax) # We need a separate axes for the number power reading because # Animation only blits the axes, not the entire figure. gs = self.axes['main'].get_subplotspec().subgridspec(nrows=2, ncols=1, height_ratios=[1, 9], hspace=0.0) if (ax_number := live_view.find_artist_with_label(self.fig.axes, 'number')) is None: ax_number = self.fig.add_subplot(gs[0], label='number') ax_number.axis('off') self.axes['number'] = ax_number self.axes['main'].set_subplotspec(gs[1]) self.axes['main'].yaxis.set_tick_params(which='both', left=True, right=True, labelright=True, labelleft=False) self.axes['main'].yaxis.set_label_position('right') self.axes['main'].grid(True, axis='y') def _add_animated_artists(self): super()._add_animated_artists() number = self.axes['number'].text(0.5, 0.5, self._title_text.format(0.0, ''), fontdict=dict(size=mpl.rcParams['axes.titlesize'], weight=mpl.rcParams['axes.titleweight']), verticalalignment='center', horizontalalignment='center', animated=self.useblit) self.animated_artists['number'] = number def _initialize(self) -> set[ArtistT]: self._title_clock.clear() self._title_clock() return super()._initialize() def _update( self, frame: tuple[npt.NDArray[float], npt.NDArray[float]] | None = None ) -> set[ArtistT]: if frame is None: return set(self.animated_artists.values()) times, power = frame mask = times > -self.display_duration animated_artists = super()._update((times[mask], power[mask])) stale_title = ((elapsed := time.perf_counter() - self._title_clock[-1]) > self.number_update_interval_ms * 1e-3) if self._first_frame or stale_title: # update number at separate rate to avoid blurring. Average over the values # that arrived since we last updated the number recent_power_readings = power[times >= -elapsed] prefix, scale = find_scale_and_prefix(recent_power_readings.mean(), self.units['y']) self.animated_artists['number'].set_text(self._title_text.format( power[-1] * 10 ** (-scale), prefix )) self._title_clock() animated_artists.add(self.animated_artists['number']) return animated_artists def _put_data(self, data_queue: queue.Queue[tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]], stop_event: threading.Event, *, averaging_time: float | None = None): if averaging_time is None: averaging_time = self.update_interval_ms * 1e-3 power = deque(maxlen=int(self.display_duration // averaging_time)) times = deque(maxlen=int(self.display_duration // averaging_time)) with self.station.excitation_path.powermeter.averaging_time.set_to(averaging_time): while not stop_event.is_set(): power.append(self.station.excitation_path.powermeter.power()) times.append(time.perf_counter()) data_queue.put((np.array(times) - times[-1], np.array(power)))
[docs] class TimeTaggerLiveView(_QcodesMixin, live_view.BatchedLiveView1D): DEFAULT_FIGSIZE = (885, 388) tagger: TimeTagger counter: CounterMeasurement def __init__(self, station: Station, channels: Sequence[int], data_rate: bool = True, display_duration: float = 60., averaging_time: float = 0.1, update_interval_ms: int = int(1e3 / 60), autoscale: Literal['', 'y'] | None = 'y', autoscale_interval_ms: int | None = 1000, show_fps: bool = False, useblit: bool = True, yscale: str | ScaleT = 'linear', ylim: tuple[float, float] | None = None, style: StyleT | Sequence[StyleT] = ('dark_background', 'fast'), fig_kw: Mapping[str, Any] | None = None, log_level: int | None = None): self.station = station self.tagger = station.detection_path.tagger update_interval_ms = max(1, update_interval_ms) xlabel = r'$\Delta t$' ylabel = 'Count rate' if data_rate else 'Counts' units: dict[Literal['x', 'y'], str] = {'x': 's', 'y': 'cps' if data_rate else 'cts'} fig_kw = {'num': 'Time Tagger Viewer'} | (fig_kw if fig_kw is not None else {}) super().__init__(self._put_data, n_lines=len(channels), plot_legend='upper left', update_interval_ms=update_interval_ms, autoscale=autoscale, autoscale_interval_ms=autoscale_interval_ms, show_fps=show_fps, useblit=useblit, blocking_queue=True, xlabel=xlabel, ylabel=ylabel, units=units, yscale=yscale, xlim=(-display_duration, 0), ylim=ylim, style=style, fig_kw=fig_kw, log_level=log_level, channels=channels, display_duration=display_duration, averaging_time=averaging_time, data_rate=data_rate) if 'figsize' not in self.fig_kw: _set_figure_position(self.fig, self.DEFAULT_FIGSIZE, loc='top left') def _add_axes(self, ax_main: Axes | None = None): super()._add_axes(ax_main) self.axes['main'].yaxis.set_tick_params(which='both', left=True, right=True, labelright=True, labelleft=False) self.axes['main'].yaxis.set_label_position('right') self.axes['main'].grid(True, axis='y') def _put_data(self, data_queue: queue.Queue[tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]], stop_event: threading.Event, *, channels: Sequence[int], display_duration: float = 60., averaging_time: float = 0.1, data_rate: bool = True): assert self.station.detection_path.bandwidth() >= 0 if averaging_time is None: averaging_time = self.update_interval_ms * 1e-3 if not hasattr(self, 'counter'): self.counter = self.tagger.add_counter_measurement('live_view_counter', label='Time Tagger Count Viewer') self.counter.channels(channels) self.counter.binwidth(int(10**12 * averaging_time)) self.counter.n_values(int(display_duration // averaging_time)) self.counter.rolling(True) times = self.counter.time_bins() / 10**12 times -= times[-1] data = self.counter.data_normalized if data_rate else self.counter.data labels = self._resolve_labels(channels) with ( self.station.excitation_path.open_shutter(), self.station.detection_path.active_detection_path.set_to('apd', allow_changes=True), self.station.detection_path.bandwidth.set_to(0, allow_changes=True) ): self.counter.clear() self.counter.start() while not stop_event.is_set(): data_queue.put((times, dict(zip(labels, data.get())))) def _resolve_labels(self, channels: Sequence[int]): labels = [] for channel in channels: for virtual_channel in itertools.chain.from_iterable( self.tagger.virtual_channel_lists ): if virtual_channel.get_channel() == channel: labels.append(virtual_channel.label) break else: labels.append(str(channel)) return labels def __del__(self): try: self.tagger.counter_measurements.remove(self.counter) except (KeyError, ValueError): pass
def _lambda2eV(lambd: float | npt.ArrayLike) -> float | npt.NDArray[float]: """Convert meters to eV. Parameters ---------- lambd : Value in wavelength (meters). Returns ------- eV : Value in electron Volts. """ with misc.filter_warnings('ignore', RuntimeWarning): result = const.lambda2eV(lambd) result[np.isinf(result)] = 1e16 return result def _eV2lambda(eV: float | npt.ArrayLike) -> float | npt.NDArray[float]: """Convert eV to meters. Parameters ---------- eV : Value in electron Volts. Returns ------- lambd : Value in wavelength (meters). """ with misc.filter_warnings('ignore', RuntimeWarning): result = const.eV2lambda(eV) result[np.isinf(result)] = 1e16 return result def _set_figure_position(fig, figsize: tuple[int, int], loc: str = 'bottom right'): w, h = figsize # Windows title bar looks to be 31 pixels tall. # Task bar is 150 pixels wide atm. dw, dh = 150, 31 if 'right' in loc: x = 1920 - w else: x = dw if 'bottom' in loc: y = 1080 - h else: y = dh # assume 1920 x 1080 px fig.canvas.manager.window.setGeometry(x, y, w, h) with misc.filter_warnings('ignore', UserWarning): try: fig.tight_layout() except RuntimeError: # Different layout engine pass