Skip to content

Commit 3b8d324

Browse files
authored
[Relax][PyTorch] Support gru op for ExportedProgram importer (#18360)
1 parent f30b29c commit 3b8d324

File tree

2 files changed

+366
-0
lines changed

2 files changed

+366
-0
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,300 @@ def _lstm(self, node: fx.Node) -> relax.Var:
391391
output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2]))
392392
return output
393393

394+
def _gru(self, node: fx.Node) -> relax.Var:
395+
args = self.retrieve_args(node)
396+
input_tensor = args[0]
397+
hx = args[1] if len(args) > 1 else None
398+
params = args[2] if len(args) > 2 else None
399+
has_biases = args[3] if len(args) > 3 else True
400+
num_layers = args[4] if len(args) > 4 else 1
401+
_dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
402+
_train = args[6] if len(args) > 6 else False # Not used in inference
403+
bidirectional = args[7] if len(args) > 7 else False
404+
batch_first = args[8] if len(args) > 8 else False
405+
406+
if bidirectional:
407+
raise NotImplementedError("Bidirectional GRU is not yet supported")
408+
409+
input_shape = self.shape_of(input_tensor)
410+
if batch_first:
411+
batch_size, seq_len, input_size = input_shape
412+
else:
413+
seq_len, batch_size, input_size = input_shape
414+
415+
if isinstance(seq_len, tvm.tir.IntImm):
416+
seq_len = seq_len.value
417+
if isinstance(batch_size, tvm.tir.IntImm):
418+
batch_size = batch_size.value
419+
if isinstance(input_size, tvm.tir.IntImm):
420+
input_size = input_size.value
421+
422+
if params and len(params) >= 2:
423+
# For multi-layer, we need to extract the first layer's weights
424+
# to determine hidden size
425+
if num_layers > 1:
426+
# Multi-layer: params[0] is first layer's weight_ih
427+
weight_ih = params[0]
428+
else:
429+
# Single layer: params[0] is weight_ih
430+
weight_ih = params[0]
431+
# Extract hidden size from weight dimensions
432+
# weight_ih has shape (3 * hidden_size, input_size)
433+
weight_ih_shape = self.shape_of(weight_ih)
434+
hidden_size = weight_ih_shape[0] // 3 # 3 gates: reset, update, new
435+
else:
436+
# Fallback to a default hidden size
437+
hidden_size = 16
438+
439+
# Implement actual GRU computation using Relax operations
440+
# GRU equations:
441+
# r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr)
442+
# z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz)
443+
# n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn))
444+
# h_t = (1 - z_t) * n_t + z_t * h_{t-1}
445+
dtype = input_tensor.struct_info.dtype
446+
447+
# Reshape input for processing
448+
if batch_first:
449+
# Input: (batch, seq_len, input_size) -> (seq_len, batch, input_size)
450+
input_reshaped = self.block_builder.emit(
451+
relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
452+
)
453+
else:
454+
input_reshaped = input_tensor
455+
456+
# Initialize hidden states for all layers
457+
if hx is not None:
458+
# hx shape: (num_layers, batch_size, hidden_size)
459+
h_states = []
460+
for layer in range(num_layers):
461+
h_layer = self.block_builder.emit(
462+
relax.op.take(hx, relax.const(layer, "int64"), axis=0, mode="clip")
463+
)
464+
h_states.append(h_layer)
465+
else:
466+
h_states = []
467+
for layer in range(num_layers):
468+
h_layer = self.block_builder.emit(
469+
relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype)
470+
)
471+
h_states.append(h_layer)
472+
473+
outputs = []
474+
475+
for t in range(seq_len):
476+
# Get input at time t: (batch_size, input_size)
477+
x_t = self.block_builder.emit(
478+
relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip")
479+
)
480+
481+
# Process through each layer
482+
current_input = x_t
483+
new_h_states = []
484+
485+
for layer in range(num_layers):
486+
# Get layer parameters
487+
if params and len(params) >= 4 * num_layers:
488+
# Multi-layer case: params are organized as
489+
# [layer0_ih, layer0_hh, layer0_bias_ih, layer0_bias_hh, layer1_ih, ...]
490+
param_offset = layer * 4
491+
weight_ih = params[param_offset]
492+
weight_hh = params[param_offset + 1]
493+
bias_ih = params[param_offset + 2] if has_biases else None
494+
bias_hh = params[param_offset + 3] if has_biases else None
495+
elif params and len(params) >= 4:
496+
# Single layer case
497+
weight_ih = params[0]
498+
weight_hh = params[1]
499+
bias_ih = params[2] if has_biases else None
500+
bias_hh = params[3] if has_biases else None
501+
else:
502+
# Fallback: create zero weights
503+
weight_ih = self.block_builder.emit(
504+
relax.op.zeros(
505+
relax.ShapeExpr(
506+
(3 * hidden_size, input_size if layer == 0 else hidden_size)
507+
),
508+
dtype,
509+
)
510+
)
511+
weight_hh = self.block_builder.emit(
512+
relax.op.zeros(relax.ShapeExpr((3 * hidden_size, hidden_size)), dtype)
513+
)
514+
bias_ih = None
515+
bias_hh = None
516+
517+
# Get previous hidden state for this layer
518+
h_prev = h_states[layer]
519+
520+
# Split weights by gates: PyTorch GRU gate order: reset, update, new (r, z, n)
521+
gate_size = hidden_size
522+
523+
# Reset gate weights
524+
weight_ih_r = self.block_builder.emit(
525+
relax.op.strided_slice(weight_ih, axes=[0], begin=[0], end=[gate_size])
526+
)
527+
weight_hh_r = self.block_builder.emit(
528+
relax.op.strided_slice(weight_hh, axes=[0], begin=[0], end=[gate_size])
529+
)
530+
531+
# Update gate weights
532+
weight_ih_z = self.block_builder.emit(
533+
relax.op.strided_slice(
534+
weight_ih, axes=[0], begin=[gate_size], end=[2 * gate_size]
535+
)
536+
)
537+
weight_hh_z = self.block_builder.emit(
538+
relax.op.strided_slice(
539+
weight_hh, axes=[0], begin=[gate_size], end=[2 * gate_size]
540+
)
541+
)
542+
543+
# New gate weights
544+
weight_ih_n = self.block_builder.emit(
545+
relax.op.strided_slice(
546+
weight_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size]
547+
)
548+
)
549+
weight_hh_n = self.block_builder.emit(
550+
relax.op.strided_slice(
551+
weight_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size]
552+
)
553+
)
554+
555+
# Transpose weights for matmul
556+
weight_ih_r_t = self.block_builder.emit(
557+
relax.op.permute_dims(weight_ih_r, axes=[1, 0])
558+
)
559+
weight_hh_r_t = self.block_builder.emit(
560+
relax.op.permute_dims(weight_hh_r, axes=[1, 0])
561+
)
562+
weight_ih_z_t = self.block_builder.emit(
563+
relax.op.permute_dims(weight_ih_z, axes=[1, 0])
564+
)
565+
weight_hh_z_t = self.block_builder.emit(
566+
relax.op.permute_dims(weight_hh_z, axes=[1, 0])
567+
)
568+
weight_ih_n_t = self.block_builder.emit(
569+
relax.op.permute_dims(weight_ih_n, axes=[1, 0])
570+
)
571+
weight_hh_n_t = self.block_builder.emit(
572+
relax.op.permute_dims(weight_hh_n, axes=[1, 0])
573+
)
574+
575+
# Compute reset gate: r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr)
576+
r_ih = self.block_builder.emit(
577+
relax.op.linear_algebra.matmul(current_input, weight_ih_r_t)
578+
)
579+
r_hh = self.block_builder.emit(
580+
relax.op.linear_algebra.matmul(h_prev, weight_hh_r_t)
581+
)
582+
if bias_ih is not None and bias_hh is not None:
583+
bias_ih_r = self.block_builder.emit(
584+
relax.op.strided_slice(bias_ih, axes=[0], begin=[0], end=[gate_size])
585+
)
586+
bias_hh_r = self.block_builder.emit(
587+
relax.op.strided_slice(bias_hh, axes=[0], begin=[0], end=[gate_size])
588+
)
589+
r_t = self.block_builder.emit(
590+
relax.op.sigmoid(
591+
relax.op.add(
592+
relax.op.add(relax.op.add(r_ih, bias_ih_r), r_hh), bias_hh_r
593+
)
594+
)
595+
)
596+
else:
597+
r_t = self.block_builder.emit(relax.op.sigmoid(relax.op.add(r_ih, r_hh)))
598+
599+
# Compute update gate: z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz)
600+
z_ih = self.block_builder.emit(
601+
relax.op.linear_algebra.matmul(current_input, weight_ih_z_t)
602+
)
603+
z_hh = self.block_builder.emit(
604+
relax.op.linear_algebra.matmul(h_prev, weight_hh_z_t)
605+
)
606+
if bias_ih is not None and bias_hh is not None:
607+
bias_ih_z = self.block_builder.emit(
608+
relax.op.strided_slice(
609+
bias_ih, axes=[0], begin=[gate_size], end=[2 * gate_size]
610+
)
611+
)
612+
bias_hh_z = self.block_builder.emit(
613+
relax.op.strided_slice(
614+
bias_hh, axes=[0], begin=[gate_size], end=[2 * gate_size]
615+
)
616+
)
617+
z_t = self.block_builder.emit(
618+
relax.op.sigmoid(
619+
relax.op.add(
620+
relax.op.add(relax.op.add(z_ih, bias_ih_z), z_hh), bias_hh_z
621+
)
622+
)
623+
)
624+
else:
625+
z_t = self.block_builder.emit(relax.op.sigmoid(relax.op.add(z_ih, z_hh)))
626+
627+
# Compute new gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn))
628+
n_ih = self.block_builder.emit(
629+
relax.op.linear_algebra.matmul(current_input, weight_ih_n_t)
630+
)
631+
n_hh = self.block_builder.emit(
632+
relax.op.linear_algebra.matmul(h_prev, weight_hh_n_t)
633+
)
634+
if bias_ih is not None and bias_hh is not None:
635+
bias_ih_n = self.block_builder.emit(
636+
relax.op.strided_slice(
637+
bias_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size]
638+
)
639+
)
640+
bias_hh_n = self.block_builder.emit(
641+
relax.op.strided_slice(
642+
bias_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size]
643+
)
644+
)
645+
n_t = self.block_builder.emit(
646+
relax.op.tanh(
647+
relax.op.add(
648+
relax.op.add(n_ih, bias_ih_n),
649+
relax.op.multiply(r_t, relax.op.add(n_hh, bias_hh_n)),
650+
)
651+
)
652+
)
653+
else:
654+
n_t = self.block_builder.emit(
655+
relax.op.tanh(relax.op.add(n_ih, relax.op.multiply(r_t, n_hh)))
656+
)
657+
658+
# Update hidden state: h_t = (1 - z_t) * n_t + z_t * h_{t-1}
659+
one_minus_z = self.block_builder.emit(
660+
relax.op.subtract(relax.const(1.0, dtype), z_t)
661+
)
662+
h_t = self.block_builder.emit(
663+
relax.op.add(
664+
relax.op.multiply(one_minus_z, n_t), relax.op.multiply(z_t, h_prev)
665+
)
666+
)
667+
668+
new_h_states.append(h_t)
669+
670+
current_input = h_t
671+
672+
# Update hidden states for next time step
673+
h_states = new_h_states
674+
675+
# Store output (from the last layer)
676+
outputs.append(h_states[-1])
677+
678+
# Stack outputs: (seq_len, batch_size, hidden_size)
679+
output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
680+
681+
# Reshape back to batch_first if needed
682+
if batch_first:
683+
# (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, hidden_size)
684+
output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2]))
685+
686+
return output
687+
394688
########## Manipulation ##########
395689

396690
def _narrow(self, node: fx.Node) -> relax.Var:
@@ -652,6 +946,7 @@ def create_convert_map(
652946
"layer_norm.default": self._layer_norm,
653947
"linear.default": self._linear,
654948
"lstm.input": self._lstm,
949+
"gru.input": self._gru,
655950
"max_pool1d.default": self._max_pool1d,
656951
"max_pool2d.default": self._max_pool2d,
657952
"max_pool3d.default": self._max_pool3d,

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6050,5 +6050,76 @@ def main(
60506050
verify_model(TensorNoneModel(), example_args, {}, Expected)
60516051

60526052

6053+
def test_gru():
6054+
class BasicGRU(nn.Module):
6055+
def __init__(self):
6056+
super().__init__()
6057+
self.gru = nn.GRU(
6058+
input_size=4,
6059+
hidden_size=8,
6060+
num_layers=1,
6061+
batch_first=True,
6062+
bidirectional=False,
6063+
)
6064+
6065+
def forward(self, x):
6066+
y, _ = self.gru(x)
6067+
return y
6068+
6069+
torch.manual_seed(42)
6070+
x = torch.randn(2, 3, 4, dtype=torch.float32)
6071+
model = BasicGRU()
6072+
with torch.no_grad():
6073+
pytorch_output = model(x)
6074+
exported_program = export(model, args=(x,))
6075+
mod = from_exported_program(exported_program)
6076+
target = tvm.target.Target("llvm")
6077+
ex = relax.build(mod, target)
6078+
vm = relax.VirtualMachine(ex, tvm.cpu())
6079+
x_tvm = tvm.runtime.tensor(x.numpy())
6080+
tvm_output = vm["main"](x_tvm)
6081+
if hasattr(tvm_output, "numpy"):
6082+
tvm_output_np = tvm_output.numpy()
6083+
else:
6084+
tvm_output_np = tvm_output[0].numpy()
6085+
assert (
6086+
pytorch_output.shape == tvm_output_np.shape
6087+
), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}"
6088+
np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5)
6089+
6090+
class SeqFirstGRU(nn.Module):
6091+
def __init__(self):
6092+
super().__init__()
6093+
self.gru = nn.GRU(
6094+
input_size=3,
6095+
hidden_size=6,
6096+
num_layers=1,
6097+
batch_first=False,
6098+
bidirectional=False,
6099+
)
6100+
6101+
def forward(self, x):
6102+
y, _ = self.gru(x)
6103+
return y
6104+
6105+
torch.manual_seed(43)
6106+
x2 = torch.randn(4, 2, 3, dtype=torch.float32)
6107+
model2 = SeqFirstGRU()
6108+
with torch.no_grad():
6109+
pytorch_output2 = model2(x2)
6110+
exported_program2 = export(model2, args=(x2,))
6111+
mod2 = from_exported_program(exported_program2)
6112+
ex2 = relax.build(mod2, target)
6113+
vm2 = relax.VirtualMachine(ex2, tvm.cpu())
6114+
x2_tvm = tvm.runtime.tensor(x2.numpy())
6115+
tvm_output2 = vm2["main"](x2_tvm)
6116+
if hasattr(tvm_output2, "numpy"):
6117+
tvm_output2_np = tvm_output2.numpy()
6118+
else:
6119+
tvm_output2_np = tvm_output2[0].numpy()
6120+
assert pytorch_output2.shape == tvm_output2_np.shape
6121+
np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5)
6122+
6123+
60536124
if __name__ == "__main__":
60546125
tvm.testing.main()

0 commit comments

Comments
 (0)