Source code for mjolnir.measurements.measures

"""This module provides classes for measuring things."""
from __future__ import annotations

import dataclasses
import warnings
from collections.abc import Iterable, Callable, Sequence
from datetime import timedelta
from typing import Self, Generic, TypeVar, Literal, Any, cast

import numpy as np
import numpy.typing as npt
from matplotlib import pyplot as plt
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
from qcodes.parameters import (ParameterBase, Parameter, ParameterWithSetpoints, DelegateParameter,
                               MultiParameter)
from qutil import itertools, functools
from qutil.plotting import BlitManager, reformat_axis

from .. import _MONITOR_SIZE_PX
from ..helpers import find_param_source
from ..parameters import TimeParameter
from ..plotting.plot_nd import style_context

_T = TypeVar('_T', bound=np.generic)


[docs] def Id(x: _T) -> _T: return x
[docs] @dataclasses.dataclass class Measure(Generic[_T]): r"""Defines a gettable during a measurement. This class represents a parameter to measure during a measurement executed by :meth:`.handler.MeasurementHandler.measure`. Additional parameters that derive from the main parameter (through :class:`~qcodes:qcodes.parameters.DelegateParameter`) may be measured by passing them as a sequence as `delegates` to the constructor. `delegate`\ s are always gotten after the parameter, and only their parameter's cache is queried to avoid unnecessary instrument communication. Hence, the delegates work on the same raw data that is returned from the get call to the main parameter. This can, for example, be used for storing processed data (i.e., like a procfn) together with the raw data. In addition to the measurement, this class implements live plotting such that, when :attr:`plot_callback` is called, a plot is updated with the latest cached value. Parameters ---------- param : A gettable qcodes Parameter. delegates : A sequence of gettable :class:`~qcodes:qcodes.parameters.DelegateParameter` s deriving from `param`. live_plot : Enable live plotting during measurements using :class:`~.handler.MeasurementHandler`. """ param: ParameterBase delegates: Sequence[DelegateParameter] = dataclasses.field(default_factory=list) live_plot: bool = True _plot_callbacks: list[Callable[[int, int], [int, int]]] | None = dataclasses.field( default_factory=list, repr=False, init=False, compare=False ) def __post_init__(self): if not isinstance(self.param, ParameterBase): raise TypeError(f'Expected param to be of class ParameterBase, not {type(self.param)}') elif not isinstance(self.param, (Parameter, MultiParameter)) and self.live_plot: warnings.warn('Live plotting only enabled for Parameter subclasses and ' f'MultiParameter, not {type(self.param)}. Disabling.', UserWarning) self.live_plot = False for delegate in self.delegates: if not isinstance(delegate, ParameterBase): raise TypeError('Expected all delegates to be of class ParameterBase, not ' f'{type(delegate)}') if find_param_source(delegate) is not find_param_source(self.param): raise ValueError('delegates should all have the same source as param') def __hash__(self): return hash(self.param) def __or__(self, other) -> 'MeasureSet': if isinstance(other, MeasureSet): return MeasureSet([self, *other]) return MeasureSet([self, other]) def __ror__(self, other) -> 'MeasureSet': if isinstance(other, MeasureSet): return MeasureSet([*other, self]) return MeasureSet([other, self]) @property def shape(self) -> tuple[int, ...]: """The shape of the result that :meth:`get` returns.""" if isinstance(self.param, ParameterWithSetpoints): return self.param.vals.shape elif isinstance(self.param, MultiParameter): return self.param.shapes return () @property def shape_delegates(self) -> list[tuple[int, ...]]: """The shape of the result that :meth:`get_delegates` returns.""" return [delegate.vals.shape if isinstance(delegate, ParameterWithSetpoints) else () for delegate in self.delegates]
[docs] def get(self, cache: bool = False) -> _T: """The parameter value.""" if cache: return self.param.cache.get() else: return self.param.get()
[docs] def get_delegates(self, cache: bool = True) -> list[_T]: """The delegate values, defaults to the cache.""" if cache: return [delegate.cache.get() for delegate in self.delegates] else: return [delegate.get() for delegate in self.delegates]
[docs] def initialize_plotting(self, sweeps: 'SweepList', axs: npt.NDArray[Axes], dynamic_x_axis: bool = False) -> Sequence[Artist]: """Set up live plotting for a measurement defined by sweeps. Does nothing if :attr:`live_plot` is False. Parameters ---------- sweeps : The :class:`~.sweeps.SweepList` defining the measurement sweep. axs : The axes to plot in. dynamic_x_axis : Whether the x-axis is dynamically updated to the current data range or has a static range of the setpoints. This is automatically set to True if any of the sweeps' transforms are unequal to the identity. Notes ----- TODO: * Handle direct product setpoints (sweep & sweep). * Could add extra x-axis * Plot 2d if requested and sweeps is at least 2d. * Could add a 1d subplot. * Would need to check if not too slow. * Plot 2d for 2d ParameterWithSetpoints? * Plot all parameters of a single axis (only the first one is currently selected). * Could add additional axis spines. * Would be in conflict with current multiple line plots for 'small' ParameterWithSetpoints. """ if not self.live_plot: return () self.reset_plotting() dynamic_x_axis |= any( trafo is not Id for trafo in set(itertools.chain.from_iterable(sweep.transform for sweep in sweeps)) ) artists = [] for param, ax in zip(itertools.chain((self.param,), self.delegates), axs, strict=True): artists.extend(self._initialize_plotting(param, ax, sweeps, dynamic_x_axis)) return artists
def _initialize_plotting(self, param: ParameterBase, ax: Axes, sweeps: 'SweepList', dynamic_x_axis: bool = False) -> Sequence[Artist]: def update_view(j: int, k: int, buffered: bool, force_update_x: bool = False): """Update view limits. Parameters ---------- j : The direction index. k : The setpoint index for the given direction. buffered : Is the plot for a buffered acqusition? In this case, the x limits do not depend on the sweep. force_update_x : Force an update of x limits? """ # TODO: This is pretty slow. Maybe replace live plotting by pyqtgraph ydata = list(itertools.chain.from_iterable(line.get_ydata() for line in lines)) ax.dataLim.update_from_data_y(ydata, ignore=True) ax.set_ylabel(y_label.format(reformat_axis(ax, ydata, _get_unit(y), 'y'))) if k == 0 or dynamic_x_axis or force_update_x: # Update x if buffered: xdata = x.cache.get() elif dynamic_x_axis: xdata = list(itertools.chain.from_iterable(line.get_xdata() for line in lines)) else: xdata = sweeps[-1][j][:, 0] if k == 0: # Only need to update dataLims once at the beginning (since the line is not # drawn during initialization due to being all nan's) scale, kwargs = _determine_xscale(xdata) ax.set_xscale(scale, **kwargs) ax.dataLim.update_from_data_x(xdata, ignore=True) ax.set_xlabel(x_label.format(reformat_axis(ax, xdata, _get_unit(x), 'x'))) ax.autoscale_view(scalex=True, scaley=True) else: ax.autoscale_view(scalex=False, scaley=True) plot_buffered, plot_multiple = _parse_param(param) if plot_buffered: # More than ten setpoints, do color plot param = cast(ParameterWithSetpoints | MultiParameter, param) x = _get_setpoints(param) y = param if plot_multiple: if isinstance(param, ParameterWithSetpoints): lines, title = _prepare_paramwithsetpoints_plot(ax, y, x) elif isinstance(param, MultiParameter): lines, title = _prepare_multiparam_plot(ax, y, x) else: raise RuntimeError(f'param {y} not ParameterWithSetpoints or MultiParameter') # Constrain legend loc, otherwise it moves during updates. legend = ax.legend(loc='upper right') legend.set_title(title) else: lines = ax.plot(x.get(), np.full_like(x.cache.get(), np.nan), animated=True) def update(j: int, k: int) -> tuple[int, int]: x_val = np.asarray(x.cache.get()) y_vals = np.atleast_2d(y.cache.get()) force_update_x = False for line, y_val in zip(lines, y_vals): line.set_ydata(y_val) if ( np.shape(x_val) != line.get_xdata().shape or not np.allclose(x_val, line.get_xdata()) ): # setpoints changed, need to update line.set_xdata(x_val) force_update_x |= True update_view(j, k, buffered=True, force_update_x=force_update_x) return j, k else: param = cast(ParameterWithSetpoints | MultiParameter, param) x = sweeps[-1].param[0] y = param if plot_multiple: if isinstance(param, ParameterWithSetpoints): lines, title = _prepare_paramwithsetpoints_plot(ax, y, x=None) elif isinstance(param, MultiParameter): lines, title = _prepare_multiparam_plot(ax, y, x=None) else: raise RuntimeError(f'param {y} not ParameterWithSetpoints or MultiParameter') # Constrain legend loc, otherwise it moves during updates. legend = ax.legend(loc='upper right') legend.set_title(title) else: lines = ax.plot([], [], animated=True) def update(j: int, k: int) -> tuple[int, int]: x_val = x.cache.get() y_vals = np.atleast_1d(y.cache.get()) for line, y_val in zip(lines, y_vals): if k == 0: # First point of inner axis sweep. Reset line data line.set_data([x_val], [y_val]) else: line.set_data(itertools.starmap( np.append, zip(line.get_data(), [x_val, y_val]) )) # Only update view on the last line update_view(j, k, buffered=False) return j, k if unit := _get_unit(x): x_label = f'{_prepare_string(_get_label(x))} ({{}}{_prepare_string(unit)})' else: # Append dummy placeholder for '' prefix string x_label = f'{_prepare_string(_get_label(x))}{{}}' if unit := _get_unit(y): y_label = f'{_prepare_string(_get_label(y))} ({{}}{_prepare_string(unit)})' else: # Append dummy placeholder for '' prefix string y_label = f'{_prepare_string(_get_label(y))}{{}}' if y.instrument is not None: # Add the instrument label in a new line for easier identification in crowded figure y_label = _prepare_string(y.instrument.label) + '\n' + y_label ax.set_xlabel(x_label.format('')) ax.set_ylabel(y_label.format('')) self._plot_callbacks.append(update) return lines
[docs] def reset_plotting(self): """Reset plotting so that calling :attr:`plot_callback` raises a :class:`RuntimeError`.""" self._plot_callbacks.clear()
@property def plot_callback(self) -> functools.FunctionChain: if len(self._plot_callbacks) > 0: return functools.chain(*self._plot_callbacks, n_args=2) raise RuntimeError('Plotting not initialized.')
[docs] class MeasureSet(set[Measure[_T]]): """A set of :class:`Measure` s.""" def __init__(self, params: Iterable[Measure | ParameterBase] = ()): super().__init__((_to_measure(param) for param in params)) self._plot_callback: Callable[[int, int], None] | None = None self._fig: Figure | None = None self._axes: Axes | None = None def __sub__(self, other) -> Self: if isinstance(other, (ParameterBase, Measure)): return self.__class__(super().__sub__({other})) elif np.iterable(other): return self.__class__(super().__sub__(set(other))) else: return NotImplemented def __or__(self, other) -> Self: if isinstance(other, (ParameterBase, Measure)): return self.__class__(super().__or__({_to_measure(other)})) elif np.iterable(other): return self.__class__(super().__or__(set(other))) else: return NotImplemented def __ior__(self, other) -> Self: if isinstance(other, (ParameterBase, Measure)): self.add(_to_measure(other)) elif isinstance(other, self.__class__): self.update(other) elif np.iterable(other): self.update(self.__class__(other)) else: return NotImplemented return self @property def params(self) -> list[ParameterBase]: """All parameters of member :class:`Measure` s.""" return list(measure.param for measure in self) @property def delegates(self) -> list[DelegateParameter]: return list(itertools.chain.from_iterable(measure.delegates for measure in self)) @property def shapes(self) -> list[tuple[int, ...]]: """The shapes of the result that :meth:`get` returns.""" return list(measure.shape for measure in self) @property def shapes_delegates(self) -> list[tuple[int, ...]]: """The shapes of the result that :meth:`get_delegates` returns.""" return list(itertools.chain.from_iterable(measure.shape_delegates for measure in self)) @property def full_names(self) -> list[str]: return list(measure.param.full_name for measure in self) @property def full_names_delegates(self) -> list[str]: return list(delegate.full_name for measure in self for delegate in measure.delegates)
[docs] def get(self, cache: bool = False) -> list[_T]: """The parameter values.""" return list(measure.get(cache) for measure in self)
[docs] def get_delegates(self, cache: bool = True) -> list[_T]: return list(itertools.chain.from_iterable(measure.get_delegates(cache) for measure in self))
[docs] def add(self, __element): super().add(_to_measure(__element))
[docs] def update(self, *s): super().update(self.__class__(*s))
[docs] def initialize_plotting(self, sweeps: 'SweepList', dynamic_x_axis: bool = False, plot_style=None): """Initialize plotting for all member :class:`Measure` s for which :attr:`~Measure.live_plot` is true. Parameters ---------- sweeps : The :class:`~.sweeps.SweepList` defining the measurement sweep. dynamic_x_axis : Whether the x-axis is dynamically updated to the current data range or has a static range of the setpoints. plot_style : A matplotlib plot style to use for the plotting. """ def get_title(): s = '\n'.join(['{}: {} {}'] * len(sweeps)) labels = (_get_label(param) for param in sweeps.params) units = (_get_unit(param) for param in sweeps.params) values = _format_values(sweeps.params, sweeps.current_setpoints) return s.format(*itertools.flatten(zip(labels, values, units))) def update(j, k): for meas in self: # Delegate updating plot data to measures try: meas.plot_callback(j, k) except RuntimeError: pass # Update setpoint tuple in title title.set_text(get_title()) # Trigger the blit manager manager.update() # Draw rescaled axes self.fig.canvas.draw_idle() do_live_plot = [measure.live_plot * (1 + len(measure.delegates)) for measure in self] n_axes = sum(do_live_plot) if not n_axes: return with style_context(style=plot_style, fast=True): self._fig, self._axes = plt.subplots(n_axes, layout='tight', figsize=_calculate_figsize(n_axes)) self._axes = np.split(np.atleast_1d(self._axes), list(itertools.accumulate(filter(None, do_live_plot))))[:-1] # suptitle is not well handled by tight layout title = self.axes[0][0].set_title(' ' + '\n' * (len(sweeps) - 1), animated=True) lines = [] for axs, measure in zip(self.axes, itertools.compress(self, do_live_plot)): try: lines.extend(measure.initialize_plotting(sweeps, axs, dynamic_x_axis)) except NotImplementedError: for ax in axs: ax.remove() self.fig.tight_layout() manager = BlitManager(self.fig.canvas, lines + [title]) self._plot_callback = update
[docs] def reset_plotting(self): """Reset plotting so that calling :prop:`plot_callback` raises a :class:`RuntimeError`.""" self._plot_callback = None
@property def plot_callback(self) -> Callable[[int, int], None]: if self._plot_callback is None: raise RuntimeError('Plotting not initialized.') return self._plot_callback @property def fig(self) -> Figure: if self._fig is None: raise RuntimeError('Plotting not initialized.') return self._fig @property def axes(self) -> list[npt.NDArray[Axes]]: if self._axes is None: raise RuntimeError('Plotting not initialized.') return self._axes
def _to_measure(obj) -> Measure: if not isinstance(obj, Measure): return Measure(obj) return obj def _prepare_string(s): s = s.replace('{', '{{') s = s.replace('}', '}}') return s def _determine_xscale( x: npt.NDArray, rel_thresh: float = 2e-2 ) -> tuple[Literal['linear', 'log', 'asinh'], dict[str, Any]]: x = np.asarray(x) x = x[x != 0] logdiff = np.diff(np.log10(np.abs(x))) if x.size > 2 and logdiff.std() / np.abs(logdiff.mean()) <= rel_thresh: if (x <= 0).any(): return 'asinh', {'linear_width': np.abs(x[np.nanargmin(np.abs(x))])} else: return 'log', {} else: return 'linear', {} def _format_values(params: Iterable[ParameterBase], values: Iterable[Iterable[Any]]): # Use datetime formatting for TimeParameter's, float/scientific else return itertools.replace_except( (format(timedelta(seconds=val)) if isinstance(param, TimeParameter) else format(val, '.4g') for param, val in zip(params, itertools.flatten(values))), TypeError, replacement=str(None) ) def _parse_param(param: ParameterBase) -> tuple[bool, bool]: if not isinstance(param, ParameterBase): raise NotImplementedError('Live plotting not implemented for Measures without ' 'parameters.') elif isinstance(param, ParameterWithSetpoints): if len(param.setpoints) > 2: raise NotImplementedError('Live plotting not implemented for Measures with >2d ' 'setpoints.') elif len(param.setpoints) == 2: if len(param.setpoints[0].get()) > 10: raise NotImplementedError('Live plotting not implemented for Measures with 2d ' 'setpoints and shape (>10, x).') else: plot_buffered = True plot_multiple = True else: # len(setpoints) == 1 if len(param.setpoints[0].get()) > 10: plot_buffered = True plot_multiple = False else: plot_buffered = False plot_multiple = True elif isinstance(param, MultiParameter): if len(set(param.shapes)) != 1: raise NotImplementedError('Live plotting not implemented for MultiParameter Measures ' 'with heterogeneous setpoints.') elif len(param.shapes[0]) > 1: raise NotImplementedError('Live plotting not implemented for MultiParameter Measures ' 'with >1d setpoints.') elif len(param.shapes[0]) == 1 and param.shapes[0] > (1,): if not all(itertools.starmap( np.array_equal, zip(param.setpoints[1:], param.setpoints[:-1])) ): raise NotImplementedError('Live plotting not implemented for MultiParameter ' 'Measures with equal setpoints.') plot_buffered = True plot_multiple = True else: plot_buffered = False plot_multiple = True else: plot_buffered = False plot_multiple = False return plot_buffered, plot_multiple def _prepare_lines(ax, x: Parameter | None, label: str) -> list[Line2D]: if x is None: return ax.plot([], [], animated=True, label=label) else: return ax.plot(x.get(), np.full(np.shape(x.cache.get()), np.nan), animated=True, label=label) def _prepare_paramwithsetpoints_plot(ax, y: ParameterWithSetpoints, x: Parameter | None = None) -> tuple[list[Line2D], str]: lines = [] for label in _format_values([y], [y.setpoints[0].get()]): lines.extend(_prepare_lines(ax, x, label)) if unit := _get_unit(y.setpoints[0]): title = f'{y.setpoints[0].label} ({_prepare_string(unit)})' else: title = _get_label(y.setpoints[0]) return lines, title def _prepare_multiparam_plot(ax, y: MultiParameter, x: Parameter | None = None) -> tuple[list[Line2D], str]: lines = [] for label, unit in zip(y.labels, y.units): if unit: label = f'{label} ({_prepare_string(unit)})' lines.extend(_prepare_lines(ax, x, label)) return lines, '' def _get_setpoints(param: ParameterWithSetpoints | MultiParameter) -> Parameter: setpoints = param.setpoints[-1] if isinstance(setpoints, Parameter): return setpoints class _Setpoints(Parameter): def get_raw(self): return setpoints[0] if param.setpoint_names is not None: name = param.setpoint_names[-1] else: name = f'{param.name}_setpoints' if param.setpoint_labels is not None: label = param.setpoint_labels[-1] else: label = None return _Setpoints(name, instrument=None, label=label) def _get_label(param: ParameterBase) -> str: return param.label if hasattr(param, 'label') else param.name def _get_unit(param: ParameterBase) -> str: return param.unit if hasattr(param, 'unit') else '' def _calculate_figsize(n_axes: int) -> tuple[float, float]: _, default_height = plt.rcParams['figure.figsize'] width = _MONITOR_SIZE_PX[1] / plt.rcParams['figure.dpi'] height = default_height * 0.75 * n_axes max_height = (_MONITOR_SIZE_PX[0] - 31) / plt.rcParams['figure.dpi'] return width, min(height, max_height)