"""This module defines classes that handle calibration of parameters."""
from __future__ import annotations
import abc
import configparser
import copy
import dataclasses
import logging
import os
import sys
import tempfile
import time
from collections import defaultdict
from collections.abc import Mapping
from contextlib import contextmanager
from datetime import datetime
from numbers import Real
from pathlib import Path
from typing import Any, Callable, Sequence, Literal, cast, TypeVar
from unittest import mock
import noisyopt
import numpy as np
from adaptive.learner import BaseLearner
from adaptive.learner.learner1D import curvature_loss_function
from adaptive.learner.average_learner1D import AverageLearner1D
from adaptive.runner import _TimeGoal, BaseRunner, BlockingRunner
from adaptive.utils import SequentialExecutor
from dataclassabc import dataclassabc
from matplotlib import pyplot as plt
from numpy import typing as npt
from numpy.polynomial import Polynomial
from qcodes.parameters import Parameter
from qcodes.station import Station
from qcodes_contrib_drivers.drivers.Attocube.AMC100 import MultiAxisPosition
from qutil import const, functools, itertools, io
from scipy import interpolate, optimize
try:
from qutil.concurrent import BatchExecutor, SortedArgsBatchCaller
except ImportError:
BatchExecutor = None
SortedArgsBatchCaller = None
try:
sys.path.append(str(Path(os.environ['TIMETAGGER_INSTALL_PATH'], 'driver', 'python')))
import TimeTagger as tt
except (ImportError, KeyError):
tt = None
from . import _TIMESTAMP_FORMAT
from .helpers import mean_and_standard_error, timed
_T = TypeVar('_T')
Point = tuple[int, Real]
Points = list[Point]
[docs]
class AverageSequenceLearner1D(AverageLearner1D):
"""Learns the values of noisy functions at given points."""
def __init__(self, function: Callable[[tuple[int, Real]], Real], sequence: Sequence[Any],
min_samples: int = 50, max_samples: int = sys.maxsize,
min_error: float = 0):
super().__init__(function, itertools.minmax(sequence), min_samples=min_samples,
max_samples=max_samples, min_error=min_error)
self._to_do_points = set(sequence)
self._ntotal = len(sequence)
self.sequence = list(sequence)
[docs]
def new(self):
return AverageSequenceLearner1D(self.function, self.sequence,
min_samples=self.min_samples,
max_samples=self.max_samples,
min_error=self.min_error)
[docs]
def tell_pending(self, seed_x: Point) -> None:
super().tell_pending(seed_x)
self._to_do_points.discard(seed_x[1])
[docs]
def tell(self, seed_x: Point, y: Real) -> None:
super().tell(seed_x, y)
self._to_do_points.discard(seed_x[1])
[docs]
def remove_unfinished(self) -> None:
for i in self.pending_points:
self._to_do_points.add(i)
super().remove_unfinished()
[docs]
def done(self) -> bool:
return not self._to_do_points and not self.pending_points and not self._undersampled_points
def _ask_points_without_adding(self, n: int) -> tuple[list[float], list[float]]:
if n != 1:
raise ValueError("This Learner is designed to only ask for 1 point here.")
return sorted(self._to_do_points, key=self.sequence.index)[0:1], [np.inf]
[docs]
@dataclasses.dataclass
class PowerCalibration:
"""Handles calibration of a variable neutral density filter.
Parameters
----------
wavelength :
The laser wavelength
timestamp :
The timestamp of the calibration.
angle :
The :class:`numpy:numpy.ndarray` of angle samples.
power :
The :class:`numpy:numpy.ndarray` of power samples.
error :
The :class:`numpy:numpy.ndarray` of errors of each sample.
log :
Log-spaced samples.
"""
wavelength: float
timestamp: float = dataclasses.field(repr=False)
angle: npt.NDArray[float] = dataclasses.field(repr=False)
power: npt.NDArray[float] = dataclasses.field(repr=False)
error: npt.NDArray[float] = dataclasses.field(repr=False)
log: bool = False
def __post_init__(self):
self._log = logging.getLogger(f'{__name__}.{self.__class__.__qualname__}')
def __len__(self) -> int:
return len(self.angle)
def __call__(self, angle: npt.ArrayLike) -> npt.NDArray[float]:
return self.angle_to_power(angle)
@property
def age(self) -> float:
"""The calibration age in seconds."""
return time.time() - self.timestamp
@property
def bounds(self) -> tuple[float, float]:
"""The bounds of the calibration range."""
return itertools.minmax(self.angle)
@functools.cached_property
def angle_to_power(self):
# Goes through every data point, no smoothing
# spline = interpolate.InterpolatedUnivariateSpline(self.angle, self.power, 1 / self.error,
# ext='raise', k=2)
# Smoothes. Reduce s to decrease smoothing
spline = interpolate.UnivariateSpline(self.angle, self.power, 1 / self.error,
ext='raise', k=2, s=None)
# Smoothes, but bad fit for large range of angles
# spline = interpolate.LSQUnivariateSpline(
# self.angle, self.power,
# t=np.linspace(self.angle.min(), self.angle.max(), len(self.angle) // 5)[1:-1],
# w=1/self.error, k=3, ext='raise'
# )
spline.bounds = spline._data[3:5]
return spline
@functools.cached_property
def angle_to_power_deriv(self):
"""The calibration function's derivative."""
spline = self.angle_to_power.derivative(n=1)
spline.bounds = self.angle_to_power.bounds
return spline
@functools.cached_property
def power_to_angle(self):
"""The inverted calibration function."""
return np.vectorize(_invert_spline(self.angle_to_power))
[docs]
@timed
def rescaled(self, wavelength: float, angle_callback: Parameter,
power_callback: Parameter, target_power: float | None = None,
target_angle: float | None = None, n_samples: int = 50):
"""The calibration function rescaled by a global factor.
Parameters
----------
wavelength :
The wavelength at which to rescale.
angle_callback :
The callback to set and get the angle.
power_callback :
The callback to get the power.
target_power :
The power to reach at a given angle. Takes precedence over
target_angle.
target_angle :
The angle at which to rescale for the given wavelength.
n_samples :
The number of power samples to average over.
"""
def objective_fun(x):
if angle_callback.cache.get() != x[0]:
angle_callback(x[0])
return (power_callback() - target_power) ** 2
msg = f'Calibrating power at {wavelength} nm and {{:.2g}}{{}} in fast mode.'
msg_result = 'Power at angle {:.2f}º is ({:.2g} ± {:.2g}) W.'
alpha = 0.05
x0 = angle_callback()
if target_power is not None:
self._log.info(msg.format(target_power, ' W'))
result = noisyopt.minimizeCompass(objective_fun, x0=[x0],
bounds=np.array([self.bounds]),
errorcontrol=True,
paired=False,
scaling=[0.25],
deltatol=0.04,
feps=500e-12**2,
alpha=alpha,
disp=False)
if not result.success:
angle_callback(x0)
raise ValueError('Could not rescale power(angle) calibration. Optimization result:'
f'\n{result}')
angle_callback(result.x[0])
p = result.fun ** .5 + target_power
pse = .5 / result.fun ** .5 * result.funse
scale = target_power / self.angle_to_power(result.x[0])
self._log.info(msg_result.format(result.x[0], p, pse))
else:
if target_angle is None:
target_angle = angle_callback.cache.get()
self._log.info(msg.format(target_angle, 'º'))
angle_callback(target_angle)
power = np.array([power_callback() for _ in range(n_samples)])
p, pse = mean_and_standard_error(power, alpha=alpha)
scale = p / self.angle_to_power(target_angle)
self._log.info(msg_result.format(target_angle, p, pse))
return self.__class__(wavelength,
# This is not a "real" calibration, so keep the old timestamp.
self.timestamp,
self.angle, self.power * scale, self.error * scale)
[docs]
def plot(self):
"""Plot the calibration function."""
if (num := f'Power calibration {self.wavelength} nm') in plt.get_figlabels():
return (fig := plt.figure(num)), fig.axes
fig, ax = plt.subplots(layout='constrained', num=num)
x = np.linspace(self.angle.min(), self.angle.max(), 1001)
ax.errorbar(self.angle, self.power, self.error, fmt='r.')
ax.plot(x, self.angle_to_power(x))
ax.grid(True)
ax.set_yscale('log')
ax.set_xlabel('Angle (\u00b0)')
ax.set_ylabel('Power (W)')
ax.set_title(rf'$\lambda = ${self.wavelength} nm')
ax2 = ax.secondary_yaxis('right', functions=(lambda x: x / 13, lambda x: x * 13))
ax2.set_ylabel('Power at sample (W)')
return fig, ax
@dataclasses.dataclass
class _Handler:
"""Base class for handlers."""
logical_station: Station = dataclasses.field(repr=False, metadata={'json_exclude': True})
def __post_init__(self):
self._log = logging.getLogger(f'{__name__}.{self.__class__.__qualname__}')
def _JSONEncoder(self) -> dict[str, Any]: # noqa
return _asdict(self)
@property
def detection_path(self):
return self.logical_station.detection_path
@property
def excitation_path(self):
return self.logical_station.excitation_path
[docs]
@dataclasses.dataclass
class CalibrationHandler(abc.ABC, _Handler):
"""Base class for calibration handlers."""
def __post_init__(self):
super().__post_init__()
self.folder = cast(Path, Path(self.folder)) # noqa
@property
@abc.abstractmethod
def folder(self) -> Path:
"""The folder in which to store serialized calibration data."""
...
@property
@abc.abstractmethod
def calibration(self) -> Any:
"""The calibration function."""
...
[docs]
@abc.abstractmethod
def load_calibration_from_file(self, *args, **kwargs) -> Any:
"""Loads serialized calibration data."""
...
[docs]
@abc.abstractmethod
def calibrate(self, *args, **kwargs) -> Any:
"""Performs a calibration."""
...
@staticmethod
def _datetime_from_filename(file: Path) -> datetime:
return datetime.strptime(file.stem[-19:], _TIMESTAMP_FORMAT)
@staticmethod
def _extract_wavelength(file, precision):
return round(float(file.stem.split('_')[2][:-2].replace(',', '.')), precision)
[docs]
@dataclassabc
class PowerCalibrationHandler(CalibrationHandler):
"""Calibration handler for the ND filter.
Parameters
----------
sampling :
Use a naive sequential or adaptive sampling. Defaults to the
value of
:attr:`~.instruments.logical_instruments.ExcitationPath.power_calibration_strategy`.
loss :
The maximum sampling loss. If None, it is not used for the
goal.
dtheta :
The maximum spacing between adjacent sample points. Ignored if
npoints is given or sampling is 'adaptive'.
duration :
The maximum duration.
min_samples :
The minimum number of samples taken for each average. If
None, it is not used for the goal.
nsamples :
The minimum number of total samples. If None, it is not used
for the goal.
npoints :
The minimum number of points to be sampled. If None, it is
not used for the goal.
ntasks :
The number of points the executor visits per batch.
wait :
Sleep for this amount of time before taking a power
measurement.
folder :
The folder to save learners in.
learner_kwargs :
Additional kwargs for the adaptive learner.
"""
sampling: Literal['sequential', 'adaptive'] | None = None
loss: float | None = 0.1
dtheta: float | None = 0.5
duration: float | None = 300.
min_samples: int | None = 25
nsamples: int | None = None
npoints: int | None = None
ntasks: int = 1
wait: float = 0.0
folder: str | Path = 'C:/Data/Triton/calibration/nd_filter'
learner_kwargs: dict = dataclasses.field(default_factory=dict)
def __post_init__(self):
super().__post_init__()
# Declare private attributes. These cannot be dataclass fields because they might not be
# picklable
self._log: logging.Logger
self._runner: BaseRunner
self._learner: BaseLearner
self._calibration: PowerCalibration
# Load last used power calibration so that self._learner is set. Otherwise
# self.calibration errors
self.load_calibration_from_file()
if self.sampling is None:
self.sampling = self.excitation_path.power_calibration_strategy()
def _setup_adaptive(self, bounds: tuple[float, float]):
if self.sampling == 'adaptive':
if BatchExecutor is None:
raise ValueError('qutil.concurrent is not available, so the adaptive sampling '
"strategy isn't either")
self._learner = AverageLearner1D(
self._objective_fun,
bounds=bounds,
loss_per_interval=curvature_loss_function( # type: ignore
**({'area_factor': 5, 'euclid_factor': 0.02, 'horizontal_factor': 0.03}
| self.learner_kwargs)
),
min_samples=self.min_samples
)
self._executor = BatchExecutor({self._objective_fun: SortedArgsBatchCaller()})
else:
if self.npoints is None:
sequence = np.linspace(*bounds[::-1],
int(abs(np.diff(bounds).item()) // self.dtheta) + 1)
else:
sequence = np.linspace(*bounds[::-1], self.npoints)
self._learner = AverageSequenceLearner1D(self._objective_fun, sequence)
self._executor = SequentialExecutor()
def _objective_fun(self, seed_x):
seed, x = seed_x
if seed == 0:
self._log.debug(f'Setting position {x:.4g}')
self.excitation_path.nd_filter.position(x)
time.sleep(self.wait)
return self.excitation_path.powermeter.power()
def _goal(self, learner):
if self._time_goal(learner):
return True
if hasattr(learner, 'done'):
return learner.done()
achieved = True
if self.min_samples is not None and hasattr(learner, 'min_samples_per_point'):
achieved &= learner.min_samples_per_point >= self.min_samples
if self.nsamples is not None:
achieved &= learner.nsamples >= self.nsamples
if self.loss is not None:
achieved &= learner.loss() <= self.loss
if self.npoints is not None:
achieved &= learner.npoints >= self.npoints
return achieved
@property
def _time_goal(self) -> Callable[[float], bool]:
if self.duration is not None:
return _TimeGoal(self.duration)
def _time_goal(_):
return False
return _time_goal
def _power_calibration_from_learner(self) -> PowerCalibration:
angle, power = zip(*sorted(self._learner.data.items()))
error = [self._learner.error[angle] for angle in angle]
return PowerCalibration(self._learner.wavelength,
self._learner.timestamp,
np.array(angle),
np.array(power),
np.array(error))
def _find_calibration_file(self, wavelength: float):
wavelength = round(wavelength, 3)
wavelengths_on_file = []
files = sorted(self.folder.glob('power_calibration*.p'),
key=self._datetime_from_filename,
reverse=True)
for file in files:
if (extracted_wavelength := self._extract_wavelength(file, 3)) == wavelength:
return file, extracted_wavelength
wavelengths_on_file.append(extracted_wavelength)
idx = itertools.argmin((abs(wavelength - wavelength_on_file)
for wavelength_on_file in wavelengths_on_file))
return files[idx], wavelengths_on_file[idx]
@timed
def calibrate(self, wavelength: float | None = None,
file: str | os.PathLike | Path | None = None,
bounds: tuple[float, float] | None = None,
**kwargs):
"""Calibrate the power as function of rotation angle of the ND
filter using adaptive or sequential sampling.
Parameters
----------
wavelength :
The target wavelength for which to calibrate power.
file :
A file to save the results to. If False, the calibration
is not saved.
bounds :
Angular bounds for the calibration.
**kwargs
Additional settings that will overwrite the settings stored
in this dataclass instance as fields.
"""
if kwargs:
for name, value in kwargs.items():
setattr(self, name, value)
# Update metadata of the instrument
self.excitation_path.power_calibration_handler = self
self._setup_adaptive(bounds or self.bounds)
if wavelength is not None:
self.excitation_path.acquire_laser_lock(wavelength)
else:
assert self.excitation_path.laser.lock()
with (
self.excitation_path.nd_filter.position.set_to((bounds or self.bounds)[0],
allow_changes=True),
self.excitation_path.open_shutter()
):
self._log.info(f'Calibrating power at {round(self.excitation_path.wavelength(), 3)} '
'nm.')
self.excitation_path.wait_for_power_to_settle()
self._runner = BlockingRunner(self._learner, self._goal, executor=self._executor,
ntasks=self.ntasks)
now = datetime.now()
self._learner.timestamp = now.timestamp()
self._learner.wavelength = round(self.excitation_path.wavelength(), 3)
self._calibration = self._power_calibration_from_learner()
if file is not False:
self._learner.save(file or self.folder / "power_calibration_{}nm_{}.p".format(
str(self._learner.wavelength).replace('.', ','),
now.strftime(_TIMESTAMP_FORMAT)
))
def update_calibration(self, target_power: float | None = None,
target_angle: float | None = None, n_samples: int | None = None,
bounds: tuple[float, float] | None = None, attempts: int = 3):
"""Update the instrument's power calibration for a given
wavelength.
The mode depends on the `power_calibration_update_mode`
parameter.
Parameters
----------
target_power :
The power level to target, if any. Else, the calibration is
performed at target_angle.
target_angle :
The angle at which to calibrate if the
power_calibration_update_mode parameter is 'fast'. Ignored
if target_power is given. Defaults to the current angle.
n_samples :
The number of power samples to average over if target_angle
is given. Ignored otherwise as sampling is adaptive in that
case.
bounds :
Bounds for the calibration. Only if update mode is 'full'.
attempts :
The number of recalibration attempts.
"""
def _calibrate():
if self.excitation_path.power_calibration_update_mode() == 'fast':
self._calibration = self._calibration.rescaled(
wavelength, self.excitation_path.nd_filter.position,
self.excitation_path.power, target_power, target_angle,
n_samples or self.excitation_path.power_n_avg()
)
else:
self.calibrate(wavelength)
bounds = bounds or self.bounds
wavelength = round(self.excitation_path.wavelength(), 3)
attempt = 0
with (
self.excitation_path.power_calibration_update_mode.restore_at_exit(),
self.excitation_path.open_shutter()
):
if target_power is None:
if target_angle is None:
# Initial calibration without a target. No feedback necessary.
return _calibrate()
else:
self.excitation_path.nd_filter.position(target_angle)
elif (
target_power is not None
and target_power < self._calibration.power.min()
and io.query_yes_no('Power outside of calibration range. '
'Perform a full calibration with new bounds?')
):
bounds = input(f'Specify calibration bounds. Previous: {self.bounds}\n')
bounds = tuple(float(b) for b in bounds.strip('[()]').split(','))
self.calibrate(wavelength, bounds=bounds)
while attempt <= attempts:
if self.recalibration_required(target_power):
if target_power is not None:
# Set angle according to current calibration
try:
self.excitation_path.nd_filter.position(
self._calibration.power_to_angle(target_power)
)
except ValueError:
# Inverting calibration failed
self.excitation_path.power_calibration_update_mode('full')
_calibrate()
attempt += 1
else:
break
else:
raise RuntimeError(f'Could not calibrate power after {attempts} attempts.')
def recalibration_required(self, target_power: float | None = None) -> bool:
if target_power is None:
target_power = self._calibration.angle_to_power(
self.excitation_path.nd_filter.position()
)
with self.excitation_path.open_shutter():
power = self.excitation_path.wait_for_power_to_settle().mean()
if abs(1 - power / target_power) > self.excitation_path.power_calibration_tolerance():
return True
else:
return False
@property
def bounds(self) -> tuple[float, float]:
return self._calibration.bounds
@property
def wavelength(self) -> float:
return self._calibration.wavelength
@property
def calibration(self) -> PowerCalibration:
current_wavelength = round(self.excitation_path.wavelength(), 3)
previous_bounds = self.bounds
if self.wavelength != current_wavelength:
self.load_calibration_from_file(current_wavelength)
if (
_calculate_overlap(previous_bounds, self._calibration.bounds) < 0.5
and io.query_yes_no('Overlap in calibration bounds less than 50%. '
'Perform a full calibration with previous bounds?\n'
f'Previous: {previous_bounds}.\n'
f'Current: {self._calibration.bounds}.')
):
with self.excitation_path.power_calibration_update_mode.set_to('full'):
self.update_calibration(bounds=previous_bounds)
if self.wavelength != current_wavelength:
self.load_calibration_from_file(current_wavelength)
if (
self.wavelength != current_wavelength
or self._calibration.age > self.excitation_path.power_calibration_max_age()
):
self.update_calibration()
return self._calibration
def load_calibration_from_file(self, wavelength: float = 795, file=None):
if file is None:
file, wavelength = self._find_calibration_file(wavelength)
self._learner = AverageLearner1D(lambda _: 0, (-np.inf, np.inf))
self._learner.load(str(file))
self._learner.wavelength = wavelength
self._learner.timestamp = self._datetime_from_filename(file).timestamp()
self._learner.bounds = (min(self._learner.data.keys()),
max(self._learner.data.keys()))
self._calibration = self._power_calibration_from_learner()
[docs]
@dataclassabc
class CcdCalibrationHandler(CalibrationHandler):
"""Calibration handler for the CCD.
Parameters
----------
exposure_time :
The duration of exposure for CCD data.
number_accumulations:
The number of accumulations to perform for each CCD data.
n_pts :
The number of calibration points (wavelenghts).
folder :
The folder where to store calibration data.
"""
exposure_time: float = 0.1
number_accumulations: int = 2
n_pts: int = 9
folder: str | Path = 'C:/Data/Triton/calibration/ccd'
_calibration_data: defaultdict[int, dict[int, dict[str, Any]]] = dataclasses.field(
default_factory=lambda: defaultdict(lambda: defaultdict()), init=False, repr=False
)
@contextmanager
def _measurement_context(self, grating: int, central_wavelength: float):
assert self.excitation_path.laser.lock()
with (
self.detection_path.active_detection_path.set_to('ccd'),
self.detection_path.active_grating.set_to(grating),
self.detection_path.central_wavelength.set_to(central_wavelength),
self.detection_path.ccd.exposure_time.set_to(self.exposure_time),
self.detection_path.ccd.accumulation_cycle_time.set_to(0),
self.detection_path.ccd.number_accumulations.set_to(self.number_accumulations),
self.detection_path.ccd.acquisition_mode.set_to('accumulate'),
self.detection_path.ccd.read_mode.set_to('full vertical binning'),
self.detection_path.ccd.cosmic_ray_filter_mode.set_to(True),
self.excitation_path.rejection_feedback.set_to(False),
self.excitation_path.wavelength.restore_at_exit(),
self.excitation_path.open_shutter()
):
yield
def _save_file(self, pixels: npt.ArrayLike, wavelengths: npt.ArrayLike, grating: int,
central_wavelength: float, folder: os.PathLike | str | None = None) -> Path:
parser = configparser.ConfigParser()
parser.optionxform = str # Preserve case of keys
parser.add_section('Manual X-Calibration')
parser.set('Manual X-Calibration', 'Type', str(2))
parser.set('Manual X-Calibration', 'Number', str(self.n_pts))
for i, tup in enumerate(zip(pixels, wavelengths), start=1):
parser.set('Manual X-Calibration', f'Point{i}', '{},{}'.format(*tup))
if folder is None:
folder = self.folder
now = datetime.now().strftime(_TIMESTAMP_FORMAT)
file = folder / 'ccd_calibration_{}nm_{}grating_{}.cfg'.format(
str(central_wavelength).replace('.', ','), int(grating), now
)
with file.open('w') as configfile:
parser.write(configfile, space_around_delimiters=False)
return file
def _find_calibration_file(self, grating, wavelength) -> tuple[Path | None, float | None]:
wavelength = round(wavelength, 1)
candidate_files = []
files = sorted(self.folder.glob('ccd_calibration*.cfg'),
key=self._datetime_from_filename,
reverse=True)
for file in files:
if self._extract_grating(file) == grating:
if (extracted_wavelength := self._extract_wavelength(file, 2)) == wavelength:
return file, extracted_wavelength
else:
candidate_files.append((file, extracted_wavelength))
if len(candidate_files):
# Grating matched but not wavelength. Return closest match
idx = itertools.argmin((abs(wavelength - wavelength_on_file)
for _, wavelength_on_file in candidate_files))
return candidate_files[idx]
return None, None
@staticmethod
def _extract_grating(file):
return int(file.stem.split('_')[3][:-7])
@functools.cached_property
def _domain(self) -> tuple[int, int]:
return 1, self.detection_path.ccd.detector_pixels.get_latest()[0]
def _calibrate_fast(self, grating, central_wavelength, old_calibration_data) -> Polynomial:
if not self.excitation_path.pump_laser.enabled():
self._log.warning('Could not perform calibration in fast mode because laser is not '
'enabled')
return Polynomial([0, 1], symbol='px')
self._log.info('Calibrating CCD in fast mode.')
with self._measurement_context(grating, central_wavelength):
center_pixel = self._calibrate_single(central_wavelength)
old_wavelength = list(old_calibration_data)[0]
px_shift = (2000 + 1) / 2 - center_pixel
wavelength_shift = old_wavelength - central_wavelength
px, wavelengths = map(np.array, zip(*old_calibration_data[old_wavelength]['points']))
# Store in a temporary file (it's not a real calibration)
file = self._save_file(px - px_shift, wavelengths - wavelength_shift, grating,
central_wavelength, folder=Path(tempfile.mkdtemp()))
self.load_calibration_from_file(file=file)
return Polynomial.fit(px - px_shift, wavelengths - wavelength_shift, deg=3,
domain=self._domain, symbol='px')
def _calibrate_dirty(self, grating, central_wavelength, old_calibration_data) -> Polynomial:
self._log.info('Calibrating CCD in dirty mode.')
old_wavelength = list(old_calibration_data)[0]
shift = old_wavelength - central_wavelength
px, wavelengths = map(np.array, zip(*old_calibration_data[old_wavelength]['points']))
# Store in a temporary file (it's not a real calibration)
file = self._save_file(px, wavelengths - shift, grating, central_wavelength,
folder=Path(tempfile.mkdtemp()))
self.load_calibration_from_file(file=file)
return Polynomial.fit(px, wavelengths - shift, deg=3, domain=self._domain, symbol='px')
def _calibrate_full(self, grating, central_wavelength) -> Polynomial:
if not self.excitation_path.pump_laser.enabled():
self._log.warning('Could not perform calibration in fast mode because laser is not '
'enabled')
return Polynomial([0, 1], symbol='px')
self._log.info('Calibrating CCD in full mode.')
bandwidth = 45 if grating == 600 else 8.5
wavelengths = np.linspace(-1, 1, self.n_pts) * (bandwidth / 2) + central_wavelength
with self._measurement_context(grating, central_wavelength):
px = np.array([self._calibrate_single(wavelength) for wavelength in wavelengths])
file = self._save_file(px, wavelengths, grating, central_wavelength)
self.load_calibration_from_file(file=file)
return Polynomial.fit(px, wavelengths, deg=3, domain=self._domain, symbol='px')
def _calibrate_single(self, wavelength: float) -> int:
self._log.debug(f'Measuring calibration wavelength {wavelength:.3g}.')
self.excitation_path.wavelength(wavelength)
vector = self.detection_path.ccd.ccd_data.get()
return np.argmax(vector)
def _calibrate_none(self, old_calibration_data) -> Polynomial:
wavelength = list(old_calibration_data)[0]
px, wavelengths = map(np.array, zip(*old_calibration_data[wavelength]['points']))
return Polynomial.fit(px, wavelengths, deg=3, domain=self._domain, symbol='px')
@timed
def calibrate(self, grating: int, central_wavelength: float,
old_calibration_data: dict[float, dict[str, Any]] | None = None) -> Polynomial:
"""Perform a calibration...
Parameters
----------
grating :
... for this grating.
central_wavelength :
... for this central wavelength of the spectrometer.
old_calibration_data :
A dictionary containing old calibration data to update.
"""
if old_calibration_data is None:
# If no file exists at all for the specified grating
return self._calibrate_full(grating, central_wavelength)
# Need to disable the get parser because it relies on the calibration!
with mock.patch.object(self.detection_path.ccd.horizontal_axis, 'get_parser', None):
match self.detection_path.ccd_calibration_update_mode():
case 'fast':
return self._calibrate_fast(grating, central_wavelength, old_calibration_data)
case 'dirty':
return self._calibrate_dirty(grating, central_wavelength, old_calibration_data)
case 'full':
return self._calibrate_full(grating, central_wavelength)
case None:
return self._calibrate_none(old_calibration_data)
case _:
raise ValueError
@property
def calibration(self) -> Polynomial:
"""A Polynomial fit for calibrating pixels to wavelength."""
grating_channel = self.detection_path.active_grating()
if grating_channel is None:
# Grating not moved yet, return the identity polynomial
return Polynomial([0, 1], symbol='px')
grating = int(grating_channel.name_parts[-1].split('_')[-1])
current_wavelength = round(self.detection_path.central_wavelength.get_latest(), 1)
if current_wavelength not in self._calibration_data[grating]:
wavelength_on_file = self.load_calibration_from_file(grating, current_wavelength)
else:
wavelength_on_file = current_wavelength
if (
current_wavelength != wavelength_on_file
or self._calibration_data[grating][wavelength_on_file]['timestamp'] <
self.detection_path.ccd_calibration_oldest_datetime().timestamp()
):
return self.calibrate(
grating, current_wavelength,
{wavelength_on_file: self._calibration_data[grating][wavelength_on_file]}
)
px, wavelength = map(np.array,
zip(*self._calibration_data[grating][current_wavelength]['points']))
# Andor ccd calibration tool uses 1-based indexing
return Polynomial.fit(px, wavelength, deg=3, domain=self._domain, symbol='px')
def load_calibration_from_file(self, grating: int = 600, central_wavelength: float = 795,
file=None) -> float:
if file is None:
file, extracted_wavelength = self._find_calibration_file(grating, central_wavelength)
if file is None:
# No calibration for this grating
self.calibrate(grating, central_wavelength, None)
return self.load_calibration_from_file(grating, central_wavelength)
else:
grating = self._extract_grating(file)
extracted_wavelength = self._extract_wavelength(file, 2)
parser = configparser.ConfigParser()
parser.read(file)
section = parser['Manual X-Calibration']
points = [tuple(map(float, val.split(',')))
for name, val in section.items()
if name.startswith('point')]
calibration = {'points': points,
'timestamp': self._datetime_from_filename(file).timestamp()}
self._calibration_data[grating].update({extracted_wavelength: calibration})
return extracted_wavelength
[docs]
@dataclasses.dataclass
class RejectionFeedbackHandler(_Handler):
"""Calibration handler for the excitation rejection feedback.
Parameters
----------
slit_width :
The spectrometer exit slit width.
fs :
The Time Tagger digitization rate.
t_avg_max :
Maximum averaging duration.
alpha :
Relative error to achieve when sampling.
redfactor :
Optimization parameter.
deltatol :
Optimization parameter.
deltainit :
Optimization parameter.
feps :
Optimization parameter.
mask :
A boolean mask defining the rotator axes to optimize.
"""
slit_width: float = 1_000
fs: float = 1.4e1
t_avg_max: float = 15
alpha: float = 5e-2
redfactor: float = 3.0
deltainit: float = 1e-1
deltatol: float = 1e-2
feps: float = 10
mask: tuple[int, int, int] = (False, True, True)
throughput: float = dataclasses.field(default=np.nan, init=False)
cache: dict[float, tuple[float, ...]] = dataclasses.field(default_factory=dict, init=False,
repr=False)
[docs]
@timed
def calibrate(self, new_wav: float, old_wav: float | None = None):
def parser(value) -> MultiAxisPosition:
if isinstance(value, MultiAxisPosition):
pass
elif isinstance(value, Mapping):
value = MultiAxisPosition(**value)
else:
value = MultiAxisPosition(*value)
return list(itertools.compress(dataclasses.astuple(value), self.mask))
def angle_callback(xs=None, *, cache=False):
if xs is None:
if cache is True:
pos = self.excitation_path.rotators.multi_axis_position.cache
else:
pos = self.excitation_path.rotators.multi_axis_position
return parser(pos.get())
else:
xs = dict(zip(itertools.compress(['axis_1', 'axis_2', 'axis_3'], self.mask), xs))
self.excitation_path.rotators.multi_axis_position.set(xs)
def objective_fun(xs):
if not all(cached == x for cached, x in zip(angle_callback(cache=True), xs)):
self._log.debug(f'Setting new angles: {xs}')
angle_callback(xs)
return measure_counter()
def measure_counter() -> float:
nonlocal values
counter.start_for(self.t_avg_max * 10**12, clear=True)
while counter.is_running():
while np.sum(~np.isnan(data := counter.data_normalized())) < 2:
# Wait until we have at least two samples so that we can compute an std
if not counter.is_running():
raise RuntimeError(f'No data within {self.t_avg_max} s')
values['mean'] = np.nanmean(data)
values['stderr'] = np.nanstd(data) / np.sqrt(np.sum(~np.isnan(data)))
if values['stderr'] / values['mean'] < self.alpha:
counter.stop()
break
return values['mean']
def measure_countrate() -> float:
countrate.start_for(10 ** 12 // self.fs, clear=True)
countrate.wait_until_finished()
return countrate.data().item()
def assert_exit_port_selected(n=1):
if n > 5:
breakpoint()
tic = time.perf_counter()
while time.perf_counter() - tic < 30:
if measure_countrate() > 150:
# dark count rate ~80
return
self.detection_path.active_detection_path('ccd')
self.detection_path.active_detection_path('apd')
assert_exit_port_selected(n+1)
if tt is None:
raise RuntimeError('TimeTagger is not installed')
self._log.info('Optimizing excitation rejection.')
self.excitation_path.acquire_laser_lock(new_wav)
assert self.excitation_path.laser.lock()
if self.detection_path.bandwidth() < 0:
self.detection_path.spectrometer.slits.side_exit_slit.init()
assert self.detection_path.bandwidth() >= 0.0
combiner = self.detection_path.tagger.add_combiner_virtual_channel()
combiner.channels([1, 2])
countrate = self.detection_path.tagger.add_count_rate_measurement()
countrate.channels([combiner.get_channel()])
counter = self.detection_path.tagger.add_counter_measurement()
counter.channels([combiner.get_channel()])
counter.binwidth(int(10**12 / self.fs))
counter.n_values(int(self.t_avg_max * self.fs))
counter.rolling(False)
values = {}
p0 = self.excitation_path.power_at_sample()
x0 = angle_callback()
if (wav := min(self.cache, default=None, key=lambda x: abs(x - new_wav))) is not None:
x1 = self.cache[wav]
else:
x1 = x0
self.throughput = np.nan
with (self.excitation_path.nd_filter.position.restore_at_exit()):
# Set a low enough power to avoid damaging the APDs
try:
self.excitation_path.power_at_sample(50e-9)
except RuntimeError:
if io.query_yes_no('Could not set power to 50 nW. Perform full recalibration?'):
bounds = input(f'Specify calibration bounds. Previous: {self.bounds}\n')
bounds = tuple(float(b) for b in bounds.strip('[()]').split(','))
with (
self.excitation_path.power_calibration_update_mode.set_to('full'),
# Make sure APDs don't see too much light
self.detection_path.bandwidth.set_to(0)
):
self.excitation_path.power_calibration_handler.calibrate(new_wav,
bounds=bounds)
self.excitation_path.power_at_sample(50e-9)
else:
raise
try:
with (
self.detection_path.active_detection_path.set_to('apd', allow_changes=True),
self.detection_path.central_wavelength.set_to(new_wav),
self.detection_path.spectrometer.slits.side_exit_slit.position.set_to(1000),
self.excitation_path.rotators.axis_channels.amplitude.set_to(45),
self.excitation_path.rotators.axis_channels.frequency.set_to(200),
self.excitation_path.open_shutter()
):
bounds = self.excitation_path.power_calibration_handler.calibration.bounds
tic = time.perf_counter()
while (
self.excitation_path.nd_filter.position() > bounds[0]
and self.excitation_path.power_at_sample() < p0
and measure_countrate() < 100e3
):
self.excitation_path.nd_filter.position.increment(-1)
toc = time.perf_counter()
self._log.info(f'Took {toc - tic:.2g} s to initialize power.')
assert_exit_port_selected()
p1 = self.excitation_path.power_at_sample()
y1 = measure_counter()
result = noisyopt.minimizeCompass(
objective_fun,
x0=x1,
bounds=np.array(x1)[:, None] + np.array([-1, 1]) * 2.5,
redfactor=self.redfactor,
deltainit=self.deltainit,
deltatol=self.deltatol,
feps=self.feps,
errorcontrol=False,
paired=False,
disp=False
)
if not result.success:
raise ValueError('Could not rescale power(angle) calbration. Optimization '
f'result:\n{result}')
elif any(np.isnan(val) for val in values.values()):
raise RuntimeError('Something went wrong. Is the tagger overflowing?')
else:
self.cache[new_wav] = result.x
self.throughput = (const.lambda2eV(new_wav * 1e-9) * result.fun * const.e
# somthing like 55% detection efficiency
/ p1 / 0.55)
self._log.info(f"Combined count rate at optimum: ({values['mean']:.0f} ± "
f"{values['stderr']:.0f}) Hz (delta {y1 - result.fun:.0f})."
f" Delta from previous position ({result.x[0] - x1[0]:.3g},"
f" {result.x[1] - x1[1]:.3g})º")
self._log.info(f'Throughput: {self.throughput:.2g}')
self._log.debug('Optimization result:\n{}', result)
except (Exception, KeyboardInterrupt) as exc:
self._log.error('Unhandled exception', exc_info=exc)
angle_callback(x0)
raise
else:
angle_callback(result.x)
finally:
self.detection_path.tagger.combiner_virtual_channels.remove(combiner)
self.detection_path.tagger.count_rate_measurements.remove(countrate)
self.detection_path.tagger.counter_measurements.remove(counter)
def _invert_spline(fun):
"""Inverts a UnivariateSpline with a ``bounds`` attribute."""
def inverse(y):
def objective_fun(x, *_):
return np.power(fun(x) - y, 2).item()
result = optimize.direct(objective_fun, bounds=[fun.bounds])
if not result.success:
raise ValueError(f'Could not invert {fun}')
if not np.allclose(y, fx := fun(result.x)):
raise ValueError(f'Invalid y={y}. f(x)={fx}')
return result.x
return inverse
def _asdict(obj: Any, dict_factory: callable = dict) -> dict[str, Any]:
"""Convert a dataclass object to a dictionary, skipping deepcopy
for fields marked with 'copy': False in their metadata.
"""
result = {}
for f in dataclasses.fields(obj):
if not f.metadata.get('json_exclude', False):
value = getattr(obj, f.name)
if hasattr(value, '__dataclass_fields__'): # Nested dataclass handling
result[f.name] = _asdict(value, dict_factory)
else:
result[f.name] = copy.deepcopy(value)
return dict_factory(result)
def _calculate_overlap(a: tuple[float, float], b: tuple[float, float]) -> float:
"""Calculate the overlap of interval b with respect to a."""
starts, ends = zip(a, b)
return min(1.0, max(0.0, (min(ends) - max(starts)) / (a[1] - a[0])))