Source code for lightlab.util.search

''' Searching with actuate-measure functions,
    usually around peaks and monotonic functions
'''
import matplotlib.pyplot as plt
import numpy as np
from IPython import display

from lightlab.util.data import MeasuredFunction
from lightlab.util.io import RangeError
from lightlab import logger


[docs]class SearchRangeError(RangeError): ''' The first argument is direction, the second is a best guess ''' pass
[docs]def plotAfterPointMeasurement(trackerMF, yTarget=None): ''' This mutates trackerMF Args: trackerMF (MeasuredFunction): function that will be plotted yTarget (float): plotted as dashed line if not None ''' display.clear_output(wait=True) plt.cla() trackerMF.simplePlot('.-') if yTarget is not None: targLineSpan = plt.xlim() plt.plot(targLineSpan, 2 * [yTarget], '--k', lw=.5) display.display(plt.gcf())
[docs]def peakSearch(evalPointFun, startBounds, nSwarm=3, xTol=0., yTol=0., livePlot=False): ''' Returns the optimal input that gives you the peak, and the peak value You must set either xTol or yTol. Be careful with yTol! It is best used with a big swarm. It does not guarantee that you are that close to peak, just that the swarm is that flat This algorithm is a modified swarm that is robust to outliers, sometimes. Each iteration, it takes <nSwarm> measurements and looks at the best (highest). The update is calculated by shrinking the swarm around the index of the best value. It does not compare between iterations: that makes it robust to one-time outliers. It attributes weight only by order of y values in an iteration, not the value between iterations or the magnitude of differences between y's within an iteration Not designed to differentiate global vs. local maxima Args: evalPointFun (function): y=f(x) one argument, one return. The function that we want to find the peak of startBounds (list, ndarray): minimum and maximum x values that bracket the peak of interest nSwarm (int): number of evaluations per iteration. Use more if it's a narrow peak in a big bounding area xTol (float): if the swarm x's fall within this range, search returns successfully yTol (float): if the swarm y's fall within this range, search returns successfully livePlot (bool): for notebook plotting Returns: (float, float): best (x,y) point of the peak ''' # Argument checking if xTol is None and yTol is None: raise ValueError('Must specify either xTol or yTol, ' + 'or peak search will never converge.') nSwarm += (nSwarm + 1) % 2 tracker = MeasuredFunction([], []) def shrinkAround(arr, bestInd, shrinkage=.6): fulcrumVal = 2 * arr[bestInd] - np.mean(arr) return fulcrumVal + (arr - fulcrumVal) * shrinkage offsToMeasure = np.linspace(*startBounds, nSwarm) for iIter in range(20): # pylint: disable=unused-variable # Take measurements of the points measuredVals = np.zeros(nSwarm) for iPt, offs in enumerate(offsToMeasure): meas = evalPointFun(offs) measuredVals[iPt] = meas tracker.addPoint((offs, meas)) if livePlot: plotAfterPointMeasurement(tracker) # Move the lowest point closer bestInd = np.argmax(measuredVals) # print('iter =', iIter, '; offArr =', offsToMeasure, '; best =', np.max(measuredVals)) worstInd = np.argmin(measuredVals) if measuredVals[bestInd] - measuredVals[worstInd] < yTol \ or offsToMeasure[-1] - offsToMeasure[0] < xTol: # logger.debug('Converged on peak') break if worstInd == float(nSwarm - 1) / 2: logger.debug('Detected positive curvature') # break offsToMeasure = shrinkAround(offsToMeasure, bestInd) return (offsToMeasure[bestInd], measuredVals[bestInd])
[docs]def doesMFbracket(targetY, twoPointMF): yRange = twoPointMF.getRange() if targetY < yRange[0]: outOfRangeDirection = 'low' elif targetY > yRange[1]: outOfRangeDirection = 'high' else: outOfRangeDirection = 'in-range' return outOfRangeDirection
[docs]def bracketSearch(evalPointFun, targetY, startBounds, xTol, hardConstrain=False, livePlot=False): ''' Searches outwards until it finds two X values whose Y values are above and below the targetY. Stop conditions * brackets it: returns new bracketing x values * step decreases until below xTol: raises RangeError * 30 iterations: raises RangeError Args: evalPointFun (function): y=f(x) one argument, one return. The function that we want to find the target Y value of startBounds (list, ndarray): x values that usually do not bracket the value of interest xTol (float): if *domain* shifts become less than this, raises RangeError hardConstrain (bool, list): If list, will stay within those livePlot (bool): for notebook plotting Returns: ([float, float]): the bracketing range ''' startBounds = sorted(startBounds) if hardConstrain is True: constrainBounds = startBounds elif hardConstrain is False: constrainBounds = [-np.infty, np.infty] else: constrainBounds = hardConstrain tracker = MeasuredFunction([], []) def measureError(xVal): yVal = evalPointFun(xVal) tracker.addPoint((xVal, yVal)) err = yVal - targetY if livePlot: plotAfterPointMeasurement(tracker, targetY) return err # First check out what happens at the edges for x in startBounds: measureError(x) isIncreasing = tracker.ordi[1] > tracker.ordi[0] # Did it start bracketed? outOfRangeDirection = doesMFbracket(targetY, tracker) if outOfRangeDirection == 'in-range': return startBounds # Which way to go? We know that tracker has 2 points in it right now if ((isIncreasing and outOfRangeDirection == 'high') or (not isIncreasing and outOfRangeDirection == 'low')): searchDirection = 1 else: searchDirection = 0 lastX = tracker.absc[-searchDirection - 1] twoAgoX = lastX lastErr = tracker.ordi[-searchDirection - 1] - targetY twoAgoErr = lastErr absStep = np.diff(startBounds)[0] / 2 # this is tricky signedStep = absStep * (1 if (searchDirection == 1) else -1) newX = lastX + signedStep # Feel out in that direction constraintViolation = None for iIter in range(100): # pylint: disable=unused-variable # Case -2: went outside of constraints, fail if newX < constrainBounds[0]: constraintViolation = 'low' if isIncreasing else 'high' elif newX > constrainBounds[1]: constraintViolation = 'high' if isIncreasing else 'low' if constraintViolation is not None: raise SearchRangeError('Violated domain constraints', constraintViolation, lastX) # Case -1: step is too small, fail if abs(signedStep) < xTol: raise SearchRangeError('Target value out of range! ' + 'This is the best guess.', outOfRangeDirection, newX) # ### Do the actual measurement #### newErr = measureError(newX) # Case 0: definitely bracketed it if np.sign(lastErr * newErr) < 0: return [twoAgoX, newX] # Case 1: climbing up a peak # Case 2: skipped over a peak, didn't bracket it, # but it looks like still going up elif abs(newErr) < abs(lastErr): # Shift everything back into the past twoAgoX = lastX twoAgoErr = lastErr lastX = newX lastErr = newErr # Case 3: skipped over a peak and started going down else: # Drop back to the past and reduce the step lastX = twoAgoX lastErr = twoAgoErr signedStep /= 2 newX = lastX + signedStep raise Exception('Bracket search did 100 iterations and still did not converge')
[docs]def binarySearch(evalPointFun, targetY, startBounds, hardConstrain=False, xTol=0, yTol=0, livePlot=False): ''' Gives the x where ``evalPointFun(x) == targetY``, approximately. The final call to evalPointFun will be of this value, so no need to call it again, if your goal is to set to the target. xTol and yTol are OR-ed conditions. If one is satisfied, it will terminate successfully. You must specify at least one. Assumes that the function is monotonic in any direction It often works when there is a peak inside the ``startBounds``, although not always. Args: evalPointFun (function): y=f(x) one argument, one return. The function that we want to find the target Y value of startBounds (list, ndarray): minimum and maximum x values that bracket the peak of interest hardConstrain (bool, list): if not True, will do a bracketSearch. If list, will stay within those xTol (float): if *domain* shifts become less than this, terminates successfully yTol (float): if *range* shifts become less than this, terminates successfully livePlot (bool): for notebook plotting Returns: (float): the optimal X value ''' # Argument checking if xTol is None and yTol is None: raise ValueError('Must specify either xTol or yTol, ' + 'or binary search will never converge.') startBounds = sorted(startBounds) tracker = MeasuredFunction([], []) def measureError(xVal): yVal = evalPointFun(xVal) tracker.addPoint((xVal, yVal)) err = yVal - targetY if livePlot: plotAfterPointMeasurement(tracker, targetY) return err # First check out what happens at the edges for x in startBounds: measureError(x) isIncreasing = tracker.ordi[1] > tracker.ordi[0] outOfRangeDirection = doesMFbracket(targetY, tracker) if outOfRangeDirection != 'in-range': # Case 1: we won't tolerate it if hardConstrain is True: if outOfRangeDirection == 'high': bestGuess = tracker.absc[np.argmax(tracker.ordi)] else: bestGuess = tracker.absc[np.argmin(tracker.ordi)] raise SearchRangeError('binarySearch function value ' + 'outside of hard constraints! ' + 'Results invalid.', outOfRangeDirection, bestGuess) # Case 2: try to get it in range else: try: newStartBounds = bracketSearch(evalPointFun=evalPointFun, targetY=targetY, startBounds=startBounds, xTol=xTol, hardConstrain=hardConstrain, livePlot=livePlot) except SearchRangeError as err: logger.debug('Failed to bracket targetY=%s. Returning best guess', targetY) return err.args[2] try: return binarySearch(evalPointFun=evalPointFun, targetY=targetY, startBounds=newStartBounds, # important change xTol=xTol, yTol=yTol, hardConstrain=True, # important change livePlot=livePlot) except SearchRangeError as err: raise SearchRangeError('It was in range and then not, ' + 'so probably noise.', err.args[1], err.args[2]) # By now we are certain that the target is bounded by the start points thisX = np.mean(startBounds) absStep = np.diff(startBounds)[0] / 4 for _ in range(30): newErr = measureError(thisX) # Case 1: converged within tolerance if abs(newErr) < yTol or absStep < xTol: return thisX # Case 2: guess new point and reduce step by factor of 2 if isIncreasing: thisX -= np.sign(newErr) * absStep else: thisX += np.sign(newErr) * absStep absStep /= 2 raise Exception('Binary search did 30 iterations and still did not converge')