www.gusucode.com > SAE RBM 程序MATLAB源码代码实现的一个关于sae的例子 > minFunc/WolfeLineSearch.m
function [t,f_new,g_new,funEvals,H] = WolfeLineSearch(... x,t,d,f,g,gtd,c1,c2,LS,maxLS,tolX,debug,doPlot,saveHessianComp,funObj,varargin) % % Bracketing Line Search to Satisfy Wolfe Conditions % % Inputs: % x: starting location % t: initial step size % d: descent direction % f: function value at starting location % g: gradient at starting location % gtd: directional derivative at starting location % c1: sufficient decrease parameter % c2: curvature parameter % debug: display debugging information % LS: type of interpolation % maxLS: maximum number of iterations % tolX: minimum allowable step length % doPlot: do a graphical display of interpolation % funObj: objective function % varargin: parameters of objective function % % Outputs: % t: step length % f_new: function value at x+t*d % g_new: gradient value at x+t*d % funEvals: number function evaluations performed by line search % H: Hessian at initial guess (only computed if requested % Evaluate the Objective and Gradient at the Initial Step if nargout == 5 [f_new,g_new,H] = feval(funObj, x + t*d, varargin{:}); else [f_new,g_new] = feval(funObj, x + t*d, varargin{:}); end funEvals = 1; gtd_new = g_new'*d; % Bracket an Interval containing a point satisfying the % Wolfe criteria LSiter = 0; t_prev = 0; f_prev = f; g_prev = g; gtd_prev = gtd; done = 0; while LSiter < maxLS %% Bracketing Phase if ~isLegal(f_new) || ~isLegal(g_new) if 0 if debug fprintf('Extrapolated into illegal region, Bisecting\n'); end t = (t + t_prev)/2; if ~saveHessianComp && nargout == 5 [f_new,g_new,H] = feval(funObj, x + t*d, varargin{:}); else [f_new,g_new] = feval(funObj, x + t*d, varargin{:}); end funEvals = funEvals + 1; gtd_new = g_new'*d; LSiter = LSiter+1; continue; else if debug fprintf('Extrapolated into illegal region, switching to Armijo line-search\n'); end t = (t + t_prev)/2; % Do Armijo if nargout == 5 [t,x_new,f_new,g_new,armijoFunEvals,H] = ArmijoBacktrack(... x,t,d,f,f,g,gtd,c1,max(0,min(LS-2,2)),tolX,debug,doPlot,saveHessianComp,... funObj,varargin{:}); else [t,x_new,f_new,g_new,armijoFunEvals] = ArmijoBacktrack(... x,t,d,f,f,g,gtd,c1,max(0,min(LS-2,2)),tolX,debug,doPlot,saveHessianComp,... funObj,varargin{:}); end funEvals = funEvals + armijoFunEvals; return; end end if f_new > f + c1*t*gtd || (LSiter > 1 && f_new >= f_prev) bracket = [t_prev t]; bracketFval = [f_prev f_new]; bracketGval = [g_prev g_new]; break; elseif abs(gtd_new) <= -c2*gtd bracket = t; bracketFval = f_new; bracketGval = g_new; done = 1; break; elseif gtd_new >= 0 bracket = [t_prev t]; bracketFval = [f_prev f_new]; bracketGval = [g_prev g_new]; break; end temp = t_prev; t_prev = t; minStep = t + 0.01*(t-temp); maxStep = t*10; if LS == 3 if debug fprintf('Extending Braket\n'); end t = maxStep; elseif LS ==4 if debug fprintf('Cubic Extrapolation\n'); end t = polyinterp([temp f_prev gtd_prev; t f_new gtd_new],doPlot,minStep,maxStep); else t = mixedExtrap(temp,f_prev,gtd_prev,t,f_new,gtd_new,minStep,maxStep,debug,doPlot); end f_prev = f_new; g_prev = g_new; gtd_prev = gtd_new; if ~saveHessianComp && nargout == 5 [f_new,g_new,H] = feval(funObj, x + t*d, varargin{:}); else [f_new,g_new] = feval(funObj, x + t*d, varargin{:}); end funEvals = funEvals + 1; gtd_new = g_new'*d; LSiter = LSiter+1; end if LSiter == maxLS bracket = [0 t]; bracketFval = [f f_new]; bracketGval = [g g_new]; end %% Zoom Phase % We now either have a point satisfying the criteria, or a bracket % surrounding a point satisfying the criteria % Refine the bracket until we find a point satisfying the criteria insufProgress = 0; Tpos = 2; LOposRemoved = 0; while ~done && LSiter < maxLS % Find High and Low Points in bracket [f_LO LOpos] = min(bracketFval); HIpos = -LOpos + 3; % Compute new trial value if LS == 3 || ~isLegal(bracketFval) || ~isLegal(bracketGval) if debug fprintf('Bisecting\n'); end t = mean(bracket); elseif LS == 4 if debug fprintf('Grad-Cubic Interpolation\n'); end t = polyinterp([bracket(1) bracketFval(1) bracketGval(:,1)'*d bracket(2) bracketFval(2) bracketGval(:,2)'*d],doPlot); else % Mixed Case %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% nonTpos = -Tpos+3; if LOposRemoved == 0 oldLOval = bracket(nonTpos); oldLOFval = bracketFval(nonTpos); oldLOGval = bracketGval(:,nonTpos); end t = mixedInterp(bracket,bracketFval,bracketGval,d,Tpos,oldLOval,oldLOFval,oldLOGval,debug,doPlot); end % Test that we are making sufficient progress if min(max(bracket)-t,t-min(bracket))/(max(bracket)-min(bracket)) < 0.1 if debug fprintf('Interpolation close to boundary'); end if insufProgress || t>=max(bracket) || t <= min(bracket) if debug fprintf(', Evaluating at 0.1 away from boundary\n'); end if abs(t-max(bracket)) < abs(t-min(bracket)) t = max(bracket)-0.1*(max(bracket)-min(bracket)); else t = min(bracket)+0.1*(max(bracket)-min(bracket)); end insufProgress = 0; else if debug fprintf('\n'); end insufProgress = 1; end else insufProgress = 0; end % Evaluate new point if ~saveHessianComp && nargout == 5 [f_new,g_new,H] = feval(funObj, x + t*d, varargin{:}); else [f_new,g_new] = feval(funObj, x + t*d, varargin{:}); end funEvals = funEvals + 1; gtd_new = g_new'*d; LSiter = LSiter+1; if f_new > f + c1*t*gtd || f_new >= f_LO % Armijo condition not satisfied or not lower than lowest % point bracket(HIpos) = t; bracketFval(HIpos) = f_new; bracketGval(:,HIpos) = g_new; Tpos = HIpos; else if abs(gtd_new) <= - c2*gtd % Wolfe conditions satisfied done = 1; elseif gtd_new*(bracket(HIpos)-bracket(LOpos)) >= 0 % Old HI becomes new LO bracket(HIpos) = bracket(LOpos); bracketFval(HIpos) = bracketFval(LOpos); bracketGval(:,HIpos) = bracketGval(:,LOpos); if LS == 5 if debug fprintf('LO Pos is being removed!\n'); end LOposRemoved = 1; oldLOval = bracket(LOpos); oldLOFval = bracketFval(LOpos); oldLOGval = bracketGval(:,LOpos); end end % New point becomes new LO bracket(LOpos) = t; bracketFval(LOpos) = f_new; bracketGval(:,LOpos) = g_new; Tpos = LOpos; end if ~done && abs((bracket(1)-bracket(2))*gtd_new) < tolX if debug fprintf('Line Search can not make further progress\n'); end break; end end %% if LSiter == maxLS if debug fprintf('Line Search Exceeded Maximum Line Search Iterations\n'); end end [f_LO LOpos] = min(bracketFval); t = bracket(LOpos); f_new = bracketFval(LOpos); g_new = bracketGval(:,LOpos); % Evaluate Hessian at new point if nargout == 5 && funEvals > 1 && saveHessianComp [f_new,g_new,H] = feval(funObj, x + t*d, varargin{:}); funEvals = funEvals + 1; end end %% function [t] = mixedExtrap(x0,f0,g0,x1,f1,g1,minStep,maxStep,debug,doPlot); alpha_c = polyinterp([x0 f0 g0; x1 f1 g1],doPlot,minStep,maxStep); alpha_s = polyinterp([x0 f0 g0; x1 sqrt(-1) g1],doPlot,minStep,maxStep); if alpha_c > minStep && abs(alpha_c - x1) < abs(alpha_s - x1) if debug fprintf('Cubic Extrapolation\n'); end t = alpha_c; else if debug fprintf('Secant Extrapolation\n'); end t = alpha_s; end end %% function [t] = mixedInterp(bracket,bracketFval,bracketGval,d,Tpos,oldLOval,oldLOFval,oldLOGval,debug,doPlot); % Mixed Case %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% nonTpos = -Tpos+3; gtdT = bracketGval(:,Tpos)'*d; gtdNonT = bracketGval(:,nonTpos)'*d; oldLOgtd = oldLOGval'*d; if bracketFval(Tpos) > oldLOFval alpha_c = polyinterp([oldLOval oldLOFval oldLOgtd bracket(Tpos) bracketFval(Tpos) gtdT],doPlot); alpha_q = polyinterp([oldLOval oldLOFval oldLOgtd bracket(Tpos) bracketFval(Tpos) sqrt(-1)],doPlot); if abs(alpha_c - oldLOval) < abs(alpha_q - oldLOval) if debug fprintf('Cubic Interpolation\n'); end t = alpha_c; else if debug fprintf('Mixed Quad/Cubic Interpolation\n'); end t = (alpha_q + alpha_c)/2; end elseif gtdT'*oldLOgtd < 0 alpha_c = polyinterp([oldLOval oldLOFval oldLOgtd bracket(Tpos) bracketFval(Tpos) gtdT],doPlot); alpha_s = polyinterp([oldLOval oldLOFval oldLOgtd bracket(Tpos) sqrt(-1) gtdT],doPlot); if abs(alpha_c - bracket(Tpos)) >= abs(alpha_s - bracket(Tpos)) if debug fprintf('Cubic Interpolation\n'); end t = alpha_c; else if debug fprintf('Quad Interpolation\n'); end t = alpha_s; end elseif abs(gtdT) <= abs(oldLOgtd) alpha_c = polyinterp([oldLOval oldLOFval oldLOgtd bracket(Tpos) bracketFval(Tpos) gtdT],... doPlot,min(bracket),max(bracket)); alpha_s = polyinterp([oldLOval sqrt(-1) oldLOgtd bracket(Tpos) bracketFval(Tpos) gtdT],... doPlot,min(bracket),max(bracket)); if alpha_c > min(bracket) && alpha_c < max(bracket) if abs(alpha_c - bracket(Tpos)) < abs(alpha_s - bracket(Tpos)) if debug fprintf('Bounded Cubic Extrapolation\n'); end t = alpha_c; else if debug fprintf('Bounded Secant Extrapolation\n'); end t = alpha_s; end else if debug fprintf('Bounded Secant Extrapolation\n'); end t = alpha_s; end if bracket(Tpos) > oldLOval t = min(bracket(Tpos) + 0.66*(bracket(nonTpos) - bracket(Tpos)),t); else t = max(bracket(Tpos) + 0.66*(bracket(nonTpos) - bracket(Tpos)),t); end else t = polyinterp([bracket(nonTpos) bracketFval(nonTpos) gtdNonT bracket(Tpos) bracketFval(Tpos) gtdT],doPlot); end end