diff --git a/light_the_torch/_patch.py b/light_the_torch/_patch.py index 9c1ed5b..950f6f6 100644 --- a/light_the_torch/_patch.py +++ b/light_the_torch/_patch.py @@ -5,7 +5,6 @@ import itertools import optparse import os -import platform import re import sys import unittest.mock @@ -253,7 +252,7 @@ def postprocessing(input, output): yield -def get_extra_index_urls(computation_backends, channel): +def get_index_urls(computation_backends, channel): if channel == Channel.STABLE: channel_paths = [""] else: @@ -284,9 +283,7 @@ def context(input): { requirement.name for requirement in input.root_reqs - if requirement.user_supplied - and not is_pinned(requirement) - and requirement.name in THIRD_PARTY_PACKAGES + if requirement.user_supplied and is_pinned(requirement) }, ): yield @@ -305,12 +302,10 @@ def context(input): @contextlib.contextmanager -def patch_link_collection( - computation_backends, channel, user_supplied_third_party_packages -): +def patch_link_collection(computation_backends, channel, user_supplied_pinned_packages): search_scope = SearchScope( find_links=[], - index_urls=get_extra_index_urls(computation_backends, channel), + index_urls=get_index_urls(computation_backends, channel), no_index=False, ) @@ -319,9 +314,8 @@ def context(input): if not ( input.project_name in PYTORCH_DISTRIBUTIONS or ( - channel == Channel.NIGHTLY - and platform.system() == "Linux" - and input.project_name not in user_supplied_third_party_packages + input.project_name in THIRD_PARTY_PACKAGES + and input.project_name not in user_supplied_pinned_packages ) ): yield diff --git a/scripts/check_pytorch_package_indices.py b/scripts/check_pytorch_package_indices.py index 6cd51d7..b946fd0 100755 --- a/scripts/check_pytorch_package_indices.py +++ b/scripts/check_pytorch_package_indices.py @@ -8,7 +8,7 @@ from light_the_torch._cb import _MINIMUM_DRIVER_VERSIONS, CPUBackend, CUDABackend from light_the_torch._patch import ( Channel, - get_extra_index_urls, + get_index_urls, PYTORCH_DISTRIBUTIONS, THIRD_PARTY_PACKAGES, ) @@ -40,11 +40,10 @@ } COMPUTATION_BACKENDS.add(CPUBackend()) -EXTRA_INDEX_URLS = sorted( +INDEX_URLS = sorted( set( itertools.chain.from_iterable( - get_extra_index_urls(COMPUTATION_BACKENDS, channel) - for channel in iter(Channel) + get_index_urls(COMPUTATION_BACKENDS, channel) for channel in iter(Channel) ) ) ) @@ -52,7 +51,7 @@ def main(): available = set() - for url in tqdm.tqdm(EXTRA_INDEX_URLS): + for url in tqdm.tqdm(INDEX_URLS): response = requests.get(url) if not response.ok: continue