Source code for mjolnir.calibration

"""This module defines classes that handle calibration of parameters."""
from __future__ import annotations

import abc
import configparser
import copy
import dataclasses
import logging
import os
import sys
import tempfile
import time
from collections import defaultdict
from collections.abc import Mapping
from contextlib import contextmanager
from datetime import datetime
from numbers import Real
from pathlib import Path
from typing import Any, Callable, Sequence, Literal, cast, TypeVar
from unittest import mock

import noisyopt
import numpy as np
from adaptive.learner import BaseLearner
from adaptive.learner.learner1D import curvature_loss_function
from adaptive.learner.average_learner1D import AverageLearner1D
from adaptive.runner import _TimeGoal, BaseRunner, BlockingRunner
from adaptive.utils import SequentialExecutor
from dataclassabc import dataclassabc
from matplotlib import pyplot as plt
from numpy import typing as npt
from numpy.polynomial import Polynomial
from qcodes.parameters import Parameter
from qcodes.station import Station
from qcodes_contrib_drivers.drivers.Attocube.AMC100 import MultiAxisPosition
from qutil import const, functools, itertools, io
from scipy import interpolate, optimize

try:
    from qutil.concurrent import BatchExecutor, SortedArgsBatchCaller
except ImportError:
    BatchExecutor = None
    SortedArgsBatchCaller = None
try:
    sys.path.append(str(Path(os.environ['TIMETAGGER_INSTALL_PATH'], 'driver', 'python')))
    import TimeTagger as tt
except (ImportError, KeyError):
    tt = None

from . import _TIMESTAMP_FORMAT
from .helpers import mean_and_standard_error, timed

_T = TypeVar('_T')
Point = tuple[int, Real]
Points = list[Point]


[docs] class AverageSequenceLearner1D(AverageLearner1D): """Learns the values of noisy functions at given points.""" def __init__(self, function: Callable[[tuple[int, Real]], Real], sequence: Sequence[Any], min_samples: int = 50, max_samples: int = sys.maxsize, min_error: float = 0): super().__init__(function, itertools.minmax(sequence), min_samples=min_samples, max_samples=max_samples, min_error=min_error) self._to_do_points = set(sequence) self._ntotal = len(sequence) self.sequence = list(sequence)
[docs] def new(self): return AverageSequenceLearner1D(self.function, self.sequence, min_samples=self.min_samples, max_samples=self.max_samples, min_error=self.min_error)
[docs] def tell_pending(self, seed_x: Point) -> None: super().tell_pending(seed_x) self._to_do_points.discard(seed_x[1])
[docs] def tell(self, seed_x: Point, y: Real) -> None: super().tell(seed_x, y) self._to_do_points.discard(seed_x[1])
[docs] def remove_unfinished(self) -> None: for i in self.pending_points: self._to_do_points.add(i) super().remove_unfinished()
[docs] def done(self) -> bool: return not self._to_do_points and not self.pending_points and not self._undersampled_points
def _ask_points_without_adding(self, n: int) -> tuple[list[float], list[float]]: if n != 1: raise ValueError("This Learner is designed to only ask for 1 point here.") return sorted(self._to_do_points, key=self.sequence.index)[0:1], [np.inf]
[docs] @dataclasses.dataclass class PowerCalibration: """Handles calibration of a variable neutral density filter. Parameters ---------- wavelength : The laser wavelength timestamp : The timestamp of the calibration. angle : The :class:`numpy:numpy.ndarray` of angle samples. power : The :class:`numpy:numpy.ndarray` of power samples. error : The :class:`numpy:numpy.ndarray` of errors of each sample. log : Log-spaced samples. """ wavelength: float timestamp: float = dataclasses.field(repr=False) angle: npt.NDArray[float] = dataclasses.field(repr=False) power: npt.NDArray[float] = dataclasses.field(repr=False) error: npt.NDArray[float] = dataclasses.field(repr=False) log: bool = False def __post_init__(self): self._log = logging.getLogger(f'{__name__}.{self.__class__.__qualname__}') def __len__(self) -> int: return len(self.angle) def __call__(self, angle: npt.ArrayLike) -> npt.NDArray[float]: return self.angle_to_power(angle) @property def age(self) -> float: """The calibration age in seconds.""" return time.time() - self.timestamp @property def bounds(self) -> tuple[float, float]: """The bounds of the calibration range.""" return itertools.minmax(self.angle) @functools.cached_property def angle_to_power(self): # Goes through every data point, no smoothing # spline = interpolate.InterpolatedUnivariateSpline(self.angle, self.power, 1 / self.error, # ext='raise', k=2) # Smoothes. Reduce s to decrease smoothing spline = interpolate.UnivariateSpline(self.angle, self.power, 1 / self.error, ext='raise', k=2, s=None) # Smoothes, but bad fit for large range of angles # spline = interpolate.LSQUnivariateSpline( # self.angle, self.power, # t=np.linspace(self.angle.min(), self.angle.max(), len(self.angle) // 5)[1:-1], # w=1/self.error, k=3, ext='raise' # ) spline.bounds = spline._data[3:5] return spline @functools.cached_property def angle_to_power_deriv(self): """The calibration function's derivative.""" spline = self.angle_to_power.derivative(n=1) spline.bounds = self.angle_to_power.bounds return spline @functools.cached_property def power_to_angle(self): """The inverted calibration function.""" return np.vectorize(_invert_spline(self.angle_to_power))
[docs] @timed def rescaled(self, wavelength: float, angle_callback: Parameter, power_callback: Parameter, target_power: float | None = None, target_angle: float | None = None, n_samples: int = 50): """The calibration function rescaled by a global factor. Parameters ---------- wavelength : The wavelength at which to rescale. angle_callback : The callback to set and get the angle. power_callback : The callback to get the power. target_power : The power to reach at a given angle. Takes precedence over target_angle. target_angle : The angle at which to rescale for the given wavelength. n_samples : The number of power samples to average over. """ def objective_fun(x): if angle_callback.cache.get() != x[0]: angle_callback(x[0]) return (power_callback() - target_power) ** 2 msg = f'Calibrating power at {wavelength} nm and {{:.2g}}{{}} in fast mode.' msg_result = 'Power at angle {:.2f}º is ({:.2g} ± {:.2g}) W.' alpha = 0.05 x0 = angle_callback() if target_power is not None: self._log.info(msg.format(target_power, ' W')) result = noisyopt.minimizeCompass(objective_fun, x0=[x0], bounds=np.array([self.bounds]), errorcontrol=True, paired=False, scaling=[0.25], deltatol=0.04, feps=500e-12**2, alpha=alpha, disp=False) if not result.success: angle_callback(x0) raise ValueError('Could not rescale power(angle) calibration. Optimization result:' f'\n{result}') angle_callback(result.x[0]) p = result.fun ** .5 + target_power pse = .5 / result.fun ** .5 * result.funse scale = target_power / self.angle_to_power(result.x[0]) self._log.info(msg_result.format(result.x[0], p, pse)) else: if target_angle is None: target_angle = angle_callback.cache.get() self._log.info(msg.format(target_angle, 'º')) angle_callback(target_angle) power = np.array([power_callback() for _ in range(n_samples)]) p, pse = mean_and_standard_error(power, alpha=alpha) scale = p / self.angle_to_power(target_angle) self._log.info(msg_result.format(target_angle, p, pse)) return self.__class__(wavelength, # This is not a "real" calibration, so keep the old timestamp. self.timestamp, self.angle, self.power * scale, self.error * scale)
[docs] def plot(self): """Plot the calibration function.""" if (num := f'Power calibration {self.wavelength} nm') in plt.get_figlabels(): return (fig := plt.figure(num)), fig.axes fig, ax = plt.subplots(layout='constrained', num=num) x = np.linspace(self.angle.min(), self.angle.max(), 1001) ax.errorbar(self.angle, self.power, self.error, fmt='r.') ax.plot(x, self.angle_to_power(x)) ax.grid(True) ax.set_yscale('log') ax.set_xlabel('Angle (\u00b0)') ax.set_ylabel('Power (W)') ax.set_title(rf'$\lambda = ${self.wavelength} nm') ax2 = ax.secondary_yaxis('right', functions=(lambda x: x / 13, lambda x: x * 13)) ax2.set_ylabel('Power at sample (W)') return fig, ax
@dataclasses.dataclass class _Handler: """Base class for handlers.""" logical_station: Station = dataclasses.field(repr=False, metadata={'json_exclude': True}) def __post_init__(self): self._log = logging.getLogger(f'{__name__}.{self.__class__.__qualname__}') def _JSONEncoder(self) -> dict[str, Any]: # noqa return _asdict(self) @property def detection_path(self): return self.logical_station.detection_path @property def excitation_path(self): return self.logical_station.excitation_path
[docs] @dataclasses.dataclass class CalibrationHandler(abc.ABC, _Handler): """Base class for calibration handlers.""" def __post_init__(self): super().__post_init__() self.folder = cast(Path, Path(self.folder)) # noqa @property @abc.abstractmethod def folder(self) -> Path: """The folder in which to store serialized calibration data.""" ... @property @abc.abstractmethod def calibration(self) -> Any: """The calibration function.""" ...
[docs] @abc.abstractmethod def load_calibration_from_file(self, *args, **kwargs) -> Any: """Loads serialized calibration data.""" ...
[docs] @abc.abstractmethod def calibrate(self, *args, **kwargs) -> Any: """Performs a calibration.""" ...
@staticmethod def _datetime_from_filename(file: Path) -> datetime: return datetime.strptime(file.stem[-19:], _TIMESTAMP_FORMAT) @staticmethod def _extract_wavelength(file, precision): return round(float(file.stem.split('_')[2][:-2].replace(',', '.')), precision)
[docs] @dataclassabc class PowerCalibrationHandler(CalibrationHandler): """Calibration handler for the ND filter. Parameters ---------- sampling : Use a naive sequential or adaptive sampling. Defaults to the value of :attr:`~.instruments.logical_instruments.ExcitationPath.power_calibration_strategy`. loss : The maximum sampling loss. If None, it is not used for the goal. dtheta : The maximum spacing between adjacent sample points. Ignored if npoints is given or sampling is 'adaptive'. duration : The maximum duration. min_samples : The minimum number of samples taken for each average. If None, it is not used for the goal. nsamples : The minimum number of total samples. If None, it is not used for the goal. npoints : The minimum number of points to be sampled. If None, it is not used for the goal. ntasks : The number of points the executor visits per batch. wait : Sleep for this amount of time before taking a power measurement. folder : The folder to save learners in. learner_kwargs : Additional kwargs for the adaptive learner. """ sampling: Literal['sequential', 'adaptive'] | None = None loss: float | None = 0.1 dtheta: float | None = 0.5 duration: float | None = 300. min_samples: int | None = 25 nsamples: int | None = None npoints: int | None = None ntasks: int = 1 wait: float = 0.0 folder: str | Path = 'C:/Data/Triton/calibration/nd_filter' learner_kwargs: dict = dataclasses.field(default_factory=dict) def __post_init__(self): super().__post_init__() # Declare private attributes. These cannot be dataclass fields because they might not be # picklable self._log: logging.Logger self._runner: BaseRunner self._learner: BaseLearner self._calibration: PowerCalibration # Load last used power calibration so that self._learner is set. Otherwise # self.calibration errors self.load_calibration_from_file() if self.sampling is None: self.sampling = self.excitation_path.power_calibration_strategy() def _setup_adaptive(self, bounds: tuple[float, float]): if self.sampling == 'adaptive': if BatchExecutor is None: raise ValueError('qutil.concurrent is not available, so the adaptive sampling ' "strategy isn't either") self._learner = AverageLearner1D( self._objective_fun, bounds=bounds, loss_per_interval=curvature_loss_function( # type: ignore **({'area_factor': 5, 'euclid_factor': 0.02, 'horizontal_factor': 0.03} | self.learner_kwargs) ), min_samples=self.min_samples ) self._executor = BatchExecutor({self._objective_fun: SortedArgsBatchCaller()}) else: if self.npoints is None: sequence = np.linspace(*bounds[::-1], int(abs(np.diff(bounds).item()) // self.dtheta) + 1) else: sequence = np.linspace(*bounds[::-1], self.npoints) self._learner = AverageSequenceLearner1D(self._objective_fun, sequence) self._executor = SequentialExecutor() def _objective_fun(self, seed_x): seed, x = seed_x if seed == 0: self._log.debug(f'Setting position {x:.4g}') self.excitation_path.nd_filter.position(x) time.sleep(self.wait) return self.excitation_path.powermeter.power() def _goal(self, learner): if self._time_goal(learner): return True if hasattr(learner, 'done'): return learner.done() achieved = True if self.min_samples is not None and hasattr(learner, 'min_samples_per_point'): achieved &= learner.min_samples_per_point >= self.min_samples if self.nsamples is not None: achieved &= learner.nsamples >= self.nsamples if self.loss is not None: achieved &= learner.loss() <= self.loss if self.npoints is not None: achieved &= learner.npoints >= self.npoints return achieved @property def _time_goal(self) -> Callable[[float], bool]: if self.duration is not None: return _TimeGoal(self.duration) def _time_goal(_): return False return _time_goal def _power_calibration_from_learner(self) -> PowerCalibration: angle, power = zip(*sorted(self._learner.data.items())) error = [self._learner.error[angle] for angle in angle] return PowerCalibration(self._learner.wavelength, self._learner.timestamp, np.array(angle), np.array(power), np.array(error)) def _find_calibration_file(self, wavelength: float): wavelength = round(wavelength, 3) wavelengths_on_file = [] files = sorted(self.folder.glob('power_calibration*.p'), key=self._datetime_from_filename, reverse=True) for file in files: if (extracted_wavelength := self._extract_wavelength(file, 3)) == wavelength: return file, extracted_wavelength wavelengths_on_file.append(extracted_wavelength) idx = itertools.argmin((abs(wavelength - wavelength_on_file) for wavelength_on_file in wavelengths_on_file)) return files[idx], wavelengths_on_file[idx] @timed def calibrate(self, wavelength: float | None = None, file: str | os.PathLike | Path | None = None, bounds: tuple[float, float] | None = None, **kwargs): """Calibrate the power as function of rotation angle of the ND filter using adaptive or sequential sampling. Parameters ---------- wavelength : The target wavelength for which to calibrate power. file : A file to save the results to. If False, the calibration is not saved. bounds : Angular bounds for the calibration. **kwargs Additional settings that will overwrite the settings stored in this dataclass instance as fields. """ if kwargs: for name, value in kwargs.items(): setattr(self, name, value) # Update metadata of the instrument self.excitation_path.power_calibration_handler = self self._setup_adaptive(bounds or self.bounds) if wavelength is not None: self.excitation_path.acquire_laser_lock(wavelength) else: assert self.excitation_path.laser.lock() with ( self.excitation_path.nd_filter.position.set_to((bounds or self.bounds)[0], allow_changes=True), self.excitation_path.open_shutter() ): self._log.info(f'Calibrating power at {round(self.excitation_path.wavelength(), 3)} ' 'nm.') self.excitation_path.wait_for_power_to_settle() self._runner = BlockingRunner(self._learner, self._goal, executor=self._executor, ntasks=self.ntasks) now = datetime.now() self._learner.timestamp = now.timestamp() self._learner.wavelength = round(self.excitation_path.wavelength(), 3) self._calibration = self._power_calibration_from_learner() if file is not False: self._learner.save(file or self.folder / "power_calibration_{}nm_{}.p".format( str(self._learner.wavelength).replace('.', ','), now.strftime(_TIMESTAMP_FORMAT) )) def update_calibration(self, target_power: float | None = None, target_angle: float | None = None, n_samples: int | None = None, bounds: tuple[float, float] | None = None, attempts: int = 3): """Update the instrument's power calibration for a given wavelength. The mode depends on the `power_calibration_update_mode` parameter. Parameters ---------- target_power : The power level to target, if any. Else, the calibration is performed at target_angle. target_angle : The angle at which to calibrate if the power_calibration_update_mode parameter is 'fast'. Ignored if target_power is given. Defaults to the current angle. n_samples : The number of power samples to average over if target_angle is given. Ignored otherwise as sampling is adaptive in that case. bounds : Bounds for the calibration. Only if update mode is 'full'. attempts : The number of recalibration attempts. """ def _calibrate(): if self.excitation_path.power_calibration_update_mode() == 'fast': self._calibration = self._calibration.rescaled( wavelength, self.excitation_path.nd_filter.position, self.excitation_path.power, target_power, target_angle, n_samples or self.excitation_path.power_n_avg() ) else: self.calibrate(wavelength) bounds = bounds or self.bounds wavelength = round(self.excitation_path.wavelength(), 3) attempt = 0 with ( self.excitation_path.power_calibration_update_mode.restore_at_exit(), self.excitation_path.open_shutter() ): if target_power is None: if target_angle is None: # Initial calibration without a target. No feedback necessary. return _calibrate() else: self.excitation_path.nd_filter.position(target_angle) elif ( target_power is not None and target_power < self._calibration.power.min() and io.query_yes_no('Power outside of calibration range. ' 'Perform a full calibration with new bounds?') ): bounds = input(f'Specify calibration bounds. Previous: {self.bounds}\n') bounds = tuple(float(b) for b in bounds.strip('[()]').split(',')) self.calibrate(wavelength, bounds=bounds) while attempt <= attempts: if self.recalibration_required(target_power): if target_power is not None: # Set angle according to current calibration try: self.excitation_path.nd_filter.position( self._calibration.power_to_angle(target_power) ) except ValueError: # Inverting calibration failed self.excitation_path.power_calibration_update_mode('full') _calibrate() attempt += 1 else: break else: raise RuntimeError(f'Could not calibrate power after {attempts} attempts.') def recalibration_required(self, target_power: float | None = None) -> bool: if target_power is None: target_power = self._calibration.angle_to_power( self.excitation_path.nd_filter.position() ) with self.excitation_path.open_shutter(): power = self.excitation_path.wait_for_power_to_settle().mean() if abs(1 - power / target_power) > self.excitation_path.power_calibration_tolerance(): return True else: return False @property def bounds(self) -> tuple[float, float]: return self._calibration.bounds @property def wavelength(self) -> float: return self._calibration.wavelength @property def calibration(self) -> PowerCalibration: current_wavelength = round(self.excitation_path.wavelength(), 3) previous_bounds = self.bounds if self.wavelength != current_wavelength: self.load_calibration_from_file(current_wavelength) if ( _calculate_overlap(previous_bounds, self._calibration.bounds) < 0.5 and io.query_yes_no('Overlap in calibration bounds less than 50%. ' 'Perform a full calibration with previous bounds?\n' f'Previous: {previous_bounds}.\n' f'Current: {self._calibration.bounds}.') ): with self.excitation_path.power_calibration_update_mode.set_to('full'): self.update_calibration(bounds=previous_bounds) if self.wavelength != current_wavelength: self.load_calibration_from_file(current_wavelength) if ( self.wavelength != current_wavelength or self._calibration.age > self.excitation_path.power_calibration_max_age() ): self.update_calibration() return self._calibration def load_calibration_from_file(self, wavelength: float = 795, file=None): if file is None: file, wavelength = self._find_calibration_file(wavelength) self._learner = AverageLearner1D(lambda _: 0, (-np.inf, np.inf)) self._learner.load(str(file)) self._learner.wavelength = wavelength self._learner.timestamp = self._datetime_from_filename(file).timestamp() self._learner.bounds = (min(self._learner.data.keys()), max(self._learner.data.keys())) self._calibration = self._power_calibration_from_learner()
[docs] @dataclassabc class CcdCalibrationHandler(CalibrationHandler): """Calibration handler for the CCD. Parameters ---------- exposure_time : The duration of exposure for CCD data. number_accumulations: The number of accumulations to perform for each CCD data. n_pts : The number of calibration points (wavelenghts). folder : The folder where to store calibration data. """ exposure_time: float = 0.1 number_accumulations: int = 2 n_pts: int = 9 folder: str | Path = 'C:/Data/Triton/calibration/ccd' _calibration_data: defaultdict[int, dict[int, dict[str, Any]]] = dataclasses.field( default_factory=lambda: defaultdict(lambda: defaultdict()), init=False, repr=False ) @contextmanager def _measurement_context(self, grating: int, central_wavelength: float): assert self.excitation_path.laser.lock() with ( self.detection_path.active_detection_path.set_to('ccd'), self.detection_path.active_grating.set_to(grating), self.detection_path.central_wavelength.set_to(central_wavelength), self.detection_path.ccd.exposure_time.set_to(self.exposure_time), self.detection_path.ccd.accumulation_cycle_time.set_to(0), self.detection_path.ccd.number_accumulations.set_to(self.number_accumulations), self.detection_path.ccd.acquisition_mode.set_to('accumulate'), self.detection_path.ccd.read_mode.set_to('full vertical binning'), self.detection_path.ccd.cosmic_ray_filter_mode.set_to(True), self.excitation_path.rejection_feedback.set_to(False), self.excitation_path.wavelength.restore_at_exit(), self.excitation_path.open_shutter() ): yield def _save_file(self, pixels: npt.ArrayLike, wavelengths: npt.ArrayLike, grating: int, central_wavelength: float, folder: os.PathLike | str | None = None) -> Path: parser = configparser.ConfigParser() parser.optionxform = str # Preserve case of keys parser.add_section('Manual X-Calibration') parser.set('Manual X-Calibration', 'Type', str(2)) parser.set('Manual X-Calibration', 'Number', str(self.n_pts)) for i, tup in enumerate(zip(pixels, wavelengths), start=1): parser.set('Manual X-Calibration', f'Point{i}', '{},{}'.format(*tup)) if folder is None: folder = self.folder now = datetime.now().strftime(_TIMESTAMP_FORMAT) file = folder / 'ccd_calibration_{}nm_{}grating_{}.cfg'.format( str(central_wavelength).replace('.', ','), int(grating), now ) with file.open('w') as configfile: parser.write(configfile, space_around_delimiters=False) return file def _find_calibration_file(self, grating, wavelength) -> tuple[Path | None, float | None]: wavelength = round(wavelength, 1) candidate_files = [] files = sorted(self.folder.glob('ccd_calibration*.cfg'), key=self._datetime_from_filename, reverse=True) for file in files: if self._extract_grating(file) == grating: if (extracted_wavelength := self._extract_wavelength(file, 2)) == wavelength: return file, extracted_wavelength else: candidate_files.append((file, extracted_wavelength)) if len(candidate_files): # Grating matched but not wavelength. Return closest match idx = itertools.argmin((abs(wavelength - wavelength_on_file) for _, wavelength_on_file in candidate_files)) return candidate_files[idx] return None, None @staticmethod def _extract_grating(file): return int(file.stem.split('_')[3][:-7]) @functools.cached_property def _domain(self) -> tuple[int, int]: return 1, self.detection_path.ccd.detector_pixels.get_latest()[0] def _calibrate_fast(self, grating, central_wavelength, old_calibration_data) -> Polynomial: if not self.excitation_path.pump_laser.enabled(): self._log.warning('Could not perform calibration in fast mode because laser is not ' 'enabled') return Polynomial([0, 1], symbol='px') self._log.info('Calibrating CCD in fast mode.') with self._measurement_context(grating, central_wavelength): center_pixel = self._calibrate_single(central_wavelength) old_wavelength = list(old_calibration_data)[0] px_shift = (2000 + 1) / 2 - center_pixel wavelength_shift = old_wavelength - central_wavelength px, wavelengths = map(np.array, zip(*old_calibration_data[old_wavelength]['points'])) # Store in a temporary file (it's not a real calibration) file = self._save_file(px - px_shift, wavelengths - wavelength_shift, grating, central_wavelength, folder=Path(tempfile.mkdtemp())) self.load_calibration_from_file(file=file) return Polynomial.fit(px - px_shift, wavelengths - wavelength_shift, deg=3, domain=self._domain, symbol='px') def _calibrate_dirty(self, grating, central_wavelength, old_calibration_data) -> Polynomial: self._log.info('Calibrating CCD in dirty mode.') old_wavelength = list(old_calibration_data)[0] shift = old_wavelength - central_wavelength px, wavelengths = map(np.array, zip(*old_calibration_data[old_wavelength]['points'])) # Store in a temporary file (it's not a real calibration) file = self._save_file(px, wavelengths - shift, grating, central_wavelength, folder=Path(tempfile.mkdtemp())) self.load_calibration_from_file(file=file) return Polynomial.fit(px, wavelengths - shift, deg=3, domain=self._domain, symbol='px') def _calibrate_full(self, grating, central_wavelength) -> Polynomial: if not self.excitation_path.pump_laser.enabled(): self._log.warning('Could not perform calibration in fast mode because laser is not ' 'enabled') return Polynomial([0, 1], symbol='px') self._log.info('Calibrating CCD in full mode.') bandwidth = 45 if grating == 600 else 8.5 wavelengths = np.linspace(-1, 1, self.n_pts) * (bandwidth / 2) + central_wavelength with self._measurement_context(grating, central_wavelength): px = np.array([self._calibrate_single(wavelength) for wavelength in wavelengths]) file = self._save_file(px, wavelengths, grating, central_wavelength) self.load_calibration_from_file(file=file) return Polynomial.fit(px, wavelengths, deg=3, domain=self._domain, symbol='px') def _calibrate_single(self, wavelength: float) -> int: self._log.debug(f'Measuring calibration wavelength {wavelength:.3g}.') self.excitation_path.wavelength(wavelength) vector = self.detection_path.ccd.ccd_data.get() return np.argmax(vector) def _calibrate_none(self, old_calibration_data) -> Polynomial: wavelength = list(old_calibration_data)[0] px, wavelengths = map(np.array, zip(*old_calibration_data[wavelength]['points'])) return Polynomial.fit(px, wavelengths, deg=3, domain=self._domain, symbol='px') @timed def calibrate(self, grating: int, central_wavelength: float, old_calibration_data: dict[float, dict[str, Any]] | None = None) -> Polynomial: """Perform a calibration... Parameters ---------- grating : ... for this grating. central_wavelength : ... for this central wavelength of the spectrometer. old_calibration_data : A dictionary containing old calibration data to update. """ if old_calibration_data is None: # If no file exists at all for the specified grating return self._calibrate_full(grating, central_wavelength) # Need to disable the get parser because it relies on the calibration! with mock.patch.object(self.detection_path.ccd.horizontal_axis, 'get_parser', None): match self.detection_path.ccd_calibration_update_mode(): case 'fast': return self._calibrate_fast(grating, central_wavelength, old_calibration_data) case 'dirty': return self._calibrate_dirty(grating, central_wavelength, old_calibration_data) case 'full': return self._calibrate_full(grating, central_wavelength) case None: return self._calibrate_none(old_calibration_data) case _: raise ValueError @property def calibration(self) -> Polynomial: """A Polynomial fit for calibrating pixels to wavelength.""" grating_channel = self.detection_path.active_grating() if grating_channel is None: # Grating not moved yet, return the identity polynomial return Polynomial([0, 1], symbol='px') grating = int(grating_channel.name_parts[-1].split('_')[-1]) current_wavelength = round(self.detection_path.central_wavelength.get_latest(), 1) if current_wavelength not in self._calibration_data[grating]: wavelength_on_file = self.load_calibration_from_file(grating, current_wavelength) else: wavelength_on_file = current_wavelength if ( current_wavelength != wavelength_on_file or self._calibration_data[grating][wavelength_on_file]['timestamp'] < self.detection_path.ccd_calibration_oldest_datetime().timestamp() ): return self.calibrate( grating, current_wavelength, {wavelength_on_file: self._calibration_data[grating][wavelength_on_file]} ) px, wavelength = map(np.array, zip(*self._calibration_data[grating][current_wavelength]['points'])) # Andor ccd calibration tool uses 1-based indexing return Polynomial.fit(px, wavelength, deg=3, domain=self._domain, symbol='px') def load_calibration_from_file(self, grating: int = 600, central_wavelength: float = 795, file=None) -> float: if file is None: file, extracted_wavelength = self._find_calibration_file(grating, central_wavelength) if file is None: # No calibration for this grating self.calibrate(grating, central_wavelength, None) return self.load_calibration_from_file(grating, central_wavelength) else: grating = self._extract_grating(file) extracted_wavelength = self._extract_wavelength(file, 2) parser = configparser.ConfigParser() parser.read(file) section = parser['Manual X-Calibration'] points = [tuple(map(float, val.split(','))) for name, val in section.items() if name.startswith('point')] calibration = {'points': points, 'timestamp': self._datetime_from_filename(file).timestamp()} self._calibration_data[grating].update({extracted_wavelength: calibration}) return extracted_wavelength
[docs] @dataclasses.dataclass class RejectionFeedbackHandler(_Handler): """Calibration handler for the excitation rejection feedback. Parameters ---------- slit_width : The spectrometer exit slit width. fs : The Time Tagger digitization rate. t_avg_max : Maximum averaging duration. alpha : Relative error to achieve when sampling. redfactor : Optimization parameter. deltatol : Optimization parameter. deltainit : Optimization parameter. feps : Optimization parameter. mask : A boolean mask defining the rotator axes to optimize. """ slit_width: float = 1_000 fs: float = 1.4e1 t_avg_max: float = 15 alpha: float = 5e-2 redfactor: float = 3.0 deltainit: float = 1e-1 deltatol: float = 1e-2 feps: float = 10 mask: tuple[int, int, int] = (False, True, True) throughput: float = dataclasses.field(default=np.nan, init=False) cache: dict[float, tuple[float, ...]] = dataclasses.field(default_factory=dict, init=False, repr=False)
[docs] @timed def calibrate(self, new_wav: float, old_wav: float | None = None): def parser(value) -> MultiAxisPosition: if isinstance(value, MultiAxisPosition): pass elif isinstance(value, Mapping): value = MultiAxisPosition(**value) else: value = MultiAxisPosition(*value) return list(itertools.compress(dataclasses.astuple(value), self.mask)) def angle_callback(xs=None, *, cache=False): if xs is None: if cache is True: pos = self.excitation_path.rotators.multi_axis_position.cache else: pos = self.excitation_path.rotators.multi_axis_position return parser(pos.get()) else: xs = dict(zip(itertools.compress(['axis_1', 'axis_2', 'axis_3'], self.mask), xs)) self.excitation_path.rotators.multi_axis_position.set(xs) def objective_fun(xs): if not all(cached == x for cached, x in zip(angle_callback(cache=True), xs)): self._log.debug(f'Setting new angles: {xs}') angle_callback(xs) return measure_counter() def measure_counter() -> float: nonlocal values counter.start_for(self.t_avg_max * 10**12, clear=True) while counter.is_running(): while np.sum(~np.isnan(data := counter.data_normalized())) < 2: # Wait until we have at least two samples so that we can compute an std if not counter.is_running(): raise RuntimeError(f'No data within {self.t_avg_max} s') values['mean'] = np.nanmean(data) values['stderr'] = np.nanstd(data) / np.sqrt(np.sum(~np.isnan(data))) if values['stderr'] / values['mean'] < self.alpha: counter.stop() break return values['mean'] def measure_countrate() -> float: countrate.start_for(10 ** 12 // self.fs, clear=True) countrate.wait_until_finished() return countrate.data().item() def assert_exit_port_selected(n=1): if n > 5: breakpoint() tic = time.perf_counter() while time.perf_counter() - tic < 30: if measure_countrate() > 150: # dark count rate ~80 return self.detection_path.active_detection_path('ccd') self.detection_path.active_detection_path('apd') assert_exit_port_selected(n+1) if tt is None: raise RuntimeError('TimeTagger is not installed') self._log.info('Optimizing excitation rejection.') self.excitation_path.acquire_laser_lock(new_wav) assert self.excitation_path.laser.lock() if self.detection_path.bandwidth() < 0: self.detection_path.spectrometer.slits.side_exit_slit.init() assert self.detection_path.bandwidth() >= 0.0 combiner = self.detection_path.tagger.add_combiner_virtual_channel() combiner.channels([1, 2]) countrate = self.detection_path.tagger.add_count_rate_measurement() countrate.channels([combiner.get_channel()]) counter = self.detection_path.tagger.add_counter_measurement() counter.channels([combiner.get_channel()]) counter.binwidth(int(10**12 / self.fs)) counter.n_values(int(self.t_avg_max * self.fs)) counter.rolling(False) values = {} p0 = self.excitation_path.power_at_sample() x0 = angle_callback() if (wav := min(self.cache, default=None, key=lambda x: abs(x - new_wav))) is not None: x1 = self.cache[wav] else: x1 = x0 self.throughput = np.nan with (self.excitation_path.nd_filter.position.restore_at_exit()): # Set a low enough power to avoid damaging the APDs try: self.excitation_path.power_at_sample(50e-9) except RuntimeError: if io.query_yes_no('Could not set power to 50 nW. Perform full recalibration?'): bounds = input(f'Specify calibration bounds. Previous: {self.bounds}\n') bounds = tuple(float(b) for b in bounds.strip('[()]').split(',')) with ( self.excitation_path.power_calibration_update_mode.set_to('full'), # Make sure APDs don't see too much light self.detection_path.bandwidth.set_to(0) ): self.excitation_path.power_calibration_handler.calibrate(new_wav, bounds=bounds) self.excitation_path.power_at_sample(50e-9) else: raise try: with ( self.detection_path.active_detection_path.set_to('apd', allow_changes=True), self.detection_path.central_wavelength.set_to(new_wav), self.detection_path.spectrometer.slits.side_exit_slit.position.set_to(1000), self.excitation_path.rotators.axis_channels.amplitude.set_to(45), self.excitation_path.rotators.axis_channels.frequency.set_to(200), self.excitation_path.open_shutter() ): bounds = self.excitation_path.power_calibration_handler.calibration.bounds tic = time.perf_counter() while ( self.excitation_path.nd_filter.position() > bounds[0] and self.excitation_path.power_at_sample() < p0 and measure_countrate() < 100e3 ): self.excitation_path.nd_filter.position.increment(-1) toc = time.perf_counter() self._log.info(f'Took {toc - tic:.2g} s to initialize power.') assert_exit_port_selected() p1 = self.excitation_path.power_at_sample() y1 = measure_counter() result = noisyopt.minimizeCompass( objective_fun, x0=x1, bounds=np.array(x1)[:, None] + np.array([-1, 1]) * 2.5, redfactor=self.redfactor, deltainit=self.deltainit, deltatol=self.deltatol, feps=self.feps, errorcontrol=False, paired=False, disp=False ) if not result.success: raise ValueError('Could not rescale power(angle) calbration. Optimization ' f'result:\n{result}') elif any(np.isnan(val) for val in values.values()): raise RuntimeError('Something went wrong. Is the tagger overflowing?') else: self.cache[new_wav] = result.x self.throughput = (const.lambda2eV(new_wav * 1e-9) * result.fun * const.e # somthing like 55% detection efficiency / p1 / 0.55) self._log.info(f"Combined count rate at optimum: ({values['mean']:.0f} ± " f"{values['stderr']:.0f}) Hz (delta {y1 - result.fun:.0f})." f" Delta from previous position ({result.x[0] - x1[0]:.3g}," f" {result.x[1] - x1[1]:.3g})º") self._log.info(f'Throughput: {self.throughput:.2g}') self._log.debug('Optimization result:\n{}', result) except (Exception, KeyboardInterrupt) as exc: self._log.error('Unhandled exception', exc_info=exc) angle_callback(x0) raise else: angle_callback(result.x) finally: self.detection_path.tagger.combiner_virtual_channels.remove(combiner) self.detection_path.tagger.count_rate_measurements.remove(countrate) self.detection_path.tagger.counter_measurements.remove(counter)
def _invert_spline(fun): """Inverts a UnivariateSpline with a ``bounds`` attribute.""" def inverse(y): def objective_fun(x, *_): return np.power(fun(x) - y, 2).item() result = optimize.direct(objective_fun, bounds=[fun.bounds]) if not result.success: raise ValueError(f'Could not invert {fun}') if not np.allclose(y, fx := fun(result.x)): raise ValueError(f'Invalid y={y}. f(x)={fx}') return result.x return inverse def _asdict(obj: Any, dict_factory: callable = dict) -> dict[str, Any]: """Convert a dataclass object to a dictionary, skipping deepcopy for fields marked with 'copy': False in their metadata. """ result = {} for f in dataclasses.fields(obj): if not f.metadata.get('json_exclude', False): value = getattr(obj, f.name) if hasattr(value, '__dataclass_fields__'): # Nested dataclass handling result[f.name] = _asdict(value, dict_factory) else: result[f.name] = copy.deepcopy(value) return dict_factory(result) def _calculate_overlap(a: tuple[float, float], b: tuple[float, float]) -> float: """Calculate the overlap of interval b with respect to a.""" starts, ends = zip(a, b) return min(1.0, max(0.0, (min(ends) - max(starts)) / (a[1] - a[0])))