From pytorch-skills
Debug and fix torch.compile graph breaks to prioritize and eliminate graph-breaks, and optionally achieve fullgraph compilation
How this skill is triggered — by the user, by Claude, or both
Slash command
/pytorch-skills:debug-graph-breaksThe summary Claude sees in its skill listing — used to decide when to auto-load this skill
Diagnose and eliminate `torch.compile` graph breaks by running the script, reading the output, applying fixes directly, and re-running to verify — whether the goal is `fullgraph=True` or simply reducing the most impactful breaks.
Diagnose and eliminate torch.compile graph breaks by running the script, reading the output, applying fixes directly, and re-running to verify — whether the goal is fullgraph=True or simply reducing the most impactful breaks.
When this skill is invoked, help the user find, understand, and fix graph breaks in their torch.compiled code. The workflow has three phases: Detect, Diagnose, Fix.
The graph break documentation website at https://meta-pytorch.org/compile-graph-break-site/ has detailed pages for each known graph break type. When the user encounters a specific graph break (e.g., GB0059), fetch the corresponding page (e.g., https://meta-pytorch.org/compile-graph-break-site/gb/gb0059.html) to get detailed context, examples, and fix suggestions specific to that break. If the user provides a local path to a clone of the graph break website repository, read the documentation files directly from that directory instead.
If the user provides a script path and no existing logs, run it with TORCH_LOGS="graph_breaks":
TORCH_LOGS="graph_breaks" python your_script.py 2>&1
This prints each graph break with:
If the output is ambiguous or a break's origin is unclear, re-run with verbose mode:
TORCH_LOGS="graph_breaks" TORCHDYNAMO_VERBOSE=1 python your_script.py 2>&1
The verbose mode adds the internal Dynamo stack trace and recent bytecode instructions, which can help when the standard output is ambiguous.
If the user provides existing TORCH_LOGS output, parse it directly and skip to Phase 2.
Other detection methods are available if the user requests them or provides the relevant artifacts:
If the user has a tlparse report (a structured HTML report from TORCH_TRACE), they can provide:
A tlparse report contains:
index.html: Main page with a stack trie showing all compilations as a tree, color-coded by status (green = success, lime = graph break, red = error). Each compilation is identified by a compile id like [0/0] (frame 0, first compile) or [1/0] (frame 1, first compile — often a resume after graph break).failures_and_restarts.html: Table of all RestartAnalysis events (graph breaks) and compilation failures, with the full graph break reason for each.-_0_0_0/, -_1_0_1/): Contain build products:
dynamo_graph_break_reason_N.txt — Full graph break details: user code location, reason, explanation, hints, user stack trace, and internal Dynamo traceback.dynamo_output_graph_N.txt — The FX graph produced by that compilation.compilation_metrics_N.html — Compile time, graph metrics (ops, nodes, inputs), restart reasons, guard count, cache metrics.*_ORIGINAL_BYTECODE_N.txt / *_MODIFIED_BYTECODE_N.txt — Bytecode before/after Dynamo modification.When reading a tlparse report:
failures_and_restarts.html to get a quick summary of all graph breaks.dynamo_graph_break_reason_N.txt for full details.compilation_metrics_N.html to understand compile time impact per subgraph._1 suffix (e.g., [0/0_1]) indicate a restart — the _1 means Dynamo restarted analysis after discovering a graph break on attempt _0.fullgraph=True (fix one break at a time)If the user wants to fix breaks one at a time rather than all at once, use fullgraph=True:
torch.compile(fn, fullgraph=True)(*args)
This raises an error on the first graph break encountered. Fix it, re-run, fix the next.
For each graph break found, do the following:
Parse the graph break reason. Extract the GB type string (e.g., "Failed to trace builtin operator", "Unsupported Tensor.item() call with capture_scalar_outputs=False") and the GB documentation URL from the output.
Fetch the graph break documentation page. Use the URL from the log output (e.g., https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0059.html) to get detailed information, examples, and suggested fixes for that specific graph break type. If a local clone of the graph break site was provided, read the corresponding file from the local directory instead.
Classify the graph break using the hint categories from the log output:
torch.compile.Identify the user code location. Use the user stack trace to pinpoint the exact line in user code causing the break. Read that code directly to understand context.
Read the user's script. Always read the relevant code in the user's script to understand the surrounding context. Compare it to examples from the graph break documentation page to identify the exact pattern causing the break and determine the most targeted fix.
Prioritize. Not all graph breaks are equal. Decide which to fix first:
print() calls, logging, simple .item() calls — these have straightforward fixes.fullgraph=True, identify which breaks matter most for performance and which can be left alone.For each graph break, apply a fix based on its type. Always read the user's code before making changes and then make changes directly.
When graph breaks originate in a third-party library (not the user's code), the user cannot directly modify the library. Instead:
Inspect the library source. If the user provides the path to the installed package (e.g., the transformers site-packages directory), read the relevant source files to understand the problematic code and its call chain.
Subclass to override. Create a subclass of the model that overrides the method causing the graph break with a compile-compatible implementation. This avoids modifying the library and is portable:
from transformers import BartModel
class CompilableBartModel(BartModel):
def _override_problematic_method(self, *args, **kwargs):
# Compile-compatible reimplementation
...
model = CompilableBartModel(config)
compiled_model = torch.compile(model, fullgraph=True)
fullgraph=True. After applying the subclass fix, re-run with fullgraph=True to confirm all graph breaks are resolved.This approach is preferred over monkey-patching because it creates a self-contained, reusable fix.
These are APIs that let code run inside a compiled region without causing graph breaks:
torch._higher_order_ops.print(format_str, *args, **kwargs)
print(). Uses Python format-string syntax.torch._higher_order_ops.print("Activated shape: {}", h.shape)torch._dynamo.config.reorderable_logging_functions
print, logging.info, custom loggers) to this set. Dynamo will reorder them to avoid graph breaks while preserving the logging call.torch._dynamo.config.reorderable_logging_functions.add(print)torch._dynamo.config.ignore_logging_functions
torch._dynamo.config.ignore_logging_functions = [print]@torch._dynamo.decorators.leaf_function
@fn.register_fake for shape inference at compile time.@torch._dynamo.decorators.leaf_function
def my_function(x):
# This runs eagerly at runtime, no graph break
...
@my_function.register_fake
def my_function_fake(x):
# Return a tensor with the correct shape/dtype for compile-time tracing
return torch.empty_like(x)
@leaf_function(mutates_args={"buf"})@torch._dynamo.decorators.nonstrict_trace
@torch._dynamo.decorators.nonstrict_trace
def my_function(x):
# Dynamo won't trace this, but AOT autograd will
return x.relu()
torch.compiler.disable()
Custom ops (torch.library)
import torch.library
@torch.library.custom_op("mylib::my_op", mutates_args=())
def my_op(x: torch.Tensor) -> torch.Tensor:
# arbitrary non-traceable code here
return x.numpy().copy() # example
@my_op.register_fake
def my_op_fake(x):
return torch.empty_like(x)
@leaf_function, custom ops are part of the public torch.library API and show up as named nodes in the FX graph, making them more suitable when the op will be reused across multiple compiled functions.print() / logging (GB0059: "Failed to trace builtin operator")
torch._dynamo.config.reorderable_logging_functions.add(print) — keeps the print, no code change neededtorch._dynamo.config.ignore_logging_functions = [print] — drops the print during compilationtorch._higher_order_ops.print("format {}", arg) — compile-safe print HOPData-dependent branching (GB0035, GB0170)
if tensor.item() > 0: or if tensor.sum(): — control flow depends on tensor values.torch.cond() for simple if/else, or restructure to avoid data-dependent branches. torch.where() works as a branchless replacement for simple cases.torch._check() or make it a Python constant.Tensor.item() / tolist() (GB0124, GB0109)
torch.norm(h) instead of h.norm().item() followed by scalar math).torch._dynamo.config.capture_scalar_outputs = True to capture .item() in the graph..item() call outside the compiled region.Unsupported function call (GB0147)
@leaf_function with a register_fake if shape inference is possible.@nonstrict_trace if the function body contains ops that benefit from compilation.torch.compiler.disable() if the function doesn't need compilation.Attempt to trace generator (GB0003)
Unsupported context manager (GB0142)
torch.no_grad() instead of torch.inference_mode()).Module-level hooks (GB0083, GB0029)
torch._dynamo.config.compiled_autograd=True) or remove the hooks during compilation.Graph break in loop (GB7000)
torch.compile: Split your function so the compilable core is in one function and non-compilable setup/teardown is outside.max() -> torch.max(), abs() -> torch.abs(), sorted() -> torch.argsort()..item(), .tolist(), .numpy(), bool(tensor), int(tensor).torch.where, torch.clamp, torch.norm).After applying fixes, re-run the script with TORCH_LOGS="graph_breaks" (or fullgraph=True) to verify:
fullgraph=True runs without errors, or the most impactful breaks are eliminated.After fixing graph breaks, help the user measure the improvement. Suggest these non-intrusive approaches:
Measuring compile time:
torch._dynamo.utils.CompileProfiler as the backend to see per-graph compile times:with torch._dynamo.utils.CompileProfiler() as prof:
compiled_fn = torch.compile(fn, backend=prof)
compiled_fn(*args)
print(prof.report())
Measuring runtime performance:
import time
# Warmup
for _ in range(3):
compiled_fn(*args)
if torch.cuda.is_available():
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(100):
compiled_fn(*args)
if torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / 100
print(f"Average iteration: {elapsed*1000:.2f}ms")
Comparing before/after:
TORCH_LOGS="graph_breaks" output (count the "Graph break in user code" lines) or from a tlparse report's stack trie.torch/_dynamo/graph_break_hints.pyleaf_function: torch/_dynamo/decorators.pynonstrict_trace: torch/_dynamo/decorators.pytorch.compiler.disable(): torch/_dynamo/decorators.pyerror_on_graph_break(): torch/_dynamo/decorators.pytorch._higher_order_ops.print: torch/_higher_order_ops/print.pyCompileProfiler: torch/_dynamo/utils.pyProvides CDSS development patterns for drug interaction checking, dose validation, clinical scoring (NEWS2, qSOFA), and alert classification integrated into EMR workflows.
npx claudepluginhub meta-pytorch/skills --plugin pytorch-skills