"""This module defines some helper functions."""
from __future__ import annotations
import atexit
import datetime
import importlib
import logging
import os
import re
import time
import warnings
from collections.abc import Iterator, Sequence, Callable
from pathlib import Path
from typing import ParamSpec, TypeVar, Literal
import mjolnir
import numpy as np
import numpy.typing as npt
from qcodes.dataset import load_by_run_spec
from qcodes.instrument import InstrumentBase, ChannelTuple, InstrumentChannel
from qcodes.metadatable import MetadatableWithName
from qcodes.monitor import Monitor
from qcodes.parameters import ParameterBase, Parameter
from qcodes.parameters.command import Command
from qcodes.station import Station
from qcodes_contrib_drivers.drivers.QDevil import QDAC2
from qutil import io, itertools, functools
from qutil.ui import ThreadedWebserver
from scipy import special, stats
_P = ParamSpec('_P')
_T = TypeVar('_T')
[docs]
def setup_logging(levels: dict[str, int]):
"""Set up logging (only for calibration module currently).
Since qcodes controls the root logger, we need to make sure loggers
are independent of it.
"""
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
root_logger = logging.getLogger('mjolnir')
root_logger.setLevel(logging.WARNING)
root_logger.propagate = False
for name, level in levels.items():
logger = logging.getLogger(name)
logger.setLevel(level)
for handler in logger.handlers:
handler.close()
logger.removeHandler(handler)
handler = logging.StreamHandler()
handler.setLevel(level)
handler.setFormatter(formatter)
logger.addHandler(handler)
def _get_channel_wrapper_class(channel_type: str | None,) -> type[InstrumentChannel] | None:
"""Get channel class from string specified in yaml.
Copied from qcodes.instrument.delegate.delegate_instrument.py
"""
if channel_type is None:
return None
else:
channel_type_elems = str(channel_type).split(".")
module_name = ".".join(channel_type_elems[:-1])
instr_class_name = channel_type_elems[-1]
module = importlib.import_module(module_name)
return getattr(module, instr_class_name)
[docs]
def save_to_hdf5(run_id: int, savepath: os.PathLike, *params_to_skip,
method: Literal['dataset', 'dataarray'] = 'dataset', compress: bool = False):
"""Extract *run_id* from the database and save to hdf5.
Parameters
----------
run_id :
Captured run id.
savepath :
File where the dataset should be saved.
*params_to_skip :
Names of dataset parameters to skip exporting.
method :
Export to an xarray ``Dataset`` or ``DataArray``.
compress :
Use level 5 gzip compression.
"""
ds = load_by_run_spec(captured_run_id=run_id)
path = io.query_overwrite(savepath).with_suffix('.h5')
if not path.parent.exists() and io.query_yes_no('Directory does not exist. Create?'):
path.parent.mkdir(parents=True, exist_ok=True)
comp = dict(zlib=True, complevel=5, shuffle=True, fletcher32=True)
match method:
case 'dataset':
xds = ds.to_xarray_dataset(
*set(spec.name for spec in ds.dependent_parameters).difference(params_to_skip)
)
xds.to_netcdf(path, engine='h5netcdf',
encoding={var: comp for var in xds.data_vars} if compress else None)
print(f"Saved to {path}.")
case 'dataarray':
raise NotImplementedError
case _:
raise ValueError
[docs]
def find_param_source(param: ParameterBase) -> ParameterBase:
"""Return the root source of a parameter."""
if hasattr(param, 'source'):
return find_param_source(param.source)
return param
[docs]
def find_parameters(instrument: MetadatableWithName, *params_to_exclude: Parameter,
instruments_to_exclude: Sequence[InstrumentBase] = (),
underlying_instruments_to_exclude: Sequence[InstrumentBase] = (),
param_classes_to_exclude: tuple[type[Parameter], ...] | type[Parameter] = (),
only_gettable: bool = False,
max_nesting_level: int = 0) -> Iterator[Parameter]:
"""Recursively find all parameters on *instrument* except
*params_to_exclude*, those of class *param_classes_to_exclude*, or
those whose physical instrument is in
*underlying_instruments_to_exclude*, up to a nesting level of
*max_nesting_level*."""
def valid(param: ParameterBase) -> bool:
return (
isinstance(param, Parameter)
and param not in params_to_exclude
and param.instrument not in instruments_to_exclude
and param.underlying_instrument not in underlying_instruments_to_exclude
and not isinstance(find_param_source(param), param_classes_to_exclude)
and (not only_gettable or param.gettable)
)
def recurse(inst_or_chan: MetadatableWithName, recursion_level: int) -> Iterator[Parameter]:
if recursion_level == max_nesting_level:
return
if isinstance(inst_or_chan, ChannelTuple):
for channel in inst_or_chan:
yield from recurse(channel, recursion_level + 1)
elif isinstance(inst_or_chan, InstrumentBase):
yield from (param for param in inst_or_chan.parameters.values() if valid(param))
for submodule in inst_or_chan.submodules.values():
yield from recurse(submodule, recursion_level + 1)
yield from recurse(instrument, -1)
[docs]
def find_parameters_with_nesting_level_below(instrument: InstrumentBase,
nesting_level: int) -> Iterator[Parameter]:
"""Find all parameters in instrument and submodules at most
nesting_level deep."""
params_with_nesting_level_above = find_parameters(instrument, max_nesting_level=nesting_level)
params_with_nesting_level_below = find_parameters(instrument, *params_with_nesting_level_above,
max_nesting_level=1_000)
return params_with_nesting_level_below
def _exclude_all_snapshots(instrument: InstrumentBase, *params_to_exclude: ParameterBase,
param_classes_to_exclude: tuple[type[Parameter], ...] = (),
max_nesting_level: int = 0):
"""Disable snapshotting on all child parameters of instr_or_channel.
Use at your own risk.
"""
for parameter in find_parameters(instrument, *params_to_exclude,
param_classes_to_exclude=param_classes_to_exclude,
max_nesting_level=max_nesting_level):
parameter.snapshot_exclude = True
[docs]
def load_station(filename: str | Sequence[str], use_monitor: bool = False,
update_snapshot: bool = False) -> Station:
"""Load the station specified in *filename*, which should be a file
living in mjolnir/config."""
if isinstance(filename, str):
filenames = [filename]
else:
filenames = filename
config_files = [
str((Path(mjolnir.__file__).parent / 'config' / filename).with_suffix('.yaml'))
for filename in filenames
]
return Station(config_file=config_files,
use_monitor=use_monitor,
update_snapshot=update_snapshot)
[docs]
def start_monitor(station: Station, *params_to_exclude: ParameterBase,
instruments_to_exclude: Sequence[InstrumentBase] = (),
underlying_instruments_to_exclude: Sequence[InstrumentBase] = (),
param_classes_to_exclude: tuple[type[Parameter], ...] | type[Parameter] = (),
max_nesting_level: int = 0,
update: bool = False, url: str = 'localhost', port: int = 3000,
show: bool = False,
**monitor_kwargs) -> tuple[ThreadedWebserver, Monitor]:
"""Start a monitor on parameters of instruments on station.
Parameters
----------
station :
The qcodes station hosting the instruments.
*params_to_exclude :
Parameters to exclude from the monitor.
underlying_instruments_to_exclude :
Exclude parameters bound to these underlying instruments.
instruments_to_exclude :
Exclude parameters bound to these instruments.
param_classes_to_exclude :
Exclude parameters of these types.
max_nesting_level :
Include parameters of up to this many children.
update :
Update parameter values.
url :
The url of the server.
port :
The port under which to reach the server.
show :
Open a new webbrowser tab with the monitor.
**monitor_kwargs :
Keyword arguments passed to the
:class:`~qcodes:qcodes.monitor.Monitor` constructor.
Returns
-------
server :
The web server running the monitor.
monitor :
The :class:`~qcodes:qcodes.monitor.Monitor` instance.
"""
# The set might not be necessary but include it to be sure no duplicates are added.
parameters = set(itertools.chain(*(
find_parameters(
component, *params_to_exclude,
instruments_to_exclude=instruments_to_exclude,
underlying_instruments_to_exclude=underlying_instruments_to_exclude,
param_classes_to_exclude=param_classes_to_exclude,
only_gettable=True,
max_nesting_level=max_nesting_level
)
for component in station.components.values()
)))
gettable_parameters = []
if update:
for param in parameters:
try:
param.get()
except Exception:
continue
gettable_parameters.append(param)
failed_to_get = parameters.difference(gettable_parameters)
if len(failed_to_get):
warnings.warn('Failed to get the following parameters, so I excluded them from the '
"monitor:\n - {}".format('\n - '.join(map(str, failed_to_get))),
RuntimeWarning)
else:
gettable_parameters.extend(parameters)
monitor_kwargs.setdefault('use_root_instrument', False)
server = ThreadedWebserver(qcodes_monitor_mode=True, url=url, port=port)
monitor = Monitor(*gettable_parameters, **monitor_kwargs)
atexit.register(server.stop)
atexit.register(monitor.stop)
if show:
server.show()
return server, monitor
[docs]
def calculate_rise_time(RC: float, thresh=0.9) -> float:
"""Computes the rise time of a single-stage RC filter."""
return -RC * np.log(1 - thresh)
[docs]
def calculate_ramp_time_with_filter(slew_rate: float, voltage_swing: float, RC: float,
thresh: float) -> float:
r"""Calculate the time for a linear ramp with an analogue RC filter
to settle at *thresh*.
The settling time can be obtained by inverse Laplace-transform of
the transfer function
.. math::
C(s) = \frac{1}{s\tau + 1}\times\frac{k}{s^2}
where :math:`\tau=RC` and :math:`k` is the slope of the ramp.
Parameters
----------
slew_rate :
Slope of the ramp.
voltage_swing :
Range of the ramp.
RC :
Time constant of the filter.
thresh :
Settling threshold that should be achieved.
"""
ramp_time = thresh * abs(voltage_swing) / slew_rate
try:
settle_time = RC * (1 + special.lambertw(-np.exp(-(RC * thresh) / RC - 1))).real
except ZeroDivisionError:
settle_time = 0.0
return ramp_time + settle_time
[docs]
def update_qdac2_line_label(channel, gate_type, side):
"""Give QDAC2 channels more meaningful labels."""
def set_label_maybe(obj):
if hasattr(obj, 'label') and obj.label == f'ch{channel.number}':
obj.label = f"{gate_type} {side}"
def recurse(obj):
try:
for param in obj.parameters.values():
recurse(param)
except AttributeError:
set_label_maybe(obj)
else:
set_label_maybe(obj)
recurse(channel)
[docs]
def make_dc_constant_V_blocking(channel: QDAC2.QDac2Channel, thresh: float = 0.99):
"""Modify the dc_constant_V parameter of a QDAC-II channel to block
until the voltage has settled at *thresh* of its designated value.
"""
if not isinstance(channel, QDAC2.QDac2Channel):
raise TypeError(f'Expected QDac2Channel, not {type(channel)}.')
def set_fixed_voltage_immediately_then_block(v):
ramptime = calculate_ramp_time_with_filter(channel.dc_slew_rate_V_per_s.get_latest(),
abs(v - channel.dc_constant_V.get_latest()),
max_filter_RC, thresh)
tic = time.time()
channel._set_fixed_voltage_immediately(v)
if (elapsed := time.time() - tic) < ramptime:
time.sleep(ramptime - elapsed)
# Let's assume for simplicity that all filters are 1st order RC
# TODO: de-hardcode
try:
bob_filter_RC = channel.bob_filter_RC()
except AttributeError:
bob_filter_RC = 0
cold_filter_RC = 1e3 * 10e-9
max_filter_RC = max(bob_filter_RC, cold_filter_RC)
param = channel.dc_constant_V
param.set_raw = Command(arg_count=1, cmd=set_fixed_voltage_immediately_then_block)
param.set = param._wrap_set(param.set_raw)
[docs]
def make_dc_constant_V_nonblocking(channel: QDAC2.QDac2Channel):
"""Revert :func:`make_dc_constant_V_blocking`."""
param = channel.dc_constant_V
param.set_raw = Command(arg_count=1, cmd=channel._set_fixed_voltage_immediately)
param.set = param._wrap_set(param.set_raw)
[docs]
def rise_time_dependent_post_delay(self, _: float, val: float,
channel: QDAC2.QDac2Channel, thresh: float) -> None:
"""Side effect for bob_filter_RC parameter of QDAC-II channels.
Waits at least as long as a step impulse takes to settle at
*thresh*.
Use together with :func:`py:functools.partial`.
TODO:
Unused as of now.
"""
if hasattr(self, '_dc_constant_V_static_post_delay'):
channel.dc_constant_V.post_delay = self._dc_constant_V_static_post_delay
self._dc_constant_V_static_post_delay = channel.dc_constant_V.post_delay
channel.dc_constant_V.post_delay = max(calculate_rise_time(val, thresh),
self._dc_constant_V_static_post_delay)
[docs]
def timed(func: Callable[_P, _T]) -> Callable[_P, _T]:
"""Adds timing to func and reports it to stdout."""
@functools.wraps(func)
def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T:
tic = time.perf_counter()
result = func(*args, **kwargs)
toc = time.perf_counter()
print(f'{func.__qualname__} took {datetime.timedelta(seconds=round(toc - tic, 3))} h.')
return result
return wrapped
[docs]
def mean_and_standard_error(data: npt.ArrayLike, axis: int = 0,
alpha: float = 0.05) -> tuple[float, float]:
"""Compute the mean and standard error on the mean.
References
----------
https://en.wikipedia.org/wiki/Student%27s_t-distribution#Confidence_intervals
"""
data = np.asanyarray(data)
n = data.shape[axis]
mean = data.mean(axis)
standard_error = data.std(axis, ddof=1) / n ** 0.5 * stats.t.ppf(1 - alpha, df=n - 1)
return mean, standard_error
[docs]
def camel_to_snake(name: str) -> str:
"""Convert a CamelCase name to snake_case."""
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()