Skip to content

Commit

Permalink
move StaticType to devirtualize pkg
Browse files Browse the repository at this point in the history
Change-Id: Iadcbf84a666a3e43ce2b460e064b111efa0f2022
  • Loading branch information
mateusz834 committed Feb 15, 2025
1 parent 291add4 commit f63e6fa
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 37 deletions.
89 changes: 88 additions & 1 deletion src/cmd/compile/internal/devirtualize/devirtualize.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,16 @@ func StaticCall(call *ir.CallExpr) {
}

sel := call.Fun.(*ir.SelectorExpr)
typ := ir.StaticType(sel.X)
typ := staticType(sel.X)
if typ == nil {
return
}

// Don't try to devirtualize calls that we statically know that would have failed at runtime.
// This can happen in such case: any(0).(interface {A()}).A(), this typechecks without
// any errors, but will cause a runtime panic. We statically know that int(0) does not
// implement that interface, thus we skip the devirtualization, as it is not possible
// to make a type assertion from interface{A()} to int (int does not implement interface{A()}).
if !typecheck.Implements(typ, sel.X.Type()) {
return
}
Expand Down Expand Up @@ -136,3 +141,85 @@ func StaticCall(call *ir.CallExpr) {
// Desugar OCALLMETH, if we created one (#57309).
typecheck.FixMethodCall(call)
}

func staticType(n ir.Node) *types.Type {
for {
switch n1 := n.(type) {
case *ir.ConvExpr:
if n1.Op() == ir.OCONVNOP || n1.Op() == ir.OCONVIFACE {
n = n1.X
continue
}
case *ir.InlinedCallExpr:
if n1.Op() == ir.OINLCALL {
n = n1.SingleResult()
continue
}
case *ir.ParenExpr:
n = n1.X
continue
case *ir.TypeAssertExpr:
n = n1.X
continue
}

n1 := staticValue(n)
if n1 == nil {
if n.Type().IsInterface() {
return nil
}
return n.Type()
}
n = n1
}
}

func staticValue(nn ir.Node) ir.Node {
if nn.Op() != ir.ONAME {
return nil
}

n := nn.(*ir.Name).Canonical()
if n.Class != ir.PAUTO {
return nil
}

defn := n.Defn
if defn == nil {
return nil
}

var rhs ir.Node
FindRHS:
switch defn.Op() {
case ir.OAS:
defn := defn.(*ir.AssignStmt)
rhs = defn.Y
case ir.OAS2:
defn := defn.(*ir.AssignListStmt)
for i, lhs := range defn.Lhs {
if lhs == n {
rhs = defn.Rhs[i]
break FindRHS
}
}
base.Fatalf("%v missing from LHS of %v", n, defn)
case ir.OAS2DOTTYPE:
defn := defn.(*ir.AssignListStmt)
if defn.Lhs[0] == n {
rhs = defn.Rhs[0]
}
default:
return nil
}

if rhs == nil {
base.Fatalf("RHS is nil: %v", defn)
}

if ir.Reassigned(n) {
return nil
}

return rhs
}
38 changes: 2 additions & 36 deletions src/cmd/compile/internal/ir/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -840,18 +840,6 @@ func IsAddressable(n Node) bool {
return false
}

// StaticType is like StaticValue, but also follows ODOTTYPE and OCONVIFACE.
func StaticType(n Node) *types.Type {
out := staticValue(n, true)

typ := out.Type()
if typ.IsInterface() {
return nil
}

return typ
}

// StaticValue analyzes n to find the earliest expression that always
// evaluates to the same value as n, which might be from an enclosing
// function.
Expand All @@ -867,22 +855,13 @@ func StaticType(n Node) *types.Type {
// calling StaticValue on the "int(y)" expression returns the outer
// "g()" expression.
func StaticValue(n Node) Node {
return staticValue(n, false)

}

func staticValue(n Node, forDevirt bool) Node {
for {
switch n1 := n.(type) {
case *ConvExpr:
if n1.Op() == OCONVNOP {
n = n1.X
continue
}
if forDevirt && n1.Op() == OCONVIFACE {
n = n1.X
continue
}
case *InlinedCallExpr:
if n1.Op() == OINLCALL {
n = n1.SingleResult()
Expand All @@ -891,22 +870,17 @@ func staticValue(n Node, forDevirt bool) Node {
case *ParenExpr:
n = n1.X
continue
case *TypeAssertExpr:
if forDevirt {
n = n1.X
continue
}
}

n1 := staticValue1(n, forDevirt)
n1 := staticValue1(n)
if n1 == nil {
return n
}
n = n1
}
}

func staticValue1(nn Node, forDevirt bool) Node {
func staticValue1(nn Node) Node {
if nn.Op() != ONAME {
return nil
}
Expand Down Expand Up @@ -935,14 +909,6 @@ FindRHS:
}
}
base.Fatalf("%v missing from LHS of %v", n, defn)
case OAS2DOTTYPE:
if !forDevirt {
return nil
}
defn := defn.(*AssignListStmt)
if defn.Lhs[0] == n {
rhs = defn.Rhs[0]
}
default:
return nil
}
Expand Down
47 changes: 47 additions & 0 deletions test/escape_iface_with_devirt_type_assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,55 @@ func callIfA(m M) { // ERROR "can inline" "leaking param"
}
}

//go:noinline
func newImplNoInline() *Impl {
return &Impl{} // ERROR "escapes"
}

func t3() {
{
var a A = newImplNoInline()
if v, ok := a.(M); ok {
v.M() // ERROR "devirtualizing" "inlining call"
}
}
{
m := make(map[*Impl]struct{}) // ERROR "does not escape"
for v := range m {
var v A = v
if v, ok := v.(M); ok {
v.M() // ERROR "devirtualizing" "inlining call"
}
}
}
{
m := make(map[int]*Impl) // ERROR "does not escape"
for _, v := range m {
var v A = v
if v, ok := v.(M); ok {
v.M() // ERROR "devirtualizing" "inlining call"
}
}
}
{
m := make(map[int]*Impl) // ERROR "does not escape"
var v A = m[0]
if v, ok := v.(M); ok {
v.M() // ERROR "devirtualizing" "inlining call"
}
}
{
m := make(chan *Impl)
var v A = <-m
if v, ok := v.(M); ok {
v.M() // ERROR "devirtualizing" "inlining call"
}
}
}

//go:noinline
func testInvalidAsserts() {
any(0).(interface{ A() }).A() // ERROR "escapes"
{
var a M = &Impl{} // ERROR "escapes"
a.(C).C() // this will panic
Expand Down

0 comments on commit f63e6fa

Please sign in to comment.