isolate reflect.go from updating globals directly

That is, stop reusing "transformer" as the receiver on methods,
and stop writing the results to the global curPkgCache struct.

Soon we will need to support computing pkgCache for any dependency,
not just the current package, to make the caching properly robust.
This allows us to fill reflectInspector with different values.

The explicit isolation also helps prevent bugs.
For instance, we were calling recursivelyRecordAsNotObfuscated from
transformCompile, which happens after we have loaded or saved pkgCache.
Meaning, the current package sees a larger pkgCache than its dependents.
In this particular case it wasn't causing any bugs,
since the two reflect types in question only had unexported fields,
but it's still good to treat pkgCache as read-only in transformCompile.
pull/755/head
Daniel Martí 2 years ago
parent d108f21846
commit 4b0b2acf6f

@ -1402,8 +1402,12 @@ func (tf *transformer) loadPkgCache(files []*ast.File) error {
ssaPkg := ssaProg.CreatePackage(tf.pkg, files, tf.info, false) ssaPkg := ssaProg.CreatePackage(tf.pkg, files, tf.info, false)
ssaPkg.Build() ssaPkg.Build()
tf.reflectCheckedAPIs = make(map[string]bool) inspector := reflectInspector{
tf.recordReflection(ssaPkg) pkg: tf.pkg,
checkedAPIs: make(map[string]bool),
result: curPkgCache, // append the results
}
inspector.recordReflection(ssaPkg)
// Unlikely that we could stream the gob encode, as cache.Put wants an io.ReadSeeker. // Unlikely that we could stream the gob encode, as cache.Put wants an io.ReadSeeker.
var buf bytes.Buffer var buf bytes.Buffer
@ -1490,8 +1494,6 @@ type transformer struct {
// fieldToStruct helps locate struct types from any of their field // fieldToStruct helps locate struct types from any of their field
// objects. Useful when obfuscating field names. // objects. Useful when obfuscating field names.
fieldToStruct map[*types.Var]*types.Struct fieldToStruct map[*types.Var]*types.Struct
reflectCheckedAPIs map[string]bool
} }
func (tf *transformer) typecheck(files []*ast.File) error { func (tf *transformer) typecheck(files []*ast.File) error {
@ -1778,6 +1780,7 @@ func (tf *transformer) transformGoFile(file *ast.File) *ast.File {
// TODO: We match by object name here, which is actually imprecise. // TODO: We match by object name here, which is actually imprecise.
// For example, in package embed we match the type FS, but we would also // For example, in package embed we match the type FS, but we would also
// match any field or method named FS. // match any field or method named FS.
// Can we instead use an object map like ReflectObjects?
path := pkg.Path() path := pkg.Path()
switch path { switch path {
case "sync/atomic", "runtime/internal/atomic": case "sync/atomic", "runtime/internal/atomic":
@ -1795,13 +1798,6 @@ func (tf *transformer) transformGoFile(file *ast.File) *ast.File {
// the Method and MethodByName methods are what drive the logic. // the Method and MethodByName methods are what drive the logic.
case "Method", "MethodByName": case "Method", "MethodByName":
return true return true
// Some packages reach into reflect internals, like go-spew.
// It's not particularly right of them to do that,
// and it's entirely unsupported, but try to accomodate for now.
// At least it's enough to leave the rtype and Value types intact.
case "rtype", "Value":
tf.recursivelyRecordUsedForReflect(obj.Type())
return true
} }
case "crypto/x509/pkix": case "crypto/x509/pkix":
// For better or worse, encoding/asn1 detects a "SET" suffix on slice type names // For better or worse, encoding/asn1 detects a "SET" suffix on slice type names

@ -10,37 +10,55 @@ import (
"golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa"
) )
type reflectInspector struct {
pkg *types.Package
checkedAPIs map[string]bool
result pkgCache
}
// Record all instances of reflection use, and don't obfuscate types which are used in reflection. // Record all instances of reflection use, and don't obfuscate types which are used in reflection.
func (tf *transformer) recordReflection(ssaPkg *ssa.Package) { func (ri *reflectInspector) recordReflection(ssaPkg *ssa.Package) {
if reflectSkipPkg[ssaPkg.Pkg.Path()] { if reflectSkipPkg[ssaPkg.Pkg.Path()] {
return return
} }
lenPrevReflectAPIs := len(curPkgCache.ReflectAPIs) lenPrevReflectAPIs := len(ri.result.ReflectAPIs)
// find all unchecked APIs to add them to checkedAPIs after the pass // find all unchecked APIs to add them to checkedAPIs after the pass
notCheckedAPIs := make(map[string]bool) notCheckedAPIs := make(map[string]bool)
for _, knownAPI := range maps.Keys(curPkgCache.ReflectAPIs) { for _, knownAPI := range maps.Keys(ri.result.ReflectAPIs) {
if !tf.reflectCheckedAPIs[knownAPI] { if !ri.checkedAPIs[knownAPI] {
notCheckedAPIs[knownAPI] = true notCheckedAPIs[knownAPI] = true
} }
} }
tf.ignoreReflectedTypes(ssaPkg) ri.ignoreReflectedTypes(ssaPkg)
// all previously unchecked APIs have now been checked add them to checkedAPIs, // all previously unchecked APIs have now been checked add them to checkedAPIs,
// to avoid checking them twice // to avoid checking them twice
maps.Copy(tf.reflectCheckedAPIs, notCheckedAPIs) maps.Copy(ri.checkedAPIs, notCheckedAPIs)
// if a new reflectAPI is found we need to Re-evaluate all functions which might be using that API // if a new reflectAPI is found we need to Re-evaluate all functions which might be using that API
if len(curPkgCache.ReflectAPIs) > lenPrevReflectAPIs { if len(ri.result.ReflectAPIs) > lenPrevReflectAPIs {
tf.recordReflection(ssaPkg) ri.recordReflection(ssaPkg)
} }
} }
// find all functions, methods and interface declarations of a package and record their // find all functions, methods and interface declarations of a package and record their
// reflection use // reflection use
func (tf *transformer) ignoreReflectedTypes(ssaPkg *ssa.Package) { func (ri *reflectInspector) ignoreReflectedTypes(ssaPkg *ssa.Package) {
// Some packages reach into reflect internals, like go-spew.
// It's not particularly right of them to do that,
// and it's entirely unsupported, but try to accomodate for now.
// At least it's enough to leave the rtype and Value types intact.
if ri.pkg.Path() == "reflect" {
scope := ri.pkg.Scope()
ri.recursivelyRecordUsedForReflect(scope.Lookup("rtype").Type())
ri.recursivelyRecordUsedForReflect(scope.Lookup("Value").Type())
}
for _, memb := range ssaPkg.Members { for _, memb := range ssaPkg.Members {
switch x := memb.(type) { switch x := memb.(type) {
case *ssa.Type: case *ssa.Type:
@ -52,11 +70,11 @@ func (tf *transformer) ignoreReflectedTypes(ssaPkg *ssa.Package) {
at := mset.At(i) at := mset.At(i)
if m := ssaPkg.Prog.MethodValue(at); m != nil { if m := ssaPkg.Prog.MethodValue(at); m != nil {
tf.checkFunction(m) ri.checkFunction(m)
} else { } else {
m := at.Obj().(*types.Func) m := at.Obj().(*types.Func)
// handle interface declarations // handle interface declarations
tf.checkInterfaceMethod(m) ri.checkInterfaceMethod(m)
} }
} }
@ -73,7 +91,7 @@ func (tf *transformer) ignoreReflectedTypes(ssaPkg *ssa.Package) {
// these not only include top level functions, but also synthetic // these not only include top level functions, but also synthetic
// functions like the initialization of global variables // functions like the initialization of global variables
tf.checkFunction(x) ri.checkFunction(x)
} }
} }
} }
@ -85,7 +103,7 @@ func (tf *transformer) ignoreReflectedTypes(ssaPkg *ssa.Package) {
// and treat them like a parameter which is actually used in reflection. // and treat them like a parameter which is actually used in reflection.
// //
// See "UnnamedStructMethod" in the reflect.txtar test for an example. // See "UnnamedStructMethod" in the reflect.txtar test for an example.
func (tf *transformer) checkMethodSignature(reflectParams map[int]bool, sig *types.Signature) { func (ri *reflectInspector) checkMethodSignature(reflectParams map[int]bool, sig *types.Signature) {
if sig.Recv() == nil { if sig.Recv() == nil {
return return
} }
@ -114,31 +132,31 @@ func (tf *transformer) checkMethodSignature(reflectParams map[int]bool, sig *typ
if ignore { if ignore {
reflectParams[i] = true reflectParams[i] = true
tf.recursivelyRecordUsedForReflect(param.Type()) ri.recursivelyRecordUsedForReflect(param.Type())
} }
} }
} }
// Checks the signature of an interface method for potential reflection use. // Checks the signature of an interface method for potential reflection use.
func (tf *transformer) checkInterfaceMethod(m *types.Func) { func (ri *reflectInspector) checkInterfaceMethod(m *types.Func) {
reflectParams := make(map[int]bool) reflectParams := make(map[int]bool)
maps.Copy(reflectParams, curPkgCache.ReflectAPIs[m.FullName()]) maps.Copy(reflectParams, ri.result.ReflectAPIs[m.FullName()])
sig := m.Type().(*types.Signature) sig := m.Type().(*types.Signature)
if m.Exported() { if m.Exported() {
tf.checkMethodSignature(reflectParams, sig) ri.checkMethodSignature(reflectParams, sig)
} }
if len(reflectParams) > 0 { if len(reflectParams) > 0 {
curPkgCache.ReflectAPIs[m.FullName()] = reflectParams ri.result.ReflectAPIs[m.FullName()] = reflectParams
/* fmt.Printf("curPkgCache.ReflectAPIs: %v\n", curPkgCache.ReflectAPIs) */ /* fmt.Printf("curPkgCache.ReflectAPIs: %v\n", curPkgCache.ReflectAPIs) */
} }
} }
// Checks all callsites in a function declaration for use of reflection. // Checks all callsites in a function declaration for use of reflection.
func (tf *transformer) checkFunction(fun *ssa.Function) { func (ri *reflectInspector) checkFunction(fun *ssa.Function) {
/* if fun != nil && fun.Synthetic != "loaded from gc object file" { /* if fun != nil && fun.Synthetic != "loaded from gc object file" {
// fun.WriteTo crashes otherwise // fun.WriteTo crashes otherwise
fun.WriteTo(os.Stdout) fun.WriteTo(os.Stdout)
@ -148,10 +166,10 @@ func (tf *transformer) checkFunction(fun *ssa.Function) {
reflectParams := make(map[int]bool) reflectParams := make(map[int]bool)
if f != nil { if f != nil {
maps.Copy(reflectParams, curPkgCache.ReflectAPIs[f.FullName()]) maps.Copy(reflectParams, ri.result.ReflectAPIs[f.FullName()])
if f.Exported() { if f.Exported() {
tf.checkMethodSignature(reflectParams, fun.Signature) ri.checkMethodSignature(reflectParams, fun.Signature)
} }
} }
@ -171,7 +189,7 @@ func (tf *transformer) checkFunction(fun *ssa.Function) {
callName = call.Call.Method.FullName() callName = call.Call.Method.FullName()
} }
if tf.reflectCheckedAPIs[callName] { if ri.checkedAPIs[callName] {
// only check apis which were not already checked // only check apis which were not already checked
continue continue
} }
@ -179,7 +197,7 @@ func (tf *transformer) checkFunction(fun *ssa.Function) {
/* fmt.Printf("callName: %v\n", callName) */ /* fmt.Printf("callName: %v\n", callName) */
// record each call argument passed to a function parameter which is used in reflection // record each call argument passed to a function parameter which is used in reflection
knownParams := curPkgCache.ReflectAPIs[callName] knownParams := ri.result.ReflectAPIs[callName]
for knownParam := range knownParams { for knownParam := range knownParams {
if len(call.Call.Args) <= knownParam { if len(call.Call.Args) <= knownParam {
continue continue
@ -190,7 +208,7 @@ func (tf *transformer) checkFunction(fun *ssa.Function) {
/* fmt.Printf("flagging arg: %v\n", arg) */ /* fmt.Printf("flagging arg: %v\n", arg) */
visited := make(map[ssa.Value]bool) visited := make(map[ssa.Value]bool)
reflectedParam := tf.recordArgReflected(arg, visited) reflectedParam := ri.recordArgReflected(arg, visited)
if reflectedParam == nil { if reflectedParam == nil {
continue continue
} }
@ -208,7 +226,7 @@ func (tf *transformer) checkFunction(fun *ssa.Function) {
} }
if len(reflectParams) > 0 { if len(reflectParams) > 0 {
curPkgCache.ReflectAPIs[f.FullName()] = reflectParams ri.result.ReflectAPIs[f.FullName()] = reflectParams
/* fmt.Printf("curPkgCache.ReflectAPIs: %v\n", curPkgCache.ReflectAPIs) */ /* fmt.Printf("curPkgCache.ReflectAPIs: %v\n", curPkgCache.ReflectAPIs) */
} }
@ -217,7 +235,7 @@ func (tf *transformer) checkFunction(fun *ssa.Function) {
// recordArgReflected finds the type(s) of a function argument, which is being used in reflection // recordArgReflected finds the type(s) of a function argument, which is being used in reflection
// and excludes these types from obfuscation // and excludes these types from obfuscation
// It also checks if this argument has any relation to a function paramter and returns it if found. // It also checks if this argument has any relation to a function paramter and returns it if found.
func (tf *transformer) recordArgReflected(val ssa.Value, visited map[ssa.Value]bool) *ssa.Parameter { func (ri *reflectInspector) recordArgReflected(val ssa.Value, visited map[ssa.Value]bool) *ssa.Parameter {
// make sure we visit every val only once, otherwise there will be infinite recursion // make sure we visit every val only once, otherwise there will be infinite recursion
if visited[val] { if visited[val] {
return nil return nil
@ -230,26 +248,26 @@ func (tf *transformer) recordArgReflected(val ssa.Value, visited map[ssa.Value]b
case *ssa.IndexAddr: case *ssa.IndexAddr:
for _, ref := range *val.Referrers() { for _, ref := range *val.Referrers() {
if store, ok := ref.(*ssa.Store); ok { if store, ok := ref.(*ssa.Store); ok {
tf.recordArgReflected(store.Val, visited) ri.recordArgReflected(store.Val, visited)
} }
} }
return tf.recordArgReflected(val.X, visited) return ri.recordArgReflected(val.X, visited)
case *ssa.Slice: case *ssa.Slice:
return tf.recordArgReflected(val.X, visited) return ri.recordArgReflected(val.X, visited)
case *ssa.MakeInterface: case *ssa.MakeInterface:
return tf.recordArgReflected(val.X, visited) return ri.recordArgReflected(val.X, visited)
case *ssa.UnOp: case *ssa.UnOp:
return tf.recordArgReflected(val.X, visited) return ri.recordArgReflected(val.X, visited)
case *ssa.FieldAddr: case *ssa.FieldAddr:
return tf.recordArgReflected(val.X, visited) return ri.recordArgReflected(val.X, visited)
case *ssa.Alloc: case *ssa.Alloc:
/* fmt.Printf("recording val %v \n", *val.Referrers()) */ /* fmt.Printf("recording val %v \n", *val.Referrers()) */
tf.recursivelyRecordUsedForReflect(val.Type()) ri.recursivelyRecordUsedForReflect(val.Type())
for _, ref := range *val.Referrers() { for _, ref := range *val.Referrers() {
if idx, ok := ref.(*ssa.IndexAddr); ok { if idx, ok := ref.(*ssa.IndexAddr); ok {
tf.recordArgReflected(idx, visited) ri.recordArgReflected(idx, visited)
} }
} }
@ -260,9 +278,9 @@ func (tf *transformer) recordArgReflected(val ssa.Value, visited map[ssa.Value]b
return relatedParam(val, visited) return relatedParam(val, visited)
case *ssa.Const: case *ssa.Const:
tf.recursivelyRecordUsedForReflect(val.Type()) ri.recursivelyRecordUsedForReflect(val.Type())
case *ssa.Global: case *ssa.Global:
tf.recursivelyRecordUsedForReflect(val.Type()) ri.recursivelyRecordUsedForReflect(val.Type())
// TODO: this might need similar logic to *ssa.Alloc, however // TODO: this might need similar logic to *ssa.Alloc, however
// reassigning a function param to a global variable and then reflecting // reassigning a function param to a global variable and then reflecting
@ -271,7 +289,7 @@ func (tf *transformer) recordArgReflected(val ssa.Value, visited map[ssa.Value]b
// this only finds the parameters who want to be found, // this only finds the parameters who want to be found,
// otherwise relatedParam is used for more in depth analysis // otherwise relatedParam is used for more in depth analysis
tf.recursivelyRecordUsedForReflect(val.Type()) ri.recursivelyRecordUsedForReflect(val.Type())
return val return val
} }
@ -347,23 +365,23 @@ func relatedParam(val ssa.Value, visited map[ssa.Value]bool) *ssa.Parameter {
// Only the names declared in the current package are recorded. This is to ensure // Only the names declared in the current package are recorded. This is to ensure
// that reflection detection only happens within the package declaring a type. // that reflection detection only happens within the package declaring a type.
// Detecting it in downstream packages could result in inconsistencies. // Detecting it in downstream packages could result in inconsistencies.
func (tf *transformer) recursivelyRecordUsedForReflect(t types.Type) { func (ri *reflectInspector) recursivelyRecordUsedForReflect(t types.Type) {
switch t := t.(type) { switch t := t.(type) {
case *types.Named: case *types.Named:
obj := t.Obj() obj := t.Obj()
// TODO: the transformer is only needed in this function, there is // TODO: the transformer is only needed in this function, there is
// probably a way to do this with only the ssa information. // probably a way to do this with only the ssa information.
if obj.Pkg() == nil || obj.Pkg() != tf.pkg { if obj.Pkg() == nil || obj.Pkg() != ri.pkg {
return // not from the specified package return // not from the specified package
} }
if usedForReflect(obj) { if usedForReflect(obj) {
return // prevent endless recursion return // prevent endless recursion
} }
recordUsedForReflect(obj) ri.recordUsedForReflect(obj)
// Record the underlying type, too. // Record the underlying type, too.
tf.recursivelyRecordUsedForReflect(t.Underlying()) ri.recursivelyRecordUsedForReflect(t.Underlying())
case *types.Struct: case *types.Struct:
for i := 0; i < t.NumFields(); i++ { for i := 0; i < t.NumFields(); i++ {
@ -372,19 +390,19 @@ func (tf *transformer) recursivelyRecordUsedForReflect(t types.Type) {
// This check is similar to the one in *types.Named. // This check is similar to the one in *types.Named.
// It's necessary for unnamed struct types, // It's necessary for unnamed struct types,
// as they aren't named but still have named fields. // as they aren't named but still have named fields.
if field.Pkg() == nil || field.Pkg() != tf.pkg { if field.Pkg() == nil || field.Pkg() != ri.pkg {
return // not from the specified package return // not from the specified package
} }
// Record the field itself, too. // Record the field itself, too.
recordUsedForReflect(field) ri.recordUsedForReflect(field)
tf.recursivelyRecordUsedForReflect(field.Type()) ri.recursivelyRecordUsedForReflect(field.Type())
} }
case interface{ Elem() types.Type }: case interface{ Elem() types.Type }:
// Get past pointers, slices, etc. // Get past pointers, slices, etc.
tf.recursivelyRecordUsedForReflect(t.Elem()) ri.recursivelyRecordUsedForReflect(t.Elem())
} }
} }
@ -423,8 +441,8 @@ func recordedObjectString(obj types.Object) objectString {
// recordUsedForReflect records the objects whose names we cannot obfuscate due to reflection. // recordUsedForReflect records the objects whose names we cannot obfuscate due to reflection.
// We currently record named types and fields. // We currently record named types and fields.
func recordUsedForReflect(obj types.Object) { func (ri *reflectInspector) recordUsedForReflect(obj types.Object) {
if obj.Pkg().Path() != curPkg.ImportPath { if obj.Pkg().Path() != ri.pkg.Path() {
panic("called recordUsedForReflect with a foreign object") panic("called recordUsedForReflect with a foreign object")
} }
objStr := recordedObjectString(obj) objStr := recordedObjectString(obj)
@ -433,7 +451,7 @@ func recordUsedForReflect(obj types.Object) {
// do we need to record it at all? // do we need to record it at all?
return return
} }
curPkgCache.ReflectObjects[objStr] = struct{}{} ri.result.ReflectObjects[objStr] = struct{}{}
} }
func usedForReflect(obj types.Object) bool { func usedForReflect(obj types.Object) bool {

Loading…
Cancel
Save