"""This module provides plotting functions, both offline and online."""
from __future__ import annotations
import warnings
from collections import defaultdict
from collections.abc import Hashable
from contextlib import nullcontext
from copy import copy
from typing import Any, ContextManager, Literal, Mapping
from unittest import mock
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from cycler import cycler
from matplotlib import cm, colors, ticker
from matplotlib.axes import Axes
from matplotlib.colorbar import Colorbar
from matplotlib.figure import Figure
from matplotlib.gridspec import GridSpec
from matplotlib.widgets import Slider
from qcodes.dataset import DataSetProtocol, load_by_run_spec
from qcodes.dataset.data_export import reshape_2D_data
from qcodes.plotting.axis_labels import _UNITS_FOR_RESCALING, find_scale_and_prefix
from qutil import functools, itertools, misc
from qutil.plotting import BlitManager, norm_to_scale, reformat_axis
from qutil.plotting.live_view import StyleT
from xarray import DataArray, Dataset
from mjolnir.plotting.live_view import _eV2lambda, _lambda2eV
_AUTOSCALE_KEYS: set[str] = {'leakage', 'power', 'array', 'slice'}
_UNITS_FOR_RESCALING.add('cps')
_UNITS_FOR_RESCALING.add('cts')
_GRAY = colors.to_rgba('tab:gray', 0.3)
CMAP = cm.inferno.with_extremes(bad=_GRAY, under='0.0', over='1.0')
DIVERGING_CMAP = cm.RdBu_r.with_extremes(bad=_GRAY, under='0.0', over='0.0')
def style_context(style: StyleT | list[StyleT] | None = None, fast: bool = True,
after_reset: bool = True) -> ContextManager:
"""Return a context manager to define a plot style.
Parameters
----------
style :
A valid matplotlib style. If None, defaults to no change.
fast :
Add the matplotlib 'fast' style for speedier plotting.
after_reset :
Reset all style preferences before applying the style.
Returns
-------
context :
The style context manager.
"""
if style is None:
if fast:
context = mpl.style.context('fast', after_reset=after_reset)
else:
context = nullcontext()
else:
if fast:
if isinstance(style, list):
style.append('fast')
else:
style = [style, 'fast']
context = mpl.style.context(style, after_reset=after_reset)
return context
[docs]
def plot_nd(dataset_or_run_id: Dataset | DataSetProtocol | int, array_target: str | None = None,
vertical_target: str | None = None, horizontal_target: str | None = None,
plot_leakage: bool | None = None, plot_power: bool | None = None,
plot_slice: bool = True,
autoscale: bool | Mapping[Literal['leakage', 'power', 'array', 'slice'], bool] = True,
xscale: str | None = None, yscale: str | None = None,
slider_scale: Mapping[str, str] | None = None,
norm: colors.Normalize | None = None, valfmt: str | Mapping[str, str] = '.3g',
pad_subplots: bool = True, fast: bool | None = None,
style: StyleT | list[StyleT] | None = None, fig_kw: Mapping[str, Any] | None = None,
**kwargs):
"""Plots QCoDeS measurements with >= 2 sweep axes.
Plots are color plots of 2d slices through the nd data. Extra
dimensions get a slider which can be moved to select slices along
that coordinate.
Parameters
----------
dataset_or_run_id :
One of
- a qcodes dataset
- the run ID of the measurement in the currently open database
- a xarray dataset
array_target :
(Part of) the parameter name to be plotted in the main plot.
vertical_target :
(Part of) the parameter name to be plotted on the vertical axis.
horizontal_target :
(Part of) the parameter name to be plotted on the horizontal
axis.
plot_leakage :
Plot leakage values of gates swept on the vertical axis, if any,
in a subplot.
plot_power :
Plot laser power readings during parameters sweeps on the
vertical axis, if any, in a subplot.
plot_slice :
Plot a horizontal 1d-slice through the image data in a subplot.
autoscale :
Autoscale axes/colormaps when moving sliders. If True, autoscale
is applied to all axes and colormaps. If dict, should have any
of the keys {'leakage', 'power', 'array', 'slice'} with a
boolean value.
xscale :
The scale of the x-axis.
yscale :
The scale of the y-axis.
slider_scale :
The scale of a given slider. Should be a mapping with keys
corresponding to (part of) the parameter name on the slider.
norm :
The :class:`~matplotlib:matplotlib.colors.Normalize` instance to
normalize array data to the colormap.
valfmt :
A format string to be used for slider labels.
pad_subplots :
Add white space between subplots or have them tightly against
each other.
fast :
Use a fast plot style.
style :
A valid matplotlib plotting style.
fig_kw :
Kwargs passed on to the :func:`~matplotlib:matplotlib.pyplot.figure`
constructor.
**kwargs :
Kwargs passed on to the :meth:`xarray.DataArray.plot` method.
"""
autoscale = _parse_autoscale(autoscale, norm, **kwargs)
if slider_scale is None:
slider_scale = {}
if isinstance(valfmt, str):
_ = copy(valfmt)
valfmt = defaultdict(lambda: _)
else:
valfmt = defaultdict(lambda: '.2g') | valfmt
if norm is None:
kwargs.setdefault('vmin', 0)
if isinstance(dataset_or_run_id, int):
dataset = load_by_run_spec(captured_run_id=dataset_or_run_id)
else:
dataset = dataset_or_run_id
is_qcodes_dataset = isinstance(dataset, DataSetProtocol)
array_target, horizontal_target, vertical_target = _parse_targets(dataset, array_target,
horizontal_target,
vertical_target)
array_data = _to_homogeneous_dataarray(dataset, array_target)
if array_data.ndim > 1:
kwargs.setdefault('cmap', CMAP)
if array_data.squeeze().ndim < 2:
raise ValueError("This function should be called plot_n>=2d()")
# Transpose so that raveling the array is in the correct order for
# horizontal and vertical targets
array_data = array_data.transpose(..., vertical_target, horizontal_target)
array_idx: dict[Hashable, slice | int | npt.NDArray[np.int_]] = {
coord: 0 for coord in array_data.coords
if coord not in {horizontal_target, vertical_target}
}
if plot_leakage is not False:
try:
leakage_targets = _find_params_depending_on('current', vertical_target, dataset)
except ValueError as err:
if plot_leakage is True:
warnings.warn(str(err), UserWarning)
plot_leakage = False
else:
plot_leakage = True
leakage_dataset = _to_xarray_dataset(dataset, *leakage_targets)
# TODO: index results from inhomogeneous data
leakage_idx: dict[Hashable, slice | int | npt.NDArray[np.int_]] = {
coord: 0 for coord in leakage_dataset.coords
if coord not in {vertical_target, 'channel', 'index'}
}
if plot_power is not False:
try:
power_target = _find_params_depending_on('power', vertical_target, dataset)[0]
except ValueError as err:
if plot_power is True:
warnings.warn(str(err), UserWarning)
plot_power = False
else:
plot_power = True
power_data = _to_xarray_dataset(dataset, power_target)[power_target]
power_idx: dict[Hashable, slice | int | npt.NDArray[np.int_]] = {
coord: 0 for coord in power_data.coords
if coord not in {vertical_target}
}
if plot_slice:
slice_idx = copy(array_idx)
slice_idx[vertical_target] = 0
vertical_slider_dim = max(0, array_data.ndim - 2)
slider_coords = {}
if vertical_slider_dim:
slider_coords['vertical'] = set(array_data.coords).difference(
{horizontal_target, vertical_target}
)
if plot_slice:
slider_coords['horizontal'] = {vertical_target}
if fast is None:
fast = (vertical_slider_dim or plot_slice) and 'log' not in (xscale, yscale)
if fast and 'log' in (xscale, yscale):
warnings.warn('pcolorfast does not support log axes. Consider setting fast=False.')
# Set up a function that generates an update callback for sliders. The nesting level is needed
# to scope the variables
def generate_array_callback(coord, slider):
def callback(_):
array_idx[coord] = _idx_from_val(array_data, coord, slider.val)
try:
img.set_data(array_data[array_idx].squeeze())
except AttributeError:
img.set_array(array_data[array_idx].squeeze())
except TypeError:
img.set_data(img._Ax, img._Ay, array_data[array_idx].squeeze())
if autoscale['array']:
with img.norm.callbacks.blocked():
# Do not use img.autoscale because vmin/max might be given which it ignores
if autoscale['array']['vmin']:
img.norm.vmin = None
if autoscale['array']['vmax']:
img.norm.vmax = None
img.autoscale_None()
_rescale_and_relabel(img.colorbar, array_data[array_idx], 'c', plot_leakage)
if plot_leakage and coord in leakage_idx:
for line, data in zip(leakage_lines, leakage_dataset.values()):
leakage_idx[coord] = _idx_from_val(data, coord, slider.val)
line.set_xdata(data[leakage_idx].squeeze())
if autoscale['leakage']:
axes['plots']['vertical']['leakage'].relim()
axes['plots']['vertical']['leakage'].autoscale(axis='x', tight=True)
_rescale_and_relabel(axes['plots']['vertical']['leakage'], leakage_dataset,
'x', plot_leakage)
if plot_power and coord in power_idx:
power_idx[coord] = _idx_from_val(power_data, coord, slider.val)
power_line.set_xdata(power_data[power_idx].squeeze())
if autoscale['power']:
axes['plots']['vertical']['power'].relim()
axes['plots']['vertical']['power'].autoscale(axis='x', tight=True)
_rescale_and_relabel(axes['plots']['vertical']['power'], power_data, 'x',
plot_leakage)
nonlocal slice_idx
if plot_slice and coord in slice_idx:
# array might've changed
slice_idx |= {key: val for key, val in array_idx.items()
if not key == vertical_target}
_update_slice_line(slice_lines[1], array_data[slice_idx].squeeze(),
axes['plots']['horizontal']['bottom'], autoscale['slice'],
kwargs)
_update_ticks_and_labels(axes['plots'], array_data, horizontal_target, vertical_target)
manager.update()
return callback
def generate_slice_callback(coord, slider):
def callback(_):
slice_idx[coord] = _idx_from_val(array_data, coord, slider.val)
slice_lines[0].set_ydata(
[array_data.coords[vertical_target][slice_idx[vertical_target]].item()] * 2
)
_update_slice_line(slice_lines[1], array_data[slice_idx].squeeze(),
axes['plots']['horizontal']['bottom'], autoscale['slice'], kwargs)
_update_ticks_and_labels(axes['plots'], array_data, horizontal_target, vertical_target)
manager.update()
return callback
# We can finally plot all the things!
with style_context(style, fast):
color_cycle = mpl.rcParams['axes.prop_cycle'].by_key().get(
'color', ['tab:blue', 'tab:orange']
)
fig = plt.figure(**(fig_kw if fig_kw is not None else {}))
if is_qcodes_dataset:
sample_name = dataset.sample_name
captured_run_id = dataset.captured_run_id
else:
sample_name = dataset.attrs['sample_name']
captured_run_id = dataset.attrs['captured_run_id']
fig.suptitle(f"{sample_name}, run #{captured_run_id}")
axes = _setup_axes(fig, vertical_slider_dim, plot_leakage, plot_power, plot_slice,
array_data, horizontal_target, pad_subplots, fraction=0.03)
# So that we can use xarray's plot method with pcolorfast
with (mock.patch.object(axes['plots']['main'], 'pcolormesh',
axes['plots']['main'].pcolorfast)
if fast else nullcontext()):
img = array_data[array_idx].plot(ax=axes['plots']['main'],
x=horizontal_target,
y=vertical_target,
cbar_ax=axes['plots']['cbar'],
xscale=xscale, yscale=yscale, norm=norm,
animated=vertical_slider_dim, **kwargs)
blitted_artists = []
if vertical_slider_dim:
blitted_artists.append(img)
_rescale_and_relabel(img.colorbar, array_data[array_idx], 'c', plot_leakage)
axes['plots']['main'].set_title(None)
if plot_leakage:
leakage_lines = []
if len(leakage_dataset) > 4:
warnings.warn('Leakage plotting only implemented for <= 4 datasets.', UserWarning)
for data, sty in zip(leakage_dataset[leakage_idx].values(),
cycler(ls=['solid', 'dashed', 'dotted', 'dashdot'])):
leakage_lines.extend(data.plot(ax=axes['plots']['vertical']['leakage'],
y=vertical_target, yscale=yscale,
color=color_cycle[0], **sty))
if vertical_slider_dim:
blitted_artists.extend(leakage_lines)
_rescale_and_relabel(axes['plots']['vertical']['leakage'], leakage_dataset, 'x',
plot_leakage)
_color_axes(axes['plots']['vertical']['leakage'], leakage_lines[0])
axes['plots']['vertical']['leakage'].invert_xaxis()
if plot_power:
power_line, = power_data[power_idx].plot(
ax=axes['plots']['vertical']['power'], y=vertical_target, yscale=yscale,
color=color_cycle[0 if not plot_leakage else 1]
)
if vertical_slider_dim:
blitted_artists.append(power_line)
_rescale_and_relabel(axes['plots']['vertical']['power'], power_data[power_idx], 'x',
plot_leakage)
_color_axes(axes['plots']['vertical']['power'], power_line)
axes['plots']['vertical']['power'].invert_xaxis()
if plot_slice:
slice_lines = []
slice_lines.append(axes['plots']['main'].axhline(
array_data.coords[vertical_target][slice_idx[vertical_target]],
color='tab:grey', linestyle=':'
))
slice_lines.extend(array_data[slice_idx].plot(ax=axes['plots']['horizontal']['bottom'],
yscale=norm_to_scale(img.norm)))
blitted_artists.extend(slice_lines)
axes['plots']['horizontal']['bottom'].yaxis.set_major_formatter(img.colorbar.formatter)
axes['plots']['horizontal']['bottom'].set_ylim(img.norm.vmin, img.norm.vmax)
_update_ticks_and_labels(axes['plots'], array_data, horizontal_target, vertical_target)
if vertical_slider_dim or plot_slice:
manager = BlitManager(fig.canvas, blitted_artists)
sliders = {}
if vertical_slider_dim:
for i, (coord, ax) in enumerate(zip(slider_coords['vertical'],
axes['sliders']['vertical'].values())):
slider = _create_slider(array_data, ax, dataset, coord, 'vertical',
slider_scale.get(coord, 'linear'), valfmt[coord],
no=vertical_slider_dim - i - 1)
slider.on_changed(generate_array_callback(coord, slider))
sliders[coord] = slider
array_idx[coord] = _idx_from_val(array_data, coord, slider.valinit)
if plot_slice:
for coord, ax in zip(slider_coords['horizontal'],
axes['sliders']['horizontal'].values()):
slider = _create_slider(array_data, ax, dataset, coord, 'horizontal',
yscale or 'linear', valfmt[coord])
slider.on_changed(generate_slice_callback(coord, slider))
sliders[coord] = slider
slice_idx[coord] = _idx_from_val(array_data, coord, slider.valinit)
with misc.filter_warnings('ignore', UserWarning):
try:
fig.tight_layout()
except RuntimeError:
# Different layout engine
pass
return fig, axes, sliders
def _parse_targets(dataset: DataSetProtocol | Dataset,
array_target: str | None,
horizontal_target: str | None,
vertical_target: str | None) -> tuple[str, str, str]:
is_qcodes_dataset = isinstance(dataset, DataSetProtocol)
if array_target is None:
# Default to parameter dependent on most other parameters ;)
if is_qcodes_dataset:
array_target = sorted(dataset.paramspecs.values(),
key=lambda x: len(x.depends_on_))[-1].name
else:
array_target = sorted(dataset.data_vars.values(), key=lambda x: len(x.coords))[-1].name
else:
array_target = _find_param(array_target, dataset, dependent=True, qcodes=is_qcodes_dataset)
if horizontal_target is None and vertical_target is None:
# Default to last independent parameter (should be the innermost sweep / get)
if is_qcodes_dataset:
horizontal_target = itertools.first_true(
reversed(dataset.paramspecs[array_target].depends_on_),
pred=lambda x: x != array_target
)
else:
horizontal_target = itertools.first_true(
reversed(list(dataset.data_vars[array_target].coords)),
pred=lambda x: x != array_target
)
elif horizontal_target is None:
# Default to second-to-last independent parameter
if is_qcodes_dataset:
horizontal_target = itertools.first_true(
reversed(dataset.paramspecs[array_target].depends_on_),
pred=lambda x: x not in {array_target, vertical_target}
)
else:
horizontal_target = itertools.first_true(
reversed(list(dataset.data_vars[array_target].coords)),
pred=lambda x: x not in {array_target, vertical_target}
)
else:
horizontal_target = _find_param(horizontal_target, dataset, dependent=False,
qcodes=is_qcodes_dataset)
if vertical_target is None:
# Default to second-to-last independent parameter
if is_qcodes_dataset:
vertical_target = itertools.first_true(
dataset.paramspecs[array_target].depends_on_[::-1],
pred=lambda x: x not in {array_target, horizontal_target}
)
else:
vertical_target = itertools.first_true(
reversed(list(dataset.data_vars[array_target].coords)),
pred=lambda x: x not in {array_target, horizontal_target}
)
else:
vertical_target = _find_param(vertical_target, dataset, dependent=False,
qcodes=is_qcodes_dataset)
return array_target, horizontal_target, vertical_target
class _SIValfmt(str):
"""Hacks SI unit scaling into the slider valfmt arg."""
def __new__(cls, s, unit: str = '', prefix: str = '', scale: int = 0):
if not s.startswith('%'):
s = f'%{s}'
obj = super().__new__(cls, s)
obj.unit = unit
obj.prefix = prefix
obj.scale = scale
return obj
def __mod__(self, other):
if not self.unit:
return str.__mod__(self, other)
return str.__mod__(self, other * 10**(-self.scale)) + f' {self.prefix}{self.unit}'
def __format__(self, __format_spec):
# Try to future-proof if matplotlib decides to drop %-style formatting
return self % __format_spec
def _create_slider(array_data, ax_slider, dataset, coord, orientation, scale, valfmt, no: int = 0):
if isinstance(dataset, DataSetProtocol):
label = dataset.paramspecs[coord].label
unit = dataset.paramspecs[coord].unit
else:
label = dataset.coords[coord].label
unit = dataset.coords[coord].units
if orientation == 'horizontal':
ax_slider.set_xscale(scale)
else:
ax_slider.set_yscale(scale)
slider = Slider(
ax=ax_slider,
label=label,
valmin=array_data.coords[coord].min().item(),
valmax=array_data.coords[coord].max().item(),
valinit=array_data.coords[coord].min().item(),
valstep=array_data.coords[coord].to_numpy(),
valfmt=_SIValfmt(valfmt, unit, *find_scale_and_prefix(array_data.coords[coord], unit)),
orientation=orientation
)
_update_slider_label(slider, orientation, no)
return slider
def _gridspec_width_ratios(vertical_slider_dim, cbar_fraction):
if vertical_slider_dim:
ratios = [0.07 * vertical_slider_dim, 0.02]
else:
ratios = []
ratios.append(cbar_fraction)
ratios.insert(-1, 1 - sum(ratios))
return ratios
def _setup_axes(
fig: Figure,
vertical_slider_dim: int,
plot_leakage: bool,
plot_power: bool,
plot_slice: bool,
array_data: npt.NDArray,
horizontal_target: str,
pad_subplots: bool = False,
**kwargs
) -> dict[str, dict[str, Axes | dict[str | int, Axes]]]:
# Settings stolen from fig.colorbar()
cbar_fraction = kwargs.get('cbar_fraction', 0.03)
pad = kwargs.pop('pad', .05)
wh_space = 2 * pad / (1 - pad)
# Set up gridspecs. Parent gridspec:
# There is one extra column used as a spacer in case there are vertical sliders and one extra
# row in case there are horizontal ones.
# The main gridspec is subdivided into the main plot grid (of variable size), the cbar grid,
# and optional slider grids both horizontal and vertical. This is done so the spacing between
# plots can be independently adjusted from the spacing between other elements.
gs_kwargs = {'nrows': 1 + 3 * plot_slice,
'ncols': 2 + 2 * bool(vertical_slider_dim),
'figure': fig,
'hspace': wh_space,
'wspace': wh_space,
'width_ratios': _gridspec_width_ratios(vertical_slider_dim, cbar_fraction)}
gs_plots_kwargs = {'nrows': 1 + plot_slice,
'ncols': 1 + (plot_leakage or plot_power),
'hspace': wh_space / 2 if pad_subplots else 0.0,
'wspace': wh_space / 2 if pad_subplots else 0.0}
gs_cbar_kwargs = {'nrows': 1 + plot_slice,
'ncols': 1,
'hspace': wh_space / 2 if pad_subplots else 0.0}
gs_sliders_vertical_kwargs = {'nrows': 1 + plot_slice,
'ncols': 2 * vertical_slider_dim,
'hspace': 0.0,
'width_ratios': [1, 4] * vertical_slider_dim}
gs_sliders_horizontal_kwargs = {'nrows': 2,
'ncols': 1 + (plot_leakage or plot_power),
'wspace': 0.0}
if plot_leakage or plot_power:
gs_plots_kwargs['width_ratios'] = [0.15, 0.85]
gs_sliders_horizontal_kwargs['width_ratios'] = [0.15, 0.85]
if plot_slice:
gs_kwargs['height_ratios'] = [0.74, 0.13, 0.095, 0.035]
gs_kwargs['bottom'] = 0.05
gs_plots_kwargs['height_ratios'] = [0.85, 0.15]
gs_cbar_kwargs['height_ratios'] = [0.85, 0.15]
gs_sliders_vertical_kwargs['height_ratios'] = [0.85, 0.15]
if vertical_slider_dim:
gs_kwargs['left'] = 0.05
if vertical_slider_dim or plot_slice:
if (
horizontal_target == 'ccd_horizontal_axis'
and array_data.coords['ccd_horizontal_axis'].unit in ('nm', 'eV')
):
gs_kwargs['top'] = 0.85
else:
gs_kwargs['top'] = 0.92
gs = GridSpec(**gs_kwargs)
# Subdivide the plot gridspec to make room for the colorbar
gs_plots = gs[:1 + plot_slice, -2:-1].subgridspec(**gs_plots_kwargs)
gs_cbar = gs[:1 + plot_slice, -1].subgridspec(**gs_cbar_kwargs)
if vertical_slider_dim:
gs_sliders_vertical = gs[:1 + plot_slice, 0].subgridspec(**gs_sliders_vertical_kwargs)
if plot_slice:
gs_sliders_horizontal = gs[-1, -2:-1].subgridspec(**gs_sliders_horizontal_kwargs)
# Set up axes
axes = defaultdict(dict)
# Slider axes
if plot_slice:
axes['sliders']['horizontal'] = {0: fig.add_subplot(gs_sliders_horizontal[0, -1])}
if vertical_slider_dim:
axes['sliders']['vertical'] = {i: fig.add_subplot(gs_sliders_vertical[0, i])
for i in range(0, 2 * vertical_slider_dim, 2)}
# Main axes (plots)
axes['plots']['main'] = fig.add_subplot(gs_plots[0, -1])
axes['plots']['cbar'] = fig.add_subplot(gs_cbar[0, 0])
if plot_leakage or plot_power:
axes['plots']['vertical'] = {}
if plot_leakage and not plot_power:
axes['plots']['vertical'] = {
'leakage': fig.add_subplot(gs_plots[0, 0], sharey=axes['plots']['main'])
}
elif not plot_leakage and plot_power:
axes['plots']['vertical'] = {
'power': fig.add_subplot(gs_plots[0, 0], sharey=axes['plots']['main'])
}
else:
axes['plots']['vertical']['power'] = fig.add_subplot(gs_plots[0, 0],
sharey=axes['plots']['main'])
axes['plots']['vertical']['leakage'] = axes['plots']['vertical']['power'].twiny()
axes['plots']['horizontal'] = {}
if plot_slice:
axes['plots']['horizontal']['bottom'] = fig.add_subplot(gs_plots[1, -1],
sharex=axes['plots']['main'])
if horizontal_target == 'ccd_horizontal_axis':
if array_data.coords['ccd_horizontal_axis'].units == 'nm':
functions = (functools.scaled(1e+9)(_lambda2eV),
functools.scaled(1e-9)(_eV2lambda))
elif array_data.coords['ccd_horizontal_axis'].units == 'eV':
functions = (functools.scaled(1e+9)(_eV2lambda),
functools.scaled(1e-9)(_lambda2eV))
else:
return axes
axes['plots']['horizontal']['top'] = axes['plots']['main'].secondary_xaxis(
'top', functions=functions
)
return axes
def _parse_autoscale(autoscale, norm, **kwargs) -> dict[str, bool | dict[str, bool]]:
autoscale_default = dict.fromkeys(_AUTOSCALE_KEYS, True)
if not isinstance(autoscale, Mapping):
autoscale = dict.fromkeys(_AUTOSCALE_KEYS, bool(autoscale))
elif not _AUTOSCALE_KEYS.issuperset(autoscale):
raise ValueError('autoscale should be bool or mapping with possible keys '
f'{_AUTOSCALE_KEYS}')
autoscale = autoscale_default | autoscale
if autoscale['array']:
# Colorscale limits explicitly given, do not autoscale array
vminmax = dict.fromkeys({'vmin', 'vmax'}, True)
if kwargs.get('vmin', None) is not None or norm is not None and norm.vmin is not None:
vminmax['vmin'] = False
if kwargs.get('vmax', None) is not None or norm is not None and norm.vmax is not None:
vminmax['vmax'] = False
if not any(vminmax.values()):
autoscale['array'] = False
else:
autoscale['array'] = vminmax
return autoscale
def _idx_from_val(array, coord, val):
return np.where(array.coords[coord] == val)[0]
def _find_param(target, dataset, dependent=True, qcodes=True):
if qcodes:
dependent_parameters = {param.name for param in dataset.dependent_parameters}
independent_parameters = set(dataset.paramspecs.keys()).difference(dependent_parameters)
else:
# xarray
dependent_parameters = set(dataset.data_vars)
independent_parameters = set(dataset.coords)
empty_sentinel = object()
param_names = dependent_parameters if dependent else independent_parameters
param_name = itertools.first_true(param_names, default=empty_sentinel,
pred=lambda x: target in x)
if param_name is empty_sentinel:
raise ValueError(f'target {target} not found in parameters {param_names}')
return param_name
def _find_params_depending_on(target, depends_on, dataset) -> list[str]:
if isinstance(dataset, Dataset):
# xarray
params = [name for name, data_var in dataset.data_vars.items()
if depends_on in data_var.coords and target in name]
else:
# qcodes
params = [paramspec.name for param in dataset.dependent_parameters
if depends_on in (paramspec := dataset.paramspecs[param.name]).depends_on_
and target in paramspec.name]
if not len(params):
raise ValueError(f'Could not find parameter with name {target} depending on {depends_on}')
return params
def _to_homogeneous_dataarray(dataset, target):
if isinstance(dataset, Dataset):
# xarray
return dataset.data_vars[target]
# Try using qcodes to nudge data into 2d arrays even if they're missing values, in which
# case the dataset stores them as 1d
# TODO: Test if not simpler to do dataarray = dataset.to_xarray_dataset(target)[target]
if dataset.cache.live:
try:
data = dataset.cache.data()[target]
dataarray = dataset.to_xarray_dataarray_dict(target)[target]
if len(data) > 2 and data[target].ndim == 1:
# 2d data but with invalid values and no shape given at ds creation time
z_data = data.pop(target)
x_dim, y_dim = list(data)[:2]
x, y, z = reshape_2D_data(z=z_data, x=data[x_dim], y=data[y_dim])
dataarray = dataarray.from_dict({
'data': z,
'coords': {x_dim: _coords_from_dataset(dataset, x_dim, x),
y_dim: _coords_from_dataset(dataset, y_dim, y)},
'attrs': dataarray.attrs,
'name': target
})
except KeyError:
pass
else:
return dataarray
return dataset.to_xarray_dataset(target)[target]
def _to_xarray_dataset(dataset: DataSetProtocol | Dataset, *targets: str):
if isinstance(dataset, Dataset):
return dataset[list(targets)]
else:
return dataset.to_xarray_dataset(*targets)
def _convert_legacy_label_maybe(label):
if len(label) == 3:
# Yay, legacy!
new_label = (f"Trap {label[-1]} {'guard' if label[-2] == 'G' else 'central'} "
f"{'top' if label[0] == 'T' else 'bottom'}")
else:
new_label = label
return new_label
def _update_slider_label(slider, orientation, no: int):
if orientation == 'vertical':
slider.label.set_x(-0.02)
slider.label.set_y(0.5)
slider.label.set_verticalalignment('center')
slider.label.set_horizontalalignment('right')
slider.label.set_rotation('vertical')
if no % 2:
slider.valtext.set_y(1.02)
slider.valtext.set_verticalalignment('bottom')
else:
slider.label.set_x(0.5)
slider.label.set_y(-0.15)
slider.label.set_verticalalignment('top')
slider.label.set_horizontalalignment('center')
slider.label.set_rotation('horizontal')
slider.valtext.set_x(-0.02)
slider.valtext.set_horizontalalignment('right')
def _update_ticks_and_labels(axes, array_data, horizontal_target, vertical_target):
if 'vertical' in axes:
for typ, ax in axes['vertical'].items():
if len(axes['vertical']) == 1:
ax.grid(True)
ax.xaxis.tick_top()
ax.xaxis.label_position = 'top'
ax.xaxis.labelpad = 12.0
else:
ax.grid(False)
if typ == 'power' and len(axes['vertical']) == 2:
ax.set_xlabel(ax.get_xlabel(), loc='right')
else:
label = ax.get_xlabel()
ax.set_xlabel(label[:1].upper() + label[1:])
ax.set_title(None)
ax.set_ylabel(ax.get_ylabel().replace('\n', ' ').replace('[', '(').replace(']', ')'))
ax.xaxis.set_major_locator(ticker.MaxNLocator(nbins=1, prune='both'))
_rescale_and_relabel(ax, array_data.coords[vertical_target], 'y')
axes['main'].set_ylabel(None)
axes['main'].tick_params(axis='x', which='both', bottom=False)
axes['main'].tick_params(axis='y', which='both', labelleft=False, left=False)
else:
_rescale_and_relabel(axes['main'], array_data.coords[vertical_target], 'y')
if 'horizontal' in axes:
if 'bottom' in axes['horizontal']:
axes['horizontal']['bottom'].grid(True)
axes['horizontal']['bottom'].set_title(None)
axes['horizontal']['bottom'].set_ylabel('')
axes['horizontal']['bottom'].yaxis.tick_right()
axes['horizontal']['bottom'].yaxis.set_major_locator(ticker.MaxNLocator(nbins=2,
prune='upper'))
xlabel = axes['main'].xaxis.get_label()
xlabel.set_visible(False)
_rescale_and_relabel(axes['horizontal']['bottom'],
array_data.coords[horizontal_target], 'x')
axes['main'].tick_params(axis='x', which='both', labelbottom=False)
if 'top' in axes['horizontal']:
coord = array_data.coords[horizontal_target]
if coord.units == 'eV':
top_unit = 'nm'
elif coord.units == 'nm':
top_unit = 'eV'
axes['horizontal']['top'].set_xlabel(f"{coord.label} ({top_unit})")
else:
_rescale_and_relabel(axes['main'], array_data.coords[horizontal_target], 'x')
def _update_slice_line(line, data, ax, autoscale, kwargs):
line.set_ydata(data)
if autoscale:
ax.relim()
ax.autoscale(axis='y', tight=True)
if 'vmin' in kwargs:
ax.set_ylim(bottom=kwargs['vmin'])
if 'vmax' in kwargs:
ax.set_ylim(top=kwargs['vmax'])
_rescale_and_relabel(ax, data, 'y')
def _coords_from_dataset(dataset, dim: str, data: Any) -> dict:
return {'dims': dim,
'data': data,
'attrs': {'name': dataset.paramspecs[dim].name,
'units': dataset.paramspecs[dim].unit,
'label': dataset.paramspecs[dim].label,
'paramtype': dataset.paramspecs[dim].type,
'long_name': dataset.paramspecs[dim].label,
'depends_on': dataset.paramspecs[dim].depends_on_,
'inferred_from': dataset.paramspecs[dim].inferred_from_}}
def _color_axes(ax, line):
ax.xaxis.label.set_color(line.get_color())
ax.tick_params(axis='x', colors=line.get_color())
def _rescale_and_relabel(ax_or_cbar: Axes | Colorbar, data: Dataset | DataArray,
which: Literal['x', 'y', 'c'], plot_leakage: bool | None = None):
if isinstance(data, Dataset):
names, arrays = zip(*data.items())
labels = [array.attrs['label'] for array in arrays]
units = [array.attrs['unit'] for array in arrays]
if len(set(units)) != 1:
warnings.warn('Not all data have same unit.', category=RuntimeWarning, stacklevel=2)
return
else:
# smooth ...
arrays = [data]
names = [data.name]
labels = [data.label]
units = [data.unit]
unit = units[0]
prefix = reformat_axis(ax_or_cbar, np.concatenate(arrays), unit, which, only_SI=True)
if len(arrays) > 1:
if all('current' in name.lower() for name in names):
label = 'Leakage '
else:
label = f'{labels[-1]} '
else:
name = names[0]
label = labels[0]
if 'power' in name.lower() and len(label.split()) > 1 and plot_leakage and which == 'x':
# Long label, move unit to new line (but only if not on upper axis)
label = f'{label}\n'
else:
label = f'{label} '
if unit:
label = label + f'({prefix}{unit})'
else:
label = label.rstrip('\n ')
if which == 'x':
ax_or_cbar.set_xlabel(label)
elif which == 'y':
ax_or_cbar.set_ylabel(label)
else:
ax_or_cbar.set_label(label)
return prefix