Skip to content

[CuTeDSL] Fix random_ and normal_ ops to support torch.compile fullgraph#3175

Open
Flink-ddd wants to merge 1 commit intoNVIDIA:mainfrom
Flink-ddd:fix/replace-inplace-random-ops-for-dynamo
Open

[CuTeDSL] Fix random_ and normal_ ops to support torch.compile fullgraph#3175
Flink-ddd wants to merge 1 commit intoNVIDIA:mainfrom
Flink-ddd:fix/replace-inplace-random-ops-for-dynamo

Conversation

@Flink-ddd
Copy link
Copy Markdown

Purpose

Fixes #3134.

cutlass.torch.matrix and prepare_gemm_tensors internally use in-place ops Tensor.random_() and Tensor.normal_() to initialize tensors. These ops are not supported by torch.compile(fullgraph=True), causing torch.dynamo.exc.Unsupported: Tensor.random op when calling cutlass_torch.matrix inside a compiled region.

Fix: replace all in-place random initialization ops with their out-of-place equivalents like torch.randint and torch.normal.

Test Plan

Verified on NVIDIA A100 80GB PCIe, CUDA 12.2, Python 3.12, PyTorch 2.8.0.

Test Result

Before fix:

Traceback (most recent call last):
File "/workspace/reproduce.py", line 18, in <module>
result = test_fn()
^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 745, in compile_wrapper
raise e.with_traceback(None) from e.cause  # User compiler error
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.dynamo.exc.Unsupported: Tensor.random op
Explanation: This is currently not supported.
Hint: Use the out-of-place version of this op
Hint: It may be possible to write Dynamo tracing rules for this code.
Please report an issue to PyTorch if you encounter this graph
break often and it is causing performance issues.
Developer debug context: Tensor.random_(args=[ConstantVariable(int: -2),
ConstantVariable(int: 2)], kwargs={})
from user code:
File "/workspace/reproduce.py", line 15, in test_fn
x = cutlass_torch.matrix(1, 8, 8, False, Float16)
File ".../nvidia_cutlass_dsl/python_packages/cutlass/torch.py", line 273, in matrix
torch_tensor = create_and_permute_torch_tensor(
File ".../nvidia_cutlass_dsl/python_packages/cutlass/torch.py", line 146, in create_and_permute_torch_tensor
f32_torch_tensor = init_torch_tensor.random_(
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace.

After fix:

matrix shape: torch.Size([8, 8, 1]), dtype: torch.float16
compiled matrix shape: torch.Size([8, 8, 1]), dtype: torch.float16

@Flink-ddd Flink-ddd changed the title [Bugfix][CuTeDSL] Fix random_ and normal_ ops to support torch.compile fullgraph [CuTeDSL] Fix random_ and normal_ ops to support torch.compile fullgraph Apr 19, 2026
@Flink-ddd
Copy link
Copy Markdown
Author

Hi @hwu36 , Could you please take a look when you have time? Thanks.

@hwu36
Copy link
Copy Markdown
Collaborator

hwu36 commented Apr 19, 2026

@brandon-yujie-sun

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] [CuTeDSL] cutlass_torch.matrix throws torch._dynamo.exc.Unsupported error with torch.compile fullGraph

2 participants