www.gusucode.com > SAE RBM 程序MATLAB源码代码实现的一个关于sae的例子 > minFunc/WolfeLineSearch.m

    function [t,f_new,g_new,funEvals,H] = WolfeLineSearch(...
    x,t,d,f,g,gtd,c1,c2,LS,maxLS,tolX,debug,doPlot,saveHessianComp,funObj,varargin)
%
% Bracketing Line Search to Satisfy Wolfe Conditions
%
% Inputs:
%   x: starting location
%   t: initial step size
%   d: descent direction
%   f: function value at starting location
%   g: gradient at starting location
%   gtd: directional derivative at starting location
%   c1: sufficient decrease parameter
%   c2: curvature parameter
%   debug: display debugging information
%   LS: type of interpolation
%   maxLS: maximum number of iterations
%   tolX: minimum allowable step length
%   doPlot: do a graphical display of interpolation
%   funObj: objective function
%   varargin: parameters of objective function
%
% Outputs:
%   t: step length
%   f_new: function value at x+t*d
%   g_new: gradient value at x+t*d
%   funEvals: number function evaluations performed by line search
%   H: Hessian at initial guess (only computed if requested

% Evaluate the Objective and Gradient at the Initial Step
if nargout == 5
    [f_new,g_new,H] = feval(funObj, x + t*d, varargin{:}); 
else
    [f_new,g_new] = feval(funObj, x + t*d, varargin{:}); 
end
funEvals = 1;
gtd_new = g_new'*d;

% Bracket an Interval containing a point satisfying the
% Wolfe criteria

LSiter = 0;
t_prev = 0;
f_prev = f;
g_prev = g;
gtd_prev = gtd;
done = 0;

while LSiter < maxLS

    %% Bracketing Phase
    if ~isLegal(f_new) || ~isLegal(g_new)
        if 0
            if debug
                fprintf('Extrapolated into illegal region, Bisecting\n');
            end
            t = (t + t_prev)/2;
            if ~saveHessianComp && nargout == 5
                [f_new,g_new,H] = feval(funObj, x + t*d, varargin{:}); 
            else
                [f_new,g_new] = feval(funObj, x + t*d, varargin{:}); 
            end
            funEvals = funEvals + 1;
            gtd_new = g_new'*d;
            LSiter = LSiter+1;
            continue;
        else
            if debug
                fprintf('Extrapolated into illegal region, switching to Armijo line-search\n');
            end
            t = (t + t_prev)/2;
            % Do Armijo
            if nargout == 5
                [t,x_new,f_new,g_new,armijoFunEvals,H] = ArmijoBacktrack(...
                  x,t,d,f,f,g,gtd,c1,max(0,min(LS-2,2)),tolX,debug,doPlot,saveHessianComp,...
                  funObj,varargin{:});
            else
                [t,x_new,f_new,g_new,armijoFunEvals] = ArmijoBacktrack(...
                  x,t,d,f,f,g,gtd,c1,max(0,min(LS-2,2)),tolX,debug,doPlot,saveHessianComp,...
                  funObj,varargin{:});
            end
            funEvals = funEvals + armijoFunEvals;
            return;
        end
    end


    if f_new > f + c1*t*gtd || (LSiter > 1 && f_new >= f_prev)
        bracket = [t_prev t];
        bracketFval = [f_prev f_new];
        bracketGval = [g_prev g_new];
        break;
    elseif abs(gtd_new) <= -c2*gtd
        bracket = t;
        bracketFval = f_new;
        bracketGval = g_new;
        done = 1;
        break;
    elseif gtd_new >= 0
        bracket = [t_prev t];
        bracketFval = [f_prev f_new];
        bracketGval = [g_prev g_new];
        break;
    end
    temp = t_prev;
    t_prev = t;
    minStep = t + 0.01*(t-temp);
    maxStep = t*10;
    if LS == 3
        if debug
            fprintf('Extending Braket\n');
        end
        t = maxStep;
    elseif LS ==4
        if debug
            fprintf('Cubic Extrapolation\n');
        end
        t = polyinterp([temp f_prev gtd_prev; t f_new gtd_new],doPlot,minStep,maxStep);
    else
        t = mixedExtrap(temp,f_prev,gtd_prev,t,f_new,gtd_new,minStep,maxStep,debug,doPlot);
    end
    
    f_prev = f_new;
    g_prev = g_new;
    gtd_prev = gtd_new;
    if ~saveHessianComp && nargout == 5
        [f_new,g_new,H] = feval(funObj, x + t*d, varargin{:}); 
    else
        [f_new,g_new] = feval(funObj, x + t*d, varargin{:}); 
    end
    funEvals = funEvals + 1;
    gtd_new = g_new'*d;
    LSiter = LSiter+1;
end

if LSiter == maxLS
    bracket = [0 t];
    bracketFval = [f f_new];
    bracketGval = [g g_new];
end

%% Zoom Phase

% We now either have a point satisfying the criteria, or a bracket
% surrounding a point satisfying the criteria
% Refine the bracket until we find a point satisfying the criteria
insufProgress = 0;
Tpos = 2;
LOposRemoved = 0;
while ~done && LSiter < maxLS

    % Find High and Low Points in bracket
    [f_LO LOpos] = min(bracketFval);
    HIpos = -LOpos + 3;

    % Compute new trial value
    if LS == 3 || ~isLegal(bracketFval) || ~isLegal(bracketGval)
        if debug
            fprintf('Bisecting\n');
        end
        t = mean(bracket);
    elseif LS == 4
        if debug
            fprintf('Grad-Cubic Interpolation\n');
        end
        t = polyinterp([bracket(1) bracketFval(1) bracketGval(:,1)'*d
            bracket(2) bracketFval(2) bracketGval(:,2)'*d],doPlot);
    else
        % Mixed Case %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        nonTpos = -Tpos+3;
        if LOposRemoved == 0
            oldLOval = bracket(nonTpos);
            oldLOFval = bracketFval(nonTpos);
            oldLOGval = bracketGval(:,nonTpos);
        end
        t = mixedInterp(bracket,bracketFval,bracketGval,d,Tpos,oldLOval,oldLOFval,oldLOGval,debug,doPlot);
    end


    % Test that we are making sufficient progress
    if min(max(bracket)-t,t-min(bracket))/(max(bracket)-min(bracket)) < 0.1
        if debug
            fprintf('Interpolation close to boundary');
        end
        if insufProgress || t>=max(bracket) || t <= min(bracket)
            if debug
                fprintf(', Evaluating at 0.1 away from boundary\n');
            end
            if abs(t-max(bracket)) < abs(t-min(bracket))
                t = max(bracket)-0.1*(max(bracket)-min(bracket));
            else
                t = min(bracket)+0.1*(max(bracket)-min(bracket));
            end
            insufProgress = 0;
        else
            if debug
                fprintf('\n');
            end
            insufProgress = 1;
        end
    else
        insufProgress = 0;
    end

    % Evaluate new point
    if ~saveHessianComp && nargout == 5
        [f_new,g_new,H] = feval(funObj, x + t*d, varargin{:}); 
    else
        [f_new,g_new] = feval(funObj, x + t*d, varargin{:}); 
    end
    funEvals = funEvals + 1;
    gtd_new = g_new'*d;
    LSiter = LSiter+1;

    if f_new > f + c1*t*gtd || f_new >= f_LO
        % Armijo condition not satisfied or not lower than lowest
        % point
        bracket(HIpos) = t;
        bracketFval(HIpos) = f_new;
        bracketGval(:,HIpos) = g_new;
        Tpos = HIpos;
    else
        if abs(gtd_new) <= - c2*gtd
            % Wolfe conditions satisfied
            done = 1;
        elseif gtd_new*(bracket(HIpos)-bracket(LOpos)) >= 0
            % Old HI becomes new LO
            bracket(HIpos) = bracket(LOpos);
            bracketFval(HIpos) = bracketFval(LOpos);
            bracketGval(:,HIpos) = bracketGval(:,LOpos);
            if LS == 5
                if debug
                    fprintf('LO Pos is being removed!\n');
                end
                LOposRemoved = 1;
                oldLOval = bracket(LOpos);
                oldLOFval = bracketFval(LOpos);
                oldLOGval = bracketGval(:,LOpos);
            end
        end
        % New point becomes new LO
        bracket(LOpos) = t;
        bracketFval(LOpos) = f_new;
        bracketGval(:,LOpos) = g_new;
        Tpos = LOpos;
    end

    if ~done && abs((bracket(1)-bracket(2))*gtd_new) < tolX
        if debug
            fprintf('Line Search can not make further progress\n');
        end
        break;
    end

end

%%
if LSiter == maxLS
    if debug
        fprintf('Line Search Exceeded Maximum Line Search Iterations\n');
    end
end

[f_LO LOpos] = min(bracketFval);
t = bracket(LOpos);
f_new = bracketFval(LOpos);
g_new = bracketGval(:,LOpos);



% Evaluate Hessian at new point
if nargout == 5 && funEvals > 1 && saveHessianComp
    [f_new,g_new,H] = feval(funObj, x + t*d, varargin{:}); 
    funEvals = funEvals + 1;
end

end


%%
function [t] = mixedExtrap(x0,f0,g0,x1,f1,g1,minStep,maxStep,debug,doPlot);
alpha_c = polyinterp([x0 f0 g0; x1 f1 g1],doPlot,minStep,maxStep);
alpha_s = polyinterp([x0 f0 g0; x1 sqrt(-1) g1],doPlot,minStep,maxStep);
if alpha_c > minStep && abs(alpha_c - x1) < abs(alpha_s - x1)
    if debug
        fprintf('Cubic Extrapolation\n');
    end
    t = alpha_c;
else
    if debug
        fprintf('Secant Extrapolation\n');
    end
    t = alpha_s;
end
end

%%
function [t] = mixedInterp(bracket,bracketFval,bracketGval,d,Tpos,oldLOval,oldLOFval,oldLOGval,debug,doPlot);

% Mixed Case %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
nonTpos = -Tpos+3;

gtdT = bracketGval(:,Tpos)'*d;
gtdNonT = bracketGval(:,nonTpos)'*d;
oldLOgtd = oldLOGval'*d;
if bracketFval(Tpos) > oldLOFval
    alpha_c = polyinterp([oldLOval oldLOFval oldLOgtd
        bracket(Tpos) bracketFval(Tpos) gtdT],doPlot);
    alpha_q = polyinterp([oldLOval oldLOFval oldLOgtd
        bracket(Tpos) bracketFval(Tpos) sqrt(-1)],doPlot);
    if abs(alpha_c - oldLOval) < abs(alpha_q - oldLOval)
        if debug
            fprintf('Cubic Interpolation\n');
        end
        t = alpha_c;
    else
        if debug
            fprintf('Mixed Quad/Cubic Interpolation\n');
        end
        t = (alpha_q + alpha_c)/2;
    end
elseif gtdT'*oldLOgtd < 0
    alpha_c = polyinterp([oldLOval oldLOFval oldLOgtd
        bracket(Tpos) bracketFval(Tpos) gtdT],doPlot);
    alpha_s = polyinterp([oldLOval oldLOFval oldLOgtd
        bracket(Tpos) sqrt(-1) gtdT],doPlot);
    if abs(alpha_c - bracket(Tpos)) >= abs(alpha_s - bracket(Tpos))
        if debug
            fprintf('Cubic Interpolation\n');
        end
        t = alpha_c;
    else
        if debug
            fprintf('Quad Interpolation\n');
        end
        t = alpha_s;
    end
elseif abs(gtdT) <= abs(oldLOgtd)
    alpha_c = polyinterp([oldLOval oldLOFval oldLOgtd
        bracket(Tpos) bracketFval(Tpos) gtdT],...
        doPlot,min(bracket),max(bracket));
    alpha_s = polyinterp([oldLOval sqrt(-1) oldLOgtd
        bracket(Tpos) bracketFval(Tpos) gtdT],...
        doPlot,min(bracket),max(bracket));
    if alpha_c > min(bracket) && alpha_c < max(bracket)
        if abs(alpha_c - bracket(Tpos)) < abs(alpha_s - bracket(Tpos))
            if debug
                fprintf('Bounded Cubic Extrapolation\n');
            end
            t = alpha_c;
        else
            if debug
                fprintf('Bounded Secant Extrapolation\n');
            end
            t = alpha_s;
        end
    else
        if debug
            fprintf('Bounded Secant Extrapolation\n');
        end
        t = alpha_s;
    end

    if bracket(Tpos) > oldLOval
        t = min(bracket(Tpos) + 0.66*(bracket(nonTpos) - bracket(Tpos)),t);
    else
        t = max(bracket(Tpos) + 0.66*(bracket(nonTpos) - bracket(Tpos)),t);
    end
else
    t = polyinterp([bracket(nonTpos) bracketFval(nonTpos) gtdNonT
        bracket(Tpos) bracketFval(Tpos) gtdT],doPlot);
end
end