diff --git a/docs/CONTROLFLOW.md b/docs/CONTROLFLOW.md index 4dacf63..2b398c8 100644 --- a/docs/CONTROLFLOW.md +++ b/docs/CONTROLFLOW.md @@ -273,7 +273,6 @@ _s2a_l12: ### Caveats * Obfuscation breaks the lazy iteration over maps. See: [ssa2ast/polyfill.go](../internal/ssa2ast/polyfill.go) -* Generic functions not supported ### Complexity benchmark diff --git a/internal/ssa2ast/func.go b/internal/ssa2ast/func.go index 207af2d..81b520d 100644 --- a/internal/ssa2ast/func.go +++ b/internal/ssa2ast/func.go @@ -199,12 +199,6 @@ func (fc *funcConverter) convertCall(callCommon ssa.CallCommon) (*ast.CallExpr, hasRecv := val.Signature.Recv() != nil methodName := ast.NewIdent(val.Name()) - if val.TypeParams().Len() != 0 { - // TODO: to convert a call of a generic function it is enough to cut method name, - // but in the future when implementing converting generic functions this code must be rewritten - methodName.Name = methodName.Name[:strings.IndexRune(methodName.Name, '[')] - } - if hasRecv { argsOffset = 1 recvExpr, err := fc.convertSsaValue(callCommon.Args[0]) @@ -221,6 +215,25 @@ func (fc *funcConverter) convertCall(callCommon ssa.CallCommon) (*ast.CallExpr, } callExpr.Fun = methodName } + if typeArgs := val.TypeArgs(); len(typeArgs) > 0 { + // Generic methods are called in a monomorphic view (e.g. "someMethod[int string]"), + // so to get the original name, delete everything starting from "[" inclusive. + methodName.Name, _, _ = strings.Cut(methodName.Name, "[") + genericCallExpr := &ast.IndexListExpr{ + X: callExpr.Fun, + } + + // For better readability of generated code and to avoid ambiguities, + // we explicitly specify generic method types (e.g. "someMethod[int, string](0, "str")") + for _, typArg := range typeArgs { + typeExpr, err := fc.tc.Convert(typArg) + if err != nil { + return nil, err + } + genericCallExpr.Indices = append(genericCallExpr.Indices, typeExpr) + } + callExpr.Fun = genericCallExpr + } case *ssa.Builtin: name := val.Name() if _, ok := types.Unsafe.Scope().Lookup(name).(*types.Builtin); ok { @@ -1123,10 +1136,6 @@ func (fc *funcConverter) convertToStmts(ssaFunc *ssa.Function) ([]ast.Stmt, erro } func (fc *funcConverter) convert(ssaFunc *ssa.Function) (*ast.FuncDecl, error) { - if ssaFunc.Signature.TypeParams() != nil || ssaFunc.Signature.RecvTypeParams() != nil { - return nil, ErrUnsupported - } - funcDecl, err := fc.convertSignatureToFuncDecl(ssaFunc.Name(), ssaFunc.Signature) if err != nil { return nil, err diff --git a/internal/ssa2ast/func_test.go b/internal/ssa2ast/func_test.go index 4d972d0..8d198be 100644 --- a/internal/ssa2ast/func_test.go +++ b/internal/ssa2ast/func_test.go @@ -89,6 +89,7 @@ func main() { chanOps() flowOps() typeOps() + genericFunc() } func makeSprintf(tag string) func(vals ...interface{}) { @@ -340,7 +341,32 @@ func typeOps() { sprintf(discard) // Trigger phi block } _, _ = discard.Write([]byte("test")) -}` +} + +func sumIntsOrFloats[K comparable, V int64 | float64](m map[K]V) V { + var s V + for _, v := range m { + s += v + } + return s +} + +func genericFunc() { + sprintf := makeSprintf("genericFunc") + + ints := map[string]int64{ + "first": 34, + "second": 12, + } + sprintf(sumIntsOrFloats[string, int64](ints)) + + floats := map[string]float64{ + "first": 34.1, + "second": 12.1, + } + sprintf(sumIntsOrFloats(floats)) +} +` func TestConvert(t *testing.T) { runGoFile := func(f string) string { diff --git a/internal/ssa2ast/type.go b/internal/ssa2ast/type.go index 8284571..ab8fb7b 100644 --- a/internal/ssa2ast/type.go +++ b/internal/ssa2ast/type.go @@ -53,14 +53,23 @@ func (tc *typeConverter) Convert(t types.Type) (ast.Expr, error) { return chanExpr, nil case *types.Interface: methods := &ast.FieldList{} + hasComparable := false for i := 0; i < typ.NumEmbeddeds(); i++ { embeddedType := typ.EmbeddedType(i) + if namedType, ok := embeddedType.(*types.Named); ok && namedType.String() == "comparable" { + hasComparable = true + } embeddedExpr, err := tc.Convert(embeddedType) if err != nil { return nil, err } methods.List = append(methods.List, &ast.Field{Type: embeddedExpr}) } + + // Special case, handle "comparable" interface itself + if !hasComparable && typ.IsComparable() { + methods.List = append(methods.List, &ast.Field{Type: ast.NewIdent("comparable")}) + } for i := 0; i < typ.NumExplicitMethods(); i++ { method := typ.ExplicitMethod(i) methodSig, err := tc.Convert(method.Type())