package ssa2ast import ( "go/ast" "go/importer" "go/printer" "go/types" "os" "os/exec" "path/filepath" "testing" "github.com/go-quicktest/qt" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" ) const sigSrc = `package main import "unsafe" type genericStruct[T interface{}] struct{} type plainStruct struct { Dummy struct{} } func (s *plainStruct) plainStructFunc() { } func (*plainStruct) plainStructAnonFunc() { } func (s *genericStruct[T]) genericStructFunc() { } func (s *genericStruct[T]) genericStructAnonFunc() (test T) { return } func plainFuncSignature(a int, b string, c struct{}, d struct{ string }, e interface{ Dummy() string }, pointer unsafe.Pointer) (i int, er error) { return } func genericFuncSignature[T interface{ interface{} | ~int64 | bool }, X interface{ comparable }](a T, b X, c genericStruct[struct{ a T }], d genericStruct[T]) (res T) { return } ` func TestConvertSignature(t *testing.T) { conv := newFuncConverter(DefaultConfig()) f, _, info, _ := mustParseAndTypeCheckFile(sigSrc) for _, funcName := range []string{"plainStructFunc", "plainStructAnonFunc", "genericStructFunc", "plainFuncSignature", "genericFuncSignature"} { funcDecl := findFunc(f, funcName) funcDecl.Body = nil funcObj := info.Defs[funcDecl.Name].(*types.Func) funcDeclConverted, err := conv.convertSignatureToFuncDecl(funcObj.Name(), funcObj.Signature()) qt.Assert(t, qt.IsNil(err)) qt.Assert(t, qt.CmpEquals(funcDeclConverted, funcDecl, astCmpOpt)) } } const mainSrc = `package main import ( "encoding/binary" "fmt" "io" "sort" "strconv" "sync" "time" "unsafe" ) func main() { methodOps() slicesOps() iterAndMapsOps() chanOps() flowOps() typeOps() genericFunc() } func makeSprintf(tag string) func(vals ...interface{}) { i := 0 return func(vals ...interface{}) { fmt.Printf("%s(%d): %v\n", tag, i, vals) i++ } } func return42() int { return 42 } type arrayOfInts []int type structOfArraysOfInts struct { a arrayOfInts b arrayOfInts } func slicesOps() { sprintf := makeSprintf("slicesOps") slice := [...]int{1, 2} sprintf(slice[0:1:2]) // *ssa.IndexAddr sprintf(slice) slice[0] += 1 sprintf(slice) sprintf(slice[:1]) sprintf(slice[slice[0]:]) sprintf(slice[0:2]) sprintf((*[2]int)(slice[:])[return42()%2]) // *ssa.SliceToArrayPointer sprintf("test"[return42()%3]) // *ssa.Index structOfArrays := structOfArraysOfInts{a: slice[1:], b: slice[:1]} sprintf(structOfArrays.a[:1]) sprintf(structOfArrays.b[:1]) slice2 := make([]string, return42(), return42()*2) slice2[return42()-1] = "test" sprintf(slice2) return } func iterAndMapsOps() { sprintf := makeSprintf("iterAndMapsOps") // *ssa.MakeMap + *ssa.MapUpdate mmap := map[string]time.Month{ "April": time.April, "December": time.December, "January": time.January, } var vals []string for k := range mmap { vals = append(vals, k) } for _, v := range mmap { vals = append(vals, v.String()) } sort.Strings(vals) // Required. Order of map iteration not guaranteed sprintf(vals) if v, ok := mmap["?"]; ok { panic("unreachable: " + v.String()) } for idx, s := range "hello world" { sprintf(idx, s) } sprintf(mmap["April"].String()) return } type interfaceCalls interface { Return1() string } type structCalls struct { } func (r structCalls) Return1() string { return "Return1" } func (r *structCalls) Return2() string { return "Return2" } func multiOutputRes() (int, string) { return 42, "24" } func returnInterfaceCalls() interfaceCalls { return structCalls{} } func methodOps() { sprintf := makeSprintf("methodOps") defer func() { sprintf("from defer") }() defer sprintf("from defer 2") var wg sync.WaitGroup wg.Add(1) go func() { sprintf("from go") wg.Done() }() wg.Wait() i, s := multiOutputRes() sprintf(strconv.Itoa(i)) var strct structCalls strct.Return1() strct.Return2() intrfs := returnInterfaceCalls() intrfs.Return1() sprintf(strconv.Itoa(len(s))) strconv.Itoa(binary.Size(4)) sprintf(binary.LittleEndian.AppendUint32(nil, 42)) if len(s) == 0 { panic("unreachable") } sprintf(*unsafe.StringData(s)) thunkMethod1 := structCalls.Return1 sprintf(thunkMethod1(strct)) thunkMethod2 := (*structCalls).Return2 sprintf(thunkMethod2(&strct)) closureVar := "c " + s anonFnc := func(n func(structCalls) string) string { return n(structCalls{}) + "anon" + closureVar } sprintf(anonFnc(structCalls.Return1)) } func chanOps() { sprintf := makeSprintf("chanOps") a := make(chan string) b := make(chan string) c := make(chan string) d := make(chan string) select { case r1, ok := <-a: sprintf(r1, ok) case r2 := <-b: sprintf(r2) case <-c: sprintf("r3") case d <- "test": sprintf("d triggered") default: sprintf("default") } e := make(chan string, 1) e <- "hi" sprintf(<-e) close(a) val, ok := <-a sprintf(val, ok) return } func flowOps() { sprintf := makeSprintf("flowOps") i := 1 if return42()%2 == 0 { sprintf("a") i++ } else { sprintf("b") } sprintf(i) switch return42() { case 1: sprintf("1") case 2: sprintf("2") case 3: sprintf("3") case 42: sprintf("42") } } type interfaceB interface { } type testStruct struct { A, B int } func typeOps() { sprintf := makeSprintf("typeOps") // *ssa.ChangeType var interA interfaceCalls sprintf(interA) // *ssa.ChangeInterface var interB interfaceB = struct{}{} var inter0 interface{} = interB sprintf(inter0) // *ssa.Convert var f float64 = 1.0 sprintf(int(f)) casted, ok := inter0.(interfaceB) sprintf(casted, ok) casted2 := inter0.(interfaceB) sprintf(casted2) strc := testStruct{return42(), return42() + 2} strc.B += strc.A sprintf(strc) // Access to unexported structure discard := io.Discard if return42() == 0 { sprintf(discard) // Trigger phi block } _, _ = discard.Write([]byte("test")) } func sumIntsOrFloats[K comparable, V int64 | float64](m map[K]V) V { var s V for _, v := range m { s += v } return s } func genericFunc() { sprintf := makeSprintf("genericFunc") ints := map[string]int64{ "first": 34, "second": 12, } sprintf(sumIntsOrFloats[string, int64](ints)) floats := map[string]float64{ "first": 34.1, "second": 12.1, } sprintf(sumIntsOrFloats(floats)) } ` func TestConvert(t *testing.T) { runGoFile := func(f string) string { cmd := exec.Command("go", "run", f) out, err := cmd.CombinedOutput() qt.Assert(t, qt.IsNil(err)) return string(out) } testFile := filepath.Join(t.TempDir(), "convert.go") err := os.WriteFile(testFile, []byte(mainSrc), 0o777) qt.Assert(t, qt.IsNil(err)) originalOut := runGoFile(testFile) file, fset, _, _ := mustParseAndTypeCheckFile(mainSrc) ssaPkg, _, err := ssautil.BuildPackage(&types.Config{Importer: importer.Default()}, fset, types.NewPackage("test/main", ""), []*ast.File{file}, 0) qt.Assert(t, qt.IsNil(err)) for fIdx, decl := range file.Decls { funcDecl, ok := decl.(*ast.FuncDecl) if !ok { continue } path, _ := astutil.PathEnclosingInterval(file, funcDecl.Pos(), funcDecl.Pos()) ssaFunc := ssa.EnclosingFunction(ssaPkg, path) astFunc, err := Convert(ssaFunc, DefaultConfig()) qt.Assert(t, qt.IsNil(err)) file.Decls[fIdx] = astFunc } convertedFile := filepath.Join(t.TempDir(), "main.go") f, err := os.Create(convertedFile) qt.Assert(t, qt.IsNil(err)) err = printer.Fprint(f, fset, file) qt.Assert(t, qt.IsNil(err)) _ = f.Close() convertedOut := runGoFile(convertedFile) qt.Assert(t, qt.Equals(convertedOut, originalOut)) }