Traceback (most recent call last):
File "faster_rcnn/lower.py", line 11, in <module>
exported_program = torch.export.export_for_training(model, (example_args,))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/export/__init__.py", line 168, in export_for_training
return _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/export/_trace.py", line 1044, in wrapper
raise e
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/export/_trace.py", line 1017, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/export/exported_program.py", line 117, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/export/_trace.py", line 1944, in _export_for_training
export_artifact = export_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/export/_trace.py", line 1296, in _strict_export_lower_to_aten_ir
gm_torch_level = _export_to_torch_ir(
^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/export/_trace.py", line 693, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1579, in inner
result_traced = opt_f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1760, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 570, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1760, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1400, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 565, in __call__
return _compile(
^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 997, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 726, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 760, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1404, in transform_code_object
transformations(instructions, code_options)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 236, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 680, in transform
tracer.run()
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2906, in run
super().run()
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1078, in run
while self.step():
^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 988, in step
self.dispatch_table[inst.opcode](self, inst)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 685, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2378, in CALL
self._call(inst)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2372, in _call
self.call_function(fn, args, kwargs)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 923, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py", line 444, in call_function
return tx.inline_user_function_return(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 929, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3112, in inline_call
return tracer.inline_call_()
^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3249, in inline_call_
self.run()
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1078, in run
while self.step():
^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 988, in step
self.dispatch_table[inst.opcode](self, inst)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 685, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1765, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 923, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 461, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 319, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 120, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 929, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3112, in inline_call
return tracer.inline_call_()
^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3249, in inline_call_
self.run()
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1078, in run
while self.step():
^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 988, in step
self.dispatch_table[inst.opcode](self, inst)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 685, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2378, in CALL
self._call(inst)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2372, in _call
self.call_function(fn, args, kwargs)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 923, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py", line 444, in call_function
return tx.inline_user_function_return(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 929, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3112, in inline_call
return tracer.inline_call_()
^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3249, in inline_call_
self.run()
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1078, in run
while self.step():
^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 988, in step
self.dispatch_table[inst.opcode](self, inst)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 685, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1765, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 923, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 461, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 319, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 120, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 929, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3112, in inline_call
return tracer.inline_call_()
^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3249, in inline_call_
self.run()
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1078, in run
while self.step():
^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 988, in step
self.dispatch_table[inst.opcode](self, inst)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 685, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2378, in CALL
self._call(inst)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2372, in _call
self.call_function(fn, args, kwargs)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 923, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 461, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 319, in call_function
return super().call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 120, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 929, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3112, in inline_call
return tracer.inline_call_()
^^^^^^^^^^^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3249, in inline_call_
self.run()
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1078, in run
while self.step():
^^^^^^^^^^^
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 988, in step
self.dispatch_table[inst.opcode](self, inst)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1841, in STORE_ATTR
not self.export
AssertionError: Mutating module attribute cell_anchors during export.
from user code:
File "faster_rcnn/venv/lib/python3.12/site-packages/torchvision/models/detection/generalized_rcnn.py", line 104, in forward
proposals, proposal_losses = self.rpn(images, features, targets)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1760, in _call_impl
return forward_call(*args, **kwargs)
File "faster_rcnn/venv/lib/python3.12/site-packages/torchvision/models/detection/rpn.py", line 362, in forward
anchors = self.anchor_generator(images, features)
File "faster_rcnn/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1760, in _call_impl
return forward_call(*args, **kwargs)
File "faster_rcnn/venv/lib/python3.12/site-packages/torchvision/models/detection/anchor_utils.py", line 126, in forward
self.set_cell_anchors(dtype, device)
File "faster_rcnn/venv/lib/python3.12/site-packages/torchvision/models/detection/anchor_utils.py", line 77, in set_cell_anchors
self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Python version: 3.12.3 (main, Apr 9 2024, 08:09:14) [Clang 15.0.0 (clang-1500.3.9.4)] (64-bit runtime)
Python platform: macOS-14.6.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
🐛 Describe the bug
I want to be able to export
fasterrcnn_mobilenet_v3_large_fpnfor training, so it can be quantized. But runningtorch.export.export_for_trainingfails.Full traceback is below:
Versions
PyTorch version: 2.7.0.dev20250130
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.6.1 (arm64)
GCC version: Could not collect
Clang version: 18.1.7
CMake version: version 3.29.3
Libc version: N/A
Python version: 3.12.3 (main, Apr 9 2024, 08:09:14) [Clang 15.0.0 (clang-1500.3.9.4)] (64-bit runtime)
Python platform: macOS-14.6.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M2 Pro
Versions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] torch==2.7.0.dev20250130
[pip3] torchaudio==2.6.0.dev20250130
[pip3] torchvision==0.22.0.dev20250130
[conda] Could not collect
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4