Debug PyTorch AOT Autograd stage - functionalization, decompositions, IR transformations, joint forward+backward graph (when requires_grad=True), partitioning/recomputation, and post-grad passes. Use for tracing AOT stage and understanding decomposition application.
How this skill is triggered — by the user, by Claude, or both
Slash command
/torch-compile:compile-trace-aotThe summary Claude sees in its skill listing — used to decide when to auto-load this skill
How to trace and debug AOT Autograd: functionalization, joint graph creation, partitioning, and post-grad passes.
How to trace and debug AOT Autograd: functionalization, joint graph creation, partitioning, and post-grad passes.
AOT Autograd = Ahead-of-Time autograd lowering (training-specific transformations)
What it does:
Pipeline Position:
Dynamo → [Pre-Grad] → AOT Autograd → [Post-Grad] → Inductor
↓
Functionalization
Joint Graph
Partitioning
Key Location: torch/_functorch/aot_autograd.py
Training Path (needs_autograd=True):
Inference Path (needs_autograd=False):
Check logs:
TORCH_LOGS="aot" python script.py
Output shows:
[AOT] Compiling forward graph: model__0_forward_0
[AOT] Compiling backward graph: model__0_backward_0
If AOT didn't run (inference):
# No AOT messages, goes straight to Inductor
Minimal (AOT compilation info):
TORCH_LOGS="aot" python script.py
Standard (with graphs):
TORCH_LOGS="aot,aot_graphs" python script.py
Comprehensive (including joint graph):
TORCH_LOGS="aot,aot_graphs,aot_joint_graph,post_grad_graphs" python script.py
| Logger | What It Shows | When to Use |
|---|---|---|
aot | Basic AOT compilation tracking | Verify AOT ran |
aot_graphs | Forward/backward graphs after partitioning | Understanding graph structure |
aot_joint_graph | Combined forward+backward before split | Debugging partitioning |
post_grad_graphs | FX graphs before/after post-grad passes | Pattern matching effects |
import os
os.environ['TORCH_LOGS'] = 'aot,aot_graphs,aot_joint_graph'
import torch._inductor.config as config
config.debug = True
Format: {model_name}_{aot_id}__{graph_type}_{nth_graph}
Examples:
model__0__forward_0.py # First forward graph
model__0__backward_0.py # First backward graph
model__0__joint_0.py # Joint graph (if logged)
model__0__forward_transformed_0.py # After post-grad passes
Joint Graph (before partitioning):
graph():
# Forward inputs (primals)
%arg0 : Tensor = placeholder[target=arg0]
%arg1 : Tensor = placeholder[target=arg1]
# Forward computation
%mul : Tensor = call_function[target=aten.mul](args = (%arg0, 2))
%add : Tensor = call_function[target=aten.add](args = (%mul, %arg1))
# Backward inputs (tangents)
%tangent : Tensor = placeholder[target=tangent]
# Backward computation
%mul_grad : Tensor = call_function[target=aten.mul](args = (%tangent, 2))
# Outputs: forward results + gradients
return (add, mul_grad)
Forward Graph (after partitioning):
graph():
%x : Tensor = placeholder[target=x]
%weight : Tensor = placeholder[target=weight]
%mul : Tensor = call_function[target=aten.mul](args = (%x, %weight))
%add : Tensor = call_function[target=aten.add](args = (%mul, 1))
return (add, mul) # Output + saved activations for backward
Backward Graph (after partitioning):
graph():
%saved_mul : Tensor = placeholder[target=saved_mul] # From forward
%grad_output : Tensor = placeholder[target=grad_output]
%grad_mul : Tensor = call_function[target=aten.mul](args = (%grad_output, 1))
%grad_weight : Tensor = call_function[target=aten.mul](args = (%grad_mul, %saved_mul))
return (grad_weight,)
In Joint Graph:
meta["partitioner_tag"] = "is_forward" or "is_backward"In Partitioned Graphs:
Creates Core ATen IR - removes mutations and aliases to produce functional graph.
Before (Full ATen IR):
def f(x):
x.mul_(2) # In-place mutation
return x.add(1)
After (Core ATen IR):
def f(x):
x_new = x * 2 # Functional
return x_new + 1
# x.mul_() mutation tracked in metadata, applied at runtime
Logging (captured in AOT graphs):
TORCH_LOGS="aot,aot_graphs" python script.py
What to check:
mul_, add_, etc.)Check graph nodes:
grep "mul_\|add_\|sub_" /tmp/torchinductor_$USER/model__*__forward_0.py
# Should find none (all converted to out-of-place)
Check metadata (in Python):
# Graph outputs include mutation info
# Look for: return (output, mutated_input)
Joint Graph = Forward + Backward traced together in single FX graph
Purpose:
TORCH_LOGS="aot_joint_graph" python script.py
Output: model__*__joint_*.py file
Node Tags (check metadata):
# Forward nodes:
%mul : Tensor = call_function[...] # meta["partitioner_tag"] = "is_forward"
# Backward nodes:
%grad_mul : Tensor = call_function[...] # meta["partitioner_tag"] = "is_backward"
Graph Flow:
Inputs (primals) → Forward computation → Outputs
↓ (saved activations)
Tangents (grad outputs) → Backward computation → Gradients
What to look for:
Input: Joint graph (forward + backward) Output: Separate forward and backward graphs
Strategies:
TORCH_LOGS="aot,aot_graphs,aot_joint_graph" python script.py
Compare:
model__*__joint_*.pymodel__*__forward_*.pymodel__*__backward_*.pyCheck forward outputs:
# Forward should output:
# 1. User-visible outputs
# 2. Saved activations for backward
return (output, saved_activation_1, saved_activation_2, ...)
Check backward inputs:
# Backward should receive:
# 1. Saved activations from forward
# 2. Gradient w.r.t. outputs (tangents)
def backward(saved_act_1, saved_act_2, grad_output):
...
Verify correspondence:
# Forward outputs should match backward inputs
grep "return" model__*__forward_*.py
grep "placeholder" model__*__backward_*.py
What is recomputed:
How to identify:
# Compare joint vs backward graph
# If operation appears in both, it's recomputed
diff <(grep "call_function" joint.py | grep "is_forward") \
<(grep "call_function" backward.py)
After: Partitioning Before: Inductor lowering
On: Both forward and backward graphs separately
TORCH_LOGS="post_grad_graphs" python script.py
Output shows:
| Pass | What It Does | How to Verify |
|---|---|---|
| Group Batch Fusion | Batches operations together | Look for fused ops |
| B2B GEMM | Fuses back-to-back matrix multiplies | Check for combined mm ops |
| Remove Noop | Eliminates no-op operations | Count nodes before/after |
| Pattern Matching | Various graph rewrites | Compare transformed graph |
Before Post-Grad:
%mm1 : Tensor = call_function[target=aten.mm](args = (%x, %w1))
%mm2 : Tensor = call_function[target=aten.mm](args = (%mm1, %w2))
After Post-Grad (B2B GEMM fusion):
%fused_mm : Tensor = call_function[target=fused_mm_template](
args = (%x, %w1, %w2)
)
Goal: Confirm AOT Autograd executed
Steps:
Enable logging:
TORCH_LOGS="aot" python script.py
Check for AOT messages:
[AOT] Compiling forward graph: ...
[AOT] Compiling backward graph: ...
If missing:
model.train()x.requires_grad = TrueSymptom: Wrong gradients after compilation
Steps:
Compare with eager:
# Eager mode
loss = model(x)
loss.backward()
grad_eager = x.grad.clone()
# Compiled
model_compiled = torch.compile(model)
loss = model_compiled(x)
loss.backward()
grad_compiled = x.grad.clone()
torch.testing.assert_close(grad_eager, grad_compiled)
Check joint graph:
TORCH_LOGS="aot_joint_graph" python script.py
# Verify backward computation looks correct
Check partitioning:
TORCH_LOGS="aot_graphs" python script.py
# Verify forward saves correct activations
# Verify backward receives correct inputs
Isolate issue:
Symptom: OOM during backward pass
Steps:
Check what's being saved:
TORCH_LOGS="aot_graphs" python script.py
# Look at forward graph outputs
# Count number of saved activations
Enable recomputation:
from torch._functorch.aot_autograd import aot_function
from functools import partial
from functorch.compile import min_cut_rematerialization_partition
# Use min-cut partitioner for memory optimization
# (Usually automatic, but can force via config)
Analyze activation memory:
# Count tensors in forward output
grep "return" model__*__forward_*.py
# Each returned tensor (except user output) is saved
Goal: Confirm expected fusion happened
Steps:
Enable logging:
TORCH_LOGS="post_grad_graphs" python script.py
Compare before/after:
# Count operations
grep "call_function" model__*__forward_0.py | wc -l
grep "call_function" model__*__forward_transformed_0.py | wc -l
Verify specific pattern:
mm → mm fusionSymptom: No AOT log messages, straight to Inductor
Cause: Model in inference mode (no gradients needed)
Debug:
# Check if gradients needed
print(any(p.requires_grad for p in model.parameters()))
print(x.requires_grad)
Fix:
model.train() # Enable training mode
x.requires_grad = True # Or make input require grad
Symptom: High memory usage, OOM
Debug:
TORCH_LOGS="aot_graphs" python script.py
# Check forward output size
grep "return" model__*__forward_*.py
Solutions:
torch.utils.checkpoint.checkpoint()Symptom: Incomplete backward computation
Debug:
TORCH_LOGS="aot_joint_graph" python script.py
# Check if backward nodes present in joint graph
grep "is_backward" model__*__joint_*.py
Common causes:
Fix:
# Ensure gradient flow not broken
# Check for .detach() calls
# Verify requires_grad=True
Symptom: Expected fusion didn't occur
Debug:
TORCH_LOGS="post_grad_graphs" python script.py
# Compare before/after, verify pattern exists
Common causes:
Fix:
# Basic AOT tracing
TORCH_LOGS="aot,aot_graphs" python script.py
# With joint graph
TORCH_LOGS="aot,aot_graphs,aot_joint_graph" python script.py
# Include post-grad passes
TORCH_LOGS="aot,aot_graphs,post_grad_graphs" python script.py
# Full AOT debug
TORCH_LOGS="aot,aot_graphs,aot_joint_graph,post_grad_graphs" python script.py
# View forward graph
cat /tmp/torchinductor_$USER/model__*__forward_0.py
# View backward graph
cat /tmp/torchinductor_$USER/model__*__backward_0.py
# View joint graph (if logged)
cat /tmp/torchinductor_$USER/model__*__joint_0.py
# Compare before/after post-grad
diff model__*__forward_{0,transformed_0}.py
# Verify functionalization (no in-place ops)
grep "mul_\|add_\|sub_" model__*__forward_0.py
# Should return nothing
# Check partitioning (forward outputs match backward inputs)
grep "return" model__*__forward_0.py
grep "placeholder" model__*__backward_0.py
# Count saved activations
grep "return" model__*__forward_0.py | grep -o "%" | wc -l
After AOT Stage: Load compile-trace-inductor skill - Tracing Inductor lowering through codegen
Reference: See compile-overview skill for complete pipeline context.
npx claudepluginhub torchedhat/ai-marketplace --plugin torch-compileGuides creation, editing, and verification of skills for AI coding agents using test-driven development with subagent scenarios. Use when authoring or debugging skills.