Skip to content

Commit

Permalink
[Python] refactor slices on sorted (pytorch#86995)
Browse files Browse the repository at this point in the history
Sometimes you want to query the small element of a set of elements and use `sorted(elements)[0]` without a second thought. However, this is not optimal, since the entire list must be sorted first `O(n log n)`. It would be better to use the `min(elements)` method provided for this purpose `O(n)`.
Furthermore `sorted(elements)[::-1]` is not very efficient, because it would be better to use `sorted(elements, reverse=True)` to save the slice operation.

**TLDR: using `sorted(elements)[0]` is slow and can be replaced with `min(elements)`.**

I stumbled across these code snippets while playing around with CodeQL (see https://lgtm.com/query/4148064474379348546/).
Pull Request resolved: pytorch#86995
Approved by: https://github.com/jansel
  • Loading branch information
ScholliYT authored and pytorchmergebot committed Oct 25, 2022
1 parent 98f40af commit fd60b81
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions tools/testing/test_selections.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def calculate_shards(
]
for test in sorted_tests:
if must_serial(test):
min_sharded_job = sorted(sharded_jobs, key=lambda j: j.get_total_time())[0]
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
min_sharded_job.serial.append(test)
else:
min_sharded_job = sorted(sharded_jobs, key=lambda j: j.get_total_time())[0]
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
min_sharded_job.parallel.append(test)

# Round robin the unknown jobs starting with the smallest shard
index = sorted(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time())[0]
index = min(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time())
for test in unknown_tests:
sharded_jobs[index].serial.append(test)
index = (index + 1) % num_shards
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/rpc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
_ALL_WORKER_NAMES is not None
), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
worker_names = _ALL_WORKER_NAMES
leader_name = sorted(worker_names)[0]
leader_name = min(worker_names)

self_name = _get_current_rpc_agent().get_worker_info().name

Expand Down
4 changes: 2 additions & 2 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,13 @@ def create_symbolic_sizes_strides(self, ex: torch.Tensor):
candidates[ex.size(i) * ex.stride()[i]] = size[i] * stride[i]
if any(x is None for x in stride):
# bind the smallest unbound stride to a new variable
val, i = sorted(
val, i = min(
[
(ex.stride()[i], i)
for i in range(len(stride))
if stride[i] is None
]
)[0]
)
stride[i] = self.create_symbol(val)
assert all(x is not None for x in stride)
return [self.create_symintnode(i) for i in size], [self.create_symintnode(i) for i in stride] # type: ignore[arg-type]
Expand Down
2 changes: 1 addition & 1 deletion torch/masked/maskedtensor/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _masked_all(*args, **kwargs):
def _multidim_any(mask, dim, keepdim):
if isinstance(dim, int):
return _multidim_any(mask, [dim], keepdim)
for d in sorted(dim)[::-1]:
for d in sorted(dim, reverse=True):
mask = torch.any(mask, dim=d, keepdim=keepdim)
return mask

Expand Down

0 comments on commit fd60b81

Please sign in to comment.