diff --git a/main.go b/main.go index 8f75048..d9845bc 100644 --- a/main.go +++ b/main.go @@ -973,7 +973,7 @@ func (tf *transformer) recordReflectArgs(files []*ast.File) { if fnType.Pkg().Path() == "reflect" && (fnType.Name() == "TypeOf" || fnType.Name() == "ValueOf") { for _, arg := range call.Args { argType := tf.info.TypeOf(arg) - tf.recordIgnore(argType, false) + tf.recordIgnore(argType, tf.pkg.Path()) } } return true @@ -1186,7 +1186,7 @@ func (tf *transformer) transformGo(file *ast.File) *ast.File { // TODO(mvdan): add a test and think how to fix this if obfPkg := obfuscatedTypesPackage(path); obfPkg != nil { if obfPkg.Scope().Lookup(named.Obj().Name()) != nil { - tf.recordIgnore(named, true) + tf.recordIgnore(named, path) return true } } @@ -1207,7 +1207,7 @@ func (tf *transformer) transformGo(file *ast.File) *ast.File { // The type is directly referenced by name, // so obfuscatedTypesPackage can't return nil. if obfuscatedTypesPackage(path).Scope().Lookup(obj.Name()) != nil { - tf.recordIgnore(named, true) + tf.recordIgnore(named, path) return true } } @@ -1280,15 +1280,15 @@ func (tf *transformer) transformGo(file *ast.File) *ast.File { // recordIgnore adds any named types (including fields) under typ to // ignoreObjects. // -// When allPkgs is false, we stop if we encounter a named type defined in a -// dependency package. This is useful to only record uses of reflection on local -// types. -func (tf *transformer) recordIgnore(t types.Type, allPkgs bool) { +// Only the names declared in package pkgPath are recorded. This is to ensure +// that reflection detection only happens within the package declaring a type. +// Detecting it in downstream packages could result in inconsistencies. +func (tf *transformer) recordIgnore(t types.Type, pkgPath string) { switch t := t.(type) { case *types.Named: obj := t.Obj() - if !allPkgs && obj.Pkg() != tf.pkg { - return // not from the current package + if obj.Pkg() == nil || obj.Pkg().Path() != pkgPath { + return // not from the specified package } if tf.ignoreObjects[obj] { return // prevent endless recursion @@ -1296,7 +1296,7 @@ func (tf *transformer) recordIgnore(t types.Type, allPkgs bool) { tf.ignoreObjects[obj] = true // Record the underlying type, too. - tf.recordIgnore(t.Underlying(), allPkgs) + tf.recordIgnore(t.Underlying(), pkgPath) case *types.Struct: for i := 0; i < t.NumFields(); i++ { @@ -1305,19 +1305,19 @@ func (tf *transformer) recordIgnore(t types.Type, allPkgs bool) { // This check is similar to the one in *types.Named. // It's necessary for unnamed struct types, // as they aren't named but still have named fields. - if !allPkgs && field.Pkg() != tf.pkg { - return // not from the current package + if field.Pkg() == nil || field.Pkg().Path() != pkgPath { + return // not from the specified package } // Record the field itself, too. tf.ignoreObjects[field] = true - tf.recordIgnore(field.Type(), allPkgs) + tf.recordIgnore(field.Type(), pkgPath) } case interface{ Elem() types.Type }: // Get past pointers, slices, etc. - tf.recordIgnore(t.Elem(), allPkgs) + tf.recordIgnore(t.Elem(), pkgPath) } } diff --git a/testdata/scripts/reflect.txt b/testdata/scripts/reflect.txt index ecfbadf..919fc2f 100644 --- a/testdata/scripts/reflect.txt +++ b/testdata/scripts/reflect.txt @@ -28,6 +28,7 @@ import ( "strings" "test/main/importedpkg" + "test/main/importedpkg2" ) func main() { @@ -68,7 +69,16 @@ func main() { // Simply using the field name here used to cause build failures. _ = reflect.TypeOf(importedpkg.UnnamedWithDownstreamReflect{}) fmt.Printf("%v\n", importedpkg.UnnamedWithDownstreamReflect{ - ExportedField: "foo", + ExportedField: "downstream", + }) + + // An edge case; the struct type is defined in package importedpkg2. + // importedpkg2 does not use reflection on it, so it's not obfuscated there. + // importedpkg uses reflection on a type containing ReflectInSiblingImport. + // If our logic is incorrect, we might inconsistently obfuscate the type. + // We should not obfuscate it when building any package. + fmt.Printf("%v\n", importedpkg2.ReflectInSiblingImport{ + ExportedField: "sibling", }) } @@ -97,6 +107,8 @@ package importedpkg import ( "reflect" + + "test/main/importedpkg2" ) type ReflectTypeOf int @@ -123,6 +135,8 @@ type ReflectInDefined struct { ExportedField2 int unexportedField2 int + + importedpkg2.ReflectInSiblingImport } var ReflectInDefinedVar = ReflectInDefined{ExportedField2: 9000} @@ -146,11 +160,19 @@ type UnnamedWithDownstreamReflect = struct { ExportedField string } +-- importedpkg2/imported2.go -- +package importedpkg2 + +type ReflectInSiblingImport struct { + ExportedField string +} + -- main.stdout -- 9000 -{5 0} +{5 0 {}} ReflectTypeOf ReflectTypeOfIndirect ReflectValueOf{ExportedField:"abc", unexportedField:""} [method: abc] -{foo} +{downstream} +{sibling}