-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Expected behavior
A tiny Transformer-like block exported via torch.export crashes TVM when importing with tvm.relax.frontend.torch.from_exported_program(ep).
Before the crash, PyTorch emits warnings that torch.export inserted a get_attr node without a backing submodule/parameter/buffer. TVM then segfaults in tvm::relax::Tuple::Tuple(...)/FFI path while translating the exported program.
Actual behavior
torch.export succeeds but prints the above get_attr lifting warnings.
Immediately after, tvm.relax.frontend.torch.from_exported_program(ep) triggers an FFI segfault.
(In my run it shows an FFI backtrace ending in tvm::relax::Tuple::Tuple(...) / TVM FFI traceback.)
!!!!!!! TVM FFI encountered a Segfault !!!!!!!
... tvm::relax::Tuple::Tuple(...) ...
Segmentation fault (core dumped)
Environment
- OS: (Ubuntu 22.04.4 LTS (x86_64))
- TVM version: (release v0.21.0)
- Python: (3.10.16)
- LLVM: (17.0.6)
- Pytorch: (2.7.1)
Steps to reproduce
# mini_repro_export_tvm_segfault.py
import math
import torch
import torch.nn as nn
def get_attn_pad_mask(seq_q, seq_k):
B, Lq = seq_q.size()
B2, Lk = seq_k.size()
assert B == B2
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # (B,1,Lk)
return pad_attn_mask.expand(B, Lq, Lk) # (B,Lq,Lk)
class TinyMHA(nn.Module):
def __init__(self, d_model=64, d_k=16, n_heads=4, dropout=0.1):
super().__init__()
self.h, self.dk = n_heads, d_k
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
self.W_V = nn.Linear(d_model, d_k * n_heads, bias=False)
self.proj = nn.Linear(d_k * n_heads, d_model, bias=False)
self.ln = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, attn_mask): # x: (B,L,dm), attn_mask: (B,L,L)
B, L, _ = x.shape
q = self.W_Q(x).view(B, L, self.h, self.dk).transpose(1, 2) # (B,H,L,dk)
k = self.W_K(x).view(B, L, self.h, self.dk).transpose(1, 2)
v = self.W_V(x).view(B, L, self.h, self.dk).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.dk) # (B,H,L,L)
# In-place masked_fill_ with broadcasted mask coming from eq(0)+expand
scores.masked_fill_(attn_mask.unsqueeze(1), -1e9)
attn = torch.softmax(scores, dim=-1)
ctx = torch.matmul(attn, v).transpose(1, 2).reshape(B, L, self.h * self.dk)
out = self.drop(self.proj(ctx))
return self.ln(out + x)
class MiniModel(nn.Module):
def __init__(self, vocab=10000, d_model=64):
super().__init__()
self.emb = nn.Embedding(vocab, d_model)
self.mha = TinyMHA(d_model=d_model, d_k=16, n_heads=4, dropout=0.1)
self.proj = nn.Linear(d_model, vocab, bias=False)
def forward(self, enc_inputs, dec_inputs_unused=None):
x = self.emb(enc_inputs) # (B,L,dm)
mask = get_attn_pad_mask(enc_inputs, enc_inputs) # (B,L,L)
y = self.mha(x, mask) # (B,L,dm)
logits = self.proj(y) # (B,L,V)
return logits.reshape(-1, logits.size(-1)) # (B*L, V)
def my_model_function(): return MiniModel()
def GetInput():
enc = torch.randint(0, 10000, (2, 5))
enc[0, 0] = 0 # ensure eq(0) path is taken
dec = torch.randint(0, 10000, (2, 5))
return (enc, dec)
import numpy as np
from torch.export import export as torch_export
from tvm.relax.frontend.torch import from_exported_program
def trigger_known_bugs(model=None):
if model is None:
model = my_model_function()
torch.manual_seed(42); np.random.seed(42)
model.eval()
args = GetInput()
ep = torch_export(model, args) # Emits get_attr warnings (see below)
mod = from_exported_program(ep) # <-- TVM segfaults here in my env
print(mod)
if __name__ == "__main__":
import os
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "6,7")
trigger_known_bugs()Triage
- needs-triage
- bug
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug