from scipy.optimize import OptimizeResult
from scipy.sparse.linalg import eigsh
from torch import Tensor
import torch
from ._optimize import _status_message
from .function import ScalarFunction
from .line_search import strong_wolfe
_status_message['cg_warn'] = "Warning: CG iterations didn't converge. The " \
"Hessian is not positive definite."
def _cg_iters(grad, hess, max_iter, normp=1):
"""A CG solver specialized for the NewtonCG sub-problem.
Derived from Algorithm 7.1 of "Numerical Optimization (2nd Ed.)"
(Nocedal & Wright, 2006; pp. 169)
"""
# Get the most efficient dot product method for this problem
if grad.dim() == 1:
# standard dot product
dot = torch.dot
elif grad.dim() == 2:
# batched dot product
dot = lambda u,v: torch.bmm(u.unsqueeze(1), v.unsqueeze(2)).view(-1,1)
else:
# generalized dot product that supports batch inputs
dot = lambda u,v: u.mul(v).sum(-1, keepdim=True)
g_norm = grad.norm(p=normp)
tol = g_norm * g_norm.sqrt().clamp(0, 0.5)
eps = torch.finfo(grad.dtype).eps
n_iter = 0 # TODO: remove?
maxiter_reached = False
# initialize state and iterate
x = torch.zeros_like(grad)
r = grad.clone()
p = grad.neg()
rs = dot(r, r)
for n_iter in range(max_iter):
if r.norm(p=normp) < tol:
break
Bp = hess.mv(p)
curv = dot(p, Bp)
curv_sum = curv.sum()
if curv_sum < 0:
# hessian is not positive-definite
if n_iter == 0:
# if first step, fall back to steepest descent direction
# (scaled by Rayleigh quotient)
x = grad.mul(rs / curv)
#x = grad.neg()
break
elif curv_sum <= 3 * eps:
break
alpha = rs / curv
x.addcmul_(alpha, p)
r.addcmul_(alpha, Bp)
rs_new = dot(r, r)
p.mul_(rs_new / rs).sub_(r)
rs = rs_new
else:
# curvature keeps increasing; bail
maxiter_reached = True
return x, n_iter, maxiter_reached
[docs]@torch.no_grad()
def _minimize_newton_cg(
fun, x0, lr=1., max_iter=None, cg_max_iter=None,
twice_diffable=True, line_search='strong-wolfe', xtol=1e-5,
normp=1, callback=None, disp=0, return_all=False):
"""Minimize a scalar function of one or more variables using the
Newton-Raphson method, with Conjugate Gradient for the linear inverse
sub-problem.
Parameters
----------
fun : callable
Scalar objective function to minimize.
x0 : Tensor
Initialization point.
lr : float
Step size for parameter updates. If using line search, this will be
used as the initial step size for the search.
max_iter : int, optional
Maximum number of iterations to perform. Defaults to
``200 * x0.numel()``.
cg_max_iter : int, optional
Maximum number of iterations for CG subproblem. Recommended to
leave this at the default of ``20 * x0.numel()``.
twice_diffable : bool
Whether to assume the function is twice continuously differentiable.
If True, hessian-vector products will be much faster.
line_search : str
Line search specifier. Currently the available options are
{'none', 'strong_wolfe'}.
xtol : float
Average relative error in solution `xopt` acceptable for
convergence.
normp : Number or str
The norm type to use for termination conditions. Can be any value
supported by :func:`torch.norm`.
callback : callable, optional
Function to call after each iteration with the current parameter
state, e.g. ``callback(x)``.
disp : int or bool
Display (verbosity) level. Set to >0 to print status messages.
return_all : bool
Set to True to return a list of the best solution at each of the
iterations.
Returns
-------
result : OptimizeResult
Result of the optimization routine.
"""
lr = float(lr)
disp = int(disp)
xtol = x0.numel() * xtol
if max_iter is None:
max_iter = x0.numel() * 200
if cg_max_iter is None:
cg_max_iter = x0.numel() * 20
# construct scalar objective function
sf = ScalarFunction(fun, x0.shape, hessp=True, twice_diffable=twice_diffable)
closure = sf.closure
if line_search == 'strong-wolfe':
dir_evaluate = sf.dir_evaluate
# initial settings
x = x0.detach().clone(memory_format=torch.contiguous_format)
f, g, hessp, _ = closure(x)
if disp > 1:
print('initial fval: %0.4f' % f)
if return_all:
allvecs = [x]
ncg = 0 # number of cg iterations
n_iter = 0
# begin optimization loop
for n_iter in range(1, max_iter + 1):
# ============================================================
# Compute a search direction pk by applying the CG method to
# H_f(xk) p = - J_f(xk) starting from 0.
# ============================================================
# Compute search direction with conjugate gradient (GG)
d, cg_iters, cg_fail = _cg_iters(g, hessp, cg_max_iter, normp)
ncg += cg_iters
if cg_fail:
warnflag = 3
msg = _status_message['cg_warn']
break
# =====================================================
# Perform variable update (with optional line search)
# =====================================================
if line_search == 'none':
update = d.mul(lr)
x = x + update
elif line_search == 'strong-wolfe':
# strong-wolfe line search
_, _, t, ls_nevals = strong_wolfe(dir_evaluate, x, lr, d, f, g)
update = d.mul(t)
x = x + update
else:
raise ValueError('invalid line_search option {}.'.format(line_search))
# re-evaluate function
f, g, hessp, _ = closure(x)
if disp > 1:
print('iter %3d - fval: %0.4f' % (n_iter, f))
if callback is not None:
if callback(x):
warnflag = 5
msg = _status_message['callback_stop']
break
if return_all:
allvecs.append(x)
# ==========================
# check for convergence
# ==========================
if update.norm(p=normp) <= xtol:
warnflag = 0
msg = _status_message['success']
break
if not f.isfinite():
warnflag = 3
msg = _status_message['nan']
break
else:
# if we get to the end, the maximum num. iterations was reached
warnflag = 1
msg = _status_message['maxiter']
if disp:
print(msg)
print(" Current function value: %f" % f)
print(" Iterations: %d" % n_iter)
print(" Function evaluations: %d" % sf.nfev)
print(" CG iterations: %d" % ncg)
result = OptimizeResult(fun=f, x=x.view_as(x0), grad=g.view_as(x0),
status=warnflag, success=(warnflag==0),
message=msg, nit=n_iter, nfev=sf.nfev, ncg=ncg)
if return_all:
result['allvecs'] = allvecs
return result
[docs]@torch.no_grad()
def _minimize_newton_exact(
fun, x0, lr=1., max_iter=None, line_search='strong-wolfe', xtol=1e-5,
normp=1, tikhonov=0., handle_npd='grad', callback=None, disp=0,
return_all=False):
"""Minimize a scalar function of one or more variables using the
Newton-Raphson method.
This variant uses an "exact" Newton routine based on Cholesky factorization
of the explicit Hessian matrix.
Parameters
----------
fun : callable
Scalar objective function to minimize.
x0 : Tensor
Initialization point.
lr : float
Step size for parameter updates. If using line search, this will be
used as the initial step size for the search.
max_iter : int, optional
Maximum number of iterations to perform. Defaults to
``200 * x0.numel()``.
line_search : str
Line search specifier. Currently the available options are
{'none', 'strong_wolfe'}.
xtol : float
Average relative error in solution `xopt` acceptable for
convergence.
normp : Number or str
The norm type to use for termination conditions. Can be any value
supported by :func:`torch.norm`.
tikhonov : float
Optional diagonal regularization (Tikhonov) parameter for the Hessian.
handle_npd : str
Mode for handling non-positive definite hessian matrices. Can be one
of the following:
* 'grad' : use steepest descent direction (gradient)
* 'lu' : solve the inverse hessian with LU factorization
* 'eig' : use symmetric eigendecomposition to determine a
diagonal regularization parameter
callback : callable, optional
Function to call after each iteration with the current parameter
state, e.g. ``callback(x)``.
disp : int or bool
Display (verbosity) level. Set to >0 to print status messages.
return_all : bool
Set to True to return a list of the best solution at each of the
iterations.
Returns
-------
result : OptimizeResult
Result of the optimization routine.
"""
lr = float(lr)
disp = int(disp)
xtol = x0.numel() * xtol
if max_iter is None:
max_iter = x0.numel() * 200
# Construct scalar objective function
sf = ScalarFunction(fun, x0.shape, hess=True)
closure = sf.closure
if line_search == 'strong-wolfe':
dir_evaluate = sf.dir_evaluate
# initial settings
x = x0.detach().view(-1).clone(memory_format=torch.contiguous_format)
f, g, _, hess = closure(x)
if tikhonov > 0:
hess.diagonal().add_(tikhonov)
if disp > 1:
print('initial fval: %0.4f' % f)
if return_all:
allvecs = [x]
nfail = 0
n_iter = 0
# begin optimization loop
for n_iter in range(1, max_iter + 1):
# ==================================================
# Compute a search direction d by solving
# H_f(x) d = - J_f(x)
# with the true Hessian and Cholesky factorization
# ===================================================
# Compute search direction with Cholesky solve
L, info = torch.linalg.cholesky_ex(hess)
if info == 0:
d = torch.cholesky_solve(g.neg().unsqueeze(1), L).squeeze(1)
else:
nfail += 1
if handle_npd == 'lu':
d = torch.linalg.solve(hess, g.neg())
elif handle_npd in ['grad', 'cauchy']:
d = g.neg()
if handle_npd == 'cauchy':
# cauchy point for a trust radius of delta=1.
# equivalent to 'grad' with a scaled lr
gnorm = g.norm(p=2)
scale = 1 / gnorm
gHg = g.dot(hess.mv(g))
if gHg > 0:
scale *= torch.clamp_(gnorm.pow(3) / gHg, max=1)
d *= scale
elif handle_npd == 'eig':
# this setting is experimental! use with caution
# TODO: why use the factor 1.5 here? Seems to work best
eig0 = eigsh(hess.cpu().numpy(), k=1, which="SA", tol=1e-4)[0].item()
tau = max(1e-3 - 1.5 * eig0, 0)
hess.diagonal().add_(tau)
L = torch.linalg.cholesky(hess)
d = torch.cholesky_solve(g.neg().unsqueeze(1), L).squeeze(1)
else:
raise RuntimeError('invalid handle_npd encountered.')
# =====================================================
# Perform variable update (with optional line search)
# =====================================================
if line_search == 'none':
update = d.mul(lr)
x = x + update
elif line_search == 'strong-wolfe':
# strong-wolfe line search
_, _, t, ls_nevals = strong_wolfe(dir_evaluate, x, lr, d, f, g)
update = d.mul(t)
x = x + update
else:
raise ValueError('invalid line_search option {}.'.format(line_search))
# ===================================
# Re-evaluate func/Jacobian/Hessian
# ===================================
f, g, _, hess = closure(x)
if tikhonov > 0:
hess.diagonal().add_(tikhonov)
if disp > 1:
print('iter %3d - fval: %0.4f - info: %d' % (n_iter, f, info))
if callback is not None:
if callback(x):
warnflag = 5
msg = _status_message['callback_stop']
break
if return_all:
allvecs.append(x)
# ==========================
# check for convergence
# ==========================
if update.norm(p=normp) <= xtol:
warnflag = 0
msg = _status_message['success']
break
if not f.isfinite():
warnflag = 3
msg = _status_message['nan']
break
else:
# if we get to the end, the maximum num. iterations was reached
warnflag = 1
msg = _status_message['maxiter']
if disp:
print(msg)
print(" Current function value: %f" % f)
print(" Iterations: %d" % n_iter)
print(" Function evaluations: %d" % sf.nfev)
result = OptimizeResult(fun=f, x=x.view_as(x0), grad=g.view_as(x0),
hess=hess.view(2 * x0.shape),
status=warnflag, success=(warnflag==0),
message=msg, nit=n_iter, nfev=sf.nfev, nfail=nfail)
if return_all:
result['allvecs'] = allvecs
return result