track converted types when recording reflection usage

Fixes #763.
Fixes #782.
Fixes #785.
Fixes #807.
pull/809/head
Paul Scheduikat 1 year ago committed by GitHub
parent 4271bc45ae
commit bec8043790
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1543,10 +1543,10 @@ func computePkgCache(fsCache *cache.Cache, lpkg *listedPackage, pkg *types.Packa
// Fill the reflect info from SSA, which builds on top of the syntax tree and type info. // Fill the reflect info from SSA, which builds on top of the syntax tree and type info.
inspector := reflectInspector{ inspector := reflectInspector{
pkg: pkg, pkg: pkg,
checkedAPIs: make(map[string]bool), checkedAPIs: make(map[string]bool),
propagatedStores: map[*ssa.Store]bool{}, propagatedInstr: map[ssa.Instruction]bool{},
result: computed, // append the results result: computed, // append the results
} }
if ssaPkg == nil { if ssaPkg == nil {
ssaPkg = ssaBuildPkg(pkg, files, info) ssaPkg = ssaBuildPkg(pkg, files, info)

@ -15,7 +15,7 @@ type reflectInspector struct {
checkedAPIs map[string]bool checkedAPIs map[string]bool
propagatedStores map[*ssa.Store]bool propagatedInstr map[ssa.Instruction]bool
result pkgCache result pkgCache
} }
@ -26,7 +26,7 @@ func (ri *reflectInspector) recordReflection(ssaPkg *ssa.Package) {
return return
} }
prevDone := len(ri.result.ReflectAPIs) + len(ri.propagatedStores) prevDone := len(ri.result.ReflectAPIs) + len(ri.result.ReflectObjects)
// find all unchecked APIs to add them to checkedAPIs after the pass // find all unchecked APIs to add them to checkedAPIs after the pass
notCheckedAPIs := make(map[string]bool) notCheckedAPIs := make(map[string]bool)
@ -43,7 +43,7 @@ func (ri *reflectInspector) recordReflection(ssaPkg *ssa.Package) {
maps.Copy(ri.checkedAPIs, notCheckedAPIs) maps.Copy(ri.checkedAPIs, notCheckedAPIs)
// if a new reflectAPI is found we need to Re-evaluate all functions which might be using that API // if a new reflectAPI is found we need to Re-evaluate all functions which might be using that API
newDone := len(ri.result.ReflectAPIs) + len(ri.propagatedStores) newDone := len(ri.result.ReflectAPIs) + len(ri.result.ReflectObjects)
if newDone > prevDone { if newDone > prevDone {
ri.recordReflection(ssaPkg) // TODO: avoid recursing ri.recordReflection(ssaPkg) // TODO: avoid recursing
} }
@ -181,15 +181,23 @@ func (ri *reflectInspector) checkFunction(fun *ssa.Function) {
for _, block := range fun.Blocks { for _, block := range fun.Blocks {
for _, inst := range block.Instrs { for _, inst := range block.Instrs {
if ri.propagatedInstr[inst] {
break // already done
}
// fmt.Printf("inst: %v, t: %T\n", inst, inst) // fmt.Printf("inst: %v, t: %T\n", inst, inst)
switch inst := inst.(type) { switch inst := inst.(type) {
case *ssa.Store: case *ssa.Store:
if ri.propagatedStores[inst] { obj := typeToObj(inst.Addr.Type())
break // already done if usedForReflect(ri.result, obj) {
}
if storeAddrUsedForReflect(ri.result, inst.Addr.Type()) {
ri.recordArgReflected(inst.Val, make(map[ssa.Value]bool)) ri.recordArgReflected(inst.Val, make(map[ssa.Value]bool))
ri.propagatedStores[inst] = true ri.propagatedInstr[inst] = true
}
case *ssa.ChangeType:
obj := typeToObj(inst.X.Type())
if usedForReflect(ri.result, obj) {
ri.recursivelyRecordUsedForReflect(inst.Type())
ri.propagatedInstr[inst] = true
} }
case *ssa.Call: case *ssa.Call:
callName := inst.Call.Value.String() callName := inst.Call.Value.String()
@ -265,6 +273,11 @@ func (ri *reflectInspector) recordArgReflected(val ssa.Value, visited map[ssa.Va
case *ssa.MakeInterface: case *ssa.MakeInterface:
return ri.recordArgReflected(val.X, visited) return ri.recordArgReflected(val.X, visited)
case *ssa.UnOp: case *ssa.UnOp:
for _, ref := range *val.Referrers() {
if idx, ok := ref.(ssa.Value); ok {
ri.recordArgReflected(idx, visited)
}
}
return ri.recordArgReflected(val.X, visited) return ri.recordArgReflected(val.X, visited)
case *ssa.FieldAddr: case *ssa.FieldAddr:
return ri.recordArgReflected(val.X, visited) return ri.recordArgReflected(val.X, visited)
@ -274,7 +287,7 @@ func (ri *reflectInspector) recordArgReflected(val ssa.Value, visited map[ssa.Va
ri.recursivelyRecordUsedForReflect(val.Type()) ri.recursivelyRecordUsedForReflect(val.Type())
for _, ref := range *val.Referrers() { for _, ref := range *val.Referrers() {
if idx, ok := ref.(*ssa.IndexAddr); ok { if idx, ok := ref.(ssa.Value); ok {
ri.recordArgReflected(idx, visited) ri.recordArgReflected(idx, visited)
} }
} }
@ -285,6 +298,8 @@ func (ri *reflectInspector) recordArgReflected(val ssa.Value, visited map[ssa.Va
// check if the found alloc gets tainted by function parameters // check if the found alloc gets tainted by function parameters
return relatedParam(val, visited) return relatedParam(val, visited)
case *ssa.ChangeType:
ri.recursivelyRecordUsedForReflect(val.X.Type())
case *ssa.Const: case *ssa.Const:
ri.recursivelyRecordUsedForReflect(val.Type()) ri.recursivelyRecordUsedForReflect(val.Type())
case *ssa.Global: case *ssa.Global:
@ -414,7 +429,13 @@ func (ri *reflectInspector) recursivelyRecordUsedForReflect(t types.Type) {
// TODO: consider caching recordedObjectString via a map, // TODO: consider caching recordedObjectString via a map,
// if that shows an improvement in our benchmark // if that shows an improvement in our benchmark
func recordedObjectString(obj types.Object) objectString { func recordedObjectString(obj types.Object) objectString {
if obj == nil {
return ""
}
pkg := obj.Pkg() pkg := obj.Pkg()
if pkg == nil {
return ""
}
// Names which are not at the package level still need to avoid obfuscation in some cases: // Names which are not at the package level still need to avoid obfuscation in some cases:
// //
// 1. Field names on global types, which can be reached via reflection. // 1. Field names on global types, which can be reached via reflection.
@ -468,17 +489,18 @@ func usedForReflect(cache pkgCache, obj types.Object) bool {
return ok return ok
} }
// storeAddrUsedForReflect is only used in reflectInspector // We only mark named objects, so this function looks for a named object
// to see if a [ssa.Store.Addr] has been marked as used by reflection. // corresponding to a type.
// We only mark named objects, so this function looks for a type's first struct field. func typeToObj(typ types.Type) types.Object {
func storeAddrUsedForReflect(cache pkgCache, typ types.Type) bool { switch t := typ.(type) {
switch typ := typ.(type) { case *types.Named:
return t.Obj()
case *types.Struct: case *types.Struct:
if typ.NumFields() > 0 { if t.NumFields() > 0 {
return usedForReflect(cache, typ.Field(0)) return t.Field(0)
} }
case interface{ Elem() types.Type }: case interface{ Elem() types.Type }:
return storeAddrUsedForReflect(cache, typ.Elem()) return typeToObj(t.Elem())
} }
return false return nil
} }

@ -292,6 +292,59 @@ func testx509() {
} }
} }
type pingMsg struct {
Data string `sshtype:"192"`
}
type pongMsg struct {
Data string `sshtype:"193"`
}
// golang.org/x/crypto/ssh converts a reflected type to another type
func testSSH() {
var msg = pingMsg{
Data: "data",
}
json.Marshal(msg)
_ = pongMsg(msg)
}
// variations similar to ssh
type reflectedMsg struct {
Data string
}
type convertedMsg struct {
Data string
}
func reflectConvert() {
msg := reflectedMsg(convertedMsg{})
json.Marshal(msg)
}
type reflectedMsg2 struct {
Data string
}
type convertedMsg2 struct {
Data string
}
func unrelatedConvert() {
// only discoverable by rechecking the package
_ = convertedMsg2(reflectedMsg2{})
}
func reflectUnrelatedConv() {
var msg = reflectedMsg2{
Data: "data",
}
json.Marshal(msg)
}
type StatUser struct { type StatUser struct {
Id int64 `gorm:"primaryKey"` Id int64 `gorm:"primaryKey"`

@ -23,13 +23,6 @@ cmp stdout reverse.stdout
# Note that we rely on the unix-like TMPDIR env var name. # Note that we rely on the unix-like TMPDIR env var name.
[!windows] ! grepfiles ${TMPDIR} 'garble|importcfg|cache\.gob|\.go' [!windows] ! grepfiles ${TMPDIR} 'garble|importcfg|cache\.gob|\.go'
! exec garble build ./build-error
cp stderr build-error.stderr
stdin build-error.stderr
exec garble reverse ./build-error
cmp stdout build-error-reverse.stdout
[short] stop # no need to verify this with -short [short] stop # no need to verify this with -short
# Ensure that the reversed output matches the non-garbled output. # Ensure that the reversed output matches the non-garbled output.
@ -119,25 +112,6 @@ func printStackTrace(w io.Writer) error {
_, err := w.Write(stack) _, err := w.Write(stack)
return err return err
} }
-- build-error/error.go --
package p
import "reflect"
// This program is especially crafted to work with "go build",
// but fail with "garble build".
// This is because we attempt to convert from two different struct types,
// since only the anonymous one has its field name obfuscated.
// This is useful, because we test that build errors can be reversed,
// and it also includes a field name.
type UnobfuscatedStruct struct {
SomeField int
}
var _ = reflect.TypeOf(UnobfuscatedStruct{})
var _ = struct{SomeField int}(UnobfuscatedStruct{})
-- reverse.stdout -- -- reverse.stdout --
lib filename: test/main/lib/long_lib.go lib filename: test/main/lib/long_lib.go
@ -156,8 +130,3 @@ main.main(...)
test/main/long_main.go:11 +0x?? test/main/long_main.go:11 +0x??
main filename: test/main/long_main.go main filename: test/main/long_main.go
-- build-error-reverse.stdout --
# test/main/build-error
test/main/build-error/error.go:18: cannot convert UnobfuscatedStruct{} (value of type UnobfuscatedStruct) to type struct{SomeField int}
exit status 2
exit status 1

Loading…
Cancel
Save