Skip to content

Commit d6e0ce1

Browse files
author
sleygin
committed
bug fix on generic types
1 parent e764405 commit d6e0ce1

File tree

2 files changed

+264
-0
lines changed

2 files changed

+264
-0
lines changed

printer/printer.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ func (p *Printer) PrintType(node ast.Node) (string, error) {
6767
return p.printStruct(t)
6868
case *ast.Ident:
6969
return p.printIdent(t)
70+
case *ast.IndexExpr:
71+
return p.printGeneric(t)
72+
case *ast.IndexListExpr:
73+
return p.printGenericList(t)
7074
}
7175

7276
err := printer.Fprint(p.buf, p.fs, node)
@@ -151,6 +155,43 @@ func (p *Printer) printIdent(i *ast.Ident) (string, error) {
151155
return p.buf.String(), err
152156
}
153157

158+
func (p *Printer) printGeneric(pt *ast.IndexExpr) (string, error) {
159+
t, err := p.PrintType(pt.X)
160+
if err != nil {
161+
return "", err
162+
}
163+
164+
generic, err := p.PrintType(pt.Index)
165+
if err != nil {
166+
return "", err
167+
}
168+
169+
return t + "[" + generic + "]", nil
170+
}
171+
172+
func (p *Printer) printGenericList(pt *ast.IndexListExpr) (string, error) {
173+
t, err := p.PrintType(pt.X)
174+
if err != nil {
175+
return "", err
176+
}
177+
178+
baseStr := t + "["
179+
for i, expr := range pt.Indices {
180+
generic, err := p.PrintType(expr)
181+
if err != nil {
182+
return "", err
183+
}
184+
185+
if i == len(pt.Indices)-1 {
186+
baseStr = baseStr + generic + "]"
187+
} else {
188+
baseStr = baseStr + generic + ", "
189+
}
190+
}
191+
192+
return baseStr, nil
193+
}
194+
154195
func (p *Printer) printPointer(pt *ast.StarExpr) (string, error) {
155196
pointerTo, err := p.PrintType(pt.X)
156197
if err != nil {

printer/printer_test.go

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,199 @@ func TestPrinter_printIdent(t *testing.T) {
776776
}
777777
}
778778

779+
func TestPrinter_printGeneric(t *testing.T) {
780+
tests := []struct {
781+
name string
782+
init func(t minimock.Tester) *Printer
783+
inspect func(r *Printer, t *testing.T)
784+
785+
indexExpr *ast.IndexExpr
786+
787+
want1 string
788+
wantErr bool
789+
inspectErr func(err error, t *testing.T)
790+
}{
791+
{
792+
name: "success",
793+
indexExpr: &ast.IndexExpr{
794+
X: &ast.Ident{
795+
Name: "Bar",
796+
},
797+
Index: &ast.Ident{
798+
Name: "Baz",
799+
},
800+
},
801+
init: func(t minimock.Tester) *Printer {
802+
return &Printer{
803+
typesPrefix: "prefix",
804+
fs: token.NewFileSet(),
805+
buf: bytes.NewBuffer([]byte{}),
806+
types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "Bar"}}, {Name: &ast.Ident{Name: "Baz"}}},
807+
}
808+
},
809+
want1: "prefix.Bar[prefix.Baz]",
810+
wantErr: false,
811+
},
812+
{
813+
name: "success, generic from other package",
814+
indexExpr: &ast.IndexExpr{
815+
X: &ast.Ident{
816+
Name: "Bar",
817+
},
818+
Index: &ast.SelectorExpr{
819+
X: &ast.Ident{
820+
Name: "otherpkg",
821+
},
822+
Sel: &ast.Ident{
823+
Name: "Baz",
824+
},
825+
},
826+
},
827+
init: func(t minimock.Tester) *Printer {
828+
return &Printer{
829+
typesPrefix: "prefix",
830+
fs: token.NewFileSet(),
831+
buf: bytes.NewBuffer([]byte{}),
832+
types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "Bar"}}},
833+
}
834+
},
835+
want1: "prefix.Bar[otherpkg.Baz]",
836+
wantErr: false,
837+
},
838+
}
839+
840+
for _, tt := range tests {
841+
t.Run(tt.name, func(t *testing.T) {
842+
mc := minimock.NewController(t)
843+
defer mc.Wait(time.Second)
844+
845+
receiver := tt.init(mc)
846+
847+
got1, err := receiver.printGeneric(tt.indexExpr)
848+
849+
if tt.inspect != nil {
850+
tt.inspect(receiver, t)
851+
}
852+
853+
assert.Equal(t, tt.want1, got1, "Printer.printGeneric returned unexpected result")
854+
855+
if tt.wantErr {
856+
if assert.Error(t, err) && tt.inspectErr != nil {
857+
tt.inspectErr(err, t)
858+
}
859+
} else {
860+
assert.NoError(t, err)
861+
}
862+
863+
})
864+
}
865+
}
866+
867+
func TestPrinter_printGenericList(t *testing.T) {
868+
tests := []struct {
869+
name string
870+
init func(t minimock.Tester) *Printer
871+
inspect func(r *Printer, t *testing.T)
872+
873+
indexListExpr *ast.IndexListExpr
874+
875+
want1 string
876+
wantErr bool
877+
inspectErr func(err error, t *testing.T)
878+
}{
879+
{
880+
name: "success",
881+
indexListExpr: &ast.IndexListExpr{
882+
X: &ast.Ident{
883+
Name: "Bar",
884+
},
885+
Indices: []ast.Expr{
886+
&ast.Ident{
887+
Name: "Baz",
888+
},
889+
&ast.Ident{
890+
Name: "Bak",
891+
},
892+
},
893+
},
894+
init: func(t minimock.Tester) *Printer {
895+
return &Printer{
896+
typesPrefix: "prefix",
897+
fs: token.NewFileSet(),
898+
buf: bytes.NewBuffer([]byte{}),
899+
types: []*ast.TypeSpec{
900+
{Name: &ast.Ident{Name: "Bar"}},
901+
{Name: &ast.Ident{Name: "Baz"}},
902+
{Name: &ast.Ident{Name: "Bak"}},
903+
},
904+
}
905+
},
906+
want1: "prefix.Bar[prefix.Baz, prefix.Bak]",
907+
wantErr: false,
908+
},
909+
{
910+
name: "success, generic from other package",
911+
indexListExpr: &ast.IndexListExpr{
912+
X: &ast.Ident{
913+
Name: "Bar",
914+
},
915+
Indices: []ast.Expr{
916+
&ast.Ident{
917+
Name: "Baz",
918+
},
919+
&ast.SelectorExpr{
920+
X: &ast.Ident{
921+
Name: "otherpkg",
922+
},
923+
Sel: &ast.Ident{
924+
Name: "Bak",
925+
},
926+
},
927+
},
928+
},
929+
init: func(t minimock.Tester) *Printer {
930+
return &Printer{
931+
typesPrefix: "prefix",
932+
fs: token.NewFileSet(),
933+
buf: bytes.NewBuffer([]byte{}),
934+
types: []*ast.TypeSpec{
935+
{Name: &ast.Ident{Name: "Bar"}},
936+
{Name: &ast.Ident{Name: "Baz"}},
937+
},
938+
}
939+
},
940+
want1: "prefix.Bar[prefix.Baz, otherpkg.Bak]",
941+
wantErr: false,
942+
},
943+
}
944+
945+
for _, tt := range tests {
946+
t.Run(tt.name, func(t *testing.T) {
947+
mc := minimock.NewController(t)
948+
defer mc.Wait(time.Second)
949+
950+
receiver := tt.init(mc)
951+
952+
got1, err := receiver.printGenericList(tt.indexListExpr)
953+
954+
if tt.inspect != nil {
955+
tt.inspect(receiver, t)
956+
}
957+
958+
assert.Equal(t, tt.want1, got1, "Printer.printGenericList returned unexpected result")
959+
960+
if tt.wantErr {
961+
if assert.Error(t, err) && tt.inspectErr != nil {
962+
tt.inspectErr(err, t)
963+
}
964+
} else {
965+
assert.NoError(t, err)
966+
}
967+
968+
})
969+
}
970+
}
971+
779972
func TestPrinter_PrintType(t *testing.T) {
780973
tests := []struct {
781974
name string
@@ -876,6 +1069,36 @@ func TestPrinter_PrintType(t *testing.T) {
8761069
},
8771070
want1: "package.Identifier",
8781071
},
1072+
{
1073+
name: "generic type",
1074+
node: &ast.IndexExpr{X: &ast.Ident{Name: "Bar"}, Index: &ast.Ident{Name: "string"}},
1075+
init: func(t minimock.Tester) *Printer {
1076+
return &Printer{
1077+
fs: token.NewFileSet(),
1078+
buf: bytes.NewBuffer([]byte{}),
1079+
types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "Bar"}}},
1080+
}
1081+
},
1082+
want1: "Bar[string]",
1083+
},
1084+
{
1085+
name: "generic list type",
1086+
node: &ast.IndexListExpr{
1087+
X: &ast.Ident{Name: "Bar"},
1088+
Indices: []ast.Expr{
1089+
&ast.Ident{Name: "string"},
1090+
&ast.Ident{Name: "int"},
1091+
},
1092+
},
1093+
init: func(t minimock.Tester) *Printer {
1094+
return &Printer{
1095+
fs: token.NewFileSet(),
1096+
buf: bytes.NewBuffer([]byte{}),
1097+
types: []*ast.TypeSpec{{Name: &ast.Ident{Name: "Bar"}}},
1098+
}
1099+
},
1100+
want1: "Bar[string, int]",
1101+
},
8791102
}
8801103

8811104
for _, tt := range tests {

0 commit comments

Comments
 (0)