diff --git a/go.sum b/go.sum index 960292e..b452e6d 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,7 @@ golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERs golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM= diff --git a/hash.go b/hash.go index 25f5af1..3b30037 100644 --- a/hash.go +++ b/hash.go @@ -158,6 +158,9 @@ func appendFlags(w io.Writer, forBuildHash bool) { io.WriteString(w, " -seed=") io.WriteString(w, flagSeed.String()) } + if flagControlFlow && forBuildHash { + io.WriteString(w, " -ctrlflow") + } if literals.TestObfuscator != "" && forBuildHash { io.WriteString(w, literals.TestObfuscator) } diff --git a/internal/asthelper/asthelper.go b/internal/asthelper/asthelper.go index 5586c2f..a9d196a 100644 --- a/internal/asthelper/asthelper.go +++ b/internal/asthelper/asthelper.go @@ -6,6 +6,7 @@ package asthelper import ( "fmt" "go/ast" + "go/constant" "go/token" "strconv" ) @@ -85,3 +86,56 @@ func DataToByteSlice(data []byte) *ast.CallExpr { Args: []ast.Expr{StringLit(string(data))}, } } + +// SelectExpr "x.sel" +func SelectExpr(x ast.Expr, sel *ast.Ident) *ast.SelectorExpr { + return &ast.SelectorExpr{ + X: x, + Sel: sel, + } +} + +// AssignDefineStmt "Lhs := Rhs" +func AssignDefineStmt(Lhs ast.Expr, Rhs ast.Expr) *ast.AssignStmt { + return &ast.AssignStmt{ + Lhs: []ast.Expr{Lhs}, + Tok: token.DEFINE, + Rhs: []ast.Expr{Rhs}, + } +} + +// CallExprByName "fun(args...)" +func CallExprByName(fun string, args ...ast.Expr) *ast.CallExpr { + return CallExpr(ast.NewIdent(fun), args...) +} + +// AssignStmt "Lhs = Rhs" +func AssignStmt(Lhs ast.Expr, Rhs ast.Expr) *ast.AssignStmt { + return &ast.AssignStmt{ + Lhs: []ast.Expr{Lhs}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{Rhs}, + } +} + +// IndexExprByExpr "xExpr[indexExpr]" +func IndexExprByExpr(xExpr, indexExpr ast.Expr) *ast.IndexExpr { + return &ast.IndexExpr{X: xExpr, Index: indexExpr} +} + +func ConstToAst(val constant.Value) ast.Expr { + switch val.Kind() { + case constant.Bool: + return ast.NewIdent(val.ExactString()) + case constant.String: + return &ast.BasicLit{Kind: token.STRING, Value: val.ExactString()} + case constant.Int: + return &ast.BasicLit{Kind: token.INT, Value: val.ExactString()} + case constant.Float: + return &ast.BasicLit{Kind: token.FLOAT, Value: val.String()} + case constant.Complex: + return CallExprByName("complex", ConstToAst(constant.Real(val)), ConstToAst(constant.Imag(val))) + default: + panic("unreachable") + } +} diff --git a/internal/ctrlflow/ctrlflow.go b/internal/ctrlflow/ctrlflow.go new file mode 100644 index 0000000..d637545 --- /dev/null +++ b/internal/ctrlflow/ctrlflow.go @@ -0,0 +1,186 @@ +package ctrlflow + +import ( + "fmt" + "go/ast" + "go/token" + "go/types" + "log" + mathrand "math/rand" + "strconv" + "strings" + + "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/ssa" + ah "mvdan.cc/garble/internal/asthelper" + "mvdan.cc/garble/internal/ssa2ast" +) + +const ( + mergedFileName = "GARBLE_controlflow.go" + directiveName = "//garble:controlflow" + importPrefix = "___garble_import" + + defaultBlockSplits = 0 + defaultJunkJumps = 0 + defaultFlattenPasses = 1 +) + +type directiveParamMap map[string]string + +func (m directiveParamMap) GetInt(name string, def int) int { + rawVal, ok := m[name] + if !ok { + return def + } + + val, err := strconv.Atoi(rawVal) + if err != nil { + panic(fmt.Errorf("invalid flag %s format: %v", name, err)) + } + return val +} + +// parseDirective parses a directive string and returns a map of directive parameters. +// Each parameter should be in the form "key=value" or "key" +func parseDirective(directive string) (directiveParamMap, bool) { + fieldsStr, ok := strings.CutPrefix(directive, directiveName) + if !ok { + return nil, false + } + + fields := strings.Fields(fieldsStr) + if len(fields) == 0 { + return nil, true + } + m := make(map[string]string) + for _, v := range fields { + key, value, ok := strings.Cut(v, "=") + if ok { + m[key] = value + } else { + m[key] = "" + } + } + return m, true +} + +// Obfuscate obfuscates control flow of all functions with directive using control flattening. +// All obfuscated functions are removed from the original file and moved to the new one. +// Obfuscation can be customized by passing parameters from the directive, example: +// +// //garble:controlflow flatten_passes=1 junk_jumps=0 block_splits=0 +// func someMethod() {} +// +// flatten_passes - controls number of passes of control flow flattening. Have exponential complexity and more than 3 passes are not recommended in most cases. +// junk_jumps - controls how many junk jumps are added. It does not affect final binary by itself, but together with flattening linearly increases complexity. +// block_splits - controls number of times largest block must be splitted. Together with flattening improves obfuscation of long blocks without branches. +func Obfuscate(fset *token.FileSet, ssaPkg *ssa.Package, files []*ast.File, obfRand *mathrand.Rand) (newFileName string, newFile *ast.File, affectedFiles []*ast.File, err error) { + var ssaFuncs []*ssa.Function + var ssaParams []directiveParamMap + + for _, file := range files { + affected := false + for _, decl := range file.Decls { + funcDecl, ok := decl.(*ast.FuncDecl) + if !ok || funcDecl.Doc == nil { + continue + } + + for _, comment := range funcDecl.Doc.List { + params, hasDirective := parseDirective(comment.Text) + if !hasDirective { + continue + } + + path, _ := astutil.PathEnclosingInterval(file, funcDecl.Pos(), funcDecl.Pos()) + ssaFunc := ssa.EnclosingFunction(ssaPkg, path) + if ssaFunc == nil { + panic("function exists in ast but not found in ssa") + } + + ssaFuncs = append(ssaFuncs, ssaFunc) + ssaParams = append(ssaParams, params) + + log.Printf("detected function for controlflow %s (params: %v)", funcDecl.Name.Name, params) + + // Remove inplace function from original file + // TODO: implement a complete function removal + funcDecl.Name = ast.NewIdent("_") + funcDecl.Body = ah.BlockStmt() + funcDecl.Recv = nil + funcDecl.Type = &ast.FuncType{Params: &ast.FieldList{}} + affected = true + + break + } + } + + if affected { + affectedFiles = append(affectedFiles, file) + } + } + + if len(ssaFuncs) == 0 { + return + } + + newFile = &ast.File{ + Package: token.Pos(fset.Base()), + Name: ast.NewIdent(files[0].Name.Name), + } + fset.AddFile(mergedFileName, int(newFile.Package), 1) // required for correct printer output + + funcConfig := ssa2ast.DefaultConfig() + imports := make(map[string]string) // TODO: indirect imports turned into direct currently brake build process + funcConfig.ImportNameResolver = func(pkg *types.Package) *ast.Ident { + if pkg == nil || pkg.Path() == ssaPkg.Pkg.Path() { + return nil + } + + name, ok := imports[pkg.Path()] + if !ok { + name = importPrefix + strconv.Itoa(len(imports)) + imports[pkg.Path()] = name + astutil.AddNamedImport(fset, newFile, name, pkg.Path()) + } + return ast.NewIdent(name) + } + + for idx, ssaFunc := range ssaFuncs { + params := ssaParams[idx] + + split := params.GetInt("block_splits", defaultBlockSplits) + junkCount := params.GetInt("junk_jumps", defaultJunkJumps) + passes := params.GetInt("flatten_passes", defaultFlattenPasses) + + applyObfuscation := func(ssaFunc *ssa.Function) { + for i := 0; i < split; i++ { + if !applySplitting(ssaFunc, obfRand) { + break // no more candidates for splitting + } + } + if junkCount > 0 { + addJunkBlocks(ssaFunc, junkCount, obfRand) + } + for i := 0; i < passes; i++ { + applyFlattening(ssaFunc, obfRand) + } + fixBlockIndexes(ssaFunc) + } + + applyObfuscation(ssaFunc) + for _, anonFunc := range ssaFunc.AnonFuncs { + applyObfuscation(anonFunc) + } + + astFunc, err := ssa2ast.Convert(ssaFunc, funcConfig) + if err != nil { + return "", nil, nil, err + } + newFile.Decls = append(newFile.Decls, astFunc) + } + + newFileName = mergedFileName + return +} diff --git a/internal/ctrlflow/ssa.go b/internal/ctrlflow/ssa.go new file mode 100644 index 0000000..738beb3 --- /dev/null +++ b/internal/ctrlflow/ssa.go @@ -0,0 +1,43 @@ +package ctrlflow + +import ( + "go/constant" + "go/types" + "reflect" + "unsafe" + + "golang.org/x/tools/go/ssa" +) + +// setUnexportedField is used to modify unexported fields of ssa api structures. +// TODO: find an alternative way to access private fields or raise a feature request upstream +func setUnexportedField(objRaw interface{}, name string, valRaw interface{}) { + obj := reflect.ValueOf(objRaw) + for obj.Kind() == reflect.Pointer || obj.Kind() == reflect.Interface { + obj = obj.Elem() + } + + field := obj.FieldByName(name) + if !field.IsValid() { + panic("invalid field: " + name) + } + + fakeStruct := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())) + fakeStruct.Elem().Set(reflect.ValueOf(valRaw)) +} + +func setBlockParent(block *ssa.BasicBlock, ssaFunc *ssa.Function) { + setUnexportedField(block, "parent", ssaFunc) +} + +func setBlock(instr ssa.Instruction, block *ssa.BasicBlock) { + setUnexportedField(instr, "block", block) +} + +func setType(instr ssa.Instruction, typ types.Type) { + setUnexportedField(instr, "typ", typ) +} + +func makeSsaInt(i int) *ssa.Const { + return ssa.NewConst(constant.MakeInt64(int64(i)), types.Typ[types.Int]) +} diff --git a/internal/ctrlflow/transform.go b/internal/ctrlflow/transform.go new file mode 100644 index 0000000..763044d --- /dev/null +++ b/internal/ctrlflow/transform.go @@ -0,0 +1,212 @@ +package ctrlflow + +import ( + "go/token" + "go/types" + mathrand "math/rand" + "strconv" + + "golang.org/x/tools/go/ssa" +) + +type blockMapping struct { + Fake, Target *ssa.BasicBlock +} + +// applyFlattening adds a dispatcher block and uses ssa.Phi to redirect all ssa.Jump and ssa.If to the dispatcher, +// additionally shuffle all blocks +func applyFlattening(ssaFunc *ssa.Function, obfRand *mathrand.Rand) { + if len(ssaFunc.Blocks) < 3 { + return + } + + phiInstr := &ssa.Phi{Comment: "ctrflow.phi"} + setType(phiInstr, types.Typ[types.Int]) + + entryBlock := &ssa.BasicBlock{ + Comment: "ctrflow.entry", + Instrs: []ssa.Instruction{phiInstr}, + } + setBlockParent(entryBlock, ssaFunc) + + makeJumpBlock := func(from *ssa.BasicBlock) *ssa.BasicBlock { + jumpBlock := &ssa.BasicBlock{ + Comment: "ctrflow.jump", + Instrs: []ssa.Instruction{&ssa.Jump{}}, + Preds: []*ssa.BasicBlock{from}, + Succs: []*ssa.BasicBlock{entryBlock}, + } + setBlockParent(jumpBlock, ssaFunc) + return jumpBlock + } + + // map for track fake block -> real block jump + var blocksMapping []blockMapping + for _, block := range ssaFunc.Blocks { + existInstr := block.Instrs[len(block.Instrs)-1] + switch existInstr.(type) { + case *ssa.Jump: + targetBlock := block.Succs[0] + fakeBlock := makeJumpBlock(block) + blocksMapping = append(blocksMapping, blockMapping{fakeBlock, targetBlock}) + block.Succs[0] = fakeBlock + case *ssa.If: + tblock, fblock := block.Succs[0], block.Succs[1] + fakeTblock, fakeFblock := makeJumpBlock(tblock), makeJumpBlock(fblock) + + blocksMapping = append(blocksMapping, blockMapping{fakeTblock, tblock}) + blocksMapping = append(blocksMapping, blockMapping{fakeFblock, fblock}) + + block.Succs[0] = fakeTblock + block.Succs[1] = fakeFblock + case *ssa.Return, *ssa.Panic: + // control flow flattening is not applicable + default: + panic("unreachable") + } + } + + phiIdxs := obfRand.Perm(len(blocksMapping)) + for i := range phiIdxs { + phiIdxs[i]++ // 0 reserved for real entry block + } + + var entriesBlocks []*ssa.BasicBlock + obfuscatedBlocks := ssaFunc.Blocks + for i, m := range blocksMapping { + entryBlock.Preds = append(entryBlock.Preds, m.Fake) + phiInstr.Edges = append(phiInstr.Edges, makeSsaInt(phiIdxs[i])) + + obfuscatedBlocks = append(obfuscatedBlocks, m.Fake) + + cond := &ssa.BinOp{X: phiInstr, Op: token.EQL, Y: makeSsaInt(phiIdxs[i])} + setType(cond, types.Typ[types.Bool]) + + *phiInstr.Referrers() = append(*phiInstr.Referrers(), cond) + + ifInstr := &ssa.If{Cond: cond} + *cond.Referrers() = append(*cond.Referrers(), ifInstr) + + ifBlock := &ssa.BasicBlock{ + Instrs: []ssa.Instruction{cond, ifInstr}, + Succs: []*ssa.BasicBlock{m.Target, nil}, // false branch fulfilled in next iteration or linked to real entry block + } + setBlockParent(ifBlock, ssaFunc) + + setBlock(cond, ifBlock) + setBlock(ifInstr, ifBlock) + entriesBlocks = append(entriesBlocks, ifBlock) + + if i == 0 { + entryBlock.Instrs = append(entryBlock.Instrs, &ssa.Jump{}) + entryBlock.Succs = []*ssa.BasicBlock{ifBlock} + ifBlock.Preds = append(ifBlock.Preds, entryBlock) + } else { + // link previous block to current + entriesBlocks[i-1].Succs[1] = ifBlock + ifBlock.Preds = append(ifBlock.Preds, entriesBlocks[i-1]) + } + } + + lastFakeEntry := entriesBlocks[len(entriesBlocks)-1] + + realEntryBlock := ssaFunc.Blocks[0] + lastFakeEntry.Succs[1] = realEntryBlock + realEntryBlock.Preds = append(realEntryBlock.Preds, lastFakeEntry) + + obfuscatedBlocks = append(obfuscatedBlocks, entriesBlocks...) + obfRand.Shuffle(len(obfuscatedBlocks), func(i, j int) { + obfuscatedBlocks[i], obfuscatedBlocks[j] = obfuscatedBlocks[j], obfuscatedBlocks[i] + }) + ssaFunc.Blocks = append([]*ssa.BasicBlock{entryBlock}, obfuscatedBlocks...) +} + +// addJunkBlocks adds junk jumps into random blocks. Can create chains of junk jumps. +func addJunkBlocks(ssaFunc *ssa.Function, count int, obfRand *mathrand.Rand) { + if count == 0 { + return + } + var candidates []*ssa.BasicBlock + for _, block := range ssaFunc.Blocks { + if len(block.Succs) > 0 { + candidates = append(candidates, block) + } + } + + if len(candidates) == 0 { + return + } + + for i := 0; i < count; i++ { + targetBlock := candidates[obfRand.Intn(len(candidates))] + succsIdx := obfRand.Intn(len(targetBlock.Succs)) + succs := targetBlock.Succs[succsIdx] + + fakeBlock := &ssa.BasicBlock{ + Comment: "ctrflow.fake." + strconv.Itoa(i), + Instrs: []ssa.Instruction{&ssa.Jump{}}, + Preds: []*ssa.BasicBlock{targetBlock}, + Succs: []*ssa.BasicBlock{succs}, + } + setBlockParent(fakeBlock, ssaFunc) + targetBlock.Succs[succsIdx] = fakeBlock + + ssaFunc.Blocks = append(ssaFunc.Blocks, fakeBlock) + candidates = append(candidates, fakeBlock) + } +} + +// applySplitting splits biggest block into 2 parts of random size. +// Returns false if no block large enough for splitting is found +func applySplitting(ssaFunc *ssa.Function, obfRand *mathrand.Rand) bool { + var targetBlock *ssa.BasicBlock + for _, block := range ssaFunc.Blocks { + if targetBlock == nil || len(block.Instrs) > len(targetBlock.Instrs) { + targetBlock = block + } + } + + const minInstrCount = 1 + 1 // 1 exit instruction + 1 any instruction + if targetBlock == nil || len(targetBlock.Instrs) <= minInstrCount { + return false + } + + splitIdx := 1 + obfRand.Intn(len(targetBlock.Instrs)-2) + + firstPart := make([]ssa.Instruction, splitIdx+1) + copy(firstPart, targetBlock.Instrs) + firstPart[len(firstPart)-1] = &ssa.Jump{} + + secondPart := targetBlock.Instrs[splitIdx:] + targetBlock.Instrs = firstPart + + newBlock := &ssa.BasicBlock{ + Comment: "ctrflow.split." + strconv.Itoa(targetBlock.Index), + Instrs: secondPart, + Preds: []*ssa.BasicBlock{targetBlock}, + Succs: targetBlock.Succs, + } + setBlockParent(newBlock, ssaFunc) + for _, instr := range newBlock.Instrs { + setBlock(instr, newBlock) + } + + // Fix preds for ssa.Phi working + for _, succ := range targetBlock.Succs { + for i, pred := range succ.Preds { + if pred == targetBlock { + succ.Preds[i] = newBlock + } + } + } + + ssaFunc.Blocks = append(ssaFunc.Blocks, newBlock) + targetBlock.Succs = []*ssa.BasicBlock{newBlock} + return true +} + +func fixBlockIndexes(ssaFunc *ssa.Function) { + for i, block := range ssaFunc.Blocks { + block.Index = i + } +} diff --git a/internal/ssa2ast/func.go b/internal/ssa2ast/func.go new file mode 100644 index 0000000..207af2d --- /dev/null +++ b/internal/ssa2ast/func.go @@ -0,0 +1,1140 @@ +package ssa2ast + +import ( + "errors" + "fmt" + "go/ast" + "go/token" + "go/types" + "sort" + "strconv" + "strings" + + "golang.org/x/exp/slices" + "golang.org/x/tools/go/ssa" + ah "mvdan.cc/garble/internal/asthelper" +) + +var ErrUnsupported = errors.New("unsupported") + +type NameType int + +type ImportNameResolver func(pkg *types.Package) *ast.Ident + +type ConverterConfig struct { + // ImportNameResolver function to get the actual import name. + // Because converting works at function level, only the caller knows actual name of the import. + ImportNameResolver ImportNameResolver + + // NamePrefix prefix added to all new local variables. Must be reasonably unique + NamePrefix string +} + +func DefaultConfig() *ConverterConfig { + return &ConverterConfig{ + ImportNameResolver: defaultImportNameResolver, + NamePrefix: "_s2a_", + } +} + +func defaultImportNameResolver(pkg *types.Package) *ast.Ident { + if pkg == nil || pkg.Name() == "main" { + return nil + } + return ast.NewIdent(pkg.Name()) +} + +type funcConverter struct { + importNameResolver ImportNameResolver + tc *typeConverter + namePrefix string + valueNameMap map[ssa.Value]string +} + +func Convert(ssaFunc *ssa.Function, cfg *ConverterConfig) (*ast.FuncDecl, error) { + return newFuncConverter(cfg).convert(ssaFunc) +} + +func newFuncConverter(cfg *ConverterConfig) *funcConverter { + return &funcConverter{ + importNameResolver: cfg.ImportNameResolver, + tc: &typeConverter{resolver: cfg.ImportNameResolver}, + namePrefix: cfg.NamePrefix, + valueNameMap: make(map[ssa.Value]string), + } +} + +func (fc *funcConverter) getVarName(val ssa.Value) string { + if name, ok := fc.valueNameMap[val]; ok { + return name + } + + name := fc.namePrefix + strconv.Itoa(len(fc.valueNameMap)) + fc.valueNameMap[val] = name + return name +} + +func (fc *funcConverter) convertSignatureToFuncDecl(name string, signature *types.Signature) (*ast.FuncDecl, error) { + funcTypeDecl, err := fc.tc.Convert(signature) + if err != nil { + return nil, err + } + funcDecl := &ast.FuncDecl{Name: ast.NewIdent(name), Type: funcTypeDecl.(*ast.FuncType)} + if recv := signature.Recv(); recv != nil { + recvTypeExpr, err := fc.tc.Convert(recv.Type()) + if err != nil { + return nil, err + } + f := &ast.Field{Type: recvTypeExpr} + if recvName := recv.Name(); recvName != "" { + f.Names = []*ast.Ident{ast.NewIdent(recvName)} + } + funcDecl.Recv = &ast.FieldList{List: []*ast.Field{f}} + } + return funcDecl, nil +} + +func (fc *funcConverter) convertSignatureToFuncLit(signature *types.Signature) (*ast.FuncLit, error) { + funcTypeDecl, err := fc.tc.Convert(signature) + if err != nil { + return nil, err + } + return &ast.FuncLit{Type: funcTypeDecl.(*ast.FuncType)}, nil +} + +type AstBlock struct { + Index int + HasRefs bool + Body []ast.Stmt + Phi []ast.Stmt + Exit ast.Stmt +} + +type AstFunc struct { + Vars map[string]types.Type + Blocks []*AstBlock +} + +func isVoidType(typ types.Type) bool { + tuple, ok := typ.(*types.Tuple) + return ok && tuple.Len() == 0 +} + +func isStringType(typ types.Type) bool { + return types.Identical(typ, types.Typ[types.String]) || types.Identical(typ, types.Typ[types.UntypedString]) +} + +func getFieldName(tp types.Type, index int) (string, error) { + if pt, ok := tp.(*types.Pointer); ok { + tp = pt.Elem() + } + if named, ok := tp.(*types.Named); ok { + tp = named.Underlying() + } + if stp, ok := tp.(*types.Struct); ok { + return stp.Field(index).Name(), nil + } + return "", fmt.Errorf("field %d not found in %v", index, tp) +} + +func (fc *funcConverter) castCallExpr(typ types.Type, x ssa.Value) (*ast.CallExpr, error) { + castExpr, err := fc.tc.Convert(typ) + if err != nil { + return nil, err + } + valExpr, err := fc.convertSsaValue(x) + if err != nil { + return nil, err + } + return ah.CallExpr(&ast.ParenExpr{X: castExpr}, valExpr), nil +} + +func (fc *funcConverter) getLabelName(blockIdx int) *ast.Ident { + return ast.NewIdent(fmt.Sprintf("%sl%d", fc.namePrefix, blockIdx)) +} + +func (fc *funcConverter) gotoStmt(blockIdx int) *ast.BranchStmt { + return &ast.BranchStmt{ + Tok: token.GOTO, + Label: fc.getLabelName(blockIdx), + } +} + +func (fc *funcConverter) getAnonFunctionName(val *ssa.Function) (*ast.Ident, error) { + parent := val.Parent() + if parent == nil { + return nil, nil + } + anonFuncIdx := slices.Index(parent.AnonFuncs, val) + if anonFuncIdx < 0 { + return nil, fmt.Errorf("anon func %q for call not found", val.Name()) + } + return ast.NewIdent(fc.getAnonFuncName(anonFuncIdx)), nil +} + +func (fc *funcConverter) convertCall(callCommon ssa.CallCommon) (*ast.CallExpr, error) { + callExpr := &ast.CallExpr{} + argsOffset := 0 + + if !callCommon.IsInvoke() { + switch val := callCommon.Value.(type) { + case *ssa.Function: + anonFuncName, err := fc.getAnonFunctionName(val) + if err != nil { + return nil, err + } + if anonFuncName != nil { + callExpr.Fun = anonFuncName + break + } + + thunkCall, err := fc.getThunkMethodCall(val) + if err != nil { + return nil, err + } + if thunkCall != nil { + callExpr.Fun = thunkCall + break + } + + hasRecv := val.Signature.Recv() != nil + methodName := ast.NewIdent(val.Name()) + if val.TypeParams().Len() != 0 { + // TODO: to convert a call of a generic function it is enough to cut method name, + // but in the future when implementing converting generic functions this code must be rewritten + methodName.Name = methodName.Name[:strings.IndexRune(methodName.Name, '[')] + } + + if hasRecv { + argsOffset = 1 + recvExpr, err := fc.convertSsaValue(callCommon.Args[0]) + if err != nil { + return nil, err + } + callExpr.Fun = ah.SelectExpr(recvExpr, methodName) + } else { + if val.Pkg != nil { + if pkgIdent := fc.importNameResolver(val.Pkg.Pkg); pkgIdent != nil { + callExpr.Fun = ah.SelectExpr(pkgIdent, methodName) + break + } + } + callExpr.Fun = methodName + } + case *ssa.Builtin: + name := val.Name() + if _, ok := types.Unsafe.Scope().Lookup(name).(*types.Builtin); ok { + unsafePkgIdent := fc.importNameResolver(types.Unsafe) + if unsafePkgIdent == nil { + return nil, fmt.Errorf("cannot resolve unsafe package") + } + callExpr.Fun = &ast.SelectorExpr{X: unsafePkgIdent, Sel: ast.NewIdent(name)} + } else { + callExpr.Fun = ast.NewIdent(name) + } + default: + callFunExpr, err := fc.convertSsaValue(val) + if err != nil { + return nil, err + } + callExpr.Fun = callFunExpr + } + } else { + recvExpr, err := fc.convertSsaValue(callCommon.Value) + if err != nil { + return nil, err + } + callExpr.Fun = ah.SelectExpr(recvExpr, ast.NewIdent(callCommon.Method.Name())) + } + + for _, arg := range callCommon.Args[argsOffset:] { + argExpr, err := fc.convertSsaValue(arg) + if err != nil { + return nil, err + } + callExpr.Args = append(callExpr.Args, argExpr) + } + if callCommon.Signature().Variadic() { + callExpr.Ellipsis = 1 + } + return callExpr, nil +} + +func (fc *funcConverter) convertSsaValueNonExplicitNil(ssaValue ssa.Value) (ast.Expr, error) { + return fc.ssaValue(ssaValue, false) +} + +func (fc *funcConverter) convertSsaValue(ssaValue ssa.Value) (ast.Expr, error) { + return fc.ssaValue(ssaValue, true) +} + +func (fc *funcConverter) getThunkMethodCall(val *ssa.Function) (ast.Expr, error) { + const thunkPrefix = "$thunk" + if !strings.HasSuffix(val.Name(), thunkPrefix) { + return nil, nil + } + thunkType, ok := val.Object().Type().Underlying().(*types.Signature) + if !ok { + return nil, fmt.Errorf("unsupported thunk type: %w", ErrUnsupported) + } + recvVar := thunkType.Recv() + if recvVar == nil { + return nil, fmt.Errorf("unsupported non method thunk: %w", ErrUnsupported) + } + + thunkTypeAst, err := fc.tc.Convert(recvVar.Type()) + if err != nil { + return nil, err + } + trimmedName := ast.NewIdent(strings.TrimSuffix(val.Name(), thunkPrefix)) + return ah.SelectExpr(&ast.ParenExpr{X: thunkTypeAst}, trimmedName), nil +} + +func (fc *funcConverter) ssaValue(ssaValue ssa.Value, explicitNil bool) (ast.Expr, error) { + switch val := ssaValue.(type) { + case *ssa.Builtin: + return ast.NewIdent(val.Name()), nil + case *ssa.Global: + globalExpr := &ast.UnaryExpr{Op: token.AND} + newName := ast.NewIdent(val.Name()) + if pkgIdent := fc.importNameResolver(val.Pkg.Pkg); pkgIdent != nil { + globalExpr.X = ah.SelectExpr(pkgIdent, newName) + } else { + globalExpr.X = newName + } + return globalExpr, nil + case *ssa.Function: + anonFuncName, err := fc.getAnonFunctionName(val) + if err != nil { + return nil, err + } + if anonFuncName != nil { + return anonFuncName, nil + } + + thunkCall, err := fc.getThunkMethodCall(val) + if err != nil { + return nil, err + } + if thunkCall != nil { + return thunkCall, nil + } + + name := ast.NewIdent(val.Name()) + if val.Signature.Recv() == nil && val.Pkg != nil { + if pkgIdent := fc.importNameResolver(val.Pkg.Pkg); pkgIdent != nil { + return ah.SelectExpr(pkgIdent, name), nil + } + } + return name, nil + case *ssa.Const: + var constExpr ast.Expr + if val.Value == nil { + // handle nil constant for non-pointer structs + typ := val.Type() + if _, ok := typ.(*types.Named); ok { + typ = typ.Underlying() + } + if _, ok := typ.(*types.Struct); ok { + typExpr, err := fc.tc.Convert(val.Type()) + if err != nil { + return nil, err + } + return &ast.CompositeLit{Type: typExpr}, nil + } + + constExpr = ast.NewIdent("nil") + if !explicitNil { + return constExpr, nil + } + } else { + constExpr = ah.ConstToAst(val.Value) + } + + if basicType, ok := val.Type().(*types.Basic); ok { + if basicType.Info()&(types.IsString|types.IsUntyped) != 0 { + return constExpr, nil + } + } + + castExpr, err := fc.tc.Convert(val.Type()) + if err != nil { + return nil, err + } + return ah.CallExpr(&ast.ParenExpr{X: castExpr}, constExpr), nil + case *ssa.Parameter, *ssa.FreeVar: + return ast.NewIdent(val.Name()), nil + default: + return ast.NewIdent(fc.getVarName(val)), nil + } +} + +type register interface { + Name() string + Referrers() *[]ssa.Instruction + Type() types.Type + + String() string + Parent() *ssa.Function + Pos() token.Pos +} + +func (fc *funcConverter) tupleVarName(val ssa.Value, idx int) string { + return fmt.Sprintf("%s_%d", fc.getVarName(val), idx) +} + +func (fc *funcConverter) tupleVarNameAndType(reg ssa.Value, idx int) (name string, typ types.Type, hasRefs bool) { + tupleType := reg.Type().(*types.Tuple) + typ = tupleType.At(idx).Type() + name = "_" + + refs := reg.Referrers() + if refs == nil { + return + } + + for _, instr := range *refs { + extractInstr, ok := instr.(*ssa.Extract) + if ok && extractInstr.Index == idx { + hasRefs = true + name = fc.tupleVarName(reg, idx) + return + } + } + return +} + +func isNilValue(value ssa.Value) bool { + constVal, ok := value.(*ssa.Const) + return ok && constVal.Value == nil +} + +func (fc *funcConverter) convertBlock(astFunc *AstFunc, ssaBlock *ssa.BasicBlock, astBlock *AstBlock) error { + astBlock.HasRefs = len(ssaBlock.Preds) != 0 + + defineTypedVar := func(r register, typ types.Type, expr ast.Expr) ast.Stmt { + if isVoidType(typ) { + return &ast.ExprStmt{X: expr} + } + if tuple, ok := typ.(*types.Tuple); ok { + assignStmt := &ast.AssignStmt{Tok: token.ASSIGN, Rhs: []ast.Expr{expr}} + localTuple := true + tmpVars := make(map[string]types.Type) + + for i := 0; i < tuple.Len(); i++ { + name, typ, hasRefs := fc.tupleVarNameAndType(r, i) + tmpVars[name] = typ + if hasRefs { + localTuple = false + } + assignStmt.Lhs = append(assignStmt.Lhs, ast.NewIdent(name)) + } + + if !localTuple { + for n, t := range tmpVars { + astFunc.Vars[n] = t + } + } + + return assignStmt + } + + refs := r.Referrers() + if refs == nil || len(*refs) == 0 { + return ah.AssignStmt(ast.NewIdent("_"), expr) + } + + localVar := true + for _, refInstr := range *refs { + if _, ok := refInstr.(*ssa.Phi); ok || refInstr.Block() != ssaBlock { + localVar = false + } + } + + newName := fc.getVarName(r) + assignStmt := ah.AssignDefineStmt(ast.NewIdent(newName), expr) + if !localVar { + assignStmt.Tok = token.ASSIGN + astFunc.Vars[newName] = typ + } + return assignStmt + } + defineVar := func(r register, expr ast.Expr) ast.Stmt { + return defineTypedVar(r, r.Type(), expr) + } + + for _, instr := range ssaBlock.Instrs[:len(ssaBlock.Instrs)-1] { + var stmt ast.Stmt + switch instr := instr.(type) { + case *ssa.Alloc: + varType := instr.Type().Underlying().(*types.Pointer).Elem() + varExpr, err := fc.tc.Convert(varType) + if err != nil { + return err + } + stmt = defineVar(instr, ah.CallExprByName("new", varExpr)) + case *ssa.BinOp: + xExpr, err := fc.convertSsaValueNonExplicitNil(instr.X) + if err != nil { + return err + } + + var yExpr ast.Expr + // Handle special case: if nil == nil + if isNilValue(instr.X) && isNilValue(instr.Y) { + yExpr, err = fc.convertSsaValue(instr.Y) + } else { + yExpr, err = fc.convertSsaValueNonExplicitNil(instr.Y) + } + if err != nil { + return err + } + + stmt = defineVar(instr, &ast.BinaryExpr{ + X: xExpr, + Op: instr.Op, + Y: yExpr, + }) + case *ssa.Call: + callFunExpr, err := fc.convertCall(instr.Call) + if err != nil { + return err + } + stmt = defineVar(instr, callFunExpr) + case *ssa.ChangeInterface: + castExpr, err := fc.castCallExpr(instr.Type(), instr.X) + if err != nil { + return err + } + stmt = defineVar(instr, castExpr) + case *ssa.ChangeType: + castExpr, err := fc.castCallExpr(instr.Type(), instr.X) + if err != nil { + return err + } + stmt = defineVar(instr, castExpr) + case *ssa.Convert: + castExpr, err := fc.castCallExpr(instr.Type(), instr.X) + if err != nil { + return err + } + stmt = defineVar(instr, castExpr) + case *ssa.Defer: + callExpr, err := fc.convertCall(instr.Call) + if err != nil { + return err + } + stmt = &ast.DeferStmt{Call: callExpr} + case *ssa.Extract: + name := fc.tupleVarName(instr.Tuple, instr.Index) + stmt = defineVar(instr, ast.NewIdent(name)) + case *ssa.Field: + xExpr, err := fc.convertSsaValue(instr.X) + if err != nil { + return err + } + + fieldName, err := getFieldName(instr.X.Type(), instr.Field) + if err != nil { + return err + } + stmt = defineVar(instr, ah.SelectExpr(xExpr, ast.NewIdent(fieldName))) + case *ssa.FieldAddr: + xExpr, err := fc.convertSsaValue(instr.X) + if err != nil { + return err + } + + fieldName, err := getFieldName(instr.X.Type(), instr.Field) + if err != nil { + return err + } + stmt = defineVar(instr, &ast.UnaryExpr{ + Op: token.AND, + X: ah.SelectExpr(xExpr, ast.NewIdent(fieldName)), + }) + case *ssa.Go: + callExpr, err := fc.convertCall(instr.Call) + if err != nil { + return err + } + stmt = &ast.GoStmt{Call: callExpr} + case *ssa.Index: + xExpr, err := fc.convertSsaValue(instr.X) + if err != nil { + return err + } + indexExpr, err := fc.convertSsaValue(instr.Index) + if err != nil { + return err + } + stmt = defineVar(instr, ah.IndexExprByExpr(xExpr, indexExpr)) + case *ssa.IndexAddr: + xExpr, err := fc.convertSsaValue(instr.X) + if err != nil { + return err + } + indexExpr, err := fc.convertSsaValue(instr.Index) + if err != nil { + return err + } + stmt = defineVar(instr, &ast.UnaryExpr{Op: token.AND, X: ah.IndexExprByExpr(xExpr, indexExpr)}) + case *ssa.Lookup: + mapExpr, err := fc.convertSsaValue(instr.X) + if err != nil { + return err + } + + indexExpr, err := fc.convertSsaValue(instr.Index) + if err != nil { + return err + } + + mapIndexExpr := ah.IndexExprByExpr(mapExpr, indexExpr) + if instr.CommaOk { + valName, valType, valHasRefs := fc.tupleVarNameAndType(instr, 0) + okName, okType, okHasRefs := fc.tupleVarNameAndType(instr, 1) + + if valHasRefs { + astFunc.Vars[valName] = valType + } + if okHasRefs { + astFunc.Vars[okName] = okType + } + + stmt = &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(valName), ast.NewIdent(okName)}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{mapIndexExpr}, + } + } else { + stmt = defineVar(instr, mapIndexExpr) + } + case *ssa.MakeChan: + chanExpr, err := fc.tc.Convert(instr.Type()) + if err != nil { + return err + } + makeExpr := ah.CallExprByName("make", chanExpr) + if instr.Size != nil { + reserveExpr, err := fc.convertSsaValue(instr.Size) + if err != nil { + return err + } + makeExpr.Args = append(makeExpr.Args, reserveExpr) + } + stmt = defineVar(instr, makeExpr) + case *ssa.MakeInterface: + castExpr, err := fc.castCallExpr(instr.Type(), instr.X) + if err != nil { + return err + } + stmt = defineVar(instr, castExpr) + case *ssa.MakeMap: + mapExpr, err := fc.tc.Convert(instr.Type()) + if err != nil { + return err + } + makeExpr := ah.CallExprByName("make", mapExpr) + if instr.Reserve != nil { + reserveExpr, err := fc.convertSsaValue(instr.Reserve) + if err != nil { + return err + } + makeExpr.Args = append(makeExpr.Args, reserveExpr) + } + stmt = defineVar(instr, makeExpr) + case *ssa.MakeSlice: + sliceExpr, err := fc.tc.Convert(instr.Type()) + if err != nil { + return err + } + lenExpr, err := fc.convertSsaValue(instr.Len) + if err != nil { + return err + } + capExpr, err := fc.convertSsaValue(instr.Cap) + if err != nil { + return err + } + stmt = defineVar(instr, ah.CallExprByName("make", sliceExpr, lenExpr, capExpr)) + case *ssa.MapUpdate: + mapExpr, err := fc.convertSsaValue(instr.Map) + if err != nil { + return err + } + keyExpr, err := fc.convertSsaValue(instr.Key) + if err != nil { + return err + } + valueExpr, err := fc.convertSsaValue(instr.Value) + if err != nil { + return err + } + stmt = ah.AssignStmt(ah.IndexExprByExpr(mapExpr, keyExpr), valueExpr) + case *ssa.Next: + okName, okType, okHasRefs := fc.tupleVarNameAndType(instr, 0) + keyName, keyType, keyHasRefs := fc.tupleVarNameAndType(instr, 1) + valName, valType, valHasRefs := fc.tupleVarNameAndType(instr, 2) + if okHasRefs { + astFunc.Vars[okName] = okType + } + if keyHasRefs { + astFunc.Vars[keyName] = keyType + } + if valHasRefs { + astFunc.Vars[valName] = valType + } + + if instr.IsString { + idxName := fc.tupleVarName(instr.Iter, 0) + iterValName := fc.tupleVarName(instr.Iter, 1) + + stmt = ah.BlockStmt( + ah.AssignStmt(ast.NewIdent(okName), &ast.BinaryExpr{ + X: ast.NewIdent(idxName), + Op: token.LSS, + Y: ah.CallExprByName("len", ast.NewIdent(iterValName)), + }), + &ast.IfStmt{ + Cond: ast.NewIdent(okName), + Body: ah.BlockStmt( + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(keyName), ast.NewIdent(valName)}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ast.NewIdent(idxName), ah.IndexExprByExpr(ast.NewIdent(iterValName), ast.NewIdent(idxName))}, + }, + &ast.IncDecStmt{X: ast.NewIdent(idxName), Tok: token.INC}, + ), + }, + ) + } else { + stmt = &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(okName), ast.NewIdent(keyName), ast.NewIdent(valName)}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ah.CallExprByName(fc.getVarName(instr.Iter))}, + } + } + case *ssa.Phi: + phiName := fc.getVarName(instr) + astFunc.Vars[phiName] = instr.Type() + + for predIdx, edge := range instr.Edges { + edgeExpr, err := fc.convertSsaValue(edge) + if err != nil { + return err + } + + blockIdx := ssaBlock.Preds[predIdx].Index + astFunc.Blocks[blockIdx].Phi = append(astFunc.Blocks[blockIdx].Phi, ah.AssignStmt(ast.NewIdent(phiName), edgeExpr)) + } + case *ssa.Range: + xExpr, err := fc.convertSsaValue(instr.X) + if err != nil { + return err + } + if isStringType(instr.X.Type()) { + idxName := fc.tupleVarName(instr, 0) + valName := fc.tupleVarName(instr, 1) + + astFunc.Vars[idxName] = types.Typ[types.Int] + astFunc.Vars[valName] = types.NewSlice(types.Typ[types.Rune]) + + stmt = &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(idxName), ast.NewIdent(valName)}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ + ah.IntLit(0), + ah.CallExpr(&ast.ArrayType{Elt: ast.NewIdent("rune")}, xExpr), + }, + } + } else { + makeIterExpr, nextType, err := makeMapIteratorPolyfill(fc.tc, instr.X.Type().(*types.Map)) + if err != nil { + return err + } + + stmt = defineTypedVar(instr, nextType, ah.CallExpr(makeIterExpr, xExpr)) + } + case *ssa.Select: + const reservedTupleIdx = 2 + + indexName, indexType, indexHasRefs := fc.tupleVarNameAndType(instr, 0) + okName, okType, okHasRefs := fc.tupleVarNameAndType(instr, 1) + if indexHasRefs { + astFunc.Vars[indexName] = indexType + } + if okHasRefs { + astFunc.Vars[okName] = okType + } + + var stmts []ast.Stmt + + recvIndex := 0 + for idx, state := range instr.States { + chanExpr, err := fc.convertSsaValue(state.Chan) + if err != nil { + return err + } + + var commStmt ast.Stmt + switch state.Dir { + case types.SendOnly: + valueExpr, err := fc.convertSsaValue(state.Send) + if err != nil { + return err + } + commStmt = &ast.SendStmt{Chan: chanExpr, Value: valueExpr} + case types.RecvOnly: + valName, valType, valHasRefs := fc.tupleVarNameAndType(instr, reservedTupleIdx+recvIndex) + if valHasRefs { + astFunc.Vars[valName] = valType + } + commStmt = ah.AssignStmt(ast.NewIdent(valName), &ast.UnaryExpr{Op: token.ARROW, X: chanExpr}) + recvIndex++ + default: + return fmt.Errorf("not supported select chan dir %d: %w", state.Dir, ErrUnsupported) + } + + stmts = append(stmts, &ast.CommClause{ + Comm: commStmt, + Body: []ast.Stmt{ + ah.AssignStmt(ast.NewIdent(indexName), ah.IntLit(idx)), + }, + }) + } + + if !instr.Blocking { + stmts = append(stmts, &ast.CommClause{Body: []ast.Stmt{ah.AssignStmt(ast.NewIdent(indexName), ah.IntLit(len(instr.States)))}}) + } + + stmt = &ast.SelectStmt{Body: ah.BlockStmt(stmts...)} + case *ssa.Send: + chanExpr, err := fc.convertSsaValue(instr.Chan) + if err != nil { + return err + } + valExpr, err := fc.convertSsaValue(instr.X) + if err != nil { + return err + } + stmt = &ast.SendStmt{ + Chan: chanExpr, + Value: valExpr, + } + case *ssa.Slice: + valExpr, err := fc.convertSsaValue(instr.X) + if err != nil { + return err + } + sliceExpr := &ast.SliceExpr{X: valExpr} + if instr.Low != nil { + sliceExpr.Low, err = fc.convertSsaValue(instr.Low) + if err != nil { + return err + } + } + if instr.High != nil { + sliceExpr.High, err = fc.convertSsaValue(instr.High) + if err != nil { + return err + } + } + if instr.Max != nil { + sliceExpr.Max, err = fc.convertSsaValue(instr.Max) + if err != nil { + return err + } + } + stmt = defineVar(instr, sliceExpr) + case *ssa.SliceToArrayPointer: + castExpr, err := fc.tc.Convert(instr.Type()) + if err != nil { + return err + } + xExpr, err := fc.convertSsaValue(instr.X) + if err != nil { + return err + } + stmt = defineVar(instr, ah.CallExpr(&ast.ParenExpr{X: castExpr}, xExpr)) + case *ssa.Store: + addrExpr, err := fc.convertSsaValue(instr.Addr) + if err != nil { + return err + } + valExpr, err := fc.convertSsaValue(instr.Val) + if err != nil { + return err + } + stmt = ah.AssignStmt(&ast.StarExpr{X: addrExpr}, valExpr) + case *ssa.TypeAssert: + valExpr, err := fc.convertSsaValue(instr.X) + if err != nil { + return err + } + + assertTypeExpr, err := fc.tc.Convert(instr.AssertedType) + if err != nil { + return err + } + + if instr.CommaOk { + valName, valType, valHasRefs := fc.tupleVarNameAndType(instr, 0) + okName, okType, okHasRefs := fc.tupleVarNameAndType(instr, 1) + if valHasRefs { + astFunc.Vars[valName] = valType + } + if okHasRefs { + astFunc.Vars[okName] = okType + } + + stmt = &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(valName), ast.NewIdent(okName)}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{&ast.TypeAssertExpr{X: valExpr, Type: assertTypeExpr}}, + } + } else { + stmt = defineVar(instr, &ast.TypeAssertExpr{X: valExpr, Type: assertTypeExpr}) + } + case *ssa.UnOp: + valExpr, err := fc.convertSsaValue(instr.X) + if err != nil { + return err + } + if instr.CommaOk { + if instr.Op != token.ARROW { + return fmt.Errorf("unary operator %q in %v: %w", instr.Op, instr, ErrUnsupported) + } + + valName, valType, valHasRefs := fc.tupleVarNameAndType(instr, 0) + okName, okType, okHasRefs := fc.tupleVarNameAndType(instr, 1) + if valHasRefs { + astFunc.Vars[valName] = valType + } + if okHasRefs { + astFunc.Vars[okName] = okType + } + + stmt = &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(valName), ast.NewIdent(okName)}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{&ast.UnaryExpr{ + Op: token.ARROW, + X: valExpr, + }}, + } + } else if instr.Op == token.MUL { + stmt = defineVar(instr, &ast.StarExpr{X: valExpr}) + } else { + stmt = defineVar(instr, &ast.UnaryExpr{Op: instr.Op, X: valExpr}) + } + case *ssa.MakeClosure: + anonFunc := instr.Fn.(*ssa.Function) + anonFuncName, err := fc.getAnonFunctionName(anonFunc) + if err != nil { + return err + } + if anonFuncName == nil { + return fmt.Errorf("make closure for non anon func %q: %w", anonFunc.Name(), ErrUnsupported) + } + + callExpr := &ast.CallExpr{Fun: anonFuncName} + for _, freeVar := range instr.Bindings { + varExr, err := fc.convertSsaValue(freeVar) + if err != nil { + return err + } + callExpr.Args = append(callExpr.Args, varExr) + } + + stmt = defineVar(instr, callExpr) + case *ssa.RunDefers, *ssa.DebugRef: + // ignored + continue + default: + return fmt.Errorf("instruction %v: %w", instr, ErrUnsupported) + } + + if stmt != nil { + astBlock.Body = append(astBlock.Body, stmt) + } + } + + exitInstr := ssaBlock.Instrs[len(ssaBlock.Instrs)-1] + switch exit := exitInstr.(type) { + case *ssa.Jump: + targetBlockIdx := ssaBlock.Succs[0].Index + astBlock.Exit = fc.gotoStmt(targetBlockIdx) + case *ssa.If: + tblock := ssaBlock.Succs[0].Index + fblock := ssaBlock.Succs[1].Index + + condExpr, err := fc.convertSsaValue(exit.Cond) + if err != nil { + return err + } + + astBlock.Exit = &ast.IfStmt{ + Cond: condExpr, + Body: ah.BlockStmt(fc.gotoStmt(tblock)), + Else: ah.BlockStmt(fc.gotoStmt(fblock)), + } + case *ssa.Return: + exitStmt := &ast.ReturnStmt{} + for _, result := range exit.Results { + resultExpr, err := fc.convertSsaValue(result) + if err != nil { + return err + } + exitStmt.Results = append(exitStmt.Results, resultExpr) + } + astBlock.Exit = exitStmt + case *ssa.Panic: + panicArgExpr, err := fc.convertSsaValue(exit.X) + if err != nil { + return err + } + astBlock.Exit = &ast.ExprStmt{X: ah.CallExprByName("panic", panicArgExpr)} + default: + return fmt.Errorf("exit instruction %v: %w", exit, ErrUnsupported) + } + + return nil +} + +func (fc *funcConverter) getAnonFuncName(idx int) string { + return fmt.Sprintf(fc.namePrefix+"anonFunc%d", idx) +} + +func (fc *funcConverter) convertAnonFuncs(anonFuncs []*ssa.Function) ([]ast.Stmt, error) { + var stmts []ast.Stmt + + for i, anonFunc := range anonFuncs { + anonLit, err := fc.convertSignatureToFuncLit(anonFunc.Signature) + if err != nil { + return nil, err + } + anonStmts, err := fc.convertToStmts(anonFunc) + if err != nil { + return nil, err + } + anonLit.Body = ah.BlockStmt(anonStmts...) + + if len(anonFunc.FreeVars) == 0 { + stmts = append(stmts, &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(fc.getAnonFuncName(i))}, + Tok: token.DEFINE, + Rhs: []ast.Expr{anonLit}, + }) + continue + } + + var closureVars []*types.Var + for _, freeVar := range anonFunc.FreeVars { + closureVars = append(closureVars, types.NewVar(token.NoPos, nil, freeVar.Name(), freeVar.Type())) + } + + makeClosureType := types.NewSignatureType(nil, nil, nil, types.NewTuple(closureVars...), types.NewTuple( + types.NewVar(token.NoPos, nil, "", anonFunc.Signature), + ), false) + + makeClosureLit, err := fc.convertSignatureToFuncLit(makeClosureType) + if err != nil { + return nil, err + } + makeClosureLit.Body = ah.BlockStmt(&ast.ReturnStmt{Results: []ast.Expr{anonLit}}) + + stmts = append(stmts, &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(fc.getAnonFuncName(i))}, + Tok: token.DEFINE, + Rhs: []ast.Expr{makeClosureLit}, + }) + } + return stmts, nil +} + +func (fc *funcConverter) convertToStmts(ssaFunc *ssa.Function) ([]ast.Stmt, error) { + stmts, err := fc.convertAnonFuncs(ssaFunc.AnonFuncs) + if err != nil { + return nil, err + } + + f := &AstFunc{ + Vars: make(map[string]types.Type), + Blocks: make([]*AstBlock, len(ssaFunc.Blocks)), + } + for i := range f.Blocks { + f.Blocks[i] = &AstBlock{Index: ssaFunc.Blocks[i].Index} + } + + for i, ssaBlock := range ssaFunc.Blocks { + if err := fc.convertBlock(f, ssaBlock, f.Blocks[i]); err != nil { + return nil, err + } + } + + groupedVar := make(map[types.Type][]string) + for varName, varType := range f.Vars { + exists := false + for groupedType, names := range groupedVar { + if types.Identical(varType, groupedType) { + groupedVar[groupedType] = append(names, varName) + exists = true + break + } + } + if !exists { + groupedVar[varType] = []string{varName} + } + } + var specs []ast.Spec + for varType, varNames := range groupedVar { + typeExpr, err := fc.tc.Convert(varType) + if err != nil { + return nil, err + } + spec := &ast.ValueSpec{ + Type: typeExpr, + } + + sort.Strings(varNames) + for _, name := range varNames { + spec.Names = append(spec.Names, ast.NewIdent(name)) + } + specs = append(specs, spec) + } + if len(specs) > 0 { + stmts = append(stmts, &ast.DeclStmt{Decl: &ast.GenDecl{ + Tok: token.VAR, + Specs: specs, + }}) + } + + for _, block := range f.Blocks { + blockStmts := &ast.BlockStmt{List: append(block.Body, block.Phi...)} + blockStmts.List = append(blockStmts.List, block.Exit) + if block.HasRefs { + stmts = append(stmts, &ast.LabeledStmt{Label: fc.getLabelName(block.Index), Stmt: blockStmts}) + } else { + stmts = append(stmts, blockStmts) + } + } + return stmts, nil +} + +func (fc *funcConverter) convert(ssaFunc *ssa.Function) (*ast.FuncDecl, error) { + if ssaFunc.Signature.TypeParams() != nil || ssaFunc.Signature.RecvTypeParams() != nil { + return nil, ErrUnsupported + } + + funcDecl, err := fc.convertSignatureToFuncDecl(ssaFunc.Name(), ssaFunc.Signature) + if err != nil { + return nil, err + } + funcStmts, err := fc.convertToStmts(ssaFunc) + if err != nil { + return nil, err + } + funcDecl.Body = ah.BlockStmt(funcStmts...) + return funcDecl, err +} diff --git a/internal/ssa2ast/func_test.go b/internal/ssa2ast/func_test.go new file mode 100644 index 0000000..4d972d0 --- /dev/null +++ b/internal/ssa2ast/func_test.go @@ -0,0 +1,398 @@ +package ssa2ast + +import ( + "go/ast" + "go/importer" + "go/printer" + "go/types" + "os" + "os/exec" + "path/filepath" + "testing" + + "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/ssa" + + "github.com/google/go-cmp/cmp" + "golang.org/x/tools/go/ssa/ssautil" +) + +const sigSrc = `package main + +import "unsafe" + +type genericStruct[T interface{}] struct{} +type plainStruct struct { + Dummy struct{} +} + +func (s *plainStruct) plainStructFunc() { + +} + +func (*plainStruct) plainStructAnonFunc() { + +} + +func (s *genericStruct[T]) genericStructFunc() { + +} +func (s *genericStruct[T]) genericStructAnonFunc() (test T) { + return +} + +func plainFuncSignature(a int, b string, c struct{}, d struct{ string }, e interface{ Dummy() string }, pointer unsafe.Pointer) (i int, er error) { + return +} + +func genericFuncSignature[T interface{ interface{} | ~int64 | bool }, X interface{ comparable }](a T, b X, c genericStruct[struct{ a T }], d genericStruct[T]) (res T) { + return +} +` + +func TestConvertSignature(t *testing.T) { + conv := newFuncConverter(DefaultConfig()) + + f, _, info, _ := mustParseAndTypeCheckFile(sigSrc) + for _, funcName := range []string{"plainStructFunc", "plainStructAnonFunc", "genericStructFunc", "plainFuncSignature", "genericFuncSignature"} { + funcDecl := findFunc(f, funcName) + funcDecl.Body = nil + + funcObj := info.Defs[funcDecl.Name].(*types.Func) + funcDeclConverted, err := conv.convertSignatureToFuncDecl(funcObj.Name(), funcObj.Type().(*types.Signature)) + if err != nil { + t.Fatal(err) + } + if structDiff := cmp.Diff(funcDecl, funcDeclConverted, astCmpOpt); structDiff != "" { + t.Fatalf("method decl not equals: %s", structDiff) + } + } +} + +const mainSrc = `package main + +import ( + "encoding/binary" + "fmt" + "io" + "sort" + "strconv" + "sync" + "time" + "unsafe" +) + +func main() { + methodOps() + slicesOps() + iterAndMapsOps() + chanOps() + flowOps() + typeOps() +} + +func makeSprintf(tag string) func(vals ...interface{}) { + i := 0 + return func(vals ...interface{}) { + fmt.Printf("%s(%d): %v\n", tag, i, vals) + i++ + } +} + +func return42() int { + return 42 +} + +type arrayOfInts []int + +type structOfArraysOfInts struct { + a arrayOfInts + b arrayOfInts +} + +func slicesOps() { + sprintf := makeSprintf("slicesOps") + + slice := [...]int{1, 2} + sprintf(slice[0:1:2]) + // *ssa.IndexAddr + sprintf(slice) + slice[0] += 1 + sprintf(slice) + + sprintf(slice[:1]) + sprintf(slice[slice[0]:]) + sprintf(slice[0:2]) + + sprintf((*[2]int)(slice[:])[return42()%2]) // *ssa.SliceToArrayPointer + + sprintf("test"[return42()%3]) // *ssa.Index + + structOfArrays := structOfArraysOfInts{a: slice[1:], b: slice[:1]} + sprintf(structOfArrays.a[:1]) + sprintf(structOfArrays.b[:1]) + + slice2 := make([]string, return42(), return42()*2) + slice2[return42()-1] = "test" + sprintf(slice2) + + return +} + +func iterAndMapsOps() { + sprintf := makeSprintf("iterAndMapsOps") + + // *ssa.MakeMap + *ssa.MapUpdate + mmap := map[string]time.Month{ + "April": time.April, + "December": time.December, + "January": time.January, + } + + var vals []string + for k := range mmap { + vals = append(vals, k) + } + for _, v := range mmap { + vals = append(vals, v.String()) + } + sort.Strings(vals) // Required. Order of map iteration not guaranteed + sprintf(vals) + + if v, ok := mmap["?"]; ok { + panic("unreachable: " + v.String()) + } + for idx, s := range "hello world" { + sprintf(idx, s) + } + + sprintf(mmap["April"].String()) + return +} + +type interfaceCalls interface { + Return1() string +} + +type structCalls struct { +} + +func (r structCalls) Return1() string { + return "Return1" +} + +func (r *structCalls) Return2() string { + return "Return2" +} + +func multiOutputRes() (int, string) { + return 42, "24" +} + +func returnInterfaceCalls() interfaceCalls { + return structCalls{} +} + +func methodOps() { + sprintf := makeSprintf("methodOps") + + defer func() { + sprintf("from defer") + }() + defer sprintf("from defer 2") + + var wg sync.WaitGroup + wg.Add(1) + go func() { + sprintf("from go") + wg.Done() + }() + wg.Wait() + + i, s := multiOutputRes() + sprintf(strconv.Itoa(i)) + + var strct structCalls + + strct.Return1() + strct.Return2() + + intrfs := returnInterfaceCalls() + intrfs.Return1() + + sprintf(strconv.Itoa(len(s))) + + strconv.Itoa(binary.Size(4)) + sprintf(binary.LittleEndian.AppendUint32(nil, 42)) + + if len(s) == 0 { + panic("unreachable") + } + + sprintf(*unsafe.StringData(s)) + + thunkMethod1 := structCalls.Return1 + sprintf(thunkMethod1(strct)) + + thunkMethod2 := (*structCalls).Return2 + sprintf(thunkMethod2(&strct)) + + closureVar := "c " + s + anonFnc := func(n func(structCalls) string) string { + return n(structCalls{}) + "anon" + closureVar + } + + sprintf(anonFnc(structCalls.Return1)) +} + +func chanOps() { + sprintf := makeSprintf("chanOps") + + a := make(chan string) + b := make(chan string) + c := make(chan string) + d := make(chan string) + + select { + case r1, ok := <-a: + sprintf(r1, ok) + case r2 := <-b: + sprintf(r2) + case <-c: + sprintf("r3") + case d <- "test": + sprintf("d triggered") + default: + sprintf("default") + } + + e := make(chan string, 1) + e <- "hi" + + sprintf(<-e) + + close(a) + val, ok := <-a + + sprintf(val, ok) + return +} + +func flowOps() { + sprintf := makeSprintf("flowOps") + i := 1 + if return42()%2 == 0 { + sprintf("a") + i++ + } else { + sprintf("b") + } + sprintf(i) + + switch return42() { + case 1: + sprintf("1") + case 2: + sprintf("2") + case 3: + sprintf("3") + case 42: + sprintf("42") + } +} + +type interfaceB interface { +} + +type testStruct struct { + A, B int +} + +func typeOps() { + sprintf := makeSprintf("typeOps") + + // *ssa.ChangeType + var interA interfaceCalls + sprintf(interA) + + // *ssa.ChangeInterface + var interB interfaceB = struct{}{} + var inter0 interface{} = interB + sprintf(inter0) + + // *ssa.Convert + var f float64 = 1.0 + sprintf(int(f)) + + casted, ok := inter0.(interfaceB) + sprintf(casted, ok) + + casted2 := inter0.(interfaceB) + sprintf(casted2) + + strc := testStruct{return42(), return42() + 2} + strc.B += strc.A + sprintf(strc) + + // Access to unexported structure + discard := io.Discard + if return42() == 0 { + sprintf(discard) // Trigger phi block + } + _, _ = discard.Write([]byte("test")) +}` + +func TestConvert(t *testing.T) { + runGoFile := func(f string) string { + cmd := exec.Command("go", "run", f) + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("compile failed: %v\n%s", err, string(out)) + } + return string(out) + } + + testFile := filepath.Join(t.TempDir(), "convert.go") + if err := os.WriteFile(testFile, []byte(mainSrc), 0o777); err != nil { + t.Fatal(err) + } + + originalOut := runGoFile(testFile) + file, fset, _, _ := mustParseAndTypeCheckFile(mainSrc) + ssaPkg, _, err := ssautil.BuildPackage(&types.Config{Importer: importer.Default()}, fset, types.NewPackage("test/main", ""), []*ast.File{file}, 0) + if err != nil { + t.Fatal(err) + } + + for fIdx, decl := range file.Decls { + funcDecl, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + + path, _ := astutil.PathEnclosingInterval(file, funcDecl.Pos(), funcDecl.Pos()) + ssaFunc := ssa.EnclosingFunction(ssaPkg, path) + + astFunc, err := Convert(ssaFunc, DefaultConfig()) + if err != nil { + t.Fatal(err) + } + file.Decls[fIdx] = astFunc + } + + convertedFile := filepath.Join(t.TempDir(), "main.go") + f, err := os.Create(convertedFile) + if err != nil { + t.Fatal(err) + } + if err := printer.Fprint(f, fset, file); err != nil { + t.Fatal(err) + } + _ = f.Close() + + convertedOut := runGoFile(convertedFile) + + if convertedOut != originalOut { + t.Fatalf("Output not equals:\n\n%s\n\n%s", originalOut, convertedOut) + } +} diff --git a/internal/ssa2ast/helpers_test.go b/internal/ssa2ast/helpers_test.go new file mode 100644 index 0000000..4d51007 --- /dev/null +++ b/internal/ssa2ast/helpers_test.go @@ -0,0 +1,72 @@ +package ssa2ast + +import ( + "go/ast" + "go/importer" + "go/parser" + "go/token" + "go/types" + + "github.com/google/go-cmp/cmp/cmpopts" +) + +var astCmpOpt = cmpopts.IgnoreTypes(token.NoPos, &ast.Object{}) + +func findStruct(file *ast.File, structName string) (name *ast.Ident, structType *ast.StructType) { + ast.Inspect(file, func(node ast.Node) bool { + if structType != nil { + return false + } + + typeSpec, ok := node.(*ast.TypeSpec) + if !ok || typeSpec.Name == nil || typeSpec.Name.Name != structName { + return true + } + typ, ok := typeSpec.Type.(*ast.StructType) + if !ok { + return true + } + structType = typ + name = typeSpec.Name + return true + }) + + if structType == nil { + panic(structName + " not found") + } + return +} + +func findFunc(file *ast.File, funcName string) *ast.FuncDecl { + for _, decl := range file.Decls { + fDecl, ok := decl.(*ast.FuncDecl) + if ok && fDecl.Name.Name == funcName { + return fDecl + } + } + panic(funcName + " not found") +} + +func mustParseAndTypeCheckFile(src string) (*ast.File, *token.FileSet, *types.Info, *types.Package) { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "a.go", src, 0) + if err != nil { + panic(err) + } + + config := types.Config{Importer: importer.Default()} + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Instances: make(map[*ast.Ident]types.Instance), + Implicits: make(map[ast.Node]types.Object), + Scopes: make(map[ast.Node]*types.Scope), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + pkg, err := config.Check("test/main", fset, []*ast.File{f}, info) + if err != nil { + panic(err) + } + return f, fset, info, pkg +} diff --git a/internal/ssa2ast/polyfill.go b/internal/ssa2ast/polyfill.go new file mode 100644 index 0000000..28791a4 --- /dev/null +++ b/internal/ssa2ast/polyfill.go @@ -0,0 +1,176 @@ +package ssa2ast + +import ( + "go/ast" + "go/token" + "go/types" +) + +func makeMapIteratorPolyfill(tc *typeConverter, mapType *types.Map) (ast.Expr, types.Type, error) { + keyTypeExpr, err := tc.Convert(mapType.Key()) + if err != nil { + return nil, nil, err + } + valueTypeExpr, err := tc.Convert(mapType.Elem()) + if err != nil { + return nil, nil, err + } + + nextType := types.NewSignatureType(nil, nil, nil, nil, types.NewTuple( + types.NewVar(token.NoPos, nil, "", types.Typ[types.Bool]), + types.NewVar(token.NoPos, nil, "", mapType.Key()), + types.NewVar(token.NoPos, nil, "", mapType.Elem()), + ), false) + + // Generated using https://github.com/lu4p/astextract from snippet: + /* + func(m map[]) func() (bool, , ) { + keys := make([], 0, len(m)) + for k := range m { + keys = append(keys, k) + } + i := 0 + return func() (ok bool, k , r ) { + if i < len(keys) { + k = keys[i] + ok, r = true, m[k] + i++ + } + return + } + } + */ + return &ast.FuncLit{ + Type: &ast.FuncType{ + Params: &ast.FieldList{List: []*ast.Field{{ + Names: []*ast.Ident{{Name: "m"}}, + Type: &ast.MapType{ + Key: keyTypeExpr, + Value: valueTypeExpr, + }, + }}}, + Results: &ast.FieldList{List: []*ast.Field{{ + Type: &ast.FuncType{ + Params: &ast.FieldList{}, + Results: &ast.FieldList{List: []*ast.Field{ + {Type: &ast.Ident{Name: "bool"}}, + {Type: keyTypeExpr}, + {Type: valueTypeExpr}, + }}, + }, + }}}, + }, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{&ast.Ident{Name: "keys"}}, + Tok: token.DEFINE, + Rhs: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.Ident{Name: "make"}, + Args: []ast.Expr{ + &ast.ArrayType{Elt: keyTypeExpr}, + &ast.BasicLit{Kind: token.INT, Value: "0"}, + &ast.CallExpr{ + Fun: &ast.Ident{Name: "len"}, + Args: []ast.Expr{&ast.Ident{Name: "m"}}, + }, + }, + }, + }, + }, + &ast.RangeStmt{ + Key: &ast.Ident{Name: "k"}, + Tok: token.DEFINE, + X: &ast.Ident{Name: "m"}, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{&ast.Ident{Name: "keys"}}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.Ident{Name: "append"}, + Args: []ast.Expr{ + &ast.Ident{Name: "keys"}, + &ast.Ident{Name: "k"}, + }, + }, + }, + }, + }, + }, + }, + &ast.AssignStmt{ + Lhs: []ast.Expr{&ast.Ident{Name: "i"}}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.BasicLit{Kind: token.INT, Value: "0"}}, + }, + &ast.ReturnStmt{Results: []ast.Expr{ + &ast.FuncLit{ + Type: &ast.FuncType{ + Params: &ast.FieldList{}, + Results: &ast.FieldList{List: []*ast.Field{ + { + Names: []*ast.Ident{{Name: "ok"}}, + Type: &ast.Ident{Name: "bool"}, + }, + { + Names: []*ast.Ident{{Name: "k"}}, + Type: keyTypeExpr, + }, + { + Names: []*ast.Ident{{Name: "r"}}, + Type: valueTypeExpr, + }, + }}, + }, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.IfStmt{ + Cond: &ast.BinaryExpr{ + X: &ast.Ident{Name: "i"}, + Op: token.LSS, + Y: &ast.CallExpr{ + Fun: &ast.Ident{Name: "len"}, + Args: []ast.Expr{&ast.Ident{Name: "keys"}}, + }, + }, + Body: &ast.BlockStmt{List: []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{&ast.Ident{Name: "k"}}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{&ast.IndexExpr{ + X: &ast.Ident{Name: "keys"}, + Index: &ast.Ident{Name: "i"}, + }}, + }, + &ast.AssignStmt{ + Lhs: []ast.Expr{ + &ast.Ident{Name: "ok"}, + &ast.Ident{Name: "r"}, + }, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ + &ast.Ident{Name: "true"}, + &ast.IndexExpr{ + X: &ast.Ident{Name: "m"}, + Index: &ast.Ident{Name: "k"}, + }, + }, + }, + &ast.IncDecStmt{ + X: &ast.Ident{Name: "i"}, + Tok: token.INC, + }, + }}, + }, + &ast.ReturnStmt{}, + }, + }, + }, + }}, + }, + }, + }, nextType, nil +} diff --git a/internal/ssa2ast/type.go b/internal/ssa2ast/type.go new file mode 100644 index 0000000..8284571 --- /dev/null +++ b/internal/ssa2ast/type.go @@ -0,0 +1,249 @@ +package ssa2ast + +import ( + "fmt" + "go/ast" + "go/token" + "go/types" + "reflect" + "strconv" +) + +type typeConverter struct { + resolver ImportNameResolver +} + +func (tc *typeConverter) Convert(t types.Type) (ast.Expr, error) { + switch typ := t.(type) { + case *types.Array: + eltExpr, err := tc.Convert(typ.Elem()) + if err != nil { + return nil, err + } + return &ast.ArrayType{ + Len: &ast.BasicLit{ + Kind: token.INT, + Value: strconv.FormatInt(typ.Len(), 10), + }, + Elt: eltExpr, + }, nil + case *types.Basic: + if typ.Kind() == types.UnsafePointer { + unsafePkgIdent := tc.resolver(types.Unsafe) + if unsafePkgIdent == nil { + return nil, fmt.Errorf("cannot resolve unsafe package") + } + return &ast.SelectorExpr{X: unsafePkgIdent, Sel: ast.NewIdent("Pointer")}, nil + } + return ast.NewIdent(typ.Name()), nil + case *types.Chan: + chanValueExpr, err := tc.Convert(typ.Elem()) + if err != nil { + return nil, err + } + chanExpr := &ast.ChanType{Value: chanValueExpr} + switch typ.Dir() { + case types.SendRecv: + chanExpr.Dir = ast.SEND | ast.RECV + case types.RecvOnly: + chanExpr.Dir = ast.RECV + case types.SendOnly: + chanExpr.Dir = ast.SEND + } + return chanExpr, nil + case *types.Interface: + methods := &ast.FieldList{} + for i := 0; i < typ.NumEmbeddeds(); i++ { + embeddedType := typ.EmbeddedType(i) + embeddedExpr, err := tc.Convert(embeddedType) + if err != nil { + return nil, err + } + methods.List = append(methods.List, &ast.Field{Type: embeddedExpr}) + } + for i := 0; i < typ.NumExplicitMethods(); i++ { + method := typ.ExplicitMethod(i) + methodSig, err := tc.Convert(method.Type()) + if err != nil { + return nil, err + } + methods.List = append(methods.List, &ast.Field{ + Names: []*ast.Ident{ast.NewIdent(method.Name())}, + Type: methodSig, + }) + } + return &ast.InterfaceType{Methods: methods}, nil + case *types.Map: + keyExpr, err := tc.Convert(typ.Key()) + if err != nil { + return nil, err + } + valueExpr, err := tc.Convert(typ.Elem()) + if err != nil { + return nil, err + } + return &ast.MapType{Key: keyExpr, Value: valueExpr}, nil + case *types.Named: + obj := typ.Obj() + + // TODO: rewrite struct inlining without reflection hack + if parent := obj.Parent(); parent != nil { + isFuncScope := reflect.ValueOf(parent).Elem().FieldByName("isFunc") + if isFuncScope.Bool() { + return tc.Convert(obj.Type().Underlying()) + } + } + + var namedExpr ast.Expr + if pkgIdent := tc.resolver(obj.Pkg()); pkgIdent != nil { + // reference to unexported named emulated through new interface with explicit declarated methods + if !token.IsExported(obj.Name()) { + var methods []*types.Func + for i := 0; i < typ.NumMethods(); i++ { + method := typ.Method(i) + if token.IsExported(method.Name()) { + methods = append(methods, method) + } + } + + fakeInterface := types.NewInterfaceType(methods, nil) + return tc.Convert(fakeInterface) + } + namedExpr = &ast.SelectorExpr{X: pkgIdent, Sel: ast.NewIdent(obj.Name())} + } else { + namedExpr = ast.NewIdent(obj.Name()) + } + + typeParams := typ.TypeArgs() + if typeParams == nil || typeParams.Len() == 0 { + return namedExpr, nil + } + if typeParams.Len() == 1 { + typeParamExpr, err := tc.Convert(typeParams.At(0)) + if err != nil { + return nil, err + } + return &ast.IndexExpr{X: namedExpr, Index: typeParamExpr}, nil + } + genericExpr := &ast.IndexListExpr{X: namedExpr} + for i := 0; i < typeParams.Len(); i++ { + typeArgs := typeParams.At(i) + typeParamExpr, err := tc.Convert(typeArgs) + if err != nil { + return nil, err + } + genericExpr.Indices = append(genericExpr.Indices, typeParamExpr) + } + return genericExpr, nil + case *types.Pointer: + expr, err := tc.Convert(typ.Elem()) + if err != nil { + return nil, err + } + return &ast.StarExpr{X: expr}, nil + case *types.Signature: + funcSigExpr := &ast.FuncType{Params: &ast.FieldList{}} + if sigParams := typ.Params(); sigParams != nil { + for i := 0; i < sigParams.Len(); i++ { + param := sigParams.At(i) + + var paramType ast.Expr + if typ.Variadic() && i == sigParams.Len()-1 { + slice := param.Type().(*types.Slice) + + eltExpr, err := tc.Convert(slice.Elem()) + if err != nil { + return nil, err + } + paramType = &ast.Ellipsis{Elt: eltExpr} + } else { + paramExpr, err := tc.Convert(param.Type()) + if err != nil { + return nil, err + } + paramType = paramExpr + } + f := &ast.Field{Type: paramType} + if name := param.Name(); name != "" { + f.Names = []*ast.Ident{ast.NewIdent(name)} + } + funcSigExpr.Params.List = append(funcSigExpr.Params.List, f) + } + } + if sigResults := typ.Results(); sigResults != nil { + funcSigExpr.Results = &ast.FieldList{} + for i := 0; i < sigResults.Len(); i++ { + result := sigResults.At(i) + resultExpr, err := tc.Convert(result.Type()) + if err != nil { + return nil, err + } + + f := &ast.Field{Type: resultExpr} + if name := result.Name(); name != "" { + f.Names = []*ast.Ident{ast.NewIdent(name)} + } + funcSigExpr.Results.List = append(funcSigExpr.Results.List, f) + } + } + if typeParams := typ.TypeParams(); typeParams != nil { + funcSigExpr.TypeParams = &ast.FieldList{} + for i := 0; i < typeParams.Len(); i++ { + typeParam := typeParams.At(i) + resultExpr, err := tc.Convert(typeParam.Constraint().Underlying()) + if err != nil { + return nil, err + } + f := &ast.Field{Type: resultExpr, Names: []*ast.Ident{ast.NewIdent(typeParam.Obj().Name())}} + funcSigExpr.TypeParams.List = append(funcSigExpr.TypeParams.List, f) + } + } + return funcSigExpr, nil + case *types.Slice: + eltExpr, err := tc.Convert(typ.Elem()) + if err != nil { + return nil, err + } + return &ast.ArrayType{Elt: eltExpr}, nil + case *types.Struct: + fieldList := &ast.FieldList{} + for i := 0; i < typ.NumFields(); i++ { + f := typ.Field(i) + fieldExpr, err := tc.Convert(f.Type()) + if err != nil { + return nil, err + } + field := &ast.Field{Type: fieldExpr} + if !f.Anonymous() { + field.Names = []*ast.Ident{ast.NewIdent(f.Name())} + } + if tag := typ.Tag(i); len(tag) > 0 { + field.Tag = &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("%q", tag)} + } + fieldList.List = append(fieldList.List, field) + } + return &ast.StructType{Fields: fieldList}, nil + case *types.TypeParam: + return ast.NewIdent(typ.Obj().Name()), nil + case *types.Union: + var unionExpr ast.Expr + for i := 0; i < typ.Len(); i++ { + term := typ.Term(i) + expr, err := tc.Convert(term.Type()) + if err != nil { + return nil, err + } + if term.Tilde() { + expr = &ast.UnaryExpr{Op: token.TILDE, X: expr} + } + if unionExpr == nil { + unionExpr = expr + } else { + unionExpr = &ast.BinaryExpr{X: unionExpr, Op: token.OR, Y: expr} + } + } + return unionExpr, nil + default: + return nil, fmt.Errorf("type %v: %w", typ, ErrUnsupported) + } +} diff --git a/internal/ssa2ast/type_test.go b/internal/ssa2ast/type_test.go new file mode 100644 index 0000000..403471b --- /dev/null +++ b/internal/ssa2ast/type_test.go @@ -0,0 +1,108 @@ +package ssa2ast + +import ( + "go/ast" + "testing" + + "github.com/google/go-cmp/cmp" +) + +const typesSrc = `package main + +import ( + "io" + "time" +) + +type localNamed bool + +type embedStruct struct { + int +} + +type genericStruct[K comparable, V int64 | float64] struct { + int +} + +type exampleStruct struct { + embedStruct + + // *types.Array + array [3]int + array2 [0]int + + // *types.Basic + bool // anonymous + string string "test:\"tag\"" + int int + int8 int8 + int16 int16 + int32 int32 + int64 int64 + uint uint + uint8 uint8 + uint16 uint16 + uint32 uint32 + uint64 uint64 + uintptr uintptr + byte byte + rune rune + float32 float32 + float64 float64 + complex64 complex64 + complex128 complex128 + + // *types.Chan + chanSendRecv chan struct{} + chanRecv <-chan struct{} + chanSend chan<- struct{} + + // *types.Interface + interface1 interface{} + interface2 interface{ io.Reader } + interface3 interface{ Dummy(int) bool } + interface4 interface { + io.Reader + io.ByteReader + Dummy(int) bool + } + + // *types.Map + strMap map[string]string + + // *types.Named + localNamed localNamed + importedNamed time.Month + + // *types.Pointer + pointer1 *string + pointer2 **string + + // *types.Signature + func1 func(int, int) int + func2 func(a int, b int, varargs ...struct{ string }) (res int) + + // *types.Slice + slice1 []int + slice2 [][]int + + // generics + generic genericStruct[genericStruct[genericStruct[bool, int64], int64], int64] +} +` + +func TestTypeToExpr(t *testing.T) { + f, _, info, _ := mustParseAndTypeCheckFile(typesSrc) + name, structAst := findStruct(f, "exampleStruct") + obj := info.Defs[name] + fc := &typeConverter{resolver: defaultImportNameResolver} + convAst, err := fc.Convert(obj.Type().Underlying()) + if err != nil { + t.Fatal(err) + } + + structConvAst := convAst.(*ast.StructType) + if structDiff := cmp.Diff(structAst, structConvAst, astCmpOpt); structDiff != "" { + t.Fatalf("struct not equals: %s", structDiff) + } +} diff --git a/main.go b/main.go index 9245ba9..46c870e 100644 --- a/main.go +++ b/main.go @@ -42,6 +42,7 @@ import ( "golang.org/x/mod/semver" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/ssa" + "mvdan.cc/garble/internal/ctrlflow" "mvdan.cc/garble/internal/linker" "mvdan.cc/garble/internal/literals" @@ -55,6 +56,8 @@ var ( flagDebug bool flagDebugDir string flagSeed seedFlag + // TODO(pagran): in the future, when control flow obfuscation will be stable migrate to flag + flagControlFlow = os.Getenv("GARBLE_EXPERIMENTAL_CONTROLFLOW") == "1" ) func init() { @@ -150,6 +153,8 @@ var ( parentWorkDir = os.Getenv("GARBLE_PARENT_WORK") ) +const actionGraphFileName = "action-graph.json" + type importerWithMap struct { importMap map[string]string importFrom func(path, dir string, mode types.ImportMode) (*types.Package, error) @@ -617,6 +622,9 @@ This command wraps "go %s". Below is its help: toolexecFlag.WriteString(" toolexec") goArgs = append(goArgs, toolexecFlag.String()) + if flagControlFlow { + goArgs = append(goArgs, "-debug-actiongraph", filepath.Join(sharedTempDir, actionGraphFileName)) + } if flagDebugDir != "" { // In case the user deletes the debug directory, // and a previous build is cached, @@ -936,6 +944,14 @@ func (tf *transformer) transformCompile(args []string) ([]string, error) { return nil, err } + // Literal and control flow obfuscation uses math/rand, so seed it deterministically. + randSeed := tf.curPkg.GarbleActionID[:] + if flagSeed.present() { + randSeed = flagSeed.bytes + } + // log.Printf("seeding math/rand with %x\n", randSeed) + tf.obfRand = mathrand.New(mathrand.NewSource(int64(binary.BigEndian.Uint64(randSeed)))) + // Even if loadPkgCache below finds a direct cache hit, // other parts of garble still need type information to obfuscate. // We could potentially avoid this by saving the type info we need in the cache, @@ -945,7 +961,39 @@ func (tf *transformer) transformCompile(args []string) ([]string, error) { return nil, err } - if tf.curPkgCache, err = loadPkgCache(tf.curPkg, tf.pkg, files, tf.info); err != nil { + var ( + ssaPkg *ssa.Package + requiredPkgs []string + ) + if flagControlFlow { + ssaPkg = ssaBuildPkg(tf.pkg, files, tf.info) + + newFileName, newFile, affectedFiles, err := ctrlflow.Obfuscate(fset, ssaPkg, files, tf.obfRand) + if err != nil { + return nil, err + } + + if newFile != nil { + files = append(files, newFile) + paths = append(paths, newFileName) + for _, file := range affectedFiles { + tf.useAllImports(file) + } + if tf.pkg, tf.info, err = typecheck(tf.curPkg.ImportPath, files, tf.origImporter); err != nil { + return nil, err + } + + for _, imp := range newFile.Imports { + path, err := strconv.Unquote(imp.Path.Value) + if err != nil { + panic(err) // should never happen + } + requiredPkgs = append(requiredPkgs, path) + } + } + } + + if tf.curPkgCache, err = loadPkgCache(tf.curPkg, tf.pkg, files, tf.info, ssaPkg); err != nil { return nil, err } @@ -958,19 +1006,11 @@ func (tf *transformer) transformCompile(args []string) ([]string, error) { } flags = alterTrimpath(flags) - newImportCfg, err := tf.processImportCfg(flags) + newImportCfg, err := tf.processImportCfg(flags, requiredPkgs) if err != nil { return nil, err } - // Literal obfuscation uses math/rand, so seed it deterministically. - randSeed := tf.curPkg.GarbleActionID[:] - if flagSeed.present() { - randSeed = flagSeed.bytes - } - // log.Printf("seeding math/rand with %x\n", randSeed) - tf.obfRand = mathrand.New(mathrand.NewSource(int64(binary.BigEndian.Uint64(randSeed)))) - // If this is a package to obfuscate, swap the -p flag with the new package path. // We don't if it's the main package, as that just uses "-p main". // We only set newPkgPath if we're obfuscating the import path, @@ -1170,7 +1210,7 @@ func (tf *transformer) transformLinkname(localName, newName string) (string, str // processImportCfg parses the importcfg file passed to a compile or link step. // It also builds a new importcfg file to account for obfuscated import paths. -func (tf *transformer) processImportCfg(flags []string) (newImportCfg string, _ error) { +func (tf *transformer) processImportCfg(flags []string, requiredPkgs []string) (newImportCfg string, _ error) { importCfg := flagValue(flags, "-importcfg") if importCfg == "" { return "", fmt.Errorf("could not find -importcfg argument") @@ -1182,6 +1222,15 @@ func (tf *transformer) processImportCfg(flags []string) (newImportCfg string, _ var packagefiles, importmaps [][2]string + // using for track required but not imported packages + var newIndirectImports map[string]bool + if requiredPkgs != nil { + newIndirectImports = make(map[string]bool) + for _, pkg := range requiredPkgs { + newIndirectImports[pkg] = true + } + } + for _, line := range strings.Split(string(data), "\n") { if line == "" || strings.HasPrefix(line, "#") { continue @@ -1203,6 +1252,7 @@ func (tf *transformer) processImportCfg(flags []string) (newImportCfg string, _ continue } packagefiles = append(packagefiles, [2]string{importPath, objectPath}) + delete(newIndirectImports, importPath) } } @@ -1231,6 +1281,45 @@ func (tf *transformer) processImportCfg(flags []string) (newImportCfg string, _ } fmt.Fprintf(newCfg, "importmap %s=%s\n", beforePath, afterPath) } + + if len(newIndirectImports) > 0 { + f, err := os.Open(filepath.Join(sharedTempDir, actionGraphFileName)) + if err != nil { + return "", fmt.Errorf("cannot open action graph file: %v", err) + } + defer f.Close() + + var actions []struct { + Mode string + Package string + Objdir string + } + if err := json.NewDecoder(f).Decode(&actions); err != nil { + return "", fmt.Errorf("cannot parse action graph file: %v", err) + } + + // theoretically action graph can be long, to optimise it process it in one pass + // with an early exit when all the required imports are found + for _, action := range actions { + if action.Mode != "build" { + continue + } + if ok := newIndirectImports[action.Package]; !ok { + continue + } + + packagefiles = append(packagefiles, [2]string{action.Package, filepath.Join(action.Objdir, "_pkg_.a")}) // file name hardcoded in compiler + delete(newIndirectImports, action.Package) + if len(newIndirectImports) == 0 { + break + } + } + + if len(newIndirectImports) > 0 { + return "", fmt.Errorf("cannot resolve required packages from action graph file: %v", requiredPkgs) + } + } + for _, pair := range packagefiles { impPath, pkgfile := pair[0], pair[1] lpkg, err := listPackage(tf.curPkg, impPath) @@ -1306,6 +1395,27 @@ func (c *pkgCache) CopyFrom(c2 pkgCache) { maps.Copy(c.EmbeddedAliasFields, c2.EmbeddedAliasFields) } +func ssaBuildPkg(pkg *types.Package, files []*ast.File, info *types.Info) *ssa.Package { + // Create SSA packages for all imports. Order is not significant. + ssaProg := ssa.NewProgram(fset, 0) + created := make(map[*types.Package]bool) + var createAll func(pkgs []*types.Package) + createAll = func(pkgs []*types.Package) { + for _, p := range pkgs { + if !created[p] { + created[p] = true + ssaProg.CreatePackage(p, nil, nil, true) + createAll(p.Imports()) + } + } + } + createAll(pkg.Imports()) + + ssaPkg := ssaProg.CreatePackage(pkg, files, info, false) + ssaPkg.Build() + return ssaPkg +} + func openCache() (*cache.Cache, error) { // Use a subdirectory for the hashed build cache, to clarify what it is, // and to allow us to have other directories or files later on without mixing. @@ -1316,7 +1426,7 @@ func openCache() (*cache.Cache, error) { return cache.Open(dir) } -func loadPkgCache(lpkg *listedPackage, pkg *types.Package, files []*ast.File, info *types.Info) (pkgCache, error) { +func loadPkgCache(lpkg *listedPackage, pkg *types.Package, files []*ast.File, info *types.Info, ssaPkg *ssa.Package) (pkgCache, error) { fsCache, err := openCache() if err != nil { return pkgCache{}, err @@ -1335,10 +1445,10 @@ func loadPkgCache(lpkg *listedPackage, pkg *types.Package, files []*ast.File, in } return loaded, nil } - return computePkgCache(fsCache, lpkg, pkg, files, info) + return computePkgCache(fsCache, lpkg, pkg, files, info, ssaPkg) } -func computePkgCache(fsCache *cache.Cache, lpkg *listedPackage, pkg *types.Package, files []*ast.File, info *types.Info) (pkgCache, error) { +func computePkgCache(fsCache *cache.Cache, lpkg *listedPackage, pkg *types.Package, files []*ast.File, info *types.Info, ssaPkg *ssa.Package) (pkgCache, error) { // Not yet in the cache. Load the cache entries for all direct dependencies, // build our cache entry, and write it to disk. // Note that practically all errors from Cache.GetFile are a cache miss; @@ -1395,7 +1505,7 @@ func computePkgCache(fsCache *cache.Cache, lpkg *listedPackage, pkg *types.Packa if err != nil { return err } - computedImp, err := computePkgCache(fsCache, lpkg, pkg, files, info) + computedImp, err := computePkgCache(fsCache, lpkg, pkg, files, info, nil) if err != nil { return err } @@ -1428,29 +1538,14 @@ func computePkgCache(fsCache *cache.Cache, lpkg *listedPackage, pkg *types.Packa } // Fill the reflect info from SSA, which builds on top of the syntax tree and type info. - // Create SSA packages for all imports. Order is not significant. - ssaProg := ssa.NewProgram(fset, 0) - created := make(map[*types.Package]bool) - var createAll func(pkgs []*types.Package) - createAll = func(pkgs []*types.Package) { - for _, p := range pkgs { - if !created[p] { - created[p] = true - ssaProg.CreatePackage(p, nil, nil, true) - createAll(p.Imports()) - } - } - } - createAll(pkg.Imports()) - - ssaPkg := ssaProg.CreatePackage(pkg, files, info, false) - ssaPkg.Build() - inspector := reflectInspector{ pkg: pkg, checkedAPIs: make(map[string]bool), result: computed, // append the results } + if ssaPkg == nil { + ssaPkg = ssaBuildPkg(pkg, files, info) + } inspector.recordReflection(ssaPkg) // Unlikely that we could stream the gob encode, as cache.Put wants an io.ReadSeeker. @@ -1551,6 +1646,10 @@ type transformer struct { // of packages, without any obfuscation. This is helpful to make // decisions on how to obfuscate our input code. origImporter importerWithMap + + // usedAllImportsFiles is used to prevent multiple calls of tf.useAllImports function on one file + // in case of simultaneously applied control flow and literals obfuscation + usedAllImportsFiles map[*ast.File]bool } func typecheck(pkgPath string, files []*ast.File, origImporter importerWithMap) (*types.Package, *types.Info, error) { @@ -1656,6 +1755,13 @@ func isSafeForInstanceType(typ types.Type) bool { } func (tf *transformer) useAllImports(file *ast.File) { + if tf.usedAllImportsFiles == nil { + tf.usedAllImportsFiles = make(map[*ast.File]bool) + } else if ok := tf.usedAllImportsFiles[file]; ok { + return + } + tf.usedAllImportsFiles[file] = true + for _, imp := range file.Imports { if imp.Name != nil && imp.Name.Name == "_" { continue @@ -2004,7 +2110,7 @@ func (tf *transformer) transformLink(args []string) ([]string, error) { // lack any extension. flags, args := splitFlagsFromArgs(args) - newImportCfg, err := tf.processImportCfg(flags) + newImportCfg, err := tf.processImportCfg(flags, nil) if err != nil { return nil, err } diff --git a/testdata/script/ctrlflow.txtar b/testdata/script/ctrlflow.txtar new file mode 100644 index 0000000..8ff3273 --- /dev/null +++ b/testdata/script/ctrlflow.txtar @@ -0,0 +1,68 @@ +env GARBLE_EXPERIMENTAL_CONTROLFLOW=1 +exec garble -literals -debugdir=debug -seed=0002deadbeef build -o=main$exe +exec ./main +cmp stderr main.stderr + +# simple check to ensure that control flow will work. Must be a minimum of 10 goto's +grep 'goto _s2a_l10' $WORK/debug/test/main/GARBLE_controlflow.go + +# obfuscated function must be removed from original file +! grep 'main\(\)' $WORK/debug/test/main/garble_main.go +# original file must contains empty function +grep '\_\(\)' $WORK/debug/test/main/garble_main.go + +# switch must be simplified +! grep switch $WORK/debug/test/main/GARBLE_controlflow.go + +# obfuscated file must contains interface for unexported interface emulation +grep 'GoString\(\) string' $WORK/debug/test/main/GARBLE_controlflow.go +grep 'String\(\) string' $WORK/debug/test/main/GARBLE_controlflow.go + +# control flow obfuscation should work correctly with literals obfuscation +! binsubstr main$exe 'correct name' + +-- go.mod -- +module test/main + +go 1.20 +-- garble_main.go -- +package main + +import ( + "encoding/binary" + "encoding/hex" + "hash/crc32" +) + +//garble:controlflow flatten_passes=1 junk_jumps=10 block_splits=10 +func main() { + // Reference to the unexported interface triggers creation of a new interface + // with a list of all functions of the private interface + endian := binary.LittleEndian + println(endian.String()) + println(endian.GoString()) + println(endian.Uint16([]byte{0, 1})) + + // Switch statement should be simplified to if statements + switch endian.String() { + case "LittleEndian": + println("correct name") + default: + panic("unreachable") + } + + // Indirect import "hash" package + hash := crc32.New(crc32.IEEETable) + hash.Write([]byte("1")) + hash.Write([]byte("2")) + hash.Write([]byte("3")) + + println(hex.EncodeToString(hash.Sum(nil))) +} + +-- main.stderr -- +LittleEndian +binary.LittleEndian +256 +correct name +884863d2