@@ -391,6 +391,300 @@ def _lstm(self, node: fx.Node) -> relax.Var:
391391 output = self .block_builder .emit (relax .op .permute_dims (output , axes = [1 , 0 , 2 ]))
392392 return output
393393
394+ def _gru (self , node : fx .Node ) -> relax .Var :
395+ args = self .retrieve_args (node )
396+ input_tensor = args [0 ]
397+ hx = args [1 ] if len (args ) > 1 else None
398+ params = args [2 ] if len (args ) > 2 else None
399+ has_biases = args [3 ] if len (args ) > 3 else True
400+ num_layers = args [4 ] if len (args ) > 4 else 1
401+ _dropout = args [5 ] if len (args ) > 5 else 0.0 # Not used in inference
402+ _train = args [6 ] if len (args ) > 6 else False # Not used in inference
403+ bidirectional = args [7 ] if len (args ) > 7 else False
404+ batch_first = args [8 ] if len (args ) > 8 else False
405+
406+ if bidirectional :
407+ raise NotImplementedError ("Bidirectional GRU is not yet supported" )
408+
409+ input_shape = self .shape_of (input_tensor )
410+ if batch_first :
411+ batch_size , seq_len , input_size = input_shape
412+ else :
413+ seq_len , batch_size , input_size = input_shape
414+
415+ if isinstance (seq_len , tvm .tir .IntImm ):
416+ seq_len = seq_len .value
417+ if isinstance (batch_size , tvm .tir .IntImm ):
418+ batch_size = batch_size .value
419+ if isinstance (input_size , tvm .tir .IntImm ):
420+ input_size = input_size .value
421+
422+ if params and len (params ) >= 2 :
423+ # For multi-layer, we need to extract the first layer's weights
424+ # to determine hidden size
425+ if num_layers > 1 :
426+ # Multi-layer: params[0] is first layer's weight_ih
427+ weight_ih = params [0 ]
428+ else :
429+ # Single layer: params[0] is weight_ih
430+ weight_ih = params [0 ]
431+ # Extract hidden size from weight dimensions
432+ # weight_ih has shape (3 * hidden_size, input_size)
433+ weight_ih_shape = self .shape_of (weight_ih )
434+ hidden_size = weight_ih_shape [0 ] // 3 # 3 gates: reset, update, new
435+ else :
436+ # Fallback to a default hidden size
437+ hidden_size = 16
438+
439+ # Implement actual GRU computation using Relax operations
440+ # GRU equations:
441+ # r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr)
442+ # z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz)
443+ # n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn))
444+ # h_t = (1 - z_t) * n_t + z_t * h_{t-1}
445+ dtype = input_tensor .struct_info .dtype
446+
447+ # Reshape input for processing
448+ if batch_first :
449+ # Input: (batch, seq_len, input_size) -> (seq_len, batch, input_size)
450+ input_reshaped = self .block_builder .emit (
451+ relax .op .permute_dims (input_tensor , axes = [1 , 0 , 2 ])
452+ )
453+ else :
454+ input_reshaped = input_tensor
455+
456+ # Initialize hidden states for all layers
457+ if hx is not None :
458+ # hx shape: (num_layers, batch_size, hidden_size)
459+ h_states = []
460+ for layer in range (num_layers ):
461+ h_layer = self .block_builder .emit (
462+ relax .op .take (hx , relax .const (layer , "int64" ), axis = 0 , mode = "clip" )
463+ )
464+ h_states .append (h_layer )
465+ else :
466+ h_states = []
467+ for layer in range (num_layers ):
468+ h_layer = self .block_builder .emit (
469+ relax .op .zeros (relax .ShapeExpr ((batch_size , hidden_size )), dtype )
470+ )
471+ h_states .append (h_layer )
472+
473+ outputs = []
474+
475+ for t in range (seq_len ):
476+ # Get input at time t: (batch_size, input_size)
477+ x_t = self .block_builder .emit (
478+ relax .op .take (input_reshaped , relax .const (t , "int64" ), axis = 0 , mode = "clip" )
479+ )
480+
481+ # Process through each layer
482+ current_input = x_t
483+ new_h_states = []
484+
485+ for layer in range (num_layers ):
486+ # Get layer parameters
487+ if params and len (params ) >= 4 * num_layers :
488+ # Multi-layer case: params are organized as
489+ # [layer0_ih, layer0_hh, layer0_bias_ih, layer0_bias_hh, layer1_ih, ...]
490+ param_offset = layer * 4
491+ weight_ih = params [param_offset ]
492+ weight_hh = params [param_offset + 1 ]
493+ bias_ih = params [param_offset + 2 ] if has_biases else None
494+ bias_hh = params [param_offset + 3 ] if has_biases else None
495+ elif params and len (params ) >= 4 :
496+ # Single layer case
497+ weight_ih = params [0 ]
498+ weight_hh = params [1 ]
499+ bias_ih = params [2 ] if has_biases else None
500+ bias_hh = params [3 ] if has_biases else None
501+ else :
502+ # Fallback: create zero weights
503+ weight_ih = self .block_builder .emit (
504+ relax .op .zeros (
505+ relax .ShapeExpr (
506+ (3 * hidden_size , input_size if layer == 0 else hidden_size )
507+ ),
508+ dtype ,
509+ )
510+ )
511+ weight_hh = self .block_builder .emit (
512+ relax .op .zeros (relax .ShapeExpr ((3 * hidden_size , hidden_size )), dtype )
513+ )
514+ bias_ih = None
515+ bias_hh = None
516+
517+ # Get previous hidden state for this layer
518+ h_prev = h_states [layer ]
519+
520+ # Split weights by gates: PyTorch GRU gate order: reset, update, new (r, z, n)
521+ gate_size = hidden_size
522+
523+ # Reset gate weights
524+ weight_ih_r = self .block_builder .emit (
525+ relax .op .strided_slice (weight_ih , axes = [0 ], begin = [0 ], end = [gate_size ])
526+ )
527+ weight_hh_r = self .block_builder .emit (
528+ relax .op .strided_slice (weight_hh , axes = [0 ], begin = [0 ], end = [gate_size ])
529+ )
530+
531+ # Update gate weights
532+ weight_ih_z = self .block_builder .emit (
533+ relax .op .strided_slice (
534+ weight_ih , axes = [0 ], begin = [gate_size ], end = [2 * gate_size ]
535+ )
536+ )
537+ weight_hh_z = self .block_builder .emit (
538+ relax .op .strided_slice (
539+ weight_hh , axes = [0 ], begin = [gate_size ], end = [2 * gate_size ]
540+ )
541+ )
542+
543+ # New gate weights
544+ weight_ih_n = self .block_builder .emit (
545+ relax .op .strided_slice (
546+ weight_ih , axes = [0 ], begin = [2 * gate_size ], end = [3 * gate_size ]
547+ )
548+ )
549+ weight_hh_n = self .block_builder .emit (
550+ relax .op .strided_slice (
551+ weight_hh , axes = [0 ], begin = [2 * gate_size ], end = [3 * gate_size ]
552+ )
553+ )
554+
555+ # Transpose weights for matmul
556+ weight_ih_r_t = self .block_builder .emit (
557+ relax .op .permute_dims (weight_ih_r , axes = [1 , 0 ])
558+ )
559+ weight_hh_r_t = self .block_builder .emit (
560+ relax .op .permute_dims (weight_hh_r , axes = [1 , 0 ])
561+ )
562+ weight_ih_z_t = self .block_builder .emit (
563+ relax .op .permute_dims (weight_ih_z , axes = [1 , 0 ])
564+ )
565+ weight_hh_z_t = self .block_builder .emit (
566+ relax .op .permute_dims (weight_hh_z , axes = [1 , 0 ])
567+ )
568+ weight_ih_n_t = self .block_builder .emit (
569+ relax .op .permute_dims (weight_ih_n , axes = [1 , 0 ])
570+ )
571+ weight_hh_n_t = self .block_builder .emit (
572+ relax .op .permute_dims (weight_hh_n , axes = [1 , 0 ])
573+ )
574+
575+ # Compute reset gate: r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr)
576+ r_ih = self .block_builder .emit (
577+ relax .op .linear_algebra .matmul (current_input , weight_ih_r_t )
578+ )
579+ r_hh = self .block_builder .emit (
580+ relax .op .linear_algebra .matmul (h_prev , weight_hh_r_t )
581+ )
582+ if bias_ih is not None and bias_hh is not None :
583+ bias_ih_r = self .block_builder .emit (
584+ relax .op .strided_slice (bias_ih , axes = [0 ], begin = [0 ], end = [gate_size ])
585+ )
586+ bias_hh_r = self .block_builder .emit (
587+ relax .op .strided_slice (bias_hh , axes = [0 ], begin = [0 ], end = [gate_size ])
588+ )
589+ r_t = self .block_builder .emit (
590+ relax .op .sigmoid (
591+ relax .op .add (
592+ relax .op .add (relax .op .add (r_ih , bias_ih_r ), r_hh ), bias_hh_r
593+ )
594+ )
595+ )
596+ else :
597+ r_t = self .block_builder .emit (relax .op .sigmoid (relax .op .add (r_ih , r_hh )))
598+
599+ # Compute update gate: z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz)
600+ z_ih = self .block_builder .emit (
601+ relax .op .linear_algebra .matmul (current_input , weight_ih_z_t )
602+ )
603+ z_hh = self .block_builder .emit (
604+ relax .op .linear_algebra .matmul (h_prev , weight_hh_z_t )
605+ )
606+ if bias_ih is not None and bias_hh is not None :
607+ bias_ih_z = self .block_builder .emit (
608+ relax .op .strided_slice (
609+ bias_ih , axes = [0 ], begin = [gate_size ], end = [2 * gate_size ]
610+ )
611+ )
612+ bias_hh_z = self .block_builder .emit (
613+ relax .op .strided_slice (
614+ bias_hh , axes = [0 ], begin = [gate_size ], end = [2 * gate_size ]
615+ )
616+ )
617+ z_t = self .block_builder .emit (
618+ relax .op .sigmoid (
619+ relax .op .add (
620+ relax .op .add (relax .op .add (z_ih , bias_ih_z ), z_hh ), bias_hh_z
621+ )
622+ )
623+ )
624+ else :
625+ z_t = self .block_builder .emit (relax .op .sigmoid (relax .op .add (z_ih , z_hh )))
626+
627+ # Compute new gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn))
628+ n_ih = self .block_builder .emit (
629+ relax .op .linear_algebra .matmul (current_input , weight_ih_n_t )
630+ )
631+ n_hh = self .block_builder .emit (
632+ relax .op .linear_algebra .matmul (h_prev , weight_hh_n_t )
633+ )
634+ if bias_ih is not None and bias_hh is not None :
635+ bias_ih_n = self .block_builder .emit (
636+ relax .op .strided_slice (
637+ bias_ih , axes = [0 ], begin = [2 * gate_size ], end = [3 * gate_size ]
638+ )
639+ )
640+ bias_hh_n = self .block_builder .emit (
641+ relax .op .strided_slice (
642+ bias_hh , axes = [0 ], begin = [2 * gate_size ], end = [3 * gate_size ]
643+ )
644+ )
645+ n_t = self .block_builder .emit (
646+ relax .op .tanh (
647+ relax .op .add (
648+ relax .op .add (n_ih , bias_ih_n ),
649+ relax .op .multiply (r_t , relax .op .add (n_hh , bias_hh_n )),
650+ )
651+ )
652+ )
653+ else :
654+ n_t = self .block_builder .emit (
655+ relax .op .tanh (relax .op .add (n_ih , relax .op .multiply (r_t , n_hh )))
656+ )
657+
658+ # Update hidden state: h_t = (1 - z_t) * n_t + z_t * h_{t-1}
659+ one_minus_z = self .block_builder .emit (
660+ relax .op .subtract (relax .const (1.0 , dtype ), z_t )
661+ )
662+ h_t = self .block_builder .emit (
663+ relax .op .add (
664+ relax .op .multiply (one_minus_z , n_t ), relax .op .multiply (z_t , h_prev )
665+ )
666+ )
667+
668+ new_h_states .append (h_t )
669+
670+ current_input = h_t
671+
672+ # Update hidden states for next time step
673+ h_states = new_h_states
674+
675+ # Store output (from the last layer)
676+ outputs .append (h_states [- 1 ])
677+
678+ # Stack outputs: (seq_len, batch_size, hidden_size)
679+ output = self .block_builder .emit (relax .op .stack (outputs , axis = 0 ))
680+
681+ # Reshape back to batch_first if needed
682+ if batch_first :
683+ # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, hidden_size)
684+ output = self .block_builder .emit (relax .op .permute_dims (output , axes = [1 , 0 , 2 ]))
685+
686+ return output
687+
394688 ########## Manipulation ##########
395689
396690 def _narrow (self , node : fx .Node ) -> relax .Var :
@@ -652,6 +946,7 @@ def create_convert_map(
652946 "layer_norm.default" : self ._layer_norm ,
653947 "linear.default" : self ._linear ,
654948 "lstm.input" : self ._lstm ,
949+ "gru.input" : self ._gru ,
655950 "max_pool1d.default" : self ._max_pool1d ,
656951 "max_pool2d.default" : self ._max_pool2d ,
657952 "max_pool3d.default" : self ._max_pool3d ,
0 commit comments