|
| 1 | +// Copyright 2025 The Go Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a BSD-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +package bloop |
| 6 | + |
| 7 | +// This file contains support routines for keeping |
| 8 | +// statements alive |
| 9 | +// in such loops (example): |
| 10 | +// |
| 11 | +// for b.Loop() { |
| 12 | +// var a, b int |
| 13 | +// a = 5 |
| 14 | +// b = 6 |
| 15 | +// f(a, b) |
| 16 | +// } |
| 17 | +// |
| 18 | +// The results of a, b and f(a, b) will be kept alive. |
| 19 | +// |
| 20 | +// Formally, the lhs (if they are [ir.Name]-s) of |
| 21 | +// [ir.AssignStmt], [ir.AssignListStmt], |
| 22 | +// [ir.AssignOpStmt], and the results of [ir.CallExpr] |
| 23 | +// or its args if it doesn't return a value will be kept |
| 24 | +// alive. |
| 25 | +// |
| 26 | +// The keep alive logic is implemented with as wrapping a |
| 27 | +// runtime.KeepAlive around the Name. |
| 28 | +// |
| 29 | +// TODO: currently this is implemented with KeepAlive |
| 30 | +// because it will prevent DSE and DCE which is probably |
| 31 | +// what we want right now. And KeepAlive takes an ssa |
| 32 | +// value instead of a symbol, which is easier to manage. |
| 33 | +// But since KeepAlive's context was mainly in the runtime |
| 34 | +// and GC, should we implement a new intrinsic that lowers |
| 35 | +// to OpVarLive? Peeling out the symbols is a bit tricky |
| 36 | +// and also VarLive seems to assume that there exists a |
| 37 | +// VarDef on the same symbol that dominates it. |
| 38 | + |
| 39 | +import ( |
| 40 | + "cmd/compile/internal/base" |
| 41 | + "cmd/compile/internal/ir" |
| 42 | + "cmd/compile/internal/reflectdata" |
| 43 | + "cmd/compile/internal/typecheck" |
| 44 | + "cmd/compile/internal/types" |
| 45 | + "fmt" |
| 46 | +) |
| 47 | + |
| 48 | +// getNameFromNode tries to iteratively peel down the node to |
| 49 | +// get the name. |
| 50 | +func getNameFromNode(n ir.Node) *ir.Name { |
| 51 | + var ret *ir.Name |
| 52 | + if n.Op() == ir.ONAME { |
| 53 | + ret = n.(*ir.Name) |
| 54 | + } else { |
| 55 | + // avoid infinite recursion on circular referencing nodes. |
| 56 | + seen := map[ir.Node]bool{n: true} |
| 57 | + var findName func(ir.Node) bool |
| 58 | + findName = func(a ir.Node) bool { |
| 59 | + if a.Op() == ir.ONAME { |
| 60 | + ret = a.(*ir.Name) |
| 61 | + return true |
| 62 | + } |
| 63 | + if !seen[a] { |
| 64 | + seen[a] = true |
| 65 | + return ir.DoChildren(a, findName) |
| 66 | + } |
| 67 | + return false |
| 68 | + } |
| 69 | + ir.DoChildren(n, findName) |
| 70 | + } |
| 71 | + return ret |
| 72 | +} |
| 73 | + |
| 74 | +// keepAliveAt returns a statement that is either curNode, or a |
| 75 | +// block containing curNode followed by a call to runtime.keepAlive for each |
| 76 | +// ONAME in ns. These calls ensure that names in ns will be live until |
| 77 | +// after curNode's execution. |
| 78 | +func keepAliveAt(ns []*ir.Name, curNode ir.Node) ir.Node { |
| 79 | + if len(ns) == 0 { |
| 80 | + return curNode |
| 81 | + } |
| 82 | + |
| 83 | + pos := curNode.Pos() |
| 84 | + calls := []ir.Node{curNode} |
| 85 | + for _, n := range ns { |
| 86 | + if n == nil { |
| 87 | + continue |
| 88 | + } |
| 89 | + if n.Sym() == nil { |
| 90 | + continue |
| 91 | + } |
| 92 | + if n.Sym().IsBlank() { |
| 93 | + continue |
| 94 | + } |
| 95 | + arg := ir.NewConvExpr(pos, ir.OCONV, types.Types[types.TINTER], n) |
| 96 | + if !n.Type().IsInterface() { |
| 97 | + srcRType0 := reflectdata.TypePtrAt(pos, n.Type()) |
| 98 | + arg.TypeWord = srcRType0 |
| 99 | + arg.SrcRType = srcRType0 |
| 100 | + } |
| 101 | + callExpr := typecheck.Call(pos, |
| 102 | + typecheck.LookupRuntime("KeepAlive"), |
| 103 | + []ir.Node{arg}, false).(*ir.CallExpr) |
| 104 | + callExpr.IsCompilerVarLive = true |
| 105 | + callExpr.NoInline = true |
| 106 | + calls = append(calls, callExpr) |
| 107 | + } |
| 108 | + |
| 109 | + return ir.NewBlockStmt(pos, calls) |
| 110 | +} |
| 111 | + |
| 112 | +func debugName(name *ir.Name, line string) { |
| 113 | + if base.Flag.LowerM > 0 { |
| 114 | + if name.Linksym() != nil { |
| 115 | + fmt.Printf("%v: %s will be kept alive\n", line, name.Linksym().Name) |
| 116 | + } else { |
| 117 | + fmt.Printf("%v: expr will be kept alive\n", line) |
| 118 | + } |
| 119 | + } |
| 120 | +} |
| 121 | + |
| 122 | +// preserveStmt transforms stmt so that any names defined/assigned within it |
| 123 | +// are used after stmt's execution, preventing their dead code elimination |
| 124 | +// and dead store elimination. The return value is the transformed statement. |
| 125 | +func preserveStmt(curFn *ir.Func, stmt ir.Node) (ret ir.Node) { |
| 126 | + ret = stmt |
| 127 | + switch n := stmt.(type) { |
| 128 | + case *ir.AssignStmt: |
| 129 | + // Peel down struct and slice indexing to get the names |
| 130 | + name := getNameFromNode(n.X) |
| 131 | + if name != nil { |
| 132 | + debugName(name, ir.Line(stmt)) |
| 133 | + ret = keepAliveAt([]*ir.Name{name}, n) |
| 134 | + } |
| 135 | + case *ir.AssignListStmt: |
| 136 | + names := []*ir.Name{} |
| 137 | + for _, lhs := range n.Lhs { |
| 138 | + name := getNameFromNode(lhs) |
| 139 | + if name != nil { |
| 140 | + debugName(name, ir.Line(stmt)) |
| 141 | + names = append(names, name) |
| 142 | + } |
| 143 | + } |
| 144 | + ret = keepAliveAt(names, n) |
| 145 | + case *ir.AssignOpStmt: |
| 146 | + name := getNameFromNode(n.X) |
| 147 | + if name != nil { |
| 148 | + debugName(name, ir.Line(stmt)) |
| 149 | + ret = keepAliveAt([]*ir.Name{name}, n) |
| 150 | + } |
| 151 | + case *ir.CallExpr: |
| 152 | + names := []*ir.Name{} |
| 153 | + curNode := stmt |
| 154 | + if n.Fun != nil && n.Fun.Type() != nil && n.Fun.Type().NumResults() != 0 { |
| 155 | + // This function's results are not assigned, assign them to |
| 156 | + // auto tmps and then keepAliveAt these autos. |
| 157 | + // Note: markStmt assumes the context that it's called - this CallExpr is |
| 158 | + // not within another OAS2, which is guaranteed by the case above. |
| 159 | + results := n.Fun.Type().Results() |
| 160 | + lhs := make([]ir.Node, len(results)) |
| 161 | + for i, res := range results { |
| 162 | + tmp := typecheck.TempAt(n.Pos(), curFn, res.Type) |
| 163 | + lhs[i] = tmp |
| 164 | + names = append(names, tmp) |
| 165 | + } |
| 166 | + |
| 167 | + // Create an assignment statement. |
| 168 | + assign := typecheck.AssignExpr( |
| 169 | + ir.NewAssignListStmt(n.Pos(), ir.OAS2, lhs, |
| 170 | + []ir.Node{n})).(*ir.AssignListStmt) |
| 171 | + assign.Def = true |
| 172 | + curNode = assign |
| 173 | + plural := "" |
| 174 | + if len(results) > 1 { |
| 175 | + plural = "s" |
| 176 | + } |
| 177 | + if base.Flag.LowerM > 0 { |
| 178 | + fmt.Printf("%v: function result%s will be kept alive\n", ir.Line(stmt), plural) |
| 179 | + } |
| 180 | + } else { |
| 181 | + // This function probably doesn't return anything, keep its args alive. |
| 182 | + argTmps := []ir.Node{} |
| 183 | + for i, a := range n.Args { |
| 184 | + if name := getNameFromNode(a); name != nil { |
| 185 | + // If they are name, keep them alive directly. |
| 186 | + debugName(name, ir.Line(stmt)) |
| 187 | + names = append(names, name) |
| 188 | + } else if a.Op() == ir.OSLICELIT { |
| 189 | + // variadic args are encoded as slice literal. |
| 190 | + s := a.(*ir.CompLitExpr) |
| 191 | + ns := []*ir.Name{} |
| 192 | + for i, n := range s.List { |
| 193 | + if name := getNameFromNode(n); name != nil { |
| 194 | + debugName(name, ir.Line(a)) |
| 195 | + ns = append(ns, name) |
| 196 | + } else { |
| 197 | + // We need a temporary to save this arg. |
| 198 | + tmp := typecheck.TempAt(n.Pos(), curFn, n.Type()) |
| 199 | + argTmps = append(argTmps, typecheck.AssignExpr(ir.NewAssignStmt(n.Pos(), tmp, n))) |
| 200 | + names = append(names, tmp) |
| 201 | + s.List[i] = tmp |
| 202 | + if base.Flag.LowerM > 0 { |
| 203 | + fmt.Printf("%v: function arg will be kept alive\n", ir.Line(n)) |
| 204 | + } |
| 205 | + } |
| 206 | + } |
| 207 | + names = append(names, ns...) |
| 208 | + } else { |
| 209 | + // expressions, we need to assign them to temps and change the original arg to reference |
| 210 | + // them. |
| 211 | + tmp := typecheck.TempAt(n.Pos(), curFn, a.Type()) |
| 212 | + argTmps = append(argTmps, typecheck.AssignExpr(ir.NewAssignStmt(n.Pos(), tmp, a))) |
| 213 | + names = append(names, tmp) |
| 214 | + n.Args[i] = tmp |
| 215 | + if base.Flag.LowerM > 0 { |
| 216 | + fmt.Printf("%v: function arg will be kept alive\n", ir.Line(stmt)) |
| 217 | + } |
| 218 | + } |
| 219 | + } |
| 220 | + if len(argTmps) > 0 { |
| 221 | + argTmps = append(argTmps, n) |
| 222 | + curNode = ir.NewBlockStmt(n.Pos(), argTmps) |
| 223 | + } |
| 224 | + } |
| 225 | + ret = keepAliveAt(names, curNode) |
| 226 | + } |
| 227 | + return |
| 228 | +} |
| 229 | + |
| 230 | +func preserveStmts(curFn *ir.Func, list ir.Nodes) { |
| 231 | + for i := range list { |
| 232 | + list[i] = preserveStmt(curFn, list[i]) |
| 233 | + } |
| 234 | +} |
| 235 | + |
| 236 | +// isTestingBLoop returns true if it matches the node as a |
| 237 | +// testing.(*B).Loop. See issue #61515. |
| 238 | +func isTestingBLoop(t ir.Node) bool { |
| 239 | + if t.Op() != ir.OFOR { |
| 240 | + return false |
| 241 | + } |
| 242 | + nFor, ok := t.(*ir.ForStmt) |
| 243 | + if !ok || nFor.Cond == nil || nFor.Cond.Op() != ir.OCALLFUNC { |
| 244 | + return false |
| 245 | + } |
| 246 | + n, ok := nFor.Cond.(*ir.CallExpr) |
| 247 | + if !ok || n.Fun == nil || n.Fun.Op() != ir.OMETHEXPR { |
| 248 | + return false |
| 249 | + } |
| 250 | + name := ir.MethodExprName(n.Fun) |
| 251 | + if name == nil { |
| 252 | + return false |
| 253 | + } |
| 254 | + if fSym := name.Sym(); fSym != nil && name.Class == ir.PFUNC && fSym.Pkg != nil && |
| 255 | + fSym.Name == "(*B).Loop" && fSym.Pkg.Path == "testing" { |
| 256 | + // Attempting to match a function call to testing.(*B).Loop |
| 257 | + return true |
| 258 | + } |
| 259 | + return false |
| 260 | +} |
| 261 | + |
| 262 | +type editor struct { |
| 263 | + inBloop bool |
| 264 | + curFn *ir.Func |
| 265 | +} |
| 266 | + |
| 267 | +func (e editor) edit(n ir.Node) ir.Node { |
| 268 | + e.inBloop = isTestingBLoop(n) || e.inBloop |
| 269 | + // It's in bloop, mark the stmts with bodies. |
| 270 | + ir.EditChildren(n, e.edit) |
| 271 | + if e.inBloop { |
| 272 | + switch n := n.(type) { |
| 273 | + case *ir.ForStmt: |
| 274 | + preserveStmts(e.curFn, n.Body) |
| 275 | + case *ir.IfStmt: |
| 276 | + preserveStmts(e.curFn, n.Body) |
| 277 | + preserveStmts(e.curFn, n.Else) |
| 278 | + case *ir.BlockStmt: |
| 279 | + preserveStmts(e.curFn, n.List) |
| 280 | + case *ir.CaseClause: |
| 281 | + preserveStmts(e.curFn, n.List) |
| 282 | + preserveStmts(e.curFn, n.Body) |
| 283 | + case *ir.CommClause: |
| 284 | + preserveStmts(e.curFn, n.Body) |
| 285 | + } |
| 286 | + } |
| 287 | + return n |
| 288 | +} |
| 289 | + |
| 290 | +// BloopWalk performs a walk on all functions in the package |
| 291 | +// if it imports testing and wrap the results of all qualified |
| 292 | +// statements in a runtime.KeepAlive intrinsic call. See package |
| 293 | +// doc for more details. |
| 294 | +// |
| 295 | +// for b.Loop() {...} |
| 296 | +// |
| 297 | +// loop's body. |
| 298 | +func BloopWalk(pkg *ir.Package) { |
| 299 | + hasTesting := false |
| 300 | + for _, i := range pkg.Imports { |
| 301 | + if i.Path == "testing" { |
| 302 | + hasTesting = true |
| 303 | + break |
| 304 | + } |
| 305 | + } |
| 306 | + if !hasTesting { |
| 307 | + return |
| 308 | + } |
| 309 | + for _, fn := range pkg.Funcs { |
| 310 | + e := editor{false, fn} |
| 311 | + ir.EditChildren(fn, e.edit) |
| 312 | + } |
| 313 | +} |
0 commit comments