diff --git a/main.go b/main.go index 17030d0..a1ba761 100644 --- a/main.go +++ b/main.go @@ -1402,8 +1402,12 @@ func (tf *transformer) loadPkgCache(files []*ast.File) error { ssaPkg := ssaProg.CreatePackage(tf.pkg, files, tf.info, false) ssaPkg.Build() - tf.reflectCheckedAPIs = make(map[string]bool) - tf.recordReflection(ssaPkg) + inspector := reflectInspector{ + pkg: tf.pkg, + checkedAPIs: make(map[string]bool), + result: curPkgCache, // append the results + } + inspector.recordReflection(ssaPkg) // Unlikely that we could stream the gob encode, as cache.Put wants an io.ReadSeeker. var buf bytes.Buffer @@ -1490,8 +1494,6 @@ type transformer struct { // fieldToStruct helps locate struct types from any of their field // objects. Useful when obfuscating field names. fieldToStruct map[*types.Var]*types.Struct - - reflectCheckedAPIs map[string]bool } func (tf *transformer) typecheck(files []*ast.File) error { @@ -1778,6 +1780,7 @@ func (tf *transformer) transformGoFile(file *ast.File) *ast.File { // TODO: We match by object name here, which is actually imprecise. // For example, in package embed we match the type FS, but we would also // match any field or method named FS. + // Can we instead use an object map like ReflectObjects? path := pkg.Path() switch path { case "sync/atomic", "runtime/internal/atomic": @@ -1795,13 +1798,6 @@ func (tf *transformer) transformGoFile(file *ast.File) *ast.File { // the Method and MethodByName methods are what drive the logic. case "Method", "MethodByName": return true - // Some packages reach into reflect internals, like go-spew. - // It's not particularly right of them to do that, - // and it's entirely unsupported, but try to accomodate for now. - // At least it's enough to leave the rtype and Value types intact. - case "rtype", "Value": - tf.recursivelyRecordUsedForReflect(obj.Type()) - return true } case "crypto/x509/pkix": // For better or worse, encoding/asn1 detects a "SET" suffix on slice type names diff --git a/reflect.go b/reflect.go index 0c542c5..d94cbd3 100644 --- a/reflect.go +++ b/reflect.go @@ -10,37 +10,55 @@ import ( "golang.org/x/tools/go/ssa" ) +type reflectInspector struct { + pkg *types.Package + + checkedAPIs map[string]bool + + result pkgCache +} + // Record all instances of reflection use, and don't obfuscate types which are used in reflection. -func (tf *transformer) recordReflection(ssaPkg *ssa.Package) { +func (ri *reflectInspector) recordReflection(ssaPkg *ssa.Package) { if reflectSkipPkg[ssaPkg.Pkg.Path()] { return } - lenPrevReflectAPIs := len(curPkgCache.ReflectAPIs) + lenPrevReflectAPIs := len(ri.result.ReflectAPIs) // find all unchecked APIs to add them to checkedAPIs after the pass notCheckedAPIs := make(map[string]bool) - for _, knownAPI := range maps.Keys(curPkgCache.ReflectAPIs) { - if !tf.reflectCheckedAPIs[knownAPI] { + for _, knownAPI := range maps.Keys(ri.result.ReflectAPIs) { + if !ri.checkedAPIs[knownAPI] { notCheckedAPIs[knownAPI] = true } } - tf.ignoreReflectedTypes(ssaPkg) + ri.ignoreReflectedTypes(ssaPkg) // all previously unchecked APIs have now been checked add them to checkedAPIs, // to avoid checking them twice - maps.Copy(tf.reflectCheckedAPIs, notCheckedAPIs) + maps.Copy(ri.checkedAPIs, notCheckedAPIs) // if a new reflectAPI is found we need to Re-evaluate all functions which might be using that API - if len(curPkgCache.ReflectAPIs) > lenPrevReflectAPIs { - tf.recordReflection(ssaPkg) + if len(ri.result.ReflectAPIs) > lenPrevReflectAPIs { + ri.recordReflection(ssaPkg) } } // find all functions, methods and interface declarations of a package and record their // reflection use -func (tf *transformer) ignoreReflectedTypes(ssaPkg *ssa.Package) { +func (ri *reflectInspector) ignoreReflectedTypes(ssaPkg *ssa.Package) { + // Some packages reach into reflect internals, like go-spew. + // It's not particularly right of them to do that, + // and it's entirely unsupported, but try to accomodate for now. + // At least it's enough to leave the rtype and Value types intact. + if ri.pkg.Path() == "reflect" { + scope := ri.pkg.Scope() + ri.recursivelyRecordUsedForReflect(scope.Lookup("rtype").Type()) + ri.recursivelyRecordUsedForReflect(scope.Lookup("Value").Type()) + } + for _, memb := range ssaPkg.Members { switch x := memb.(type) { case *ssa.Type: @@ -52,11 +70,11 @@ func (tf *transformer) ignoreReflectedTypes(ssaPkg *ssa.Package) { at := mset.At(i) if m := ssaPkg.Prog.MethodValue(at); m != nil { - tf.checkFunction(m) + ri.checkFunction(m) } else { m := at.Obj().(*types.Func) // handle interface declarations - tf.checkInterfaceMethod(m) + ri.checkInterfaceMethod(m) } } @@ -73,7 +91,7 @@ func (tf *transformer) ignoreReflectedTypes(ssaPkg *ssa.Package) { // these not only include top level functions, but also synthetic // functions like the initialization of global variables - tf.checkFunction(x) + ri.checkFunction(x) } } } @@ -85,7 +103,7 @@ func (tf *transformer) ignoreReflectedTypes(ssaPkg *ssa.Package) { // and treat them like a parameter which is actually used in reflection. // // See "UnnamedStructMethod" in the reflect.txtar test for an example. -func (tf *transformer) checkMethodSignature(reflectParams map[int]bool, sig *types.Signature) { +func (ri *reflectInspector) checkMethodSignature(reflectParams map[int]bool, sig *types.Signature) { if sig.Recv() == nil { return } @@ -114,31 +132,31 @@ func (tf *transformer) checkMethodSignature(reflectParams map[int]bool, sig *typ if ignore { reflectParams[i] = true - tf.recursivelyRecordUsedForReflect(param.Type()) + ri.recursivelyRecordUsedForReflect(param.Type()) } } } // Checks the signature of an interface method for potential reflection use. -func (tf *transformer) checkInterfaceMethod(m *types.Func) { +func (ri *reflectInspector) checkInterfaceMethod(m *types.Func) { reflectParams := make(map[int]bool) - maps.Copy(reflectParams, curPkgCache.ReflectAPIs[m.FullName()]) + maps.Copy(reflectParams, ri.result.ReflectAPIs[m.FullName()]) sig := m.Type().(*types.Signature) if m.Exported() { - tf.checkMethodSignature(reflectParams, sig) + ri.checkMethodSignature(reflectParams, sig) } if len(reflectParams) > 0 { - curPkgCache.ReflectAPIs[m.FullName()] = reflectParams + ri.result.ReflectAPIs[m.FullName()] = reflectParams /* fmt.Printf("curPkgCache.ReflectAPIs: %v\n", curPkgCache.ReflectAPIs) */ } } // Checks all callsites in a function declaration for use of reflection. -func (tf *transformer) checkFunction(fun *ssa.Function) { +func (ri *reflectInspector) checkFunction(fun *ssa.Function) { /* if fun != nil && fun.Synthetic != "loaded from gc object file" { // fun.WriteTo crashes otherwise fun.WriteTo(os.Stdout) @@ -148,10 +166,10 @@ func (tf *transformer) checkFunction(fun *ssa.Function) { reflectParams := make(map[int]bool) if f != nil { - maps.Copy(reflectParams, curPkgCache.ReflectAPIs[f.FullName()]) + maps.Copy(reflectParams, ri.result.ReflectAPIs[f.FullName()]) if f.Exported() { - tf.checkMethodSignature(reflectParams, fun.Signature) + ri.checkMethodSignature(reflectParams, fun.Signature) } } @@ -171,7 +189,7 @@ func (tf *transformer) checkFunction(fun *ssa.Function) { callName = call.Call.Method.FullName() } - if tf.reflectCheckedAPIs[callName] { + if ri.checkedAPIs[callName] { // only check apis which were not already checked continue } @@ -179,7 +197,7 @@ func (tf *transformer) checkFunction(fun *ssa.Function) { /* fmt.Printf("callName: %v\n", callName) */ // record each call argument passed to a function parameter which is used in reflection - knownParams := curPkgCache.ReflectAPIs[callName] + knownParams := ri.result.ReflectAPIs[callName] for knownParam := range knownParams { if len(call.Call.Args) <= knownParam { continue @@ -190,7 +208,7 @@ func (tf *transformer) checkFunction(fun *ssa.Function) { /* fmt.Printf("flagging arg: %v\n", arg) */ visited := make(map[ssa.Value]bool) - reflectedParam := tf.recordArgReflected(arg, visited) + reflectedParam := ri.recordArgReflected(arg, visited) if reflectedParam == nil { continue } @@ -208,7 +226,7 @@ func (tf *transformer) checkFunction(fun *ssa.Function) { } if len(reflectParams) > 0 { - curPkgCache.ReflectAPIs[f.FullName()] = reflectParams + ri.result.ReflectAPIs[f.FullName()] = reflectParams /* fmt.Printf("curPkgCache.ReflectAPIs: %v\n", curPkgCache.ReflectAPIs) */ } @@ -217,7 +235,7 @@ func (tf *transformer) checkFunction(fun *ssa.Function) { // recordArgReflected finds the type(s) of a function argument, which is being used in reflection // and excludes these types from obfuscation // It also checks if this argument has any relation to a function paramter and returns it if found. -func (tf *transformer) recordArgReflected(val ssa.Value, visited map[ssa.Value]bool) *ssa.Parameter { +func (ri *reflectInspector) recordArgReflected(val ssa.Value, visited map[ssa.Value]bool) *ssa.Parameter { // make sure we visit every val only once, otherwise there will be infinite recursion if visited[val] { return nil @@ -230,26 +248,26 @@ func (tf *transformer) recordArgReflected(val ssa.Value, visited map[ssa.Value]b case *ssa.IndexAddr: for _, ref := range *val.Referrers() { if store, ok := ref.(*ssa.Store); ok { - tf.recordArgReflected(store.Val, visited) + ri.recordArgReflected(store.Val, visited) } } - return tf.recordArgReflected(val.X, visited) + return ri.recordArgReflected(val.X, visited) case *ssa.Slice: - return tf.recordArgReflected(val.X, visited) + return ri.recordArgReflected(val.X, visited) case *ssa.MakeInterface: - return tf.recordArgReflected(val.X, visited) + return ri.recordArgReflected(val.X, visited) case *ssa.UnOp: - return tf.recordArgReflected(val.X, visited) + return ri.recordArgReflected(val.X, visited) case *ssa.FieldAddr: - return tf.recordArgReflected(val.X, visited) + return ri.recordArgReflected(val.X, visited) case *ssa.Alloc: /* fmt.Printf("recording val %v \n", *val.Referrers()) */ - tf.recursivelyRecordUsedForReflect(val.Type()) + ri.recursivelyRecordUsedForReflect(val.Type()) for _, ref := range *val.Referrers() { if idx, ok := ref.(*ssa.IndexAddr); ok { - tf.recordArgReflected(idx, visited) + ri.recordArgReflected(idx, visited) } } @@ -260,9 +278,9 @@ func (tf *transformer) recordArgReflected(val ssa.Value, visited map[ssa.Value]b return relatedParam(val, visited) case *ssa.Const: - tf.recursivelyRecordUsedForReflect(val.Type()) + ri.recursivelyRecordUsedForReflect(val.Type()) case *ssa.Global: - tf.recursivelyRecordUsedForReflect(val.Type()) + ri.recursivelyRecordUsedForReflect(val.Type()) // TODO: this might need similar logic to *ssa.Alloc, however // reassigning a function param to a global variable and then reflecting @@ -271,7 +289,7 @@ func (tf *transformer) recordArgReflected(val ssa.Value, visited map[ssa.Value]b // this only finds the parameters who want to be found, // otherwise relatedParam is used for more in depth analysis - tf.recursivelyRecordUsedForReflect(val.Type()) + ri.recursivelyRecordUsedForReflect(val.Type()) return val } @@ -347,23 +365,23 @@ func relatedParam(val ssa.Value, visited map[ssa.Value]bool) *ssa.Parameter { // Only the names declared in the current package are recorded. This is to ensure // that reflection detection only happens within the package declaring a type. // Detecting it in downstream packages could result in inconsistencies. -func (tf *transformer) recursivelyRecordUsedForReflect(t types.Type) { +func (ri *reflectInspector) recursivelyRecordUsedForReflect(t types.Type) { switch t := t.(type) { case *types.Named: obj := t.Obj() // TODO: the transformer is only needed in this function, there is // probably a way to do this with only the ssa information. - if obj.Pkg() == nil || obj.Pkg() != tf.pkg { + if obj.Pkg() == nil || obj.Pkg() != ri.pkg { return // not from the specified package } if usedForReflect(obj) { return // prevent endless recursion } - recordUsedForReflect(obj) + ri.recordUsedForReflect(obj) // Record the underlying type, too. - tf.recursivelyRecordUsedForReflect(t.Underlying()) + ri.recursivelyRecordUsedForReflect(t.Underlying()) case *types.Struct: for i := 0; i < t.NumFields(); i++ { @@ -372,19 +390,19 @@ func (tf *transformer) recursivelyRecordUsedForReflect(t types.Type) { // This check is similar to the one in *types.Named. // It's necessary for unnamed struct types, // as they aren't named but still have named fields. - if field.Pkg() == nil || field.Pkg() != tf.pkg { + if field.Pkg() == nil || field.Pkg() != ri.pkg { return // not from the specified package } // Record the field itself, too. - recordUsedForReflect(field) + ri.recordUsedForReflect(field) - tf.recursivelyRecordUsedForReflect(field.Type()) + ri.recursivelyRecordUsedForReflect(field.Type()) } case interface{ Elem() types.Type }: // Get past pointers, slices, etc. - tf.recursivelyRecordUsedForReflect(t.Elem()) + ri.recursivelyRecordUsedForReflect(t.Elem()) } } @@ -423,8 +441,8 @@ func recordedObjectString(obj types.Object) objectString { // recordUsedForReflect records the objects whose names we cannot obfuscate due to reflection. // We currently record named types and fields. -func recordUsedForReflect(obj types.Object) { - if obj.Pkg().Path() != curPkg.ImportPath { +func (ri *reflectInspector) recordUsedForReflect(obj types.Object) { + if obj.Pkg().Path() != ri.pkg.Path() { panic("called recordUsedForReflect with a foreign object") } objStr := recordedObjectString(obj) @@ -433,7 +451,7 @@ func recordUsedForReflect(obj types.Object) { // do we need to record it at all? return } - curPkgCache.ReflectObjects[objStr] = struct{}{} + ri.result.ReflectObjects[objStr] = struct{}{} } func usedForReflect(obj types.Object) bool {