Skip to content

Commit

Permalink
Handle unbacked symints in buffer reuse calculation (pytorch#109603)
Browse files Browse the repository at this point in the history
This is rewritten from pytorch#106655 to land faster, with peterbell10's comments.

Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: pytorch#109603
Approved by: https://github.com/yf225
  • Loading branch information
ezyang authored and pytorchmergebot committed Sep 20, 2023
1 parent 63025d4 commit b771c04
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
12 changes: 5 additions & 7 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
cache_on_self,
get_benchmark_name,
LineContext,
sympy_dot,
sympy_product,
sympy_str,
)
from ..virtualized import V
from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter
Expand All @@ -33,15 +33,13 @@


def buffer_reuse_key(node: ir.Buffer):
size = node.get_size()
stride = node.get_stride()
last_element = sympy_dot([s - 1 for s in size], stride)
return (
node.get_device(),
node.get_dtype(),
V.graph.sizevars.simplify(sympy_product(size)),
# Detect gaps in tensor storage caused by strides
V.graph.sizevars.size_hint(last_element),
# NB: this is symbolic so that we don't try to reuse a buffer
# for s0 for s1, just because they happen to share the same
# size hint
sympy_str(V.graph.sizevars.simplify(node.layout.storage_size())),
)


Expand Down
3 changes: 3 additions & 0 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2715,6 +2715,9 @@ def create_unbacked_symint(self):

return SymInt(SymNode(symbol, self, int, None, fx_node=fx_node))

def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool:
return str(symbol).startswith("i")

@record_shapeenv_event()
def create_unbacked_symbool(self):
symbol: sympy.Symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
Expand Down

0 comments on commit b771c04

Please sign in to comment.