Skip to content

Commit

Permalink
always use --index-url (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Sep 1, 2023
1 parent d914a8e commit 5ad5ff3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 17 deletions.
18 changes: 6 additions & 12 deletions light_the_torch/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import itertools
import optparse
import os
import platform
import re
import sys
import unittest.mock
Expand Down Expand Up @@ -253,7 +252,7 @@ def postprocessing(input, output):
yield


def get_extra_index_urls(computation_backends, channel):
def get_index_urls(computation_backends, channel):
if channel == Channel.STABLE:
channel_paths = [""]
else:
Expand Down Expand Up @@ -284,9 +283,7 @@ def context(input):
{
requirement.name
for requirement in input.root_reqs
if requirement.user_supplied
and not is_pinned(requirement)
and requirement.name in THIRD_PARTY_PACKAGES
if requirement.user_supplied and is_pinned(requirement)
},
):
yield
Expand All @@ -305,12 +302,10 @@ def context(input):


@contextlib.contextmanager
def patch_link_collection(
computation_backends, channel, user_supplied_third_party_packages
):
def patch_link_collection(computation_backends, channel, user_supplied_pinned_packages):
search_scope = SearchScope(
find_links=[],
index_urls=get_extra_index_urls(computation_backends, channel),
index_urls=get_index_urls(computation_backends, channel),
no_index=False,
)

Expand All @@ -319,9 +314,8 @@ def context(input):
if not (
input.project_name in PYTORCH_DISTRIBUTIONS
or (
channel == Channel.NIGHTLY
and platform.system() == "Linux"
and input.project_name not in user_supplied_third_party_packages
input.project_name in THIRD_PARTY_PACKAGES
and input.project_name not in user_supplied_pinned_packages
)
):
yield
Expand Down
9 changes: 4 additions & 5 deletions scripts/check_pytorch_package_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from light_the_torch._cb import _MINIMUM_DRIVER_VERSIONS, CPUBackend, CUDABackend
from light_the_torch._patch import (
Channel,
get_extra_index_urls,
get_index_urls,
PYTORCH_DISTRIBUTIONS,
THIRD_PARTY_PACKAGES,
)
Expand Down Expand Up @@ -40,19 +40,18 @@
}
COMPUTATION_BACKENDS.add(CPUBackend())

EXTRA_INDEX_URLS = sorted(
INDEX_URLS = sorted(
set(
itertools.chain.from_iterable(
get_extra_index_urls(COMPUTATION_BACKENDS, channel)
for channel in iter(Channel)
get_index_urls(COMPUTATION_BACKENDS, channel) for channel in iter(Channel)
)
)
)


def main():
available = set()
for url in tqdm.tqdm(EXTRA_INDEX_URLS):
for url in tqdm.tqdm(INDEX_URLS):
response = requests.get(url)
if not response.ok:
continue
Expand Down

0 comments on commit 5ad5ff3

Please sign in to comment.