@@ -478,6 +478,13 @@ def test_qnn_backend_full_like(self):
478
478
sample_input = (torch .randn (1 , 2 , 3 , 4 ),)
479
479
self .lower_module_and_test_output (module , sample_input )
480
480
481
+ def test_qnn_backend_gather (self ):
482
+ module = Gather () # noqa: F405
483
+ shape = (2 , 2 , 3 , 4 )
484
+ sample_input = (torch .randn (shape ), torch .randn (shape ))
485
+ module = self .get_qdq_module (module , sample_input )
486
+ self .lower_module_and_test_output (module , sample_input )
487
+
481
488
def test_qnn_backend_gelu (self ):
482
489
module = Gelu () # noqa: F405
483
490
sample_input = (torch .randn (2 , 5 , 1 , 3 ),)
@@ -821,12 +828,17 @@ def test_qnn_backend_select_copy(self):
821
828
self .lower_module_and_test_output (module , sample_input )
822
829
823
830
def test_qnn_backend_slice_copy (self ):
824
- modules = [SliceCopy (), SliceCopyWithStep ()] # noqa: F405
825
- sample_input = (
826
- torch .randn ([1 , 512 ]),
827
- torch .randn ([1 , 8 ]),
828
- )
829
- for module in modules :
831
+ modules = [
832
+ SliceCopyDefaultParameter (),
833
+ SliceCopy (),
834
+ SliceCopyWithStep (),
835
+ ] # noqa: F405
836
+ sample_inputs = [
837
+ (torch .randn ([2 , 1 , 320 , 512 ]),),
838
+ (torch .randn ([1 , 512 ]), torch .randn ([1 , 8 ])),
839
+ (torch .randn ([1 , 512 ]), torch .randn ([1 , 8 ])),
840
+ ]
841
+ for module , sample_input in zip (modules , sample_inputs ):
830
842
self .lower_module_and_test_output (module , sample_input )
831
843
832
844
def test_qnn_backend_stack (self ):
@@ -1593,6 +1605,13 @@ def test_qnn_backend_full_like(self):
1593
1605
module = self .get_qdq_module (module , sample_input )
1594
1606
self .lower_module_and_test_output (module , sample_input )
1595
1607
1608
+ def test_qnn_backend_gather (self ):
1609
+ module = Gather () # noqa: F405
1610
+ shape = (2 , 2 , 3 , 4 )
1611
+ sample_input = (torch .randn (shape ), torch .randn (shape ))
1612
+ module = self .get_qdq_module (module , sample_input )
1613
+ self .lower_module_and_test_output (module , sample_input )
1614
+
1596
1615
def test_qnn_backend_gelu (self ):
1597
1616
module = Gelu () # noqa: F405
1598
1617
sample_input = (torch .randn (2 , 5 , 1 , 3 ),)
@@ -1991,12 +2010,17 @@ def test_qnn_backend_sin(self):
1991
2010
self .lower_module_and_test_output (module , sample_input )
1992
2011
1993
2012
def test_qnn_backend_slice_copy (self ):
1994
- modules = [SliceCopy (), SliceCopyWithStep ()] # noqa: F405
1995
- sample_input = (
1996
- torch .randn ([1 , 512 ]),
1997
- torch .randn ([1 , 8 ]),
1998
- )
1999
- for module in modules :
2013
+ modules = [
2014
+ SliceCopyDefaultParameter (),
2015
+ SliceCopy (),
2016
+ SliceCopyWithStep (),
2017
+ ] # noqa: F405
2018
+ sample_inputs = [
2019
+ (torch .randn ([2 , 1 , 320 , 512 ]),),
2020
+ (torch .randn ([1 , 512 ]), torch .randn ([1 , 8 ])),
2021
+ (torch .randn ([1 , 512 ]), torch .randn ([1 , 8 ])),
2022
+ ]
2023
+ for module , sample_input in zip (modules , sample_inputs ):
2000
2024
module = self .get_qdq_module (module , sample_input )
2001
2025
self .lower_module_and_test_output (module , sample_input )
2002
2026
0 commit comments