Skip to content

Commit

Permalink
Merge pull request rfeinman#10 from calvinmccarter/master
Browse files Browse the repository at this point in the history
Enable import of _status_message for scipy >= 1.8.0
  • Loading branch information
rfeinman authored Mar 30, 2022
2 parents c6b54ae + 4397f2c commit 72ae847
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
5 changes: 4 additions & 1 deletion torchmin/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import torch
from torch import Tensor
from scipy.optimize import OptimizeResult
from scipy.optimize.optimize import _status_message

from .function import ScalarFunction
from .line_search import strong_wolfe

try:
from scipy.optimize.optimize import _status_message
except ImportError:
from scipy.optimize._optimize import _status_message

class HessianUpdateStrategy(ABC):
def __init__(self):
Expand Down
5 changes: 4 additions & 1 deletion torchmin/cg.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import torch
from scipy.optimize import OptimizeResult
from scipy.optimize.optimize import _status_message

from .function import ScalarFunction
from .line_search import strong_wolfe

try:
from scipy.optimize.optimize import _status_message
except ImportError:
from scipy.optimize._optimize import _status_message

dot = lambda u,v: torch.dot(u.view(-1), v.view(-1))

Expand Down
6 changes: 5 additions & 1 deletion torchmin/newton.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from scipy.optimize import OptimizeResult
from scipy.optimize.optimize import _status_message
from scipy.sparse.linalg import eigsh
from torch import Tensor
import torch

from .function import ScalarFunction
from .line_search import strong_wolfe

try:
from scipy.optimize.optimize import _status_message
except ImportError:
from scipy.optimize._optimize import _status_message

_status_message['cg_warn'] = "Warning: CG iterations didn't converge. The " \
"Hessian is not positive definite."

Expand Down
7 changes: 6 additions & 1 deletion torchmin/trustregion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@
from abc import ABC, abstractmethod
import torch
from torch.linalg import norm
from scipy.optimize.optimize import OptimizeResult, _status_message
from scipy.optimize import OptimizeResult

from ..function import ScalarFunction
from ..optim.minimizer import Minimizer

try:
from scipy.optimize.optimize import _status_message
except ImportError:
from scipy.optimize._optimize import _status_message

status_messages = (
_status_message['success'],
_status_message['maxiter'],
Expand Down

0 comments on commit 72ae847

Please sign in to comment.