Skip to content

Commit 6c3740b

Browse files
committed
Wasm: specialization of number comparisons
1 parent c71968d commit 6c3740b

File tree

6 files changed

+188
-51
lines changed

6 files changed

+188
-51
lines changed

compiler/lib-wasm/generate.ml

Lines changed: 126 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ module Generate (Target : Target_sig.S) = struct
6868
type repr =
6969
| Value
7070
| Float
71+
| Int
7172
| Int32
7273
| Nativeint
7374
| Int64
@@ -76,24 +77,23 @@ module Generate (Target : Target_sig.S) = struct
7677
match r with
7778
| Value -> Type.value
7879
| Float -> F64
79-
| Int32 -> I32
80-
| Nativeint -> I32
80+
| Int | Int32 | Nativeint -> I32
8181
| Int64 -> I64
8282

8383
let specialized_primitive_type (_, params, result) =
8484
{ W.params = List.map ~f:repr_type params; result = [ repr_type result ] }
8585

8686
let box_value r e =
8787
match r with
88-
| Value -> e
88+
| Value | Int -> e
8989
| Float -> Memory.box_float e
9090
| Int32 -> Memory.box_int32 e
9191
| Nativeint -> Memory.box_nativeint e
9292
| Int64 -> Memory.box_int64 e
9393

9494
let unbox_value r e =
9595
match r with
96-
| Value -> e
96+
| Value | Int -> e
9797
| Float -> Memory.unbox_float e
9898
| Int32 -> Memory.unbox_int32 e
9999
| Nativeint -> Memory.unbox_nativeint e
@@ -106,9 +106,9 @@ module Generate (Target : Target_sig.S) = struct
106106
[ "caml_int32_bswap", (`Pure, [ Int32 ], Int32)
107107
; "caml_nativeint_bswap", (`Pure, [ Nativeint ], Nativeint)
108108
; "caml_int64_bswap", (`Pure, [ Int64 ], Int64)
109-
; "caml_int32_compare", (`Pure, [ Int32; Int32 ], Value)
110-
; "caml_nativeint_compare", (`Pure, [ Nativeint; Nativeint ], Value)
111-
; "caml_int64_compare", (`Pure, [ Int64; Int64 ], Value)
109+
; "caml_int32_compare", (`Pure, [ Int32; Int32 ], Int)
110+
; "caml_nativeint_compare", (`Pure, [ Nativeint; Nativeint ], Int)
111+
; "caml_int64_compare", (`Pure, [ Int64; Int64 ], Int)
112112
; "caml_string_get32", (`Mutator, [ Value; Value ], Int32)
113113
; "caml_string_get32u", (`Mutator, [ Value; Value ], Int32)
114114
; "caml_string_get64", (`Mutator, [ Value; Value ], Int64)
@@ -135,7 +135,7 @@ module Generate (Target : Target_sig.S) = struct
135135
; "caml_ldexp_float", (`Pure, [ Float; Value ], Float)
136136
; "caml_erf_float", (`Pure, [ Float ], Float)
137137
; "caml_erfc_float", (`Pure, [ Float ], Float)
138-
; "caml_float_compare", (`Pure, [ Float; Float ], Value)
138+
; "caml_float_compare", (`Pure, [ Float; Float ], Int)
139139
];
140140
h
141141

@@ -310,6 +310,38 @@ module Generate (Target : Target_sig.S) = struct
310310
(transl_prim_arg ctx ?typ:tz z)
311311
| _ -> invalid_arity name l ~expected:3)
312312

313+
let register_comparison name cmp_int cmp_boxed_int cmp_float =
314+
register_prim name `Mutable (fun ctx _ (hint : Optimization_hint.t option) l ->
315+
match l with
316+
| [ x; y ] -> (
317+
let x' = transl_prim_arg ctx x in
318+
let y' = transl_prim_arg ctx y in
319+
match hint, get_type ctx x, get_type ctx y with
320+
| _, Int _, Int _ -> cmp_int ctx x y
321+
| Some (Hint_int Int32), _, _ | _, Number Int32, Number Int32 ->
322+
let* x' = Memory.unbox_int32 x' in
323+
let* y' = Memory.unbox_int32 y' in
324+
return (W.BinOp (I32 cmp_boxed_int, x', y'))
325+
| Some (Hint_int Nativeint), _, _ | _, Number Nativeint, Number Nativeint ->
326+
let* x' = Memory.unbox_nativeint x' in
327+
let* y' = Memory.unbox_nativeint y' in
328+
return (W.BinOp (I32 cmp_boxed_int, x', y'))
329+
| Some (Hint_int Int64), _, _ | _, Number Int64, Number Int64 ->
330+
let* x' = Memory.unbox_int64 x' in
331+
let* y' = Memory.unbox_int64 y' in
332+
return (W.BinOp (I64 cmp_boxed_int, x', y'))
333+
| _, Number Float, Number Float -> float_comparison cmp_float x' y'
334+
| _ ->
335+
let* f =
336+
register_import
337+
~name
338+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
339+
in
340+
let* x' = x' in
341+
let* y' = y' in
342+
return (W.Call (f, [ x'; y' ])))
343+
| _ -> invalid_arity name l ~expected:2)
344+
313345
let () =
314346
register_bin_prim
315347
"caml_array_unsafe_get"
@@ -792,6 +824,92 @@ module Generate (Target : Target_sig.S) = struct
792824
~init:(return [])
793825
in
794826
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal ~load l);
827+
register_comparison
828+
"caml_greaterthan"
829+
(fun ctx x y -> translate_int_comparison ctx (fun y x -> Arith.(x < y)) x y)
830+
(Gt S)
831+
Gt;
832+
register_comparison
833+
"caml_greaterequal"
834+
(fun ctx x y -> translate_int_comparison ctx (fun y x -> Arith.(x <= y)) x y)
835+
(Ge S)
836+
Ge;
837+
register_comparison
838+
"caml_lessthan"
839+
(fun ctx x y -> translate_int_comparison ctx Arith.( < ) x y)
840+
(Lt S)
841+
Lt;
842+
register_comparison
843+
"caml_lessequal"
844+
(fun ctx x y -> translate_int_comparison ctx Arith.( <= ) x y)
845+
(Le S)
846+
Le;
847+
register_comparison
848+
"caml_equal"
849+
(fun ctx x y -> translate_int_equality ctx ~negate:false x y)
850+
Eq
851+
Eq;
852+
register_comparison
853+
"caml_notequal"
854+
(fun ctx x y -> translate_int_equality ctx ~negate:true x y)
855+
Ne
856+
Ne;
857+
register_prim "caml_compare" `Mutable (fun ctx _ _ l ->
858+
match l with
859+
| [ x; y ] -> (
860+
let x' = transl_prim_arg ctx x in
861+
let y' = transl_prim_arg ctx y in
862+
match get_type ctx x, get_type ctx y with
863+
| Int _, Int _ ->
864+
Arith.(
865+
(Value.int_val y' < Value.int_val x')
866+
- (Value.int_val x' < Value.int_val y'))
867+
| Number Int32, Number Int32 ->
868+
let* f =
869+
register_import
870+
~name:"caml_int32_compare"
871+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
872+
in
873+
let* x' = Memory.unbox_int32 x' in
874+
let* y' = Memory.unbox_int32 y' in
875+
return (W.Call (f, [ x'; y' ]))
876+
| Number Nativeint, Number Nativeint ->
877+
let* f =
878+
register_import
879+
~name:"caml_nativeint_compare"
880+
(Fun (Type.primitive_type 2))
881+
in
882+
let* x' = Memory.unbox_nativeint x' in
883+
let* y' = Memory.unbox_nativeint y' in
884+
return (W.Call (f, [ x'; y' ]))
885+
| Number Int64, Number Int64 ->
886+
let* f =
887+
register_import
888+
~name:"caml_int64_compare"
889+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
890+
in
891+
let* x' = Memory.unbox_int64 x' in
892+
let* y' = Memory.unbox_int64 y' in
893+
return (W.Call (f, [ x'; y' ]))
894+
| Number Float, Number Float ->
895+
let* f =
896+
register_import
897+
~name:"caml_float_compare"
898+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
899+
in
900+
let* x' = Memory.unbox_int64 x' in
901+
let* y' = Memory.unbox_int64 y' in
902+
return (W.Call (f, [ x'; y' ]))
903+
| _ ->
904+
let* f =
905+
register_import
906+
~name:"caml_compare"
907+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
908+
in
909+
let* x' = x' in
910+
let* y' = y' in
911+
return (W.Call (f, [ x'; y' ])))
912+
| _ -> invalid_arity "caml_compare" l ~expected:2);
795913
let caml_ba_get ~ctx ~context ~unsafe ~kind ~layout ta indices =
796914
let ta' = transl_prim_arg ctx ta in
797915
Bigarray.get

compiler/lib-wasm/typing.ml

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,13 @@ let prim_type ~approx prim hint args =
212212
| "caml_lessthan"
213213
| "caml_lessequal"
214214
| "caml_equal"
215-
| "caml_compare" -> Int Ref
215+
| "caml_notequal"
216+
| "caml_compare" -> Int Normalized
216217
| "caml_int32_bswap" -> Number Int32
217218
| "caml_nativeint_bswap" -> Number Nativeint
218219
| "caml_int64_bswap" -> Number Int64
219-
| "caml_int32_compare" | "caml_nativeint_compare" | "caml_int64_compare" -> Int Ref
220+
| "caml_int32_compare" | "caml_nativeint_compare" | "caml_int64_compare" ->
221+
Int Normalized
220222
| "caml_string_get32" -> Number Int32
221223
| "caml_string_get64" -> Number Int64
222224
| "caml_bytes_get32" -> Number Int32
@@ -227,7 +229,7 @@ let prim_type ~approx prim hint args =
227229
| "caml_nextafter_float" -> Number Float
228230
| "caml_classify_float" -> Int Ref
229231
| "caml_ldexp_float" | "caml_erf_float" | "caml_erfc_float" -> Number Float
230-
| "caml_float_compare" -> Int Ref
232+
| "caml_float_compare" -> Int Normalized
231233
| "caml_floatarray_unsafe_get" -> Number Float
232234
| "caml_bytes_unsafe_get"
233235
| "caml_string_unsafe_get"
@@ -446,6 +448,27 @@ let solver st =
446448
in
447449
Solver.f () g (propagate st)
448450

451+
let print_opt typ f e =
452+
match e with
453+
| Prim
454+
( Extern
455+
( ( "caml_greaterthan"
456+
| "caml_greaterequal"
457+
| "caml_lessthan"
458+
| "caml_lessequal"
459+
| "caml_equal"
460+
| "caml_compare" )
461+
, _ )
462+
, l ) -> (
463+
match List.map ~f:(arg_type ~approx:typ) l with
464+
| [ Int _; Int _ ]
465+
| [ Number Int32; Number Int32 ]
466+
| [ Number Int64; Number Int64 ]
467+
| [ Number Nativeint; Number Nativeint ]
468+
| [ Number Float; Number Float ] -> Format.fprintf f " OPT"
469+
| _ -> ())
470+
| _ -> ()
471+
449472
let f ~state ~info ~deadcode_sentinal p =
450473
update_deps state p;
451474
let function_parameters = mark_function_parameters p in
@@ -466,7 +489,8 @@ let f ~state ~info ~deadcode_sentinal p =
466489
Format.err_formatter
467490
(fun _ i ->
468491
match i with
469-
| Instr (Let (x, _)) -> Format.asprintf "{%a}" Domain.print (Var.Tbl.get typ x)
492+
| Instr (Let (x, e)) ->
493+
Format.asprintf "{%a}%a" Domain.print (Var.Tbl.get typ x) (print_opt typ) e
470494
| _ -> "")
471495
p);
472496
typ

runtime/wasm/compare.wat

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -556,53 +556,49 @@
556556
(i32.const 0))
557557

558558
(func (export "caml_compare")
559-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
559+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
560560
(local $res i32)
561561
(local.set $res
562562
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 1)))
563563
(if (i32.lt_s (local.get $res) (i32.const 0))
564-
(then (return (ref.i31 (i32.const -1)))))
564+
(then (return (i32.const -1))))
565565
(if (i32.gt_s (local.get $res) (i32.const 0))
566-
(then (return (ref.i31 (i32.const 1)))))
567-
(ref.i31 (i32.const 0)))
566+
(then (return (i32.const 1))))
567+
(i32.const 0))
568568

569569
(func (export "caml_equal")
570-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
571-
(ref.i31
572-
(i32.eqz
573-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
570+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
571+
(i32.eqz
572+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
574573

575574
(func (export "caml_notequal")
576-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
577-
(ref.i31
578-
(i32.ne (i32.const 0)
579-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
575+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
576+
(i32.ne (i32.const 0)
577+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
580578

581579
(func (export "caml_lessthan")
582-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
580+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
583581
(local $res i32)
584582
(local.set $res
585583
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))
586-
(ref.i31
587-
(i32.and (i32.lt_s (local.get $res) (i32.const 0))
588-
(i32.ne (local.get $res) (global.get $unordered)))))
584+
(i32.and (i32.lt_s (local.get $res) (i32.const 0))
585+
(i32.ne (local.get $res) (global.get $unordered))))
589586

590587
(func (export "caml_lessequal")
591-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
588+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
592589
(local $res i32)
593590
(local.set $res
594591
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))
595-
(ref.i31
596-
(i32.and (i32.le_s (local.get $res) (i32.const 0))
597-
(i32.ne (local.get $res) (global.get $unordered)))))
592+
(i32.and (i32.le_s (local.get $res) (i32.const 0))
593+
(i32.ne (local.get $res) (global.get $unordered))))
598594

599595
(func (export "caml_greaterthan")
600-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
601-
(ref.i31 (i32.lt_s (i32.const 0)
602-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
596+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
597+
(i32.lt_s (i32.const 0)
598+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
603599

604600
(func (export "caml_greaterequal")
605-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
606-
(ref.i31 (i32.le_s (i32.const 0)
607-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
601+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
602+
(i32.le_s (i32.const 0)
603+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
608604
)

runtime/wasm/float.wat

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,13 +1132,12 @@
11321132
(struct.new $float (local.get $y)))
11331133

11341134
(func (export "caml_float_compare")
1135-
(param $x f64) (param $y f64) (result (ref eq))
1136-
(ref.i31
1137-
(i32.add
1138-
(i32.sub (f64.gt (local.get $x) (local.get $y))
1139-
(f64.lt (local.get $x) (local.get $y)))
1140-
(i32.sub (f64.eq (local.get $x) (local.get $x))
1141-
(f64.eq (local.get $y) (local.get $y))))))
1135+
(param $x f64) (param $y f64) (result i32)
1136+
(i32.add
1137+
(i32.sub (f64.gt (local.get $x) (local.get $y))
1138+
(f64.lt (local.get $x) (local.get $y)))
1139+
(i32.sub (f64.eq (local.get $x) (local.get $x))
1140+
(f64.eq (local.get $y) (local.get $y)))))
11421141

11431142
(func (export "caml_round") (param $x f64) (result f64)
11441143
(local $y f64)

runtime/wasm/int32.wat

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@
126126

127127
(export "caml_nativeint_compare" (func $caml_int32_compare))
128128
(func $caml_int32_compare (export "caml_int32_compare")
129-
(param $i1 i32) (param $i2 i32) (result (ref eq))
130-
(ref.i31 (i32.sub (i32.gt_s (local.get $i1) (local.get $i2))
131-
(i32.lt_s (local.get $i1) (local.get $i2)))))
129+
(param $i1 i32) (param $i2 i32) (result i32)
130+
(i32.sub (i32.gt_s (local.get $i1) (local.get $i2))
131+
(i32.lt_s (local.get $i1) (local.get $i2))))
132132

133133
(global $nativeint_ops (export "nativeint_ops") (ref $custom_operations)
134134
(struct.new $custom_operations

runtime/wasm/int64.wat

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,9 @@
124124
(i64.const 8)))))
125125

126126
(func (export "caml_int64_compare")
127-
(param $i1 i64) (param $i2 i64) (result (ref eq))
128-
(ref.i31 (i32.sub (i64.gt_s (local.get $i1) (local.get $i2))
129-
(i64.lt_s (local.get $i1) (local.get $i2)))))
127+
(param $i1 i64) (param $i2 i64) (result i32)
128+
(i32.sub (i64.gt_s (local.get $i1) (local.get $i2))
129+
(i64.lt_s (local.get $i1) (local.get $i2))))
130130

131131
(@string $INT64_ERRMSG "Int64.of_string")
132132

0 commit comments

Comments
 (0)