Source code for torchmin.cg

import torch
from scipy.optimize import OptimizeResult

from .function import ScalarFunction
from .line_search import strong_wolfe

try:
    from scipy.optimize.optimize import _status_message
except ImportError:
    from scipy.optimize._optimize import _status_message

dot = lambda u,v: torch.dot(u.view(-1), v.view(-1))


[docs]@torch.no_grad() def _minimize_cg(fun, x0, max_iter=None, gtol=1e-5, normp=float('inf'), callback=None, disp=0, return_all=False): """Minimize a scalar function of one or more variables using nonlinear conjugate gradient. The algorithm is described in Nocedal & Wright (2006) chapter 5.2. Parameters ---------- fun : callable Scalar objective function to minimize. x0 : Tensor Initialization point. max_iter : int Maximum number of iterations to perform. Defaults to ``200 * x0.numel()``. gtol : float Termination tolerance on 1st-order optimality (gradient norm). normp : float The norm type to use for termination conditions. Can be any value supported by :func:`torch.norm`. callback : callable, optional Function to call after each iteration with the current parameter state, e.g. ``callback(x)`` disp : int or bool Display (verbosity) level. Set to >0 to print status messages. return_all : bool, optional Set to True to return a list of the best solution at each of the iterations. """ disp = int(disp) if max_iter is None: max_iter = x0.numel() * 200 # Construct scalar objective function sf = ScalarFunction(fun, x_shape=x0.shape) closure = sf.closure dir_evaluate = sf.dir_evaluate # initialize x = x0.detach().flatten() f, g, _, _ = closure(x) if disp > 1: print('initial fval: %0.4f' % f) if return_all: allvecs = [x] d = g.neg() grad_norm = g.norm(p=normp) old_f = f + g.norm() / 2 # Sets the initial step guess to dx ~ 1 for niter in range(1, max_iter + 1): # delta/gtd delta = dot(g, g) gtd = dot(g, d) # compute initial step guess based on (f - old_f) / gtd t0 = torch.clamp(2.02 * (f - old_f) / gtd, max=1.0) if t0 <= 0: warnflag = 4 msg = 'Initial step guess is negative.' break old_f = f # buffer to store next direction vector cached_step = [None] def polak_ribiere_powell_step(t, g_next): y = g_next - g beta = torch.clamp(dot(y, g_next) / delta, min=0) d_next = -g_next + d.mul(beta) torch.norm(g_next, p=normp, out=grad_norm) return t, d_next def descent_condition(t, f_next, g_next): # Polak-Ribiere+ needs an explicit check of a sufficient # descent condition, which is not guaranteed by strong Wolfe. cached_step[:] = polak_ribiere_powell_step(t, g_next) t, d_next = cached_step # Accept step if it leads to convergence. cond1 = grad_norm <= gtol # Accept step if sufficient descent condition applies. cond2 = dot(d_next, g_next) <= -0.01 * dot(g_next, g_next) return cond1 | cond2 # Perform CG step f, g, t, ls_evals = \ strong_wolfe(dir_evaluate, x, t0, d, f, g, gtd, c2=0.4, extra_condition=descent_condition) # Update x and then update d (in that order) x = x + d.mul(t) if t == cached_step[0]: # Reuse already computed results if possible d = cached_step[1] else: d = polak_ribiere_powell_step(t, g)[1] if disp > 1: print('iter %3d - fval: %0.4f' % (niter, f)) if return_all: allvecs.append(x) if callback is not None: callback(x) # check optimality if grad_norm <= gtol: warnflag = 0 msg = _status_message['success'] break else: # if we get to the end, the maximum iterations was reached warnflag = 1 msg = _status_message['maxiter'] if disp: print("%s%s" % ("Warning: " if warnflag != 0 else "", msg)) print(" Current function value: %f" % f) print(" Iterations: %d" % niter) print(" Function evaluations: %d" % sf.nfev) result = OptimizeResult(fun=f, x=x.view_as(x0), grad=g.view_as(x0), status=warnflag, success=(warnflag == 0), message=msg, nit=niter, nfev=sf.nfev) if return_all: result['allvecs'] = allvecs return result