-
Notifications
You must be signed in to change notification settings - Fork 105
/
setup.py
141 lines (123 loc) · 5.11 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
# with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
# and limitations under the License.
import codecs
import os
import setuptools
import subprocess
import re
import warnings
class VersionHelper(object):
"""Helper class to figure out current package version from git tag."""
__VERSION_FP = "pecos/_version.py"
__VERSION_PY = \
"""
# This file is automatically generated from Git version tag by running setup.
# Only distribution/installed packages contain this file.
__version__ = "%s"
"""
@classmethod
def __update_version_py(cls):
"""Update version from git tag infomation.
If not in git repository or git tag missing, will use a dummy version 0.0.0
"""
# Dummy version, for non-Git repo installation or tag info missing
ver = "0.0.0"
# Check Git repository info for the version
if os.path.isdir(".git"):
# Run git describe to get current tag, commit hash is not included
git_desc = subprocess.run(["git", "describe", "--tags", "--abbrev=0"],
stdout=subprocess.PIPE)
if git_desc.returncode == 0: # Success
# Clean version tag
git_tag = git_desc.stdout.decode('utf-8')
assert re.match(r'v\d+.\d+.\d+', git_tag), f"We use tags like v0.1.0, but got {git_tag}"
ver = git_tag[len("v"):].strip()
# If cannot get version info, raise warning
if ver == "0.0.0":
warnings.warn(f"Unable to run retrieve version from git info, "
f"maybe not in a Git repository, or tag info missing? "
f"Will write dummy version 0.0.0 to {cls.__VERSION_FP}")
# Write version tag
with open(cls.__VERSION_FP, "w") as ver_fp:
ver_fp.write(cls.__VERSION_PY % ver)
assert os.path.isfile(cls.__VERSION_FP), f"{cls.__VERSION_FP} does not exist."
print(f"Set version to {ver}")
@classmethod
def __read_version_file(cls):
"""Read version from file."""
here = os.path.abspath(os.path.dirname(__file__))
with codecs.open(os.path.join(here, cls.__VERSION_FP), 'r') as fp:
return fp.read()
@classmethod
def get_version(cls):
"""Get version from git tag and write to file.
Return version info.
"""
cls.__update_version_py()
for line in cls.__read_version_file().splitlines():
if line.startswith('__version__'):
delim = '"' if '"' in line else "'"
return line.split(delim)[1]
else:
raise RuntimeError("Unable to find version string.")
with open("README.md", "r", encoding="utf-8") as f:
long_description = f.read()
# Requirements
numpy_requires = [
'numpy>=1.19.5,<2.0.0; python_version>="3.9"'
]
setup_requires = numpy_requires + [
'pytest-runner'
]
install_requires = numpy_requires + [
'scipy>=1.4.1,<1.14.0',
'scikit-learn>=0.24.1',
'torch>=2.0; python_version>="3.9"',
'sentencepiece>=0.1.86,!=0.1.92', # 0.1.92 results in error for transformers
'transformers>=4.31.0; python_version>="3.9"', # the minimal version supporting py3.9
'peft>=0.11.0; python_version>="3.9"',
'datasets>=2.19.1; python_version>="3.9"',
]
# Fetch Numpy before building Numpy-dependent extension, if Numpy required version was not installed
setuptools.distutils.core.Distribution().fetch_build_eggs(numpy_requires)
# Get extra manual compile args if any
# Example usage:
# > PECOS_MANUAL_COMPILE_ARGS="-Werror" python3 -m pip install --editable .
manual_compile_args = os.environ.get('PECOS_MANUAL_COMPILE_ARGS', default=None)
if manual_compile_args:
manual_compile_args = manual_compile_args.split(',')
else:
manual_compile_args = []
# Compile C/C++ extension
ext_module = setuptools.Extension(
"pecos.core.libpecos_float32",
sources=["pecos/core/libpecos.cpp"],
include_dirs=["pecos/core", "/usr/include/", "/usr/local/include"],
libraries=["gomp", "gcc", "stdc++"],
extra_compile_args=["-fopenmp", "-O3", "-std=c++17"] + manual_compile_args,
)
setuptools.setup(
name="libpecos",
version=VersionHelper.get_version(),
description="PECOS - Predictions for Enormous and Correlated Output Spaces",
url="https://github.com/amzn/pecos",
author="Amazon.com, Inc.",
license="Apache 2.0",
packages=setuptools.find_packages(where="."),
package_dir={"": "."},
include_package_data=True,
ext_modules=[ext_module],
long_description=long_description,
long_description_content_type="text/markdown",
setup_requires=setup_requires,
install_requires=install_requires,
tests_require=["pytest"]
)