blacklist struct fields with reflection too

In the added test, the unexported field used to be garbled.

Reflection can only reach exported methods, exported fields, and
unexported fields. Exported methods and fields are currently never
garbled, so unexported fields was the only missing piece.
pull/28/head
Daniel Martí 4 years ago
parent 809b7a8dda
commit 80538f19c7

@ -452,12 +452,29 @@ func hashWith(salt, value string) string {
// used with reflect.TypeOf or reflect.ValueOf. Since we obfuscate one package // used with reflect.TypeOf or reflect.ValueOf. Since we obfuscate one package
// at a time, we only detect those if the type definition and the reflect usage // at a time, we only detect those if the type definition and the reflect usage
// are both in the same package. // are both in the same package.
func buildBlacklist(files []*ast.File, info *types.Info, pkg *types.Package) (blacklist []types.Object) { //
// The blacklist mainly contains named types and their field declarations.
func buildBlacklist(files []*ast.File, info *types.Info, pkg *types.Package) map[types.Object]struct{} {
// Keep track of the current syntax tree level. If reflectCallLevel is // Keep track of the current syntax tree level. If reflectCallLevel is
// non-negative, we are under a reflect call. // non-negative, we are under a reflect call.
level := 0 level := 0
reflectCallLevel := -1 reflectCallLevel := -1
blacklist := make(map[types.Object]struct{})
addToBlacklist := func(named *types.Named) {
obj := named.Obj()
if obj == nil || obj.Pkg() != pkg {
return
}
blacklist[obj] = struct{}{}
strct, _ := named.Underlying().(*types.Struct)
if strct != nil {
for i := 0; i < strct.NumFields(); i++ {
blacklist[strct.Field(i)] = struct{}{}
}
}
}
visit := func(node ast.Node) bool { visit := func(node ast.Node) bool {
if node == nil { if node == nil {
if level == reflectCallLevel { if level == reflectCallLevel {
@ -468,8 +485,9 @@ func buildBlacklist(files []*ast.File, info *types.Info, pkg *types.Package) (bl
} }
if reflectCallLevel >= 0 && level >= reflectCallLevel { if reflectCallLevel >= 0 && level >= reflectCallLevel {
expr, _ := node.(ast.Expr) expr, _ := node.(ast.Expr)
if obj := objOf(info.TypeOf(expr)); obj != nil && obj.Pkg() == pkg { named := namedType(info.TypeOf(expr))
blacklist = append(blacklist, obj) if named != nil {
addToBlacklist(named)
} }
} }
level++ level++
@ -495,7 +513,7 @@ func buildBlacklist(files []*ast.File, info *types.Info, pkg *types.Package) (bl
} }
// transformGo garbles the provided Go syntax node. // transformGo garbles the provided Go syntax node.
func transformGo(file *ast.File, info *types.Info, blacklist []types.Object) *ast.File { func transformGo(file *ast.File, info *types.Info, blacklist map[types.Object]struct{}) *ast.File {
// Remove all comments, minus the "//go:" compiler directives. // Remove all comments, minus the "//go:" compiler directives.
// The final binary should still not contain comment text, but removing // The final binary should still not contain comment text, but removing
// it helps ensure that (and makes position info less predictable). // it helps ensure that (and makes position info less predictable).
@ -538,7 +556,7 @@ func transformGo(file *ast.File, info *types.Info, blacklist []types.Object) *as
if vr, ok := obj.(*types.Var); ok && vr.Embedded() { if vr, ok := obj.(*types.Var); ok && vr.Embedded() {
// ObjectOf returns the field for embedded struct // ObjectOf returns the field for embedded struct
// fields, not the type it uses. Use the type. // fields, not the type it uses. Use the type.
obj = objOf(obj.Type()) obj = namedType(obj.Type()).Obj()
pkg = obj.Pkg() pkg = obj.Pkg()
} }
@ -548,11 +566,9 @@ func transformGo(file *ast.File, info *types.Info, blacklist []types.Object) *as
return true // could be a Go plugin API return true // could be a Go plugin API
} }
// TODO: also do this for method receivers // The object itself is blacklisted, e.g. a type definition.
for _, item := range blacklist { if _, ok := blacklist[obj]; ok {
if obj == item { return true
return true
}
} }
// log.Printf("%#v %T", node, obj) // log.Printf("%#v %T", node, obj)
@ -620,15 +636,15 @@ func implementedOutsideGo(obj *types.Func) bool {
(obj.Scope() != nil && obj.Scope().Pos() == token.NoPos) (obj.Scope() != nil && obj.Scope().Pos() == token.NoPos)
} }
// objOf tries to obtain the object behind a *types.Named, even if it's behind a // named tries to obtain the *types.Named behind a type, if there is one.
// pointer type. This is useful to obtain "testing.T" from "*testing.T", or to // This is useful to obtain "testing.T" from "*testing.T", or to obtain the type
// obtain the type declaration object from an embedded field. // declaration object from an embedded field.
func objOf(t types.Type) types.Object { func namedType(t types.Type) *types.Named {
switch t := t.(type) { switch t := t.(type) {
case *types.Named: case *types.Named:
return t.Obj() return t
case interface{ Elem() types.Type }: case interface{ Elem() types.Type }:
return objOf(t.Elem()) return namedType(t.Elem())
default: default:
return nil return nil
} }
@ -643,7 +659,7 @@ func isTestSignature(sign *types.Signature) bool {
if params.Len() != 1 { if params.Len() != 1 {
return false return false
} }
obj := objOf(params.At(0).Type()) obj := namedType(params.At(0).Type()).Obj()
return obj != nil && obj.Pkg().Path() == "testing" && obj.Name() == "T" return obj != nil && obj.Pkg().Path() == "testing" && obj.Name() == "T"
} }

@ -23,6 +23,7 @@ package main
import ( import (
"fmt" "fmt"
"reflect"
_ "unsafe" _ "unsafe"
"test/main/imported" "test/main/imported"
@ -36,12 +37,20 @@ func linkedPrintln(a ...interface{}) (n int, err error)
func main() { func main() {
fmt.Println(imported.ImportedVar) fmt.Println(imported.ImportedVar)
fmt.Println(imported.ImportedConst) fmt.Println(imported.ImportedConst)
imported.ImportedFunc('x') fmt.Println(imported.ImportedFunc('x'))
fmt.Println(imported.ImportedType(3)) fmt.Println(imported.ImportedType(3))
fmt.Printf("%T\n", imported.ReflectTypeOf(2)) fmt.Printf("%T\n", imported.ReflectTypeOf(2))
fmt.Printf("%T\n", imported.ReflectTypeOfIndirect(4)) fmt.Printf("%T\n", imported.ReflectTypeOfIndirect(4))
fmt.Printf("%#v\n", imported.ReflectValueOfVar)
v := imported.ReflectValueOfVar
fmt.Printf("%#v\n", v)
method := reflect.ValueOf(&v).MethodByName("ExportedMethodName")
if method.IsValid() {
fmt.Println(method.Call(nil))
} else {
fmt.Println("method not found")
}
linkedPrintln(nil) linkedPrintln(nil)
fmt.Println(quote.Go()) fmt.Println(quote.Go())
@ -68,10 +77,14 @@ type ReflectTypeOfIndirect int
var _ = reflect.TypeOf(new([]*ReflectTypeOfIndirect)) var _ = reflect.TypeOf(new([]*ReflectTypeOfIndirect))
type ReflectValueOf struct { type ReflectValueOf struct {
Foo int `bar:"baz"` ExportedField string
unexportedField string
} }
var ReflectValueOfVar = ReflectValueOf{Foo: 3} func (r *ReflectValueOf) ExportedMethodName() string { return "method: "+r.ExportedField }
var ReflectValueOfVar = ReflectValueOf{ExportedField: "abc"}
var _ = reflect.TypeOf(ReflectValueOfVar) var _ = reflect.TypeOf(ReflectValueOfVar)
@ -81,9 +94,11 @@ type ImportedType int
-- main.stdout -- -- main.stdout --
imported var value imported var value
imported const value imported const value
x
3 3
imported.ReflectTypeOf imported.ReflectTypeOf
imported.ReflectTypeOfIndirect imported.ReflectTypeOfIndirect
imported.ReflectValueOf{Foo:3} imported.ReflectValueOf{ExportedField:"abc", unexportedField:""}
[method: abc]
<nil> <nil>
Don't communicate by sharing memory, share memory by communicating. Don't communicate by sharing memory, share memory by communicating.

Loading…
Cancel
Save