Skip to content

Commit

Permalink
move PyTorch version processing out of sort key (#68)
Browse files Browse the repository at this point in the history
* move PyTorch version processing out of sort key

* refactor to not change candidate version
  • Loading branch information
pmeier authored May 4, 2022
1 parent 453b02b commit 50dbc37
Showing 1 changed file with 40 additions and 19 deletions.
59 changes: 40 additions & 19 deletions light_the_torch/_patch.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import contextlib
import dataclasses

import enum
import functools

import itertools
import optparse
import re
import sys
Expand Down Expand Up @@ -253,34 +252,56 @@ def context(input):

@contextlib.contextmanager
def patch_candidate_selection(computation_backends):
allowed_locals = {None, *computation_backends}
computation_backend_pattern = re.compile(
r"^/whl/(?P<computation_backend>(cpu|cu\d+|rocm([\d.]+)))/"
r"/(?P<computation_backend>(cpu|cu\d+|rocm([\d.]+)))/"
)

def extract_local_specifier(candidate):
local = candidate.version.local

if local is None:
match = computation_backend_pattern.search(candidate.link.path)
local = match["computation_backend"] if match else "any"

# Early PyTorch distributions used the "any" local specifier to indicate a
# pure Python binary. This was changed to no local specifier later.
# Setting this to "cpu" is technically not correct as it will exclude this
# binary if a non-CPU backend is requested. Still, this is probably the
# right thing to do, since the user requested a specific backend and
# although this binary will work with it, it was not compiled against it.
if local == "any":
local = "cpu"

return local

def preprocessing(input):
candidates = iter(input.candidates)
candidate = next(candidates)

if candidate.name not in PYTORCH_DISTRIBUTIONS:
# At this stage all candidates have the same name. Thus, if the first is
# not a PyTorch distribution, we don't need to check the rest and can
# return without changes.
return

input.candidates = [
candidate
for candidate in input.candidates
if candidate.name not in PYTORCH_DISTRIBUTIONS
or candidate.version.local in allowed_locals
for candidate in itertools.chain([candidate], candidates)
if extract_local_specifier(candidate) in computation_backends
]

sort_key = CandidateEvaluator._sort_key
vanilla_sort_key = CandidateEvaluator._sort_key

def patched_sort_key(candidate_evaluator, candidate):
if candidate.name not in PYTORCH_DISTRIBUTIONS:
return sort_key(candidate_evaluator, candidate)

if candidate.version.local is not None:
computation_backend_str = candidate.version.local.replace("any", "cpu")
else:
match = computation_backend_pattern.match(candidate.link.path)
computation_backend_str = match["computation_backend"] if match else "cpu"

# At this stage all candidates have the same name. Thus, we don't need to
# mirror the exact key structure that the vanilla sort keys have.
return (
cb.ComputationBackend.from_str(computation_backend_str),
candidate.version,
vanilla_sort_key(candidate_evaluator, candidate)
if candidate.name not in PYTORCH_DISTRIBUTIONS
else (
cb.ComputationBackend.from_str(extract_local_specifier(candidate)),
candidate.version.base_version,
)
)

with apply_fn_patch(
Expand Down

0 comments on commit 50dbc37

Please sign in to comment.