add control flow obfuscation
Implemented control flow flattening with additional features such as block splitting and junk jumpspull/774/head
parent
d89a55687c
commit
0e2e483472
@ -0,0 +1,186 @@
|
||||
package ctrlflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"log"
|
||||
mathrand "math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/tools/go/ast/astutil"
|
||||
"golang.org/x/tools/go/ssa"
|
||||
ah "mvdan.cc/garble/internal/asthelper"
|
||||
"mvdan.cc/garble/internal/ssa2ast"
|
||||
)
|
||||
|
||||
const (
|
||||
mergedFileName = "GARBLE_controlflow.go"
|
||||
directiveName = "//garble:controlflow"
|
||||
importPrefix = "___garble_import"
|
||||
|
||||
defaultBlockSplits = 0
|
||||
defaultJunkJumps = 0
|
||||
defaultFlattenPasses = 1
|
||||
)
|
||||
|
||||
type directiveParamMap map[string]string
|
||||
|
||||
func (m directiveParamMap) GetInt(name string, def int) int {
|
||||
rawVal, ok := m[name]
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
|
||||
val, err := strconv.Atoi(rawVal)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("invalid flag %s format: %v", name, err))
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// parseDirective parses a directive string and returns a map of directive parameters.
|
||||
// Each parameter should be in the form "key=value" or "key"
|
||||
func parseDirective(directive string) (directiveParamMap, bool) {
|
||||
fieldsStr, ok := strings.CutPrefix(directive, directiveName)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
fields := strings.Fields(fieldsStr)
|
||||
if len(fields) == 0 {
|
||||
return nil, true
|
||||
}
|
||||
m := make(map[string]string)
|
||||
for _, v := range fields {
|
||||
key, value, ok := strings.Cut(v, "=")
|
||||
if ok {
|
||||
m[key] = value
|
||||
} else {
|
||||
m[key] = ""
|
||||
}
|
||||
}
|
||||
return m, true
|
||||
}
|
||||
|
||||
// Obfuscate obfuscates control flow of all functions with directive using control flattening.
|
||||
// All obfuscated functions are removed from the original file and moved to the new one.
|
||||
// Obfuscation can be customized by passing parameters from the directive, example:
|
||||
//
|
||||
// //garble:controlflow flatten_passes=1 junk_jumps=0 block_splits=0
|
||||
// func someMethod() {}
|
||||
//
|
||||
// flatten_passes - controls number of passes of control flow flattening. Have exponential complexity and more than 3 passes are not recommended in most cases.
|
||||
// junk_jumps - controls how many junk jumps are added. It does not affect final binary by itself, but together with flattening linearly increases complexity.
|
||||
// block_splits - controls number of times largest block must be splitted. Together with flattening improves obfuscation of long blocks without branches.
|
||||
func Obfuscate(fset *token.FileSet, ssaPkg *ssa.Package, files []*ast.File, obfRand *mathrand.Rand) (newFileName string, newFile *ast.File, affectedFiles []*ast.File, err error) {
|
||||
var ssaFuncs []*ssa.Function
|
||||
var ssaParams []directiveParamMap
|
||||
|
||||
for _, file := range files {
|
||||
affected := false
|
||||
for _, decl := range file.Decls {
|
||||
funcDecl, ok := decl.(*ast.FuncDecl)
|
||||
if !ok || funcDecl.Doc == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, comment := range funcDecl.Doc.List {
|
||||
params, hasDirective := parseDirective(comment.Text)
|
||||
if !hasDirective {
|
||||
continue
|
||||
}
|
||||
|
||||
path, _ := astutil.PathEnclosingInterval(file, funcDecl.Pos(), funcDecl.Pos())
|
||||
ssaFunc := ssa.EnclosingFunction(ssaPkg, path)
|
||||
if ssaFunc == nil {
|
||||
panic("function exists in ast but not found in ssa")
|
||||
}
|
||||
|
||||
ssaFuncs = append(ssaFuncs, ssaFunc)
|
||||
ssaParams = append(ssaParams, params)
|
||||
|
||||
log.Printf("detected function for controlflow %s (params: %v)", funcDecl.Name.Name, params)
|
||||
|
||||
// Remove inplace function from original file
|
||||
// TODO: implement a complete function removal
|
||||
funcDecl.Name = ast.NewIdent("_")
|
||||
funcDecl.Body = ah.BlockStmt()
|
||||
funcDecl.Recv = nil
|
||||
funcDecl.Type = &ast.FuncType{Params: &ast.FieldList{}}
|
||||
affected = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if affected {
|
||||
affectedFiles = append(affectedFiles, file)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ssaFuncs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
newFile = &ast.File{
|
||||
Package: token.Pos(fset.Base()),
|
||||
Name: ast.NewIdent(files[0].Name.Name),
|
||||
}
|
||||
fset.AddFile(mergedFileName, int(newFile.Package), 1) // required for correct printer output
|
||||
|
||||
funcConfig := ssa2ast.DefaultConfig()
|
||||
imports := make(map[string]string) // TODO: indirect imports turned into direct currently brake build process
|
||||
funcConfig.ImportNameResolver = func(pkg *types.Package) *ast.Ident {
|
||||
if pkg == nil || pkg.Path() == ssaPkg.Pkg.Path() {
|
||||
return nil
|
||||
}
|
||||
|
||||
name, ok := imports[pkg.Path()]
|
||||
if !ok {
|
||||
name = importPrefix + strconv.Itoa(len(imports))
|
||||
imports[pkg.Path()] = name
|
||||
astutil.AddNamedImport(fset, newFile, name, pkg.Path())
|
||||
}
|
||||
return ast.NewIdent(name)
|
||||
}
|
||||
|
||||
for idx, ssaFunc := range ssaFuncs {
|
||||
params := ssaParams[idx]
|
||||
|
||||
split := params.GetInt("block_splits", defaultBlockSplits)
|
||||
junkCount := params.GetInt("junk_jumps", defaultJunkJumps)
|
||||
passes := params.GetInt("flatten_passes", defaultFlattenPasses)
|
||||
|
||||
applyObfuscation := func(ssaFunc *ssa.Function) {
|
||||
for i := 0; i < split; i++ {
|
||||
if !applySplitting(ssaFunc, obfRand) {
|
||||
break // no more candidates for splitting
|
||||
}
|
||||
}
|
||||
if junkCount > 0 {
|
||||
addJunkBlocks(ssaFunc, junkCount, obfRand)
|
||||
}
|
||||
for i := 0; i < passes; i++ {
|
||||
applyFlattening(ssaFunc, obfRand)
|
||||
}
|
||||
fixBlockIndexes(ssaFunc)
|
||||
}
|
||||
|
||||
applyObfuscation(ssaFunc)
|
||||
for _, anonFunc := range ssaFunc.AnonFuncs {
|
||||
applyObfuscation(anonFunc)
|
||||
}
|
||||
|
||||
astFunc, err := ssa2ast.Convert(ssaFunc, funcConfig)
|
||||
if err != nil {
|
||||
return "", nil, nil, err
|
||||
}
|
||||
newFile.Decls = append(newFile.Decls, astFunc)
|
||||
}
|
||||
|
||||
newFileName = mergedFileName
|
||||
return
|
||||
}
|
@ -0,0 +1,43 @@
|
||||
package ctrlflow
|
||||
|
||||
import (
|
||||
"go/constant"
|
||||
"go/types"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/tools/go/ssa"
|
||||
)
|
||||
|
||||
// setUnexportedField is used to modify unexported fields of ssa api structures.
|
||||
// TODO: find an alternative way to access private fields or raise a feature request upstream
|
||||
func setUnexportedField(objRaw interface{}, name string, valRaw interface{}) {
|
||||
obj := reflect.ValueOf(objRaw)
|
||||
for obj.Kind() == reflect.Pointer || obj.Kind() == reflect.Interface {
|
||||
obj = obj.Elem()
|
||||
}
|
||||
|
||||
field := obj.FieldByName(name)
|
||||
if !field.IsValid() {
|
||||
panic("invalid field: " + name)
|
||||
}
|
||||
|
||||
fakeStruct := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr()))
|
||||
fakeStruct.Elem().Set(reflect.ValueOf(valRaw))
|
||||
}
|
||||
|
||||
func setBlockParent(block *ssa.BasicBlock, ssaFunc *ssa.Function) {
|
||||
setUnexportedField(block, "parent", ssaFunc)
|
||||
}
|
||||
|
||||
func setBlock(instr ssa.Instruction, block *ssa.BasicBlock) {
|
||||
setUnexportedField(instr, "block", block)
|
||||
}
|
||||
|
||||
func setType(instr ssa.Instruction, typ types.Type) {
|
||||
setUnexportedField(instr, "typ", typ)
|
||||
}
|
||||
|
||||
func makeSsaInt(i int) *ssa.Const {
|
||||
return ssa.NewConst(constant.MakeInt64(int64(i)), types.Typ[types.Int])
|
||||
}
|
@ -0,0 +1,212 @@
|
||||
package ctrlflow
|
||||
|
||||
import (
|
||||
"go/token"
|
||||
"go/types"
|
||||
mathrand "math/rand"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/tools/go/ssa"
|
||||
)
|
||||
|
||||
type blockMapping struct {
|
||||
Fake, Target *ssa.BasicBlock
|
||||
}
|
||||
|
||||
// applyFlattening adds a dispatcher block and uses ssa.Phi to redirect all ssa.Jump and ssa.If to the dispatcher,
|
||||
// additionally shuffle all blocks
|
||||
func applyFlattening(ssaFunc *ssa.Function, obfRand *mathrand.Rand) {
|
||||
if len(ssaFunc.Blocks) < 3 {
|
||||
return
|
||||
}
|
||||
|
||||
phiInstr := &ssa.Phi{Comment: "ctrflow.phi"}
|
||||
setType(phiInstr, types.Typ[types.Int])
|
||||
|
||||
entryBlock := &ssa.BasicBlock{
|
||||
Comment: "ctrflow.entry",
|
||||
Instrs: []ssa.Instruction{phiInstr},
|
||||
}
|
||||
setBlockParent(entryBlock, ssaFunc)
|
||||
|
||||
makeJumpBlock := func(from *ssa.BasicBlock) *ssa.BasicBlock {
|
||||
jumpBlock := &ssa.BasicBlock{
|
||||
Comment: "ctrflow.jump",
|
||||
Instrs: []ssa.Instruction{&ssa.Jump{}},
|
||||
Preds: []*ssa.BasicBlock{from},
|
||||
Succs: []*ssa.BasicBlock{entryBlock},
|
||||
}
|
||||
setBlockParent(jumpBlock, ssaFunc)
|
||||
return jumpBlock
|
||||
}
|
||||
|
||||
// map for track fake block -> real block jump
|
||||
var blocksMapping []blockMapping
|
||||
for _, block := range ssaFunc.Blocks {
|
||||
existInstr := block.Instrs[len(block.Instrs)-1]
|
||||
switch existInstr.(type) {
|
||||
case *ssa.Jump:
|
||||
targetBlock := block.Succs[0]
|
||||
fakeBlock := makeJumpBlock(block)
|
||||
blocksMapping = append(blocksMapping, blockMapping{fakeBlock, targetBlock})
|
||||
block.Succs[0] = fakeBlock
|
||||
case *ssa.If:
|
||||
tblock, fblock := block.Succs[0], block.Succs[1]
|
||||
fakeTblock, fakeFblock := makeJumpBlock(tblock), makeJumpBlock(fblock)
|
||||
|
||||
blocksMapping = append(blocksMapping, blockMapping{fakeTblock, tblock})
|
||||
blocksMapping = append(blocksMapping, blockMapping{fakeFblock, fblock})
|
||||
|
||||
block.Succs[0] = fakeTblock
|
||||
block.Succs[1] = fakeFblock
|
||||
case *ssa.Return, *ssa.Panic:
|
||||
// control flow flattening is not applicable
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
phiIdxs := obfRand.Perm(len(blocksMapping))
|
||||
for i := range phiIdxs {
|
||||
phiIdxs[i]++ // 0 reserved for real entry block
|
||||
}
|
||||
|
||||
var entriesBlocks []*ssa.BasicBlock
|
||||
obfuscatedBlocks := ssaFunc.Blocks
|
||||
for i, m := range blocksMapping {
|
||||
entryBlock.Preds = append(entryBlock.Preds, m.Fake)
|
||||
phiInstr.Edges = append(phiInstr.Edges, makeSsaInt(phiIdxs[i]))
|
||||
|
||||
obfuscatedBlocks = append(obfuscatedBlocks, m.Fake)
|
||||
|
||||
cond := &ssa.BinOp{X: phiInstr, Op: token.EQL, Y: makeSsaInt(phiIdxs[i])}
|
||||
setType(cond, types.Typ[types.Bool])
|
||||
|
||||
*phiInstr.Referrers() = append(*phiInstr.Referrers(), cond)
|
||||
|
||||
ifInstr := &ssa.If{Cond: cond}
|
||||
*cond.Referrers() = append(*cond.Referrers(), ifInstr)
|
||||
|
||||
ifBlock := &ssa.BasicBlock{
|
||||
Instrs: []ssa.Instruction{cond, ifInstr},
|
||||
Succs: []*ssa.BasicBlock{m.Target, nil}, // false branch fulfilled in next iteration or linked to real entry block
|
||||
}
|
||||
setBlockParent(ifBlock, ssaFunc)
|
||||
|
||||
setBlock(cond, ifBlock)
|
||||
setBlock(ifInstr, ifBlock)
|
||||
entriesBlocks = append(entriesBlocks, ifBlock)
|
||||
|
||||
if i == 0 {
|
||||
entryBlock.Instrs = append(entryBlock.Instrs, &ssa.Jump{})
|
||||
entryBlock.Succs = []*ssa.BasicBlock{ifBlock}
|
||||
ifBlock.Preds = append(ifBlock.Preds, entryBlock)
|
||||
} else {
|
||||
// link previous block to current
|
||||
entriesBlocks[i-1].Succs[1] = ifBlock
|
||||
ifBlock.Preds = append(ifBlock.Preds, entriesBlocks[i-1])
|
||||
}
|
||||
}
|
||||
|
||||
lastFakeEntry := entriesBlocks[len(entriesBlocks)-1]
|
||||
|
||||
realEntryBlock := ssaFunc.Blocks[0]
|
||||
lastFakeEntry.Succs[1] = realEntryBlock
|
||||
realEntryBlock.Preds = append(realEntryBlock.Preds, lastFakeEntry)
|
||||
|
||||
obfuscatedBlocks = append(obfuscatedBlocks, entriesBlocks...)
|
||||
obfRand.Shuffle(len(obfuscatedBlocks), func(i, j int) {
|
||||
obfuscatedBlocks[i], obfuscatedBlocks[j] = obfuscatedBlocks[j], obfuscatedBlocks[i]
|
||||
})
|
||||
ssaFunc.Blocks = append([]*ssa.BasicBlock{entryBlock}, obfuscatedBlocks...)
|
||||
}
|
||||
|
||||
// addJunkBlocks adds junk jumps into random blocks. Can create chains of junk jumps.
|
||||
func addJunkBlocks(ssaFunc *ssa.Function, count int, obfRand *mathrand.Rand) {
|
||||
if count == 0 {
|
||||
return
|
||||
}
|
||||
var candidates []*ssa.BasicBlock
|
||||
for _, block := range ssaFunc.Blocks {
|
||||
if len(block.Succs) > 0 {
|
||||
candidates = append(candidates, block)
|
||||
}
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
targetBlock := candidates[obfRand.Intn(len(candidates))]
|
||||
succsIdx := obfRand.Intn(len(targetBlock.Succs))
|
||||
succs := targetBlock.Succs[succsIdx]
|
||||
|
||||
fakeBlock := &ssa.BasicBlock{
|
||||
Comment: "ctrflow.fake." + strconv.Itoa(i),
|
||||
Instrs: []ssa.Instruction{&ssa.Jump{}},
|
||||
Preds: []*ssa.BasicBlock{targetBlock},
|
||||
Succs: []*ssa.BasicBlock{succs},
|
||||
}
|
||||
setBlockParent(fakeBlock, ssaFunc)
|
||||
targetBlock.Succs[succsIdx] = fakeBlock
|
||||
|
||||
ssaFunc.Blocks = append(ssaFunc.Blocks, fakeBlock)
|
||||
candidates = append(candidates, fakeBlock)
|
||||
}
|
||||
}
|
||||
|
||||
// applySplitting splits biggest block into 2 parts of random size.
|
||||
// Returns false if no block large enough for splitting is found
|
||||
func applySplitting(ssaFunc *ssa.Function, obfRand *mathrand.Rand) bool {
|
||||
var targetBlock *ssa.BasicBlock
|
||||
for _, block := range ssaFunc.Blocks {
|
||||
if targetBlock == nil || len(block.Instrs) > len(targetBlock.Instrs) {
|
||||
targetBlock = block
|
||||
}
|
||||
}
|
||||
|
||||
const minInstrCount = 1 + 1 // 1 exit instruction + 1 any instruction
|
||||
if targetBlock == nil || len(targetBlock.Instrs) <= minInstrCount {
|
||||
return false
|
||||
}
|
||||
|
||||
splitIdx := 1 + obfRand.Intn(len(targetBlock.Instrs)-2)
|
||||
|
||||
firstPart := make([]ssa.Instruction, splitIdx+1)
|
||||
copy(firstPart, targetBlock.Instrs)
|
||||
firstPart[len(firstPart)-1] = &ssa.Jump{}
|
||||
|
||||
secondPart := targetBlock.Instrs[splitIdx:]
|
||||
targetBlock.Instrs = firstPart
|
||||
|
||||
newBlock := &ssa.BasicBlock{
|
||||
Comment: "ctrflow.split." + strconv.Itoa(targetBlock.Index),
|
||||
Instrs: secondPart,
|
||||
Preds: []*ssa.BasicBlock{targetBlock},
|
||||
Succs: targetBlock.Succs,
|
||||
}
|
||||
setBlockParent(newBlock, ssaFunc)
|
||||
for _, instr := range newBlock.Instrs {
|
||||
setBlock(instr, newBlock)
|
||||
}
|
||||
|
||||
// Fix preds for ssa.Phi working
|
||||
for _, succ := range targetBlock.Succs {
|
||||
for i, pred := range succ.Preds {
|
||||
if pred == targetBlock {
|
||||
succ.Preds[i] = newBlock
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ssaFunc.Blocks = append(ssaFunc.Blocks, newBlock)
|
||||
targetBlock.Succs = []*ssa.BasicBlock{newBlock}
|
||||
return true
|
||||
}
|
||||
|
||||
func fixBlockIndexes(ssaFunc *ssa.Function) {
|
||||
for i, block := range ssaFunc.Blocks {
|
||||
block.Index = i
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,398 @@
|
||||
package ssa2ast
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/importer"
|
||||
"go/printer"
|
||||
"go/types"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/tools/go/ast/astutil"
|
||||
"golang.org/x/tools/go/ssa"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"golang.org/x/tools/go/ssa/ssautil"
|
||||
)
|
||||
|
||||
const sigSrc = `package main
|
||||
|
||||
import "unsafe"
|
||||
|
||||
type genericStruct[T interface{}] struct{}
|
||||
type plainStruct struct {
|
||||
Dummy struct{}
|
||||
}
|
||||
|
||||
func (s *plainStruct) plainStructFunc() {
|
||||
|
||||
}
|
||||
|
||||
func (*plainStruct) plainStructAnonFunc() {
|
||||
|
||||
}
|
||||
|
||||
func (s *genericStruct[T]) genericStructFunc() {
|
||||
|
||||
}
|
||||
func (s *genericStruct[T]) genericStructAnonFunc() (test T) {
|
||||
return
|
||||
}
|
||||
|
||||
func plainFuncSignature(a int, b string, c struct{}, d struct{ string }, e interface{ Dummy() string }, pointer unsafe.Pointer) (i int, er error) {
|
||||
return
|
||||
}
|
||||
|
||||
func genericFuncSignature[T interface{ interface{} | ~int64 | bool }, X interface{ comparable }](a T, b X, c genericStruct[struct{ a T }], d genericStruct[T]) (res T) {
|
||||
return
|
||||
}
|
||||
`
|
||||
|
||||
func TestConvertSignature(t *testing.T) {
|
||||
conv := newFuncConverter(DefaultConfig())
|
||||
|
||||
f, _, info, _ := mustParseAndTypeCheckFile(sigSrc)
|
||||
for _, funcName := range []string{"plainStructFunc", "plainStructAnonFunc", "genericStructFunc", "plainFuncSignature", "genericFuncSignature"} {
|
||||
funcDecl := findFunc(f, funcName)
|
||||
funcDecl.Body = nil
|
||||
|
||||
funcObj := info.Defs[funcDecl.Name].(*types.Func)
|
||||
funcDeclConverted, err := conv.convertSignatureToFuncDecl(funcObj.Name(), funcObj.Type().(*types.Signature))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if structDiff := cmp.Diff(funcDecl, funcDeclConverted, astCmpOpt); structDiff != "" {
|
||||
t.Fatalf("method decl not equals: %s", structDiff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const mainSrc = `package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func main() {
|
||||
methodOps()
|
||||
slicesOps()
|
||||
iterAndMapsOps()
|
||||
chanOps()
|
||||
flowOps()
|
||||
typeOps()
|
||||
}
|
||||
|
||||
func makeSprintf(tag string) func(vals ...interface{}) {
|
||||
i := 0
|
||||
return func(vals ...interface{}) {
|
||||
fmt.Printf("%s(%d): %v\n", tag, i, vals)
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
func return42() int {
|
||||
return 42
|
||||
}
|
||||
|
||||
type arrayOfInts []int
|
||||
|
||||
type structOfArraysOfInts struct {
|
||||
a arrayOfInts
|
||||
b arrayOfInts
|
||||
}
|
||||
|
||||
func slicesOps() {
|
||||
sprintf := makeSprintf("slicesOps")
|
||||
|
||||
slice := [...]int{1, 2}
|
||||
sprintf(slice[0:1:2])
|
||||
// *ssa.IndexAddr
|
||||
sprintf(slice)
|
||||
slice[0] += 1
|
||||
sprintf(slice)
|
||||
|
||||
sprintf(slice[:1])
|
||||
sprintf(slice[slice[0]:])
|
||||
sprintf(slice[0:2])
|
||||
|
||||
sprintf((*[2]int)(slice[:])[return42()%2]) // *ssa.SliceToArrayPointer
|
||||
|
||||
sprintf("test"[return42()%3]) // *ssa.Index
|
||||
|
||||
structOfArrays := structOfArraysOfInts{a: slice[1:], b: slice[:1]}
|
||||
sprintf(structOfArrays.a[:1])
|
||||
sprintf(structOfArrays.b[:1])
|
||||
|
||||
slice2 := make([]string, return42(), return42()*2)
|
||||
slice2[return42()-1] = "test"
|
||||
sprintf(slice2)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func iterAndMapsOps() {
|
||||
sprintf := makeSprintf("iterAndMapsOps")
|
||||
|
||||
// *ssa.MakeMap + *ssa.MapUpdate
|
||||
mmap := map[string]time.Month{
|
||||
"April": time.April,
|
||||
"December": time.December,
|
||||
"January": time.January,
|
||||
}
|
||||
|
||||
var vals []string
|
||||
for k := range mmap {
|
||||
vals = append(vals, k)
|
||||
}
|
||||
for _, v := range mmap {
|
||||
vals = append(vals, v.String())
|
||||
}
|
||||
sort.Strings(vals) // Required. Order of map iteration not guaranteed
|
||||
sprintf(vals)
|
||||
|
||||
if v, ok := mmap["?"]; ok {
|
||||
panic("unreachable: " + v.String())
|
||||
}
|
||||
for idx, s := range "hello world" {
|
||||
sprintf(idx, s)
|
||||
}
|
||||
|
||||
sprintf(mmap["April"].String())
|
||||
return
|
||||
}
|
||||
|
||||
type interfaceCalls interface {
|
||||
Return1() string
|
||||
}
|
||||
|
||||
type structCalls struct {
|
||||
}
|
||||
|
||||
func (r structCalls) Return1() string {
|
||||
return "Return1"
|
||||
}
|
||||
|
||||
func (r *structCalls) Return2() string {
|
||||
return "Return2"
|
||||
}
|
||||
|
||||
func multiOutputRes() (int, string) {
|
||||
return 42, "24"
|
||||
}
|
||||
|
||||
func returnInterfaceCalls() interfaceCalls {
|
||||
return structCalls{}
|
||||
}
|
||||
|
||||
func methodOps() {
|
||||
sprintf := makeSprintf("methodOps")
|
||||
|
||||
defer func() {
|
||||
sprintf("from defer")
|
||||
}()
|
||||
defer sprintf("from defer 2")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
sprintf("from go")
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
i, s := multiOutputRes()
|
||||
sprintf(strconv.Itoa(i))
|
||||
|
||||
var strct structCalls
|
||||
|
||||
strct.Return1()
|
||||
strct.Return2()
|
||||
|
||||
intrfs := returnInterfaceCalls()
|
||||
intrfs.Return1()
|
||||
|
||||
sprintf(strconv.Itoa(len(s)))
|
||||
|
||||
strconv.Itoa(binary.Size(4))
|
||||
sprintf(binary.LittleEndian.AppendUint32(nil, 42))
|
||||
|
||||
if len(s) == 0 {
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
sprintf(*unsafe.StringData(s))
|
||||
|
||||
thunkMethod1 := structCalls.Return1
|
||||
sprintf(thunkMethod1(strct))
|
||||
|
||||
thunkMethod2 := (*structCalls).Return2
|
||||
sprintf(thunkMethod2(&strct))
|
||||
|
||||
closureVar := "c " + s
|
||||
anonFnc := func(n func(structCalls) string) string {
|
||||
return n(structCalls{}) + "anon" + closureVar
|
||||
}
|
||||
|
||||
sprintf(anonFnc(structCalls.Return1))
|
||||
}
|
||||
|
||||
func chanOps() {
|
||||
sprintf := makeSprintf("chanOps")
|
||||
|
||||
a := make(chan string)
|
||||
b := make(chan string)
|
||||
c := make(chan string)
|
||||
d := make(chan string)
|
||||
|
||||
select {
|
||||
case r1, ok := <-a:
|
||||
sprintf(r1, ok)
|
||||
case r2 := <-b:
|
||||
sprintf(r2)
|
||||
case <-c:
|
||||
sprintf("r3")
|
||||
case d <- "test":
|
||||
sprintf("d triggered")
|
||||
default:
|
||||
sprintf("default")
|
||||
}
|
||||
|
||||
e := make(chan string, 1)
|
||||
e <- "hi"
|
||||
|
||||
sprintf(<-e)
|
||||
|
||||
close(a)
|
||||
val, ok := <-a
|
||||
|
||||
sprintf(val, ok)
|
||||
return
|
||||
}
|
||||
|
||||
func flowOps() {
|
||||
sprintf := makeSprintf("flowOps")
|
||||
i := 1
|
||||
if return42()%2 == 0 {
|
||||
sprintf("a")
|
||||
i++
|
||||
} else {
|
||||
sprintf("b")
|
||||
}
|
||||
sprintf(i)
|
||||
|
||||
switch return42() {
|
||||
case 1:
|
||||
sprintf("1")
|
||||
case 2:
|
||||
sprintf("2")
|
||||
case 3:
|
||||
sprintf("3")
|
||||
case 42:
|
||||
sprintf("42")
|
||||
}
|
||||
}
|
||||
|
||||
type interfaceB interface {
|
||||
}
|
||||
|
||||
type testStruct struct {
|
||||
A, B int
|
||||
}
|
||||
|
||||
func typeOps() {
|
||||
sprintf := makeSprintf("typeOps")
|
||||
|
||||
// *ssa.ChangeType
|
||||
var interA interfaceCalls
|
||||
sprintf(interA)
|
||||
|
||||
// *ssa.ChangeInterface
|
||||
var interB interfaceB = struct{}{}
|
||||
var inter0 interface{} = interB
|
||||
sprintf(inter0)
|
||||
|
||||
// *ssa.Convert
|
||||
var f float64 = 1.0
|
||||
sprintf(int(f))
|
||||
|
||||
casted, ok := inter0.(interfaceB)
|
||||
sprintf(casted, ok)
|
||||
|
||||
casted2 := inter0.(interfaceB)
|
||||
sprintf(casted2)
|
||||
|
||||
strc := testStruct{return42(), return42() + 2}
|
||||
strc.B += strc.A
|
||||
sprintf(strc)
|
||||
|
||||
// Access to unexported structure
|
||||
discard := io.Discard
|
||||
if return42() == 0 {
|
||||
sprintf(discard) // Trigger phi block
|
||||
}
|
||||
_, _ = discard.Write([]byte("test"))
|
||||
}`
|
||||
|
||||
func TestConvert(t *testing.T) {
|
||||
runGoFile := func(f string) string {
|
||||
cmd := exec.Command("go", "run", f)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v\n%s", err, string(out))
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
testFile := filepath.Join(t.TempDir(), "convert.go")
|
||||
if err := os.WriteFile(testFile, []byte(mainSrc), 0o777); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
originalOut := runGoFile(testFile)
|
||||
file, fset, _, _ := mustParseAndTypeCheckFile(mainSrc)
|
||||
ssaPkg, _, err := ssautil.BuildPackage(&types.Config{Importer: importer.Default()}, fset, types.NewPackage("test/main", ""), []*ast.File{file}, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for fIdx, decl := range file.Decls {
|
||||
funcDecl, ok := decl.(*ast.FuncDecl)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
path, _ := astutil.PathEnclosingInterval(file, funcDecl.Pos(), funcDecl.Pos())
|
||||
ssaFunc := ssa.EnclosingFunction(ssaPkg, path)
|
||||
|
||||
astFunc, err := Convert(ssaFunc, DefaultConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
file.Decls[fIdx] = astFunc
|
||||
}
|
||||
|
||||
convertedFile := filepath.Join(t.TempDir(), "main.go")
|
||||
f, err := os.Create(convertedFile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := printer.Fprint(f, fset, file); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = f.Close()
|
||||
|
||||
convertedOut := runGoFile(convertedFile)
|
||||
|
||||
if convertedOut != originalOut {
|
||||
t.Fatalf("Output not equals:\n\n%s\n\n%s", originalOut, convertedOut)
|
||||
}
|
||||
}
|
@ -0,0 +1,72 @@
|
||||
package ssa2ast
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/importer"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"go/types"
|
||||
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
)
|
||||
|
||||
var astCmpOpt = cmpopts.IgnoreTypes(token.NoPos, &ast.Object{})
|
||||
|
||||
func findStruct(file *ast.File, structName string) (name *ast.Ident, structType *ast.StructType) {
|
||||
ast.Inspect(file, func(node ast.Node) bool {
|
||||
if structType != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
typeSpec, ok := node.(*ast.TypeSpec)
|
||||
if !ok || typeSpec.Name == nil || typeSpec.Name.Name != structName {
|
||||
return true
|
||||
}
|
||||
typ, ok := typeSpec.Type.(*ast.StructType)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
structType = typ
|
||||
name = typeSpec.Name
|
||||
return true
|
||||
})
|
||||
|
||||
if structType == nil {
|
||||
panic(structName + " not found")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func findFunc(file *ast.File, funcName string) *ast.FuncDecl {
|
||||
for _, decl := range file.Decls {
|
||||
fDecl, ok := decl.(*ast.FuncDecl)
|
||||
if ok && fDecl.Name.Name == funcName {
|
||||
return fDecl
|
||||
}
|
||||
}
|
||||
panic(funcName + " not found")
|
||||
}
|
||||
|
||||
func mustParseAndTypeCheckFile(src string) (*ast.File, *token.FileSet, *types.Info, *types.Package) {
|
||||
fset := token.NewFileSet()
|
||||
f, err := parser.ParseFile(fset, "a.go", src, 0)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
config := types.Config{Importer: importer.Default()}
|
||||
info := &types.Info{
|
||||
Types: make(map[ast.Expr]types.TypeAndValue),
|
||||
Defs: make(map[*ast.Ident]types.Object),
|
||||
Uses: make(map[*ast.Ident]types.Object),
|
||||
Instances: make(map[*ast.Ident]types.Instance),
|
||||
Implicits: make(map[ast.Node]types.Object),
|
||||
Scopes: make(map[ast.Node]*types.Scope),
|
||||
Selections: make(map[*ast.SelectorExpr]*types.Selection),
|
||||
}
|
||||
pkg, err := config.Check("test/main", fset, []*ast.File{f}, info)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return f, fset, info, pkg
|
||||
}
|
@ -0,0 +1,176 @@
|
||||
package ssa2ast
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"go/types"
|
||||
)
|
||||
|
||||
func makeMapIteratorPolyfill(tc *typeConverter, mapType *types.Map) (ast.Expr, types.Type, error) {
|
||||
keyTypeExpr, err := tc.Convert(mapType.Key())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
valueTypeExpr, err := tc.Convert(mapType.Elem())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
nextType := types.NewSignatureType(nil, nil, nil, nil, types.NewTuple(
|
||||
types.NewVar(token.NoPos, nil, "", types.Typ[types.Bool]),
|
||||
types.NewVar(token.NoPos, nil, "", mapType.Key()),
|
||||
types.NewVar(token.NoPos, nil, "", mapType.Elem()),
|
||||
), false)
|
||||
|
||||
// Generated using https://github.com/lu4p/astextract from snippet:
|
||||
/*
|
||||
func(m map[<key type>]<value type>) func() (bool, <key type>, <value type>) {
|
||||
keys := make([]<key type>, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
i := 0
|
||||
return func() (ok bool, k <key type>, r <value type>) {
|
||||
if i < len(keys) {
|
||||
k = keys[i]
|
||||
ok, r = true, m[k]
|
||||
i++
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
*/
|
||||
return &ast.FuncLit{
|
||||
Type: &ast.FuncType{
|
||||
Params: &ast.FieldList{List: []*ast.Field{{
|
||||
Names: []*ast.Ident{{Name: "m"}},
|
||||
Type: &ast.MapType{
|
||||
Key: keyTypeExpr,
|
||||
Value: valueTypeExpr,
|
||||
},
|
||||
}}},
|
||||
Results: &ast.FieldList{List: []*ast.Field{{
|
||||
Type: &ast.FuncType{
|
||||
Params: &ast.FieldList{},
|
||||
Results: &ast.FieldList{List: []*ast.Field{
|
||||
{Type: &ast.Ident{Name: "bool"}},
|
||||
{Type: keyTypeExpr},
|
||||
{Type: valueTypeExpr},
|
||||
}},
|
||||
},
|
||||
}}},
|
||||
},
|
||||
Body: &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{&ast.Ident{Name: "keys"}},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{
|
||||
&ast.CallExpr{
|
||||
Fun: &ast.Ident{Name: "make"},
|
||||
Args: []ast.Expr{
|
||||
&ast.ArrayType{Elt: keyTypeExpr},
|
||||
&ast.BasicLit{Kind: token.INT, Value: "0"},
|
||||
&ast.CallExpr{
|
||||
Fun: &ast.Ident{Name: "len"},
|
||||
Args: []ast.Expr{&ast.Ident{Name: "m"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
&ast.RangeStmt{
|
||||
Key: &ast.Ident{Name: "k"},
|
||||
Tok: token.DEFINE,
|
||||
X: &ast.Ident{Name: "m"},
|
||||
Body: &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{&ast.Ident{Name: "keys"}},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{
|
||||
&ast.CallExpr{
|
||||
Fun: &ast.Ident{Name: "append"},
|
||||
Args: []ast.Expr{
|
||||
&ast.Ident{Name: "keys"},
|
||||
&ast.Ident{Name: "k"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{&ast.Ident{Name: "i"}},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{&ast.BasicLit{Kind: token.INT, Value: "0"}},
|
||||
},
|
||||
&ast.ReturnStmt{Results: []ast.Expr{
|
||||
&ast.FuncLit{
|
||||
Type: &ast.FuncType{
|
||||
Params: &ast.FieldList{},
|
||||
Results: &ast.FieldList{List: []*ast.Field{
|
||||
{
|
||||
Names: []*ast.Ident{{Name: "ok"}},
|
||||
Type: &ast.Ident{Name: "bool"},
|
||||
},
|
||||
{
|
||||
Names: []*ast.Ident{{Name: "k"}},
|
||||
Type: keyTypeExpr,
|
||||
},
|
||||
{
|
||||
Names: []*ast.Ident{{Name: "r"}},
|
||||
Type: valueTypeExpr,
|
||||
},
|
||||
}},
|
||||
},
|
||||
Body: &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
&ast.IfStmt{
|
||||
Cond: &ast.BinaryExpr{
|
||||
X: &ast.Ident{Name: "i"},
|
||||
Op: token.LSS,
|
||||
Y: &ast.CallExpr{
|
||||
Fun: &ast.Ident{Name: "len"},
|
||||
Args: []ast.Expr{&ast.Ident{Name: "keys"}},
|
||||
},
|
||||
},
|
||||
Body: &ast.BlockStmt{List: []ast.Stmt{
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{&ast.Ident{Name: "k"}},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{&ast.IndexExpr{
|
||||
X: &ast.Ident{Name: "keys"},
|
||||
Index: &ast.Ident{Name: "i"},
|
||||
}},
|
||||
},
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{
|
||||
&ast.Ident{Name: "ok"},
|
||||
&ast.Ident{Name: "r"},
|
||||
},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{
|
||||
&ast.Ident{Name: "true"},
|
||||
&ast.IndexExpr{
|
||||
X: &ast.Ident{Name: "m"},
|
||||
Index: &ast.Ident{Name: "k"},
|
||||
},
|
||||
},
|
||||
},
|
||||
&ast.IncDecStmt{
|
||||
X: &ast.Ident{Name: "i"},
|
||||
Tok: token.INC,
|
||||
},
|
||||
}},
|
||||
},
|
||||
&ast.ReturnStmt{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}, nextType, nil
|
||||
}
|
@ -0,0 +1,249 @@
|
||||
package ssa2ast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type typeConverter struct {
|
||||
resolver ImportNameResolver
|
||||
}
|
||||
|
||||
func (tc *typeConverter) Convert(t types.Type) (ast.Expr, error) {
|
||||
switch typ := t.(type) {
|
||||
case *types.Array:
|
||||
eltExpr, err := tc.Convert(typ.Elem())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ast.ArrayType{
|
||||
Len: &ast.BasicLit{
|
||||
Kind: token.INT,
|
||||
Value: strconv.FormatInt(typ.Len(), 10),
|
||||
},
|
||||
Elt: eltExpr,
|
||||
}, nil
|
||||
case *types.Basic:
|
||||
if typ.Kind() == types.UnsafePointer {
|
||||
unsafePkgIdent := tc.resolver(types.Unsafe)
|
||||
if unsafePkgIdent == nil {
|
||||
return nil, fmt.Errorf("cannot resolve unsafe package")
|
||||
}
|
||||
return &ast.SelectorExpr{X: unsafePkgIdent, Sel: ast.NewIdent("Pointer")}, nil
|
||||
}
|
||||
return ast.NewIdent(typ.Name()), nil
|
||||
case *types.Chan:
|
||||
chanValueExpr, err := tc.Convert(typ.Elem())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chanExpr := &ast.ChanType{Value: chanValueExpr}
|
||||
switch typ.Dir() {
|
||||
case types.SendRecv:
|
||||
chanExpr.Dir = ast.SEND | ast.RECV
|
||||
case types.RecvOnly:
|
||||
chanExpr.Dir = ast.RECV
|
||||
case types.SendOnly:
|
||||
chanExpr.Dir = ast.SEND
|
||||
}
|
||||
return chanExpr, nil
|
||||
case *types.Interface:
|
||||
methods := &ast.FieldList{}
|
||||
for i := 0; i < typ.NumEmbeddeds(); i++ {
|
||||
embeddedType := typ.EmbeddedType(i)
|
||||
embeddedExpr, err := tc.Convert(embeddedType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
methods.List = append(methods.List, &ast.Field{Type: embeddedExpr})
|
||||
}
|
||||
for i := 0; i < typ.NumExplicitMethods(); i++ {
|
||||
method := typ.ExplicitMethod(i)
|
||||
methodSig, err := tc.Convert(method.Type())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
methods.List = append(methods.List, &ast.Field{
|
||||
Names: []*ast.Ident{ast.NewIdent(method.Name())},
|
||||
Type: methodSig,
|
||||
})
|
||||
}
|
||||
return &ast.InterfaceType{Methods: methods}, nil
|
||||
case *types.Map:
|
||||
keyExpr, err := tc.Convert(typ.Key())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
valueExpr, err := tc.Convert(typ.Elem())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ast.MapType{Key: keyExpr, Value: valueExpr}, nil
|
||||
case *types.Named:
|
||||
obj := typ.Obj()
|
||||
|
||||
// TODO: rewrite struct inlining without reflection hack
|
||||
if parent := obj.Parent(); parent != nil {
|
||||
isFuncScope := reflect.ValueOf(parent).Elem().FieldByName("isFunc")
|
||||
if isFuncScope.Bool() {
|
||||
return tc.Convert(obj.Type().Underlying())
|
||||
}
|
||||
}
|
||||
|
||||
var namedExpr ast.Expr
|
||||
if pkgIdent := tc.resolver(obj.Pkg()); pkgIdent != nil {
|
||||
// reference to unexported named emulated through new interface with explicit declarated methods
|
||||
if !token.IsExported(obj.Name()) {
|
||||
var methods []*types.Func
|
||||
for i := 0; i < typ.NumMethods(); i++ {
|
||||
method := typ.Method(i)
|
||||
if token.IsExported(method.Name()) {
|
||||
methods = append(methods, method)
|
||||
}
|
||||
}
|
||||
|
||||
fakeInterface := types.NewInterfaceType(methods, nil)
|
||||
return tc.Convert(fakeInterface)
|
||||
}
|
||||
namedExpr = &ast.SelectorExpr{X: pkgIdent, Sel: ast.NewIdent(obj.Name())}
|
||||
} else {
|
||||
namedExpr = ast.NewIdent(obj.Name())
|
||||
}
|
||||
|
||||
typeParams := typ.TypeArgs()
|
||||
if typeParams == nil || typeParams.Len() == 0 {
|
||||
return namedExpr, nil
|
||||
}
|
||||
if typeParams.Len() == 1 {
|
||||
typeParamExpr, err := tc.Convert(typeParams.At(0))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ast.IndexExpr{X: namedExpr, Index: typeParamExpr}, nil
|
||||
}
|
||||
genericExpr := &ast.IndexListExpr{X: namedExpr}
|
||||
for i := 0; i < typeParams.Len(); i++ {
|
||||
typeArgs := typeParams.At(i)
|
||||
typeParamExpr, err := tc.Convert(typeArgs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
genericExpr.Indices = append(genericExpr.Indices, typeParamExpr)
|
||||
}
|
||||
return genericExpr, nil
|
||||
case *types.Pointer:
|
||||
expr, err := tc.Convert(typ.Elem())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ast.StarExpr{X: expr}, nil
|
||||
case *types.Signature:
|
||||
funcSigExpr := &ast.FuncType{Params: &ast.FieldList{}}
|
||||
if sigParams := typ.Params(); sigParams != nil {
|
||||
for i := 0; i < sigParams.Len(); i++ {
|
||||
param := sigParams.At(i)
|
||||
|
||||
var paramType ast.Expr
|
||||
if typ.Variadic() && i == sigParams.Len()-1 {
|
||||
slice := param.Type().(*types.Slice)
|
||||
|
||||
eltExpr, err := tc.Convert(slice.Elem())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paramType = &ast.Ellipsis{Elt: eltExpr}
|
||||
} else {
|
||||
paramExpr, err := tc.Convert(param.Type())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paramType = paramExpr
|
||||
}
|
||||
f := &ast.Field{Type: paramType}
|
||||
if name := param.Name(); name != "" {
|
||||
f.Names = []*ast.Ident{ast.NewIdent(name)}
|
||||
}
|
||||
funcSigExpr.Params.List = append(funcSigExpr.Params.List, f)
|
||||
}
|
||||
}
|
||||
if sigResults := typ.Results(); sigResults != nil {
|
||||
funcSigExpr.Results = &ast.FieldList{}
|
||||
for i := 0; i < sigResults.Len(); i++ {
|
||||
result := sigResults.At(i)
|
||||
resultExpr, err := tc.Convert(result.Type())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f := &ast.Field{Type: resultExpr}
|
||||
if name := result.Name(); name != "" {
|
||||
f.Names = []*ast.Ident{ast.NewIdent(name)}
|
||||
}
|
||||
funcSigExpr.Results.List = append(funcSigExpr.Results.List, f)
|
||||
}
|
||||
}
|
||||
if typeParams := typ.TypeParams(); typeParams != nil {
|
||||
funcSigExpr.TypeParams = &ast.FieldList{}
|
||||
for i := 0; i < typeParams.Len(); i++ {
|
||||
typeParam := typeParams.At(i)
|
||||
resultExpr, err := tc.Convert(typeParam.Constraint().Underlying())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f := &ast.Field{Type: resultExpr, Names: []*ast.Ident{ast.NewIdent(typeParam.Obj().Name())}}
|
||||
funcSigExpr.TypeParams.List = append(funcSigExpr.TypeParams.List, f)
|
||||
}
|
||||
}
|
||||
return funcSigExpr, nil
|
||||
case *types.Slice:
|
||||
eltExpr, err := tc.Convert(typ.Elem())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ast.ArrayType{Elt: eltExpr}, nil
|
||||
case *types.Struct:
|
||||
fieldList := &ast.FieldList{}
|
||||
for i := 0; i < typ.NumFields(); i++ {
|
||||
f := typ.Field(i)
|
||||
fieldExpr, err := tc.Convert(f.Type())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
field := &ast.Field{Type: fieldExpr}
|
||||
if !f.Anonymous() {
|
||||
field.Names = []*ast.Ident{ast.NewIdent(f.Name())}
|
||||
}
|
||||
if tag := typ.Tag(i); len(tag) > 0 {
|
||||
field.Tag = &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("%q", tag)}
|
||||
}
|
||||
fieldList.List = append(fieldList.List, field)
|
||||
}
|
||||
return &ast.StructType{Fields: fieldList}, nil
|
||||
case *types.TypeParam:
|
||||
return ast.NewIdent(typ.Obj().Name()), nil
|
||||
case *types.Union:
|
||||
var unionExpr ast.Expr
|
||||
for i := 0; i < typ.Len(); i++ {
|
||||
term := typ.Term(i)
|
||||
expr, err := tc.Convert(term.Type())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if term.Tilde() {
|
||||
expr = &ast.UnaryExpr{Op: token.TILDE, X: expr}
|
||||
}
|
||||
if unionExpr == nil {
|
||||
unionExpr = expr
|
||||
} else {
|
||||
unionExpr = &ast.BinaryExpr{X: unionExpr, Op: token.OR, Y: expr}
|
||||
}
|
||||
}
|
||||
return unionExpr, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("type %v: %w", typ, ErrUnsupported)
|
||||
}
|
||||
}
|
@ -0,0 +1,108 @@
|
||||
package ssa2ast
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
const typesSrc = `package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
type localNamed bool
|
||||
|
||||
type embedStruct struct {
|
||||
int
|
||||
}
|
||||
|
||||
type genericStruct[K comparable, V int64 | float64] struct {
|
||||
int
|
||||
}
|
||||
|
||||
type exampleStruct struct {
|
||||
embedStruct
|
||||
|
||||
// *types.Array
|
||||
array [3]int
|
||||
array2 [0]int
|
||||
|
||||
// *types.Basic
|
||||
bool // anonymous
|
||||
string string "test:\"tag\""
|
||||
int int
|
||||
int8 int8
|
||||
int16 int16
|
||||
int32 int32
|
||||
int64 int64
|
||||
uint uint
|
||||
uint8 uint8
|
||||
uint16 uint16
|
||||
uint32 uint32
|
||||
uint64 uint64
|
||||
uintptr uintptr
|
||||
byte byte
|
||||
rune rune
|
||||
float32 float32
|
||||
float64 float64
|
||||
complex64 complex64
|
||||
complex128 complex128
|
||||
|
||||
// *types.Chan
|
||||
chanSendRecv chan struct{}
|
||||
chanRecv <-chan struct{}
|
||||
chanSend chan<- struct{}
|
||||
|
||||
// *types.Interface
|
||||
interface1 interface{}
|
||||
interface2 interface{ io.Reader }
|
||||
interface3 interface{ Dummy(int) bool }
|
||||
interface4 interface {
|
||||
io.Reader
|
||||
io.ByteReader
|
||||
Dummy(int) bool
|
||||
}
|
||||
|
||||
// *types.Map
|
||||
strMap map[string]string
|
||||
|
||||
// *types.Named
|
||||
localNamed localNamed
|
||||
importedNamed time.Month
|
||||
|
||||
// *types.Pointer
|
||||
pointer1 *string
|
||||
pointer2 **string
|
||||
|
||||
// *types.Signature
|
||||
func1 func(int, int) int
|
||||
func2 func(a int, b int, varargs ...struct{ string }) (res int)
|
||||
|
||||
// *types.Slice
|
||||
slice1 []int
|
||||
slice2 [][]int
|
||||
|
||||
// generics
|
||||
generic genericStruct[genericStruct[genericStruct[bool, int64], int64], int64]
|
||||
}
|
||||
`
|
||||
|
||||
func TestTypeToExpr(t *testing.T) {
|
||||
f, _, info, _ := mustParseAndTypeCheckFile(typesSrc)
|
||||
name, structAst := findStruct(f, "exampleStruct")
|
||||
obj := info.Defs[name]
|
||||
fc := &typeConverter{resolver: defaultImportNameResolver}
|
||||
convAst, err := fc.Convert(obj.Type().Underlying())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
structConvAst := convAst.(*ast.StructType)
|
||||
if structDiff := cmp.Diff(structAst, structConvAst, astCmpOpt); structDiff != "" {
|
||||
t.Fatalf("struct not equals: %s", structDiff)
|
||||
}
|
||||
}
|
@ -0,0 +1,68 @@
|
||||
env GARBLE_EXPERIMENTAL_CONTROLFLOW=1
|
||||
exec garble -literals -debugdir=debug -seed=0002deadbeef build -o=main$exe
|
||||
exec ./main
|
||||
cmp stderr main.stderr
|
||||
|
||||
# simple check to ensure that control flow will work. Must be a minimum of 10 goto's
|
||||
grep 'goto _s2a_l10' $WORK/debug/test/main/GARBLE_controlflow.go
|
||||
|
||||
# obfuscated function must be removed from original file
|
||||
! grep 'main\(\)' $WORK/debug/test/main/garble_main.go
|
||||
# original file must contains empty function
|
||||
grep '\_\(\)' $WORK/debug/test/main/garble_main.go
|
||||
|
||||
# switch must be simplified
|
||||
! grep switch $WORK/debug/test/main/GARBLE_controlflow.go
|
||||
|
||||
# obfuscated file must contains interface for unexported interface emulation
|
||||
grep 'GoString\(\) string' $WORK/debug/test/main/GARBLE_controlflow.go
|
||||
grep 'String\(\) string' $WORK/debug/test/main/GARBLE_controlflow.go
|
||||
|
||||
# control flow obfuscation should work correctly with literals obfuscation
|
||||
! binsubstr main$exe 'correct name'
|
||||
|
||||
-- go.mod --
|
||||
module test/main
|
||||
|
||||
go 1.20
|
||||
-- garble_main.go --
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"hash/crc32"
|
||||
)
|
||||
|
||||
//garble:controlflow flatten_passes=1 junk_jumps=10 block_splits=10
|
||||
func main() {
|
||||
// Reference to the unexported interface triggers creation of a new interface
|
||||
// with a list of all functions of the private interface
|
||||
endian := binary.LittleEndian
|
||||
println(endian.String())
|
||||
println(endian.GoString())
|
||||
println(endian.Uint16([]byte{0, 1}))
|
||||
|
||||
// Switch statement should be simplified to if statements
|
||||
switch endian.String() {
|
||||
case "LittleEndian":
|
||||
println("correct name")
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
// Indirect import "hash" package
|
||||
hash := crc32.New(crc32.IEEETable)
|
||||
hash.Write([]byte("1"))
|
||||
hash.Write([]byte("2"))
|
||||
hash.Write([]byte("3"))
|
||||
|
||||
println(hex.EncodeToString(hash.Sum(nil)))
|
||||
}
|
||||
|
||||
-- main.stderr --
|
||||
LittleEndian
|
||||
binary.LittleEndian
|
||||
256
|
||||
correct name
|
||||
884863d2
|
Loading…
Reference in New Issue