@@ -800,6 +800,40 @@ def aten_ops_tile(
800800 )
801801
802802
803+ def zero_output_validator (node : Node ) -> bool :
804+ if 0 in node .args [1 ]:
805+ _LOGGER .debug (
806+ f"We do not support output tensor { node .args [1 ]} tensors with zero-sized dimensions for this operation."
807+ )
808+ return False
809+ else :
810+ return True
811+
812+
813+ @dynamo_tensorrt_converter (
814+ torch .ops .aten .as_strided .default ,
815+ capability_validator = zero_output_validator ,
816+ )
817+ @dynamo_tensorrt_converter (torch .ops .aten .as_strided .default )
818+ def aten_ops_as_strided (
819+ ctx : ConversionContext ,
820+ target : Target ,
821+ args : Tuple [Argument , ...],
822+ kwargs : Dict [str , Argument ],
823+ name : str ,
824+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
825+ return impl .slice .as_strided (
826+ ctx ,
827+ target ,
828+ source_ir = SourceIR .ATEN ,
829+ name = name ,
830+ input = args [0 ],
831+ size = args [1 ],
832+ stride = args [2 ],
833+ storage_offset = args_bounds_check (args , 3 , None ),
834+ )
835+
836+
803837@dynamo_tensorrt_converter (torch .ops .aten .permute .default )
804838@enforce_tensor_types (
805839 {
@@ -2185,7 +2219,6 @@ def aten_ops_linear(
21852219 bias = args_bounds_check (args , 2 , None ),
21862220 )
21872221
2188-
21892222@dynamo_tensorrt_converter (torch .ops .aten ._cdist_forward .default )
21902223def aten_ops_cdist_forward (
21912224 ctx : ConversionContext ,
@@ -2206,39 +2239,6 @@ def aten_ops_cdist_forward(
22062239 )
22072240
22082241
2209- def zero_output_validator (node : Node ) -> bool :
2210- if 0 in node .args [1 ]:
2211- _LOGGER .debug (
2212- f"We do not support output tensor { node .args [1 ]} tensors with zero-sized dimensions for this operation."
2213- )
2214- return False
2215- else :
2216- return True
2217-
2218-
2219- @dynamo_tensorrt_converter (
2220- torch .ops .aten .as_strided .default ,
2221- capability_validator = zero_output_validator ,
2222- )
2223- def aten_ops_as_strided (
2224- ctx : ConversionContext ,
2225- target : Target ,
2226- args : Tuple [Argument , ...],
2227- kwargs : Dict [str , Argument ],
2228- name : str ,
2229- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
2230- return impl .slice .as_strided (
2231- ctx ,
2232- target ,
2233- source_ir = SourceIR .ATEN ,
2234- name = name ,
2235- input = args [0 ],
2236- size = args [1 ],
2237- stride = args [2 ],
2238- storage_offset = args_bounds_check (args , 3 , None ),
2239- )
2240-
2241-
22422242def avg_pool_param_validator (pool_node : Node ) -> bool :
22432243 ceil_mode = args_bounds_check (pool_node .args , 4 , False )
22442244 divisor_override = args_bounds_check (pool_node .args , 6 )
0 commit comments