diff --git a/main.go b/main.go index 8efd34a..612df12 100644 --- a/main.go +++ b/main.go @@ -1359,14 +1359,14 @@ func (tf *transformer) typecheck(files []*ast.File) error { // A bit hacky, but I could not find an easier way to do this. for _, obj := range tf.info.Defs { if obj != nil { - tf.recordType(obj.Type()) + tf.recordType(obj.Type(), nil) } } for name, obj := range tf.info.Uses { if obj == nil { continue } - tf.recordType(obj.Type()) + tf.recordType(obj.Type(), nil) // Record into KnownEmbeddedAliasFields. obj, ok := obj.(*types.TypeName) @@ -1388,7 +1388,7 @@ func (tf *transformer) typecheck(files []*ast.File) error { cachedOutput.KnownEmbeddedAliasFields[vrStr] = aliasTypeName } for _, tv := range tf.info.Types { - tf.recordType(tv.Type) + tf.recordType(tv.Type, nil) } return nil } @@ -1396,27 +1396,38 @@ func (tf *transformer) typecheck(files []*ast.File) error { // recordType visits every reachable type after typechecking a package. // Right now, all it does is fill the fieldToStruct field. // Since types can be recursive, we need a map to avoid cycles. -func (tf *transformer) recordType(t types.Type) { - if tf.recordTypeDone[t] { +func (tf *transformer) recordType(used, origin types.Type) { + if tf.recordTypeDone[used] { return } - tf.recordTypeDone[t] = true - switch t := t.(type) { - case interface{ Elem() types.Type }: - tf.recordType(t.Elem()) - case *types.Named: - tf.recordType(t.Underlying()) - } - strct, _ := t.(*types.Struct) - if strct == nil { - return + if origin == nil { + origin = used } - for i := 0; i < strct.NumFields(); i++ { - field := strct.Field(i) - tf.fieldToStruct[field] = strct + type Container interface{ Elem() types.Type } + tf.recordTypeDone[used] = true + switch used := used.(type) { + case Container: + origin := origin.(Container) + tf.recordType(used.Elem(), origin.Elem()) + case *types.Named: + // If we have a generic struct like + // + // type Foo[T any] struct { Bar T } + // + // then we want the hashing to use the original "Bar T", + // because otherwise different instances like "Bar int" and "Bar bool" + // will result in different hashes and the field names will break. + // Ensure we record the original generic struct, if there is one. + tf.recordType(used.Underlying(), used.Origin().Underlying()) + case *types.Struct: + origin := origin.(*types.Struct) + for i := 0; i < used.NumFields(); i++ { + field := used.Field(i) + tf.fieldToStruct[field] = origin - if field.Embedded() { - tf.recordType(field.Type()) + if field.Embedded() { + tf.recordType(field.Type(), origin.Field(i).Type()) + } } } } diff --git a/testdata/scripts/typeparams.txt b/testdata/scripts/typeparams.txt index c64bd61..061495c 100644 --- a/testdata/scripts/typeparams.txt +++ b/testdata/scripts/typeparams.txt @@ -12,14 +12,22 @@ go 1.18 package main func main() { - //var _ GenericVector[int] GenericFunc[int, int](1, 2) + var _ GenericVector[int] + + g := GenericGraph[string]{Content: "Foo"} + g.Edges = make([]GenericGraph[string], 1) } func GenericFunc[GenericParamA, B any](x GenericParamA, y B) {} type GenericVector[GenericParamT any] []GenericParamT +type GenericGraph[T any] struct { + Content T + Edges []GenericGraph[T] +} + type PredeclaredSignedInteger interface { int | int8 | int16 | int32 | int64 }