From 6b446abe4ba17b4619b3e27b295491479dc58fa6 Mon Sep 17 00:00:00 2001 From: Luis Pineda <4759586+luisenp@users.noreply.github.com> Date: Fri, 7 Jul 2023 18:18:40 -0400 Subject: [PATCH] Removed semantic version dependency. (#574) * Removed semantic version dependency. * Fix mypy error for Python <= 3.8. * Add missing copyright header. --- tests/theseus_tests/test_misc.py | 19 +++++++++++++++++++ theseus/_version.py | 25 ++++++++++++++++++++++--- theseus/third_party/lml.py | 7 +++---- 3 files changed, 44 insertions(+), 7 deletions(-) create mode 100644 tests/theseus_tests/test_misc.py diff --git a/tests/theseus_tests/test_misc.py b/tests/theseus_tests/test_misc.py new file mode 100644 index 00000000..179bd7e3 --- /dev/null +++ b/tests/theseus_tests/test_misc.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +from theseus._version import lt_version + + +def test_lt_version(): + assert not lt_version("2.0.0", "0.4.0") + assert not lt_version("1.13.0abcd", "0.4.0") + assert not lt_version("0.4.1+yzx", "0.4.0") + assert lt_version("1.13.0.1.2.3.4", "2.0.0") + assert lt_version("1.13.0.1.2+abc", "2.0.0") + with pytest.raises(ValueError): + lt_version("1.2", "0.4.0") + lt_version("1", "0.4.0") + lt_version("1.", "0.4.0") diff --git a/theseus/_version.py b/theseus/_version.py index cbcd4953..ebe694a4 100644 --- a/theseus/_version.py +++ b/theseus/_version.py @@ -2,12 +2,31 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import re import warnings +from typing import Tuple -from semantic_version import Version -from torch import __version__ as _torch_version +import torch -if Version(_torch_version) < Version("2.0.0"): + +# Returns True/False if version string v1 is less than version string v2 +def lt_version(v1: str, v2: str) -> bool: + def _as_tuple(s: str) -> Tuple[int, int, int]: + pattern = r"^[\d.]+" + match = re.match(pattern, s) + try: + return tuple(int(x) for x in match.group().split(".")[:3]) # type: ignore + except Exception: + raise ValueError( + f"String {s} cannot be converted to (mayor, minor, micro) format." + ) + + x1, y1, z1 = _as_tuple(v1) + x2, y2, z2 = _as_tuple(v2) + return x1 < x2 or (x1 == x2 and y1 < y2) or (x1 == x2 and y1 == y2 and z1 < z2) + + +if lt_version(torch.__version__, "2.0.0"): warnings.warn( "Using torch < 2.0 for theseus is deprecated and compatibility will be " "discontinued in future releases.", diff --git a/theseus/third_party/lml.py b/theseus/third_party/lml.py index 300aaee1..b4d4b05f 100755 --- a/theseus/third_party/lml.py +++ b/theseus/third_party/lml.py @@ -20,16 +20,15 @@ # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - import numpy as np import numpy.random as npr import torch -from semantic_version import Version from torch.autograd import Function, Variable, grad from torch.nn import Module -version = Version(".".join(torch.__version__.split(".")[:3])) -old_torch = version < Version("0.4.0") +from theseus._version import lt_version + +old_torch = lt_version(torch.__version__, "0.4.0") def bdot(x, y):