"""
Nearly exact trust-region optimization subproblem.
Code ported from SciPy to PyTorch
Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers.
All rights reserved.
"""
from typing import Tuple
from torch import Tensor
import torch
from torch.linalg import norm
from scipy.linalg import get_lapack_funcs
from .base import _minimize_trust_region, BaseQuadraticSubproblem
[docs]def _minimize_trust_exact(fun, x0, **trust_region_options):
"""Minimization of scalar function of one or more variables using
a nearly exact trust-region algorithm.
Parameters
----------
fun : callable
Scalar objective function to minimize.
x0 : Tensor
Initialization point.
initial_tr_radius : float
Initial trust-region radius.
max_tr_radius : float
Maximum value of the trust-region radius. No steps that are longer
than this value will be proposed.
eta : float
Trust region related acceptance stringency for proposed steps.
gtol : float
Gradient norm must be less than ``gtol`` before successful
termination.
Returns
-------
result : OptimizeResult
Result of the optimization routine.
Notes
-----
This trust-region solver was based on [1]_, [2]_ and [3]_,
which implement similar algorithms. The algorithm is basically
that of [1]_ but ideas from [2]_ and [3]_ were also used.
References
----------
.. [1] A.R. Conn, N.I. Gould, and P.L. Toint, "Trust region methods",
Siam, pp. 169-200, 2000.
.. [2] J. Nocedal and S. Wright, "Numerical optimization",
Springer Science & Business Media. pp. 83-91, 2006.
.. [3] J.J. More and D.C. Sorensen, "Computing a trust region step",
SIAM Journal on Scientific and Statistical Computing, vol. 4(3),
pp. 553-572, 1983.
"""
return _minimize_trust_region(fun, x0,
subproblem=IterativeSubproblem,
**trust_region_options)
def solve_triangular(A, b, **kwargs):
return torch.linalg.solve_triangular(A, b.unsqueeze(1), **kwargs)[0].squeeze(1)
def solve_cholesky(A, b, **kwargs):
return torch.cholesky_solve(b.unsqueeze(1), A, **kwargs).squeeze(1)
@torch.jit.script
def estimate_smallest_singular_value(U) -> Tuple[Tensor, Tensor]:
"""Given upper triangular matrix ``U`` estimate the smallest singular
value and the correspondent right singular vector in O(n**2) operations.
A vector `e` with components selected from {+1, -1}
is selected so that the solution `w` to the system
`U.T w = e` is as large as possible. Implementation
based on algorithm 3.5.1, p. 142, from reference [1]_
adapted for lower triangular matrix.
References
----------
.. [1] G.H. Golub, C.F. Van Loan. "Matrix computations".
Forth Edition. JHU press. pp. 140-142.
"""
U = torch.atleast_2d(U)
UT = U.T
m, n = U.shape
if m != n:
raise ValueError("A square triangular matrix should be provided.")
p = torch.zeros(n, dtype=U.dtype, device=U.device)
w = torch.empty(n, dtype=U.dtype, device=U.device)
for k in range(n):
wp = (1-p[k]) / UT[k, k]
wm = (-1-p[k]) / UT[k, k]
pp = p[k+1:] + UT[k+1:, k] * wp
pm = p[k+1:] + UT[k+1:, k] * wm
if wp.abs() + norm(pp, 1) >= wm.abs() + norm(pm, 1):
w[k] = wp
p[k+1:] = pp
else:
w[k] = wm
p[k+1:] = pm
# The system `U v = w` is solved using backward substitution.
v = torch.triangular_solve(w.view(-1,1), U)[0].view(-1)
v_norm = norm(v)
s_min = norm(w) / v_norm # Smallest singular value
z_min = v / v_norm # Associated vector
return s_min, z_min
def gershgorin_bounds(H):
"""
Given a square matrix ``H`` compute upper
and lower bounds for its eigenvalues (Gregoshgorin Bounds).
"""
H_diag = torch.diag(H)
H_diag_abs = H_diag.abs()
H_row_sums = H.abs().sum(dim=1)
lb = torch.min(H_diag + H_diag_abs - H_row_sums)
ub = torch.max(H_diag - H_diag_abs + H_row_sums)
return lb, ub
def singular_leading_submatrix(A, U, k):
"""
Compute term that makes the leading ``k`` by ``k``
submatrix from ``A`` singular.
"""
u = U[:k-1, k-1]
# Compute delta
delta = u.dot(u) - A[k-1, k-1]
# Initialize v
v = A.new_zeros(A.shape[0])
v[k-1] = 1
# Compute the remaining values of v by solving a triangular system.
if k != 1:
v[:k-1] = solve_triangular(U[:k-1, :k-1], -u)
return delta, v
class IterativeSubproblem(BaseQuadraticSubproblem):
"""Quadratic subproblem solved by nearly exact iterative method."""
# UPDATE_COEFF appears in reference [1]_
# in formula 7.3.14 (p. 190) named as "theta".
# As recommended there it value is fixed in 0.01.
UPDATE_COEFF = 0.01
hess_prod = False
def __init__(self, x, fun, k_easy=0.1, k_hard=0.2):
super().__init__(x, fun)
# When the trust-region shrinks in two consecutive
# calculations (``tr_radius < previous_tr_radius``)
# the lower bound ``lambda_lb`` may be reused,
# facilitating the convergence. To indicate no
# previous value is known at first ``previous_tr_radius``
# is set to -1 and ``lambda_lb`` to None.
self.previous_tr_radius = -1
self.lambda_lb = None
self.niter = 0
self.EPS = torch.finfo(x.dtype).eps
# ``k_easy`` and ``k_hard`` are parameters used
# to determine the stop criteria to the iterative
# subproblem solver. Take a look at pp. 194-197
# from reference _[1] for a more detailed description.
self.k_easy = k_easy
self.k_hard = k_hard
# Get Lapack function for cholesky decomposition.
# NOTE: cholesky_ex requires pytorch >= 1.9.0
if 'cholesky_ex' in dir(torch.linalg):
self.torch_cholesky = True
else:
# if we don't have torch cholesky, use potrf from scipy
self.cholesky, = get_lapack_funcs(('potrf',),
(self.hess.cpu().numpy(),))
self.torch_cholesky = False
# Get info about Hessian
self.dimension = len(self.hess)
self.hess_gershgorin_lb, self.hess_gershgorin_ub = gershgorin_bounds(self.hess)
self.hess_inf = norm(self.hess, float('inf'))
self.hess_fro = norm(self.hess, 'fro')
# A constant such that for vectors smaler than that
# backward substituition is not reliable. It was stabilished
# based on Golub, G. H., Van Loan, C. F. (2013).
# "Matrix computations". Forth Edition. JHU press., p.165.
self.CLOSE_TO_ZERO = self.dimension * self.EPS * self.hess_inf
def _initial_values(self, tr_radius):
"""Given a trust radius, return a good initial guess for
the damping factor, the lower bound and the upper bound.
The values were chosen accordingly to the guidelines on
section 7.3.8 (p. 192) from [1]_.
"""
hess_norm = torch.min(self.hess_fro, self.hess_inf)
# Upper bound for the damping factor
lambda_ub = self.jac_mag / tr_radius + torch.min(-self.hess_gershgorin_lb, hess_norm)
lambda_ub = torch.clamp(lambda_ub, min=0)
# Lower bound for the damping factor
lambda_lb = self.jac_mag / tr_radius - torch.min(self.hess_gershgorin_ub, hess_norm)
lambda_lb = torch.max(lambda_lb, -self.hess.diagonal().min())
lambda_lb = torch.clamp(lambda_lb, min=0)
# Improve bounds with previous info
if tr_radius < self.previous_tr_radius:
lambda_lb = torch.max(self.lambda_lb, lambda_lb)
# Initial guess for the damping factor
if lambda_lb == 0:
lambda_initial = lambda_lb.clone()
else:
lambda_initial = torch.max(
torch.sqrt(lambda_lb * lambda_ub),
lambda_lb + self.UPDATE_COEFF*(lambda_ub-lambda_lb))
return lambda_initial, lambda_lb, lambda_ub
def solve(self, tr_radius):
"""Solve quadratic subproblem"""
lambda_current, lambda_lb, lambda_ub = self._initial_values(tr_radius)
n = self.dimension
hits_boundary = True
already_factorized = False
self.niter = 0
while True:
# Compute Cholesky factorization
if already_factorized:
already_factorized = False
else:
H = self.hess.clone()
H.diagonal().add_(lambda_current)
if self.torch_cholesky:
U, info = torch.linalg.cholesky_ex(H)
U = U.t().contiguous()
else:
U, info = self.cholesky(H.cpu().numpy(),
lower=False,
overwrite_a=False,
clean=True)
U = H.new_tensor(U)
self.niter += 1
# Check if factorization succeeded
if info == 0 and self.jac_mag > self.CLOSE_TO_ZERO:
# Successful factorization
# Solve `U.T U p = s`
p = solve_cholesky(U, -self.jac, upper=True)
p_norm = norm(p)
# Check for interior convergence
if p_norm <= tr_radius and lambda_current == 0:
hits_boundary = False
break
# Solve `U.T w = p`
w = solve_triangular(U, p, transpose=True)
w_norm = norm(w)
# Compute Newton step accordingly to
# formula (4.44) p.87 from ref [2]_.
delta_lambda = (p_norm/w_norm)**2 * (p_norm-tr_radius)/tr_radius
lambda_new = lambda_current + delta_lambda
if p_norm < tr_radius: # Inside boundary
s_min, z_min = estimate_smallest_singular_value(U)
ta, tb = self.get_boundaries_intersections(p, z_min, tr_radius)
# Choose `step_len` with the smallest magnitude.
# The reason for this choice is explained at
# ref [3]_, p. 6 (Immediately before the formula
# for `tau`).
step_len = min(ta, tb, key=torch.abs)
# Compute the quadratic term (p.T*H*p)
quadratic_term = p.dot(H.mv(p))
# Check stop criteria
relative_error = ((step_len**2 * s_min**2) /
(quadratic_term + lambda_current*tr_radius**2))
if relative_error <= self.k_hard:
p.add_(step_len * z_min)
break
# Update uncertanty bounds
lambda_ub = lambda_current
lambda_lb = torch.max(lambda_lb, lambda_current - s_min**2)
# Compute Cholesky factorization
H = self.hess.clone()
H.diagonal().add_(lambda_new)
if self.torch_cholesky:
_, info = torch.linalg.cholesky_ex(H)
else:
_, info = self.cholesky(H.cpu().numpy(),
lower=False,
overwrite_a=False,
clean=True)
if info == 0:
lambda_current = lambda_new
already_factorized = True
else:
lambda_lb = torch.max(lambda_lb, lambda_new)
lambda_current = torch.max(
torch.sqrt(lambda_lb * lambda_ub),
lambda_lb + self.UPDATE_COEFF*(lambda_ub-lambda_lb))
else: # Outside boundary
# Check stop criteria
relative_error = torch.abs(p_norm - tr_radius) / tr_radius
if relative_error <= self.k_easy:
break
# Update uncertanty bounds
lambda_lb = lambda_current
# Update damping factor
lambda_current = lambda_new
elif info == 0 and self.jac_mag <= self.CLOSE_TO_ZERO:
# jac_mag very close to zero
# Check for interior convergence
if lambda_current == 0:
p = self.jac.new_zeros(n)
hits_boundary = False
break
s_min, z_min = estimate_smallest_singular_value(U)
step_len = tr_radius
# Check stop criteria
if step_len**2 * s_min**2 <= self.k_hard * lambda_current * tr_radius**2:
p = step_len * z_min
break
# Update uncertainty bounds and dampening factor
lambda_ub = lambda_current
lambda_lb = torch.max(lambda_lb, lambda_current - s_min**2)
lambda_current = torch.max(
torch.sqrt(lambda_lb * lambda_ub),
lambda_lb + self.UPDATE_COEFF*(lambda_ub-lambda_lb))
else:
# Unsuccessful factorization
delta, v = singular_leading_submatrix(H, U, info)
v_norm = norm(v)
lambda_lb = torch.max(lambda_lb, lambda_current + delta/v_norm**2)
# Update damping factor
lambda_current = torch.max(
torch.sqrt(lambda_lb * lambda_ub),
lambda_lb + self.UPDATE_COEFF*(lambda_ub-lambda_lb))
self.lambda_lb = lambda_lb
self.lambda_current = lambda_current
self.previous_tr_radius = tr_radius
return p, hits_boundary