Source code for lightlab.util.sweep

''' Generalized sweep classes
'''

import matplotlib.pyplot as plt
import numpy as np
import time
from IPython import display
import matplotlib.cm
from collections import OrderedDict

from lightlab.util.data import argFlatten, rms
from lightlab.util.plot import plotCovEllipse
import lightlab.util.io as io
from lightlab import logger


[docs]def savePickle(savefile, data, compress=True): if compress: io.savePickleGzip(savefile, data) else: io.savePickle(savefile, data)
[docs]def loadPickle(savefile): try: data = io.loadPickleGzip(savefile) except FileNotFoundError: data = io.loadPickle(savefile) return data
[docs]class Sweeper(object): plotOptions = None monitorOptions = None def __init__(self): self.data = None self.savefile = None self.plotOptions = dict() self.monitorOptions = dict()
[docs] def gather(self): print('gather method must be overloaded in subclass')
[docs] def save(self, savefile=None): ''' Save data only Args: savefile (str/Path): file to save ''' if savefile is None: if self.savefile is not None: savefile = self.savefile else: raise ValueError('No save file specified') return savePickle(savefile, self.data)
[docs] def load(self, savefile=None): ''' This is basically make it so that gather() and load() have the same effect. It does not keep actuation or measurement members, only whatever was put in self.data Args: savefile (str/Path): file to load ''' if savefile is None: if self.savefile is not None: savefile = self.savefile else: raise ValueError('No save file specified') self.data = loadPickle(savefile)
[docs] def setPlotOptions(self, **kwargs): ''' Valid options for NdSweeper * plType * xKey * yKey * axArr * cmap-surf * cmap-curves Valid options for CommandControlSweeper * plType ''' for k, v in kwargs.items(): if k not in self.plotOptions.keys(): logger.warning(k, '%s is not a valid plot option.') logger.warning('Valid ones are %s', self.plotOptions.keys()) else: self.plotOptions[k] = v return self.plotOptions
[docs] def setMonitorOptions(self, **kwargs): ''' Valid options for NdSweeper * livePlot * plotEvery * stdoutPrint * runServer Valid options for CommandControlSweeper * livePlot * plotEvery * stdoutPrint * runServer * cmdCtrlPrint ''' for k, v in kwargs.items(): if k not in self.monitorOptions.keys(): logger.warning(k, '%s is not a valid monitor option.') logger.warning('Valid ones are %s', self.monitorOptions.keys()) else: self.monitorOptions[k] = v return self.monitorOptions
[docs] @classmethod def fromFile(cls, filename): new = cls() new.load(filename) return new
[docs]class Actuation(object): function = None domain = None doOnEveryPoint = None def __init__(self, function=None, domain=None, doOnEveryPoint=False): self.function = function self.domain = domain self.doOnEveryPoint = doOnEveryPoint
[docs]class NdSweeper(Sweeper): ''' Generic sweeper. Here's the difference between measure and parse: measure is a call to something, usually an instrument and some simpe post processing, like peak finding. * It is stored in data * When subsuming, only unique measurements are kept parse gets this in a form to visualize interactively, perhaps save and/or score along the way * When subsuming, all parse functions are maintained Make sure that measure is *bound* if it is a method ''' measure = None actuate = None parse = None static = None def __init__(self): ''' Specify the hard domain and actuate dimensions The sweep dimension order is major first, so put your slow actuations (e.g. tuning lasers) before the fast actuations (e.g. tuning current source) Args: domain (tuple, iterable): the sweep values, or a tuple of sweep values for different dimensions actuate (tuple, procedure-like): procedure, one argument per, that is called for each line of the sweep. Return is optional actuNames (tuple, str, None): Names of actuator return values. These are stored as data if present, under the key ''actuName-return'' measure (dict): dict of functions, no arguments, called at every point. Use descriptive keys please. parse (dict): dict of functions, operate on measurements, produce scalars Use descriptive keys please. ''' super().__init__() self.reinitActuation() self.measure = OrderedDict() self.actuate = OrderedDict() self.parse = OrderedDict() self.static = OrderedDict() self.monitorOptions = {'livePlot': False, 'plotEvery': 1, 'stdoutPrint': True, 'runServer': False} self.plotOptions = {'plType': 'curves', 'xKey': None, 'yKey': None, 'axArr': None, 'cmap-surf': matplotlib.cm.inferno, 'cmap-curves': matplotlib.cm.viridis} # pylint: disable=no-member
[docs] @classmethod def repeater(cls, nTrials): new = cls() new.addActuation('trial', lambda a: None, np.arange(nTrials)) return new
[docs] def gather(self, soakTime=None, autoSave=False, returnToStart=False): # pylint: disable=arguments-differ ''' Perform the sweep Args: soakTime (None, float): wait this many seconds at the first point to let things settle autoSave (bool): save data on completion, if savefile is specified returnToStart (bool): If True, actuates everything to the first point after the sweep completes Returns: None ''' # Initialize builders that start off with None grids if self.data is None: # oldData = None self.data = OrderedDict() else: # oldData = self.data.copy() for dKeySrc in (self.actuate, self.measure, self.parse): for dKey in dKeySrc.keys(): try: del self.data[dKey] except KeyError: pass try: swpName = 'Generic sweep in ' + ', '.join(self.actuate.keys()) prog = io.ProgressWriter(swpName, self.swpShape, **self.monitorOptions) # Soak at the first point if soakTime is not None: logger.debug('Soaking for %s seconds.', soakTime) for actuObj in self.actuate.values(): actuObj.function(actuObj.domain[0]) time.sleep(soakTime) for index in np.ndindex(self.swpShape): pointData = OrderedDict() # Everything that will be measured *at this index* for statKey, statMat in self.static.items(): pointData[statKey] = statMat[index] # Do the actuation, storing domain args and return values (if present) for iDim, actu in enumerate(self.actuate.items()): actuKey, actuObj = actu if actuObj.domain is None: x = None else: x = actuObj.domain[index[iDim]] pointData[actuKey] = x if iDim == self.actuDims - 1 or index[iDim + 1] == 0 or actuObj.doOnEveryPoint: y = actuObj.function(x) # The actual function call occurs here if y is not None: pointData[actuKey + '-return'] = y # Do the measurement, store return values for measKey, measFun in self.measure.items(): pointData[measKey] = measFun() # print(' Meas', measKey, ':', pointData[measKey]) # Parse and store for parseKey, parseFun in self.parse.items(): try: pointData[parseKey] = parseFun(pointData) except KeyError as err: if parseKey in self.parse.keys(): print('Parsing out of order.', 'Parser', parseKey, 'depends on parser', err, 'but is being executed first') raise err # Insert point data into the full matrix data builder # On the first go through, initialize array of correct datatype for k, v in pointData.items(): if all(i == 0 for i in index): if np.isscalar(v): self.data[k] = np.zeros(self.swpShape, dtype=float) else: self.data[k] = np.empty(self.swpShape, dtype=object) self.data[k][index] = v # Plotting during the sweep if self.monitorOptions['livePlot']: if all(i == 0 for i in index): axArr = None axArr = self.plot(axArr=axArr, index=index) flatIndex = np.ravel_multi_index(index, self.swpShape) if flatIndex % self.monitorOptions['plotEvery'] == 0: display.display(plt.gcf()) display.clear_output(wait=True) # Progress report prog.update() # End of the main loop except Exception as err: logger.error('Error while sweeping. Keeping data. %s', err) raise if returnToStart: for actuObj in self.actuate.values(): actuObj.function(actuObj.domain[0]) if autoSave: self.save()
[docs] def addActuation(self, name, function, domain, doOnEveryPoint=False): ''' Specify an actuation dimension: what is called, the domain values to use as arguments. Args: name (str): key for accessing this actuator's value data function (func): actuation function, usually linked to hardware. One argument. domain (ndarray, None): 1D array of arguments that will be passed to the function. If None, the function is called with a None argument every point (if doOnEveryPoint is True). doOnEveryPoint (bool): call this function in the inner loop (True) or once before the corresponding rows(False) ''' newActu = Actuation(function, domain, doOnEveryPoint) self.addActuationObject(name, newActu)
[docs] def addActuationObject(self, name, actuationObj): self.actuate[name] = actuationObj self._recalcSwpShape() # If any static data is present, expand it into the new dimension # when you add an actuation, it goes in the lowest index (highest number) # so if you do data_new[..., 0], you get data_old for statKey, statVal in self.static.items(): tileBy = (len(actuationObj.domain),) + statVal.ndim * (1,) self.static[statKey] = np.tile(statVal, tileBy).T
[docs] def reinitActuation(self): self.actuate = OrderedDict() self.static = OrderedDict() self._recalcSwpShape()
def _recalcSwpShape(self): self.actuDims = 0 # pylint: disable=attribute-defined-outside-init self.swpShape = () # pylint: disable=attribute-defined-outside-init for actu in self.actuate.values(): if actu.domain is not None: self.actuDims += 1 self.swpShape += (len(actu.domain), )
[docs] def addMeasurement(self, name, function): ''' Specify a measurement to be taken at every sweep point. Args: name (str): key for accessing this measurement's value data function (func): measurement function, usually linked to hardware. No arguments. ''' self.measure.update([(name, function)])
[docs] def addParser(self, name, function): ''' Adds additional parsing formulas to data, and reparses, if data has been taken Args: name (str): key for accessing this parser's value data function (func): parsing function, not linked to hardware. One dictionary argument. ''' self.parse.update([(name, function)]) try: self._reparse(name) except KeyError: self.data[name] = None
def _reparse(self, parseKeys=None): ''' Reprocess measured data into parsed data. If there is not enough data present, it does nothing. If the parser depends on Args: parseKeys (tuple, str, None): which parsers to recalculate. If None, does all. Execution order depends on addParser calls, not parseKeys Returns: None ''' if self.data is None: return if parseKeys is None: parseKeys = tuple(self.parse.keys()) else: parseKeys = argFlatten(parseKeys, typs=tuple) for pk, pFun in self.parse.items(): # We're indexing this way to make sure parsing is done in the order of parse attribute, not the order of parseKeys if pk not in parseKeys: continue tempDataMat = np.zeros(self.swpShape) for index in np.ndindex(self.swpShape): dataOfPt = OrderedDict() for datKey, datVal in self.data.items(): if np.any(datVal.shape != self.swpShape): logger.warning( 'Data member %s is wrong size for reparsing %s. Skipping.', datKey, pk) else: dataOfPt[datKey] = datVal[index] try: tempDataMat[index] = pFun(dataOfPt) except KeyError: logger.warning('Parser %s depends on unpresent data. Skipping.', pk) break else: self.data[pk] = tempDataMat
[docs] def addStaticData(self, name, contents): ''' Add a ndarray or scalar that can be referenced by parsers The array's shape must match that of the currently loaded actuation grid. If you then :meth:`addActuation`, the static data automatically expands in dimension to accomodate. Values are filled by tiling in the new dimension. Add static data after the actuations that have different static data, but before the actuations for which you want that data to be constant. Args: name (str): key for accessing this data contents (scalar, ndarray): data contents ''' if np.isscalar(contents): contents *= np.ones(self.swpShape) if np.any(contents.shape != self.swpShape): raise ValueError('Static data ' + name + ' is wrong shape for sweep.' + 'Need ' + str(self.swpShape) + '. Got ' + str(contents.shape) + 'The order that actuations and static data are added matter.') self.static[name] = contents
[docs] def subsume(self, other, useMinorOptions=False): ''' Makes the argument sweep a minor sweep within this one The new measurement dictionary will contain all measurements of both. If there is a duplicate key, the self measurement will take precedence Existing data is discarded. Args: other (NdSweeper): the minor sweep useMinorOptions (bool): where do the options come from? If False, they come from the major (i.e. self) ''' if issubclass(type(other), type(self)): new = self.copy(includeData=False) for aNam, aObj in other.actuate.items(): if aNam not in new.actuate.keys(): new.addActuationObject(aNam, aObj) for mNam, mVal in other.measure.items(): if mNam not in new.measure.keys(): new.addMeasurement(mNam, mVal) for pNam, pVal in other.parse.items(): if pNam not in new.parse.keys(): new.addParser(pNam, pVal) if useMinorOptions: new.plotOptions.update(other.plotOptions) new.monitorOptions.update(other.monitorOptions) return new elif type(other) is CommandControlSweeper: # do something else entirely raise NotImplementedError('subsuming CommandControlSweeper')
[docs] def copy(self, includeData=True): ''' Shallow copy, which means function pointers are maintained If includeData, it does a deep copy of data ''' new = NdSweeper() for aNam, aObj in self.actuate.items(): new.addActuationObject(aNam, aObj) for mNam, mVal in self.measure.items(): new.addMeasurement(mNam, mVal) for pNam, pVal in self.parse.items(): new.addParser(pNam, pVal) if includeData: from copy import deepcopy new.data = deepcopy(self.data) new.plotOptions.update(self.plotOptions) new.monitorOptions.update(self.monitorOptions) return new
[docs] def plot(self, slicer=None, tempData=None, index=None, axArr=None, pltKwargs=None): ''' Plots Much of the behavior to figure out labels and numbers for axes comes from the plotOptions attribute. The xKeys and yKeys are keys within this objects **data** dictionary (actuation, measurement, and parsers) The total number of plots will be the product of len('xKey') and len('yKey'). xKeys can be anything, including parsed data members. By default it is the minor actuation variable yKeys can also be anything that has scalar elements. By default it is everything that is currently present, except xKeys and non-scalars When doing line plots in 2D sweeps, the legend does automatic labelling. Each line must correspond to an actuation dimension, otherwise it doesn't make sense. This is despite the fact that the xKeys can still be anything. Usually, each line corresponds to a particular domain value of the major sweep axis; however, if that is specified as an xKey, the lines will correspond to the minor axis. Surface plotting: Ignores whatever is in xKeys. The plotting domain is locked to the actuation domain in order to keep a rectangular grid. The values indicated in yKeys will become color data. Args: slicer (tuple, slice): domain slices axArr (ndarray), plt.axis): axes to plot on. Equivalent to what is returned by this method pltKwargs: passed through to plotting function Todo: * Graphics caching for 2D line plots ''' global hCurves # pylint: disable=global-statement if index is None or np.all(np.array(index) == 0): hCurves = None if pltKwargs is None: pltKwargs = {} # Which data dict to use and its dimensionality if tempData is None: fullData = self.data else: fullData = tempData if fullData is not None: plotDims = list(fullData.values())[0].ndim # Instead of self.actuDims else: plotDims = self.actuDims assertValidPlotType(self.plotOptions['plType'], plotDims, type(self)) # Cuts down the domain to the region of interest if slicer is None: slicer = (slice(None),) * plotDims else: slicer = argFlatten(slicer, typs=tuple) # Figure out what the keys of data are actuationKeys = list(self.actuate.keys()) xKeys = argFlatten(self.plotOptions['xKey'], typs=tuple) yKeys = argFlatten(self.plotOptions['yKey'], typs=tuple) if len(xKeys) == 0: # default is the most minor sweep domain xKeys = (actuationKeys[-1], ) if len(yKeys) == 0: # default is all scalar ranges for datKey, datVal in fullData.items(): if (datKey not in xKeys and datKey not in actuationKeys and np.isscalar(datVal.item(0))): yKeys += (datKey, ) # Check it if (len(xKeys) == 0 or len(yKeys) == 0): raise ValueError('No axis key specified explicitly or found in self.actuate') for k in xKeys + yKeys: if k not in fullData.keys(): raise KeyError(k + ' not found in data keys. ' + 'Available data are ' + ', '.join(fullData.keys())) # Make grid of axes based on number of pairs of variables plotArrShape = np.array([len(yKeys), len(xKeys)]) if axArr is not None: pass elif self.plotOptions['axArr'] is not None: axArr = self.plotOptions['axArr'] else: _, axArr = plt.subplots(nrows=plotArrShape[0], ncols=plotArrShape[1], sharex='col', figsize=(10, plotArrShape[0] * 2.5)) # pylint: disable=unused-variable axArr = np.array(axArr) # Force into a two dimensional array if axArr.ndim == 2: pass elif axArr.ndim == 1: if np.all(plotArrShape == 1): axArr = np.expand_dims(axArr, 0) elif plotArrShape[0] == 1: axArr = np.expand_dims(axArr, 0) elif plotArrShape[1] == 1: axArr = np.expand_dims(axArr, 1) elif axArr.ndim == 0: if np.all(plotArrShape == 1): axArr = np.expand_dims(np.expand_dims(axArr, 0), 0) # Check it if np.any(axArr.shape != plotArrShape): raise ValueError('Shape of axArray does not match plotArrShape') # Prepare options for plotting that do not depend on index or line no. sample_xK = xKeys[0] sample_xData = fullData[sample_xK][slicer] if self.plotOptions['plType'] == 'curves': pltArgs = ('.-', ) if plotDims == 1: if hCurves is None: hCurves = np.empty(axArr.shape, dtype=object) elif plotDims == 2: invertDomainPriority = False autoLabeling = (plotDims == self.actuDims) if autoLabeling: if actuationKeys[0] != sample_xK: curveKey = actuationKeys[0] else: curveKey = actuationKeys[1] if index is not None: index = index[::-1] invertDomainPriority = True nLines = sample_xData.shape[0 if not invertDomainPriority else 1] colors = self.plotOptions['cmap-curves'](np.linspace(0, 1, nLines)) # Loop over axes (i.e. axis key variables) and plot for iAx, ax in np.ndenumerate(axArr): xK = xKeys[iAx[1]] yK = yKeys[iAx[0]] # dereference and slice xData = fullData[xK][slicer] yData = fullData[yK][slicer] if self.plotOptions['plType'] == 'curves': if plotDims == 1: # slice it if index is not None: xData = xData[:index[0] + 1] yData = yData[:index[0] + 1] ax.cla() curv = ax.plot(xData, yData, *pltArgs, **pltKwargs) # caching the part of the line that has already been drawn if hCurves[iAx] is not None: # pylint:disable=unsubscriptable-object try: hCurves[iAx][0].remove() except ValueError: # it was probably an old one pass hCurves[iAx] = curv elif plotDims == 2: ax.cla() # no caching, just clear if invertDomainPriority: xData = xData.T yData = yData.T for iLine in range(nLines): # slicing data based on what the line and index are xLine = xData[iLine, :] yLine = yData[iLine, :] if index is None: pass elif iLine < index[-2]: # these lines are complete pass elif iLine == index[-2]: # these lines are in-progress xLine = xLine[slice(index[-1] + 1)] yLine = yLine[slice(index[-1] + 1)] elif iLine > index[-2]: # these have not been started break # line options pltKwargs['color'] = colors[iLine][:3] if autoLabeling: curveValue = self.actuate[curveKey].domain[iLine] pltKwargs['label'] = '{} = {:.2f}'.format(curveKey, curveValue) ax.plot(xLine, yLine, *pltArgs, **pltKwargs) # legend if autoLabeling and iAx[0] == 0 and iAx[1] == plotArrShape[1] - 1: # AND it is the top right ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) else: raise ValueError('Too many dimensions in sweep to plot. ' 'This should have been caught by assertValidPlotType.') if iAx[0] == plotArrShape[0] - 1: ax.set_xlabel(xK) else: ax.tick_params(labelbottom=False) if iAx[1] == 0: ax.set_ylabel(yK) else: ax.tick_params(labelleft=False) elif self.plotOptions['plType'] == 'surf': # xKeys we treat as meaningless. just use the actuation domains # We treat yData as color data doms = [None] * 2 for iDim, actuObj in enumerate(self.actuate.values()): doms[iDim] = actuObj.domain[slicer[iDim]] domainGrids = np.meshgrid(*doms[::-1], indexing='xy') pltKwargs['cmap'] = pltKwargs.pop('cmap', self.plotOptions['cmap-surf']) pltKwargs['shading'] = pltKwargs.pop('shading', 'gouraud') cax = ax.pcolormesh(*domainGrids, yData, **pltKwargs) plt.gcf().colorbar(cax, ax=ax) ax.autoscale(tight=True) ax.set_title(yK) if iAx[0] == plotArrShape[0] - 1: ax.set_xlabel(actuationKeys[1]) else: ax.tick_params(labelbottom=False) ax.set_ylabel(actuationKeys[0]) return axArr
[docs] def saveObj(self, savefile=None): ''' Also saves what are the actuation keys. This is important for plotting when you reload ''' if savefile is None: if self.savefile is not None: savefile = self.savefile else: raise ValueError('No save file specified') self.data['actuation-keys'] = list(self.actuate.keys()) super().save(savefile) self.data.pop('actuation-keys')
[docs] @classmethod def loadObj(cls, savefile, functionSource=None): ''' savefile must have been saved with saveObj. It restores actuation names and domains to help with plotting. Functions referring to actuation and measurement cannot be saved. functionSource: an instantiated object of class `cls` If you give it a functionSource, then those can be restored as well. This is very useful if you have a parser such as live plot spectra, or move stuff here or there. Also useful if you want to re-gather for some reason. ''' newObj = cls.fromFile(savefile) # Restore actuations try: actuationKeys = newObj.data.pop('actuation-keys') except KeyError: pass else: for iAct, actuName in enumerate(actuationKeys): # Try to extract the actuation function (not domain) if functionSource is not None: actuObj = functionSource.actuate[actuName] else: # no function was given, so gathering won't work actuObj = Actuation() # Full data as taken, which is N-dimensional actData = newObj.data[actuName] # Extract one vector along the right direction to serve as domain sliceOneDim = [0] * len(actuationKeys) sliceOneDim[iAct] = slice(None) actuObj.domain = actData[tuple(sliceOneDim)] # Do the full add newObj.addActuationObject(actuName, actuObj) newObj._recalcSwpShape() # pylint: disable=protected-access if functionSource is not None: # Restore parsers. Do not reparse them newObj.parse = functionSource.parse # Restore measurements functions newObj.measure = functionSource.measure return newObj
[docs] def load(self, savefile=None): super().load(savefile) self._recalcSwpShape()
def __repr__(self): retstr = ("NdSweeper object\n" "----------------\n") retstr += "Measurements:\t" if self.measure: retstr += ", ".join("\"{}\"".format(k) for k in self.measure.keys()) else: retstr += "None" retstr += "\n" retstr += "Actuations:\t" if self.actuate: retstr += ", ".join("\"{}\"".format(k) for k in self.actuate.keys()) else: retstr += "None" retstr += "\n" retstr += "Parse:\t\t" if self.parse: retstr += ", ".join("\"{}\"".format(k) for k in self.parse.keys()) else: retstr += "None" retstr += "\n" retstr += "Static:\t\t" if self.static: retstr += ", ".join("\"{}\"".format(k) for k in self.static.keys()) else: retstr += "None" retstr += "\n" retstr += "Monitor Opt.:\t" + str(self.monitorOptions) + "\n" retstr += "Plot Options:\t" + str(self.plotOptions) + "\n" retstr += "Data:\t\t" data_str = [] for data_key, data_val in self.data.items(): data_str.append("<{} {}> \"{}\"".format(data_val.shape, data_val.dtype, data_key)) if data_str: retstr += "\n\t\t".join(data_str) else: retstr += "None" return retstr
[docs]def simpleSweep(actuate, domain, measure=None): ''' Basic sweep in one dimension, without function keys, parsing, or plotting. Args: actuate (function): a procedure or function of one argument called at every point domain (ndarray): elements passed as an argument to actuate for each point measure (function, None): a function of no arguments called at every point. None means the return of actuate will act as the measurement Returns: (ndarray): what is measured. Same length as domain ''' swpObj = NdSweeper() swpObj.addActuation('act0', actuate, domain) # pylint:disable=no-member if measure is not None: swpObj.addMeasurement('meas0', measure) swpObj.setMonitorOptions(stdoutPrint=False) swpObj.gather() if measure is not None: return swpObj.data['meas0'] else: return swpObj.data['act0-return']
[docs]class CommandControlSweeper(Sweeper): ''' Generic command-control sweep for evaluating a controller. The command function called at each point takes one argument that is an array (length M) and returns an array **of equal length**. The sweep is N (<= M) dimensional. * The user specifies the mapping between the sweep domain and the argument/return array indeces * The user specifies defaults for the other (M-N) arguments * Some of the uncontrolled arguments can be monitored Todo: How can we get this subsumed by a NdSweeper for trial repetition. CommandControlSweeper shouldn't be able to subsume as major ''' def __init__(self, evaluate, defaultArg, swpInds, domain, nTrials=1): ''' Args: evaluate (function): called at each point with array args/returns of equal length defaultArg (ndarray): default value that will be sent to the evaluate function swpIndeces (tuple, int): which channels to sweep domain (tuple, iterable): the values over which the sweep channels will be swept ''' super().__init__() self.evaluate = evaluate self.isScalar = np.isscalar(defaultArg) if self.isScalar: defaultArg = [defaultArg] self.defaultArg = np.array(defaultArg, dtype=float) self.allDims = len(self.defaultArg) self.swpInds = argFlatten(swpInds, typs=tuple) self.swpDims = len(self.swpInds) self.domain = argFlatten(domain, typs=tuple) self.swpShape = tuple(len(dom) for dom in self.domain) if len(self.domain) != self.swpDims: raise ValueError('domain and swpInds must have the same dimension.' + 'Got {} and {}'.format(len(self.domain), len(self.swpInds))) self.plotOptions = {'plType': 'curves'} self.monitorOptions = {'livePlot': False, 'plotEvery': 1, 'stdoutPrint': True, 'runServer': False, 'cmdCtrlPrint': True} # Generate actuation sweep grid self.cmdGrid = np.array(np.meshgrid(*self.domain)).T assert(self.cmdGrid.shape == self.swpShape + (self.swpDims,)) self.nTrials = nTrials
[docs] def saveObj(self, savefile=None): ''' Instead of just saving data, save the whole damn thing. Cannot save evaluate function because it is unbound. ''' if savefile is None: if self.savefile is not None: savefile = self.savefile else: raise ValueError('No save file specified') tempEvalRef = self.evaluate self.evaluate = None savePickle(savefile, self) self.evaluate = tempEvalRef
[docs] @classmethod def loadObj(cls, savefile): ''' This is basically make it so that gather() and load() have the same effect. It does not keep actuation or measurement members, only whatever was put in self.data ''' return loadPickle(savefile)
[docs] def gather(self, autoSave=False, randomize=False): # pylint: disable=arguments-differ ''' Executes the sweep Todo: Store all outputs, but provide a way just to get the controlled ones ''' measGrid = np.zeros((self.nTrials,) + self.swpShape + (self.allDims,)) swpName = 'Generic command-control sweep' prog = io.ProgressWriter(swpName, (self.nTrials,) + self.swpShape, **self.monitorOptions) if randomize: randizers = [None] * self.swpDims for iDim in range(self.swpDims): randizers[iDim] = np.random.permutation(self.swpShape[iDim]) for index in np.ndindex((self.nTrials,) + self.swpShape): if randomize: index = list(index) for iDim in range(self.swpDims): index[iDim + 1] = randizers[iDim][index[iDim + 1]] index = tuple(index) # iTrial = index[0] gridIndex = index[1:] cmdArr = self.defaultArg.copy() cmdArr[np.array(self.swpInds)] = self.cmdGrid[gridIndex] if self.isScalar: cmdArr = cmdArr[0] measArr = self.evaluate(cmdArr) if self.isScalar: measArr = np.array([measArr]) measGrid[index] = measArr self.data = measGrid if self.monitorOptions['livePlot']: self.plot(index) flatIndex = np.ravel_multi_index(index, (self.nTrials,) + self.swpShape) if flatIndex % self.monitorOptions['plotEvery'] == 0: display.clear_output(wait=True) display.display(plt.gcf()) # Note this may have to be interAx instead of gcf if self.monitorOptions['cmdCtrlPrint']: print('(trial, gridIndex) =', index) print(' cmdArr = ' + ' '.join(['{:.3f}'.format(v) for v in cmdArr])) print(' measArr = ' + ' '.join(['{:.3f}'.format(v) for v in measArr])) prog.update() if autoSave: self.save()
[docs] def toSweepData(self): ''' Using the old school temporary definition from conductor This will eventually be deprecated ''' monitorInds = tuple(filter(lambda y: y not in self.swpInds, range(self.allDims))) cmdMat = self.cmdGrid measMat = self.data[..., self.swpInds] monitMat = self.data[..., monitorInds] if len(monitorInds) > 0 else None return (cmdMat, measMat, monitMat)
[docs] def plot(self, index=None, axArr=None): plType = self.plotOptions['plType'] assertValidPlotType(plType, self.swpDims, type(self)) if plType == 'cmdErr': plotCmdCtrl(self.toSweepData(), index=index, interactive=True, ax=axArr) elif plType == 'curves' and self.swpDims == 1: # currently only works in 1 dimension if axArr is not None: plt.sca(axArr) elif index is None or np.all(index == 0): plt.subplots(figsize=(6, 6)) else: plt.cla() # display.clear_output(wait=True) cmdMat, measMat, monitMat = self.toSweepData() # pylint: disable=unused-variable xFull = cmdMat[:, 0] # All points over trials and the sweep parameter allPts = measMat[..., 0] # Plot the in-progress line if index is not None: if index[1] > 0: xInProgress = cmdMat[:index[1] + 1, 0] yInProgress = allPts[index[0], :index[1] + 1] plt.plot(xInProgress, yInProgress, 'g.-') # Plot the lines that have finished, and their statistics if index is not None: if index[0] == 0: return else: allPts = measMat[:index[0], :, 0] means = np.mean(allPts, axis=0) stddevs = np.std(allPts, axis=0) # Plot dots plt.plot(xFull, allPts.T, 'k.') # Plot means plt.plot(xFull, means, 'r', lw=2) # Plot error bars for upDown in [-1, 1]: y = means + upDown * stddevs plt.plot(xFull, y, 'b') # Make the axes pretty if axArr is None: plt.axis('square') # Expand the window to fit all points minMaxCmd = np.array([np.min(cmdMat), np.max(cmdMat)]) minMaxCmd[0] = min(np.min(cmdMat), np.min(allPts)) minMaxCmd[1] = max(np.max(cmdMat), np.max(allPts)) plt.xlim(minMaxCmd) plt.ylim(minMaxCmd) plt.xlabel('Command value') plt.ylabel('Evaluated value') plt.plot(minMaxCmd, minMaxCmd, '--k')
[docs] def score(self, bits=False, worstCase=False): ''' Takes full sweep data and returns the worst-case accuracy and precision Args: bits (bool): if true, returns values as bits of dynamic range worstCase (bool): if true, takes the performance at the worst weight, else averages via RMS ''' cmdWeights = self.cmdGrid measWeights = self.data[..., np.array(self.swpInds)] # Statistics of every dimension at every grid point (so we're norming over trials) -- # errRmsVsWeight = rms(measWeights - cmdWeights, axis=0) # Total error meanVsWeight = np.mean(measWeights, axis=0) errMeanVsWeight = meanVsWeight - cmdWeights errStddevVsWeight = rms(measWeights - meanVsWeight, axis=0) # Statistics normed over channels at every grid point # netErrRmsVsWeight = rms(errRmsVsWeight, axis=-1) netErrMeanVsWeight = rms(errMeanVsWeight, axis=-1) netErrStddevVsWeight = rms(errStddevVsWeight, axis=-1) # Take the worst case grid point consolidateErrorVsWeight = lambda x: np.max(np.abs(x)) if worstCase else rms(x, axis=None) accuracy = consolidateErrorVsWeight(netErrMeanVsWeight) # This gives accuracy precision = consolidateErrorVsWeight(netErrStddevVsWeight) # Precision if not bits: return accuracy, precision else: domainSpan = np.abs(np.max(cmdWeights) - np.min(cmdWeights)) accuracyBits = np.log2(domainSpan / accuracy) precisionBits = np.log2(domainSpan / precision) return accuracyBits, precisionBits
interAx = None hCurves = None hArrow = None hEllipse = None
[docs]def plotCmdCtrl(sweepData, index=None, ax=None, interactive=False): ''' sweepData should have ALL the command weights specified Args: sweepData (tuple): cmdWeights, measWeights, monitWeights (optional) measWeights has shape (nTrials, len(swp1), len(sp2) or 1, len(sweepingChannels)) index (tuple): tells which parts of measured weights are valid. If None, assumes sweepData is complete interactive (bool): show plot immediately after call, even with incomplete data (index must be specified) Todo: Fix the global hack for persistent plots -- actually, this is fine ''' global interAx # pylint: disable=global-statement global hArrow # pylint: disable=global-statement global hEllipse # pylint: disable=global-statement cmdWeights, measWeights, monitWeights = sweepData gridShape = cmdWeights.shape[:-1] is2D = (cmdWeights.shape[-1] == 2) if index is None: # Just do a bunch of incrementals over the last trial without refresh interAx = ax if is2D: for gridIndex in np.ndindex(gridShape): fullIndex = (measWeights.shape[0] - 1, *gridIndex) plotCmdCtrl(sweepData, index=fullIndex, interactive=False) else: plotCmdCtrl(sweepData, index=(measWeights.shape[0], 0, 0), interactive=False) else: # This is done only on the first point of initialization if interAx is None or all(i == 0 for i in index): # Initialize plotting objects if ax is None: fig, ax = plt.subplots(figsize=(5, 5)) # pylint: disable=unused-variable interAx = ax plt.cla() if is2D: hArrow = np.empty(gridShape, dtype=object) hEllipse = np.empty(gridShape, dtype=object) # Lay down grid if is2D: xSweep = cmdWeights[:, 0, 0] xRange = (np.min(xSweep), np.max(xSweep)) ySweep = cmdWeights[0, :, 1] yRange = (np.min(ySweep), np.max(ySweep)) for sPt in xSweep: interAx.plot(2 * [sPt], yRange, 'k-') for sPt in ySweep: interAx.plot(xRange, 2 * [sPt], 'k-') # plt.xlim(xRange + np.array([-1, 1]) * np.diff(xRange)[0] * .1) # plt.ylim(yRange + np.array([-1, 1]) * np.diff(yRange)[0] * .1) # interAx.set(aspect='equal') else: xSweep = cmdWeights[:, 0] xRange = (np.min(xSweep), np.max(xSweep)) interAx.plot(xRange, 2 * [0], 'k-') plt.xlim(xRange) plt.ylim(xRange - np.mean(xRange)) # interAx.set(aspect='equal') # This is done for every point if not is2D: if index[0] == 0: return plt.cla() x = cmdWeights[:, 0] # All points over trials and the sweep parameter allPts = measWeights[:index[0], :, 0] allErrors = allPts - x meanErrors = np.mean(allErrors, axis=0) stddevs = np.std(allPts, axis=0) # Plot dots for iv, v in np.ndenumerate(allErrors): interAx.plot(x[iv[1]], v, '.k') # Plot means interAx.plot(x, meanErrors, 'r', lw=2) # Plot error bars for upDown in [-1, 1]: y = meanErrors + upDown * stddevs interAx.plot(x, y, 'b') # Monitor plotting if it's there: if monitWeights is not None: allPts = monitWeights[:index[0], :, 0] for iv, v in np.ndenumerate(allPts): interAx.plot(x[iv[1]], v, '.m') monitMean = np.mean(allPts, axis=0) interAx.plot(x, monitMean, 'm', lw=2) stddevs = np.std(allPts, axis=0) y = monitMean[:, None] + np.array([[-1, 1]]) * stddevs[:, None] interAx.plot(x, y, 'g') else: # 2D gridIndex = index[1:] valsAtThisGridPt = measWeights[(slice(index[0] + 1), *index[1:])] # plot newest raw point itself (and others at this grid point) for pt in valsAtThisGridPt: interAx.plot(*pt, '.k') # plot mean error line mean = np.mean(valsAtThisGridPt, axis=0) target = cmdWeights[gridIndex] arro = interAx.plot(*zip(target, mean), 'r', lw=2) # pylint: disable=zip-builtin-not-iterating # plot variance ellipse if index[0] > 0: cov = np.cov(valsAtThisGridPt, rowvar=False) elli = plotCovEllipse(cov, mean, volume=0.5, ax=interAx, ec='b', fc='none') else: elli = None if interactive: # Replace previous graphics objects from this grid point if hArrow[gridIndex] is not None: hArrow[gridIndex][0].remove() hArrow[gridIndex] = arro if hEllipse[gridIndex] is not None: hEllipse[gridIndex].remove() hEllipse[gridIndex] = elli
# if interactive: # display.clear_output(wait=True) # display.display(interAx.figure) # Information on types of plots # (set of possible dimensions, type of sweep in {'nd', 'cmd'}) pTypes = {} pTypes['curves'] = ({1, 2}, {NdSweeper.__name__, CommandControlSweeper.__name__}) pTypes['surf'] = ({2}, {NdSweeper.__name__}) pTypes['cmdErr'] = ({1, 2}, {CommandControlSweeper.__name__})
[docs]def availablePlots(dims=None, swpType=None): ''' Filter by dims and swpType If the argument is none, do not filter by that ''' avail = [] for k, v in pTypes.items(): if dims is None or dims in v[0]: if swpType is None \ or type(swpType) is type and swpType.__name__ in v[1] \ or type(swpType) is str and swpType in v[1]: avail.append(k) return avail
[docs]def assertValidPlotType(plType, dims=None, swpClass=None): if plType not in availablePlots(dims, swpClass): errStr = ['Invalid plot type.'] errStr.append(f'This sweep is a {dims}-dimensional {swpClass.__name__}.') if plType not in availablePlots(): errStr.append(f'{plType} is not a valid plot type at all.') else: errStr.append(f'{plType} is not a valid plot type for this kind of sweep.') errStr.append('Available plots are: {}'.format(', '.join(availablePlots(dims, swpClass)))) logger.error('\n'.join(errStr)) raise KeyError(plType)