Source code for torchmin.trustregion.ncg

"""
Newton-CG trust-region optimization.

Code ported from SciPy to PyTorch

Copyright (c) 2001-2002 Enthought, Inc.  2003-2019, SciPy Developers.
All rights reserved.
"""
import torch
from torch.linalg import norm

from .base import _minimize_trust_region, BaseQuadraticSubproblem


[docs]def _minimize_trust_ncg( fun, x0, **trust_region_options): """Minimization of scalar function of one or more variables using the Newton conjugate gradient trust-region algorithm. Parameters ---------- fun : callable Scalar objective function to minimize. x0 : Tensor Initialization point. initial_trust_radius : float Initial trust-region radius. max_trust_radius : float Maximum value of the trust-region radius. No steps that are longer than this value will be proposed. eta : float Trust region related acceptance stringency for proposed steps. gtol : float Gradient norm must be less than ``gtol`` before successful termination. Returns ------- result : OptimizeResult Result of the optimization routine. Notes ----- This is algorithm (7.2) of Nocedal and Wright 2nd edition. Only the function that computes the Hessian-vector product is required. The Hessian itself is not required, and the Hessian does not need to be positive semidefinite. """ return _minimize_trust_region(fun, x0, subproblem=CGSteihaugSubproblem, **trust_region_options)
class CGSteihaugSubproblem(BaseQuadraticSubproblem): """Quadratic subproblem solved by a conjugate gradient method""" hess_prod = True def solve(self, trust_radius): """Solve the subproblem using a conjugate gradient method. Parameters ---------- trust_radius : float We are allowed to wander only this far away from the origin. Returns ------- p : Tensor The proposed step. hits_boundary : bool True if the proposed step is on the boundary of the trust region. """ # get the norm of jacobian and define the origin p_origin = torch.zeros_like(self.jac) # define a default tolerance tolerance = self.jac_mag * self.jac_mag.sqrt().clamp(max=0.5) # Stop the method if the search direction # is a direction of nonpositive curvature. if self.jac_mag < tolerance: hits_boundary = False return p_origin, hits_boundary # init the state for the first iteration z = p_origin r = self.jac d = -r # Search for the min of the approximation of the objective function. while True: # do an iteration Bd = self.hessp(d) dBd = d.dot(Bd) if dBd <= 0: # Look at the two boundary points. # Find both values of t to get the boundary points such that # ||z + t d|| == trust_radius # and then choose the one with the predicted min value. ta, tb = self.get_boundaries_intersections(z, d, trust_radius) pa = z + ta * d pb = z + tb * d p_boundary = torch.where(self(pa).lt(self(pb)), pa, pb) hits_boundary = True return p_boundary, hits_boundary r_squared = r.dot(r) alpha = r_squared / dBd z_next = z + alpha * d if norm(z_next) >= trust_radius: # Find t >= 0 to get the boundary point such that # ||z + t d|| == trust_radius ta, tb = self.get_boundaries_intersections(z, d, trust_radius) p_boundary = z + tb * d hits_boundary = True return p_boundary, hits_boundary r_next = r + alpha * Bd r_next_squared = r_next.dot(r_next) if r_next_squared.sqrt() < tolerance: hits_boundary = False return z_next, hits_boundary beta_next = r_next_squared / r_squared d_next = -r_next + beta_next * d # update the state for the next iteration z = z_next r = r_next d = d_next