Skip to content

Commit 24cacc9

Browse files
authored
CodeGen: Implement support for math.lerp lowering (#1609)
To implement math.lerp without branches, we add SELECT_NUM which selects one of the two inputs based on the comparison condition. For simplicity, we only support C == D for now; this can be extended to a more generic version with a IrCondition operand E, but that requires more work on the SSE side (to flip the comparison for some conditions like Greater, and expose more generic vcmpsd). Note: On AArch64 this will effectively result in a change in floating point behavior between native code and non-native code: clang synthesizes fmadd (because floating point contraction is allowed by default, and the arch always has the instruction), whereas this change will use fmul+fadd. I am not sure if this is good or bad, and if this is a problem in C or not. Specifically, clang's behavior results in different results between X64 and AArch64 when *not* using codegen, and with this change the behavior when using codegen is... the same? :) Fixing this will require either using LERP_NUM instead and hand-coding lowering, or exposing some sort of "quasi" MADD_NUM (which would lower to fma on AArch64 and mul+add on X64). A small benefit to the current approach is `lerp(1, 5, t)` constant-folds the subtraction. With LERP_NUM this optimization will need to be implemented manually as a partial constant-folding for LERP_NUM. A similar problem exists today for vector.cross & vector.dot. So maybe this is not something we need to fix, unsure.
1 parent c759cd5 commit 24cacc9

File tree

12 files changed

+108
-0
lines changed

12 files changed

+108
-0
lines changed

CodeGen/include/Luau/AssemblyBuilderX64.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class AssemblyBuilderX64
160160
void vmaxsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
161161
void vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
162162

163+
void vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
163164
void vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
164165

165166
void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3);

CodeGen/include/Luau/IrData.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ enum class IrCmd : uint8_t
185185
// A: double
186186
SIGN_NUM,
187187

188+
// Select B if C == D, otherwise select A
189+
// A, B: double (endpoints)
190+
// C, D: double (condition arguments)
191+
SELECT_NUM,
192+
188193
// Add/Sub/Mul/Div/Idiv two vectors
189194
// A, B: TValue
190195
ADD_VEC,

CodeGen/include/Luau/IrUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ inline bool hasResult(IrCmd cmd)
174174
case IrCmd::SQRT_NUM:
175175
case IrCmd::ABS_NUM:
176176
case IrCmd::SIGN_NUM:
177+
case IrCmd::SELECT_NUM:
177178
case IrCmd::ADD_VEC:
178179
case IrCmd::SUB_VEC:
179180
case IrCmd::MUL_VEC:

CodeGen/src/AssemblyBuilderX64.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,11 @@ void AssemblyBuilderX64::vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2
927927
placeAvx("vminsd", dst, src1, src2, 0x5d, false, AVX_0F, AVX_F2);
928928
}
929929

930+
void AssemblyBuilderX64::vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
931+
{
932+
placeAvx("vcmpeqsd", dst, src1, src2, 0x00, 0xc2, false, AVX_0F, AVX_F2);
933+
}
934+
930935
void AssemblyBuilderX64::vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
931936
{
932937
placeAvx("vcmpltsd", dst, src1, src2, 0x01, 0xc2, false, AVX_0F, AVX_F2);

CodeGen/src/IrDump.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ const char* getCmdName(IrCmd cmd)
169169
return "ABS_NUM";
170170
case IrCmd::SIGN_NUM:
171171
return "SIGN_NUM";
172+
case IrCmd::SELECT_NUM:
173+
return "SELECT_NUM";
172174
case IrCmd::ADD_VEC:
173175
return "ADD_VEC";
174176
case IrCmd::SUB_VEC:

CodeGen/src/IrLoweringA64.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
LUAU_FASTFLAG(LuauVectorLibNativeDot)
1515
LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim)
16+
LUAU_FASTFLAG(LuauCodeGenLerp)
1617

1718
namespace Luau
1819
{
@@ -703,6 +704,20 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
703704
build.fcsel(inst.regA64, temp1, inst.regA64, getConditionFP(IrCondition::Less));
704705
break;
705706
}
707+
case IrCmd::SELECT_NUM:
708+
{
709+
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
710+
inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b, inst.c, inst.d});
711+
712+
RegisterA64 temp1 = tempDouble(inst.a);
713+
RegisterA64 temp2 = tempDouble(inst.b);
714+
RegisterA64 temp3 = tempDouble(inst.c);
715+
RegisterA64 temp4 = tempDouble(inst.d);
716+
717+
build.fcmp(temp3, temp4);
718+
build.fcsel(inst.regA64, temp2, temp1, getConditionFP(IrCondition::Equal));
719+
break;
720+
}
706721
case IrCmd::ADD_VEC:
707722
{
708723
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});

CodeGen/src/IrLoweringX64.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
LUAU_FASTFLAG(LuauVectorLibNativeDot)
1919
LUAU_FASTFLAG(LuauCodeGenVectorDeadStoreElim)
20+
LUAU_FASTFLAG(LuauCodeGenLerp)
2021

2122
namespace Luau
2223
{
@@ -622,6 +623,30 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
622623
build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64);
623624
break;
624625
}
626+
case IrCmd::SELECT_NUM:
627+
{
628+
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
629+
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.c, inst.d}); // can't reuse b if a is a memory operand
630+
631+
ScopedRegX64 tmp{regs, SizeX64::xmmword};
632+
633+
if (inst.c.kind == IrOpKind::Inst)
634+
build.vcmpeqsd(tmp.reg, regOp(inst.c), memRegDoubleOp(inst.d));
635+
else
636+
{
637+
build.vmovsd(tmp.reg, memRegDoubleOp(inst.c));
638+
build.vcmpeqsd(tmp.reg, tmp.reg, memRegDoubleOp(inst.d));
639+
}
640+
641+
if (inst.a.kind == IrOpKind::Inst)
642+
build.vblendvpd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b), tmp.reg);
643+
else
644+
{
645+
build.vmovsd(inst.regX64, memRegDoubleOp(inst.a));
646+
build.vblendvpd(inst.regX64, inst.regX64, memRegDoubleOp(inst.b), tmp.reg);
647+
}
648+
break;
649+
}
625650
case IrCmd::ADD_VEC:
626651
{
627652
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});

CodeGen/src/IrTranslateBuiltins.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ static const int kBit32BinaryOpUnrolledParams = 5;
1515

1616
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen);
1717
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot);
18+
LUAU_FASTFLAGVARIABLE(LuauCodeGenLerp);
1819

1920
namespace Luau
2021
{
@@ -284,6 +285,42 @@ static BuiltinImplResult translateBuiltinMathClamp(
284285
return {BuiltinImplType::UsesFallback, 1};
285286
}
286287

288+
static BuiltinImplResult translateBuiltinMathLerp(
289+
IrBuilder& build,
290+
int nparams,
291+
int ra,
292+
int arg,
293+
IrOp args,
294+
IrOp arg3,
295+
int nresults,
296+
IrOp fallback,
297+
int pcpos
298+
)
299+
{
300+
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
301+
302+
if (nparams < 3 || nresults > 1)
303+
return {BuiltinImplType::None, -1};
304+
305+
builtinCheckDouble(build, build.vmReg(arg), pcpos);
306+
builtinCheckDouble(build, args, pcpos);
307+
builtinCheckDouble(build, arg3, pcpos);
308+
309+
IrOp a = builtinLoadDouble(build, build.vmReg(arg));
310+
IrOp b = builtinLoadDouble(build, args);
311+
IrOp t = builtinLoadDouble(build, arg3);
312+
313+
IrOp l = build.inst(IrCmd::ADD_NUM, a, build.inst(IrCmd::MUL_NUM, build.inst(IrCmd::SUB_NUM, b, a), t));
314+
IrOp r = build.inst(IrCmd::SELECT_NUM, l, b, t, build.constDouble(1.0)); // select on t==1.0
315+
316+
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), r);
317+
318+
if (ra != arg)
319+
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));
320+
321+
return {BuiltinImplType::Full, 1};
322+
}
323+
287324
static BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, int pcpos)
288325
{
289326
if (nparams < 1 || nresults > 1)
@@ -1387,6 +1424,8 @@ BuiltinImplResult translateBuiltin(
13871424
case LBF_VECTOR_MAX:
13881425
return FFlag::LuauVectorLibNativeCodegen ? translateBuiltinVectorMap2(build, IrCmd::MAX_NUM, nparams, ra, arg, args, arg3, nresults, pcpos)
13891426
: noneResult;
1427+
case LBF_MATH_LERP:
1428+
return FFlag::LuauCodeGenLerp ? translateBuiltinMathLerp(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos) : noneResult;
13901429
default:
13911430
return {BuiltinImplType::None, -1};
13921431
}

CodeGen/src/IrUtils.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <math.h>
1414

1515
LUAU_FASTFLAG(LuauVectorLibNativeDot);
16+
LUAU_FASTFLAG(LuauCodeGenLerp);
1617

1718
namespace Luau
1819
{
@@ -70,6 +71,7 @@ IrValueKind getCmdValueKind(IrCmd cmd)
7071
case IrCmd::SQRT_NUM:
7172
case IrCmd::ABS_NUM:
7273
case IrCmd::SIGN_NUM:
74+
case IrCmd::SELECT_NUM:
7375
return IrValueKind::Double;
7476
case IrCmd::ADD_VEC:
7577
case IrCmd::SUB_VEC:
@@ -656,6 +658,16 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3
656658
substitute(function, inst, build.constDouble(v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0));
657659
}
658660
break;
661+
case IrCmd::SELECT_NUM:
662+
LUAU_ASSERT(FFlag::LuauCodeGenLerp);
663+
if (inst.c.kind == IrOpKind::Constant && inst.d.kind == IrOpKind::Constant)
664+
{
665+
double c = function.doubleOp(inst.c);
666+
double d = function.doubleOp(inst.d);
667+
668+
substitute(function, inst, c == d ? inst.b : inst.a);
669+
}
670+
break;
659671
case IrCmd::NOT_ANY:
660672
if (inst.a.kind == IrOpKind::Constant)
661673
{

CodeGen/src/OptimizeConstProp.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,6 +1382,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
13821382
case IrCmd::SQRT_NUM:
13831383
case IrCmd::ABS_NUM:
13841384
case IrCmd::SIGN_NUM:
1385+
case IrCmd::SELECT_NUM:
13851386
case IrCmd::NOT_ANY:
13861387
state.substituteOrRecord(inst, index);
13871388
break;

0 commit comments

Comments
 (0)