Skip to content

Commit

Permalink
enable import of _status_message for scipy > 1.7.1
Browse files Browse the repository at this point in the history
  • Loading branch information
calvinmccarter committed Mar 23, 2022
1 parent c6b54ae commit 48ceb25
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 7 deletions.
7 changes: 5 additions & 2 deletions torchmin/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import torch
from torch import Tensor
from scipy.optimize import OptimizeResult
from scipy.optimize.optimize import _status_message

from .function import ScalarFunction
from .line_search import strong_wolfe

try:
from scipy.optimize.optimize import _status_message
except ImportError:
from scipy.optimize._optimize import _status_message

class HessianUpdateStrategy(ABC):
def __init__(self):
Expand Down Expand Up @@ -378,4 +381,4 @@ def _minimize_lbfgs(
return _minimize_bfgs_core(
fun, x0, lr, low_mem=True, history_size=history_size,
max_iter=max_iter, line_search=line_search, gtol=gtol, xtol=xtol,
normp=normp, callback=callback, disp=disp, return_all=return_all)
normp=normp, callback=callback, disp=disp, return_all=return_all)
7 changes: 5 additions & 2 deletions torchmin/cg.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import torch
from scipy.optimize import OptimizeResult
from scipy.optimize.optimize import _status_message

from .function import ScalarFunction
from .line_search import strong_wolfe

try:
from scipy.optimize.optimize import _status_message
except ImportError:
from scipy.optimize._optimize import _status_message

dot = lambda u,v: torch.dot(u.view(-1), v.view(-1))

Expand Down Expand Up @@ -140,4 +143,4 @@ def descent_condition(t, f_next, g_next):
message=msg, nit=niter, nfev=sf.nfev)
if return_all:
result['allvecs'] = allvecs
return result
return result
6 changes: 5 additions & 1 deletion torchmin/newton.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from scipy.optimize import OptimizeResult
from scipy.optimize.optimize import _status_message
from scipy.sparse.linalg import eigsh
from torch import Tensor
import torch

from .function import ScalarFunction
from .line_search import strong_wolfe

try:
from scipy.optimize.optimize import _status_message
except ImportError:
from scipy.optimize._optimize import _status_message

_status_message['cg_warn'] = "Warning: CG iterations didn't converge. The " \
"Hessian is not positive definite."

Expand Down
9 changes: 7 additions & 2 deletions torchmin/trustregion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@
from abc import ABC, abstractmethod
import torch
from torch.linalg import norm
from scipy.optimize.optimize import OptimizeResult, _status_message
from scipy.optimize import OptimizeResult

from ..function import ScalarFunction
from ..optim.minimizer import Minimizer

try:
from scipy.optimize.optimize import _status_message
except ImportError:
from scipy.optimize._optimize import _status_message

status_messages = (
_status_message['success'],
_status_message['maxiter'],
Expand Down Expand Up @@ -262,4 +267,4 @@ def _minimize_trust_region(fun, x0, subproblem=None, initial_trust_radius=1.,
if return_all:
result['allvecs'] = allvecs

return result
return result

0 comments on commit 48ceb25

Please sign in to comment.