Skip to content

Commit 8189f05

Browse files
committed
Remove MaskElement from public interface
1 parent f0e9acf commit 8189f05

File tree

3 files changed

+92
-107
lines changed

3 files changed

+92
-107
lines changed

crates/core_simd/src/masks.rs

Lines changed: 19 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
//! Types representing
33
#![allow(non_camel_case_types)]
44

5-
use crate::simd::{LaneCount, Select, Simd, SimdCast, SimdElement, SupportedLaneCount};
5+
use crate::core_simd::vector::sealed::MaskElement;
6+
use crate::simd::{LaneCount, Select, Simd, SimdElement, SupportedLaneCount};
67
use core::cmp::Ordering;
78
use core::{fmt, mem};
89

@@ -29,91 +30,6 @@ macro_rules! impl_fix_endianness {
2930

3031
impl_fix_endianness! { u8, u16, u32, u64 }
3132

32-
mod sealed {
33-
use super::*;
34-
35-
/// Not only does this seal the `MaskElement` trait, but these functions prevent other traits
36-
/// from bleeding into the parent bounds.
37-
///
38-
/// For example, `eq` could be provided by requiring `MaskElement: PartialEq`, but that would
39-
/// prevent us from ever removing that bound, or from implementing `MaskElement` on
40-
/// non-`PartialEq` types in the future.
41-
pub trait Sealed {
42-
fn valid<const N: usize>(values: Simd<Self, N>) -> bool
43-
where
44-
LaneCount<N>: SupportedLaneCount,
45-
Self: SimdElement;
46-
47-
fn eq(self, other: Self) -> bool;
48-
49-
fn to_usize(self) -> usize;
50-
fn max_unsigned() -> u64;
51-
52-
type Unsigned: SimdElement;
53-
54-
const TRUE: Self;
55-
56-
const FALSE: Self;
57-
}
58-
}
59-
use sealed::Sealed;
60-
61-
/// Marker trait for types that may be used as SIMD mask elements.
62-
///
63-
/// # Safety
64-
/// Type must be a signed integer.
65-
pub unsafe trait MaskElement: SimdElement<Mask = Self> + SimdCast + Sealed {}
66-
67-
macro_rules! impl_element {
68-
{ $ty:ty, $unsigned:ty } => {
69-
impl Sealed for $ty {
70-
#[inline]
71-
fn valid<const N: usize>(value: Simd<Self, N>) -> bool
72-
where
73-
LaneCount<N>: SupportedLaneCount,
74-
{
75-
// We can't use `Simd` directly, because `Simd`'s functions call this function and
76-
// we will end up with an infinite loop.
77-
// Safety: `value` is an integer vector
78-
unsafe {
79-
use core::intrinsics::simd;
80-
let falses: Simd<Self, N> = simd::simd_eq(value, Simd::splat(0 as _));
81-
let trues: Simd<Self, N> = simd::simd_eq(value, Simd::splat(-1 as _));
82-
let valid: Simd<Self, N> = simd::simd_or(falses, trues);
83-
simd::simd_reduce_all(valid)
84-
}
85-
}
86-
87-
#[inline]
88-
fn eq(self, other: Self) -> bool { self == other }
89-
90-
#[inline]
91-
fn to_usize(self) -> usize {
92-
self as usize
93-
}
94-
95-
#[inline]
96-
fn max_unsigned() -> u64 {
97-
<$unsigned>::MAX as u64
98-
}
99-
100-
type Unsigned = $unsigned;
101-
102-
const TRUE: Self = -1;
103-
const FALSE: Self = 0;
104-
}
105-
106-
// Safety: this is a valid mask element type
107-
unsafe impl MaskElement for $ty {}
108-
}
109-
}
110-
111-
impl_element! { i8, u8 }
112-
impl_element! { i16, u16 }
113-
impl_element! { i32, u32 }
114-
impl_element! { i64, u64 }
115-
impl_element! { isize, usize }
116-
11733
/// A SIMD vector mask for `N` elements matching the element type `T`.
11834
///
11935
/// Masks represent boolean inclusion/exclusion on a per-element basis.
@@ -155,9 +71,9 @@ where
15571
#[rustc_const_unstable(feature = "portable_simd", issue = "86656")]
15672
pub const fn splat(value: bool) -> Self {
15773
Self(Simd::splat(if value {
158-
<T::Mask as Sealed>::TRUE
74+
<T::Mask as MaskElement>::TRUE
15975
} else {
160-
<T::Mask as Sealed>::FALSE
76+
<T::Mask as MaskElement>::FALSE
16177
}))
16278
}
16379

@@ -208,7 +124,7 @@ where
208124
pub unsafe fn from_simd_unchecked(value: Simd<T::Mask, N>) -> Self {
209125
// Safety: the caller must confirm this invariant
210126
unsafe {
211-
core::intrinsics::assume(<T::Mask as Sealed>::valid(value));
127+
core::intrinsics::assume(<T::Mask as MaskElement>::valid(value));
212128
}
213129
Self(value)
214130
}
@@ -223,7 +139,7 @@ where
223139
#[track_caller]
224140
pub fn from_simd(value: Simd<T::Mask, N>) -> Self {
225141
assert!(
226-
<T::Mask as Sealed>::valid(value),
142+
<T::Mask as MaskElement>::valid(value),
227143
"all values must be either 0 or -1",
228144
);
229145
// Safety: the validity has been checked
@@ -256,9 +172,9 @@ where
256172
pub unsafe fn test_unchecked(&self, index: usize) -> bool {
257173
// Safety: the caller must confirm this invariant
258174
unsafe {
259-
<T::Mask as Sealed>::eq(
175+
<T::Mask as MaskElement>::eq(
260176
*self.0.as_array().get_unchecked(index),
261-
<T::Mask as Sealed>::TRUE,
177+
<T::Mask as MaskElement>::TRUE,
262178
)
263179
}
264180
}
@@ -271,7 +187,7 @@ where
271187
#[must_use = "method returns a new bool and does not mutate the original value"]
272188
#[track_caller]
273189
pub fn test(&self, index: usize) -> bool {
274-
<T::Mask as Sealed>::eq(self.0[index], <T::Mask as Sealed>::TRUE)
190+
<T::Mask as MaskElement>::eq(self.0[index], <T::Mask as MaskElement>::TRUE)
275191
}
276192

277193
/// Sets the value of the specified element.
@@ -283,9 +199,9 @@ where
283199
// Safety: the caller must confirm this invariant
284200
unsafe {
285201
*self.0.as_mut_array().get_unchecked_mut(index) = if value {
286-
<T::Mask as Sealed>::TRUE
202+
<T::Mask as MaskElement>::TRUE
287203
} else {
288-
<T::Mask as Sealed>::FALSE
204+
<T::Mask as MaskElement>::FALSE
289205
}
290206
}
291207
}
@@ -298,9 +214,9 @@ where
298214
#[track_caller]
299215
pub fn set(&mut self, index: usize, value: bool) {
300216
self.0[index] = if value {
301-
<T::Mask as Sealed>::TRUE
217+
<T::Mask as MaskElement>::TRUE
302218
} else {
303-
<T::Mask as Sealed>::FALSE
219+
<T::Mask as MaskElement>::FALSE
304220
}
305221
}
306222

@@ -372,8 +288,8 @@ where
372288
#[must_use = "method returns a new mask and does not mutate the original value"]
373289
pub fn from_bitmask(bitmask: u64) -> Self {
374290
Self(bitmask.select(
375-
Simd::splat(<T::Mask as Sealed>::TRUE),
376-
Simd::splat(<T::Mask as Sealed>::FALSE),
291+
Simd::splat(<T::Mask as MaskElement>::TRUE),
292+
Simd::splat(<T::Mask as MaskElement>::FALSE),
377293
))
378294
}
379295

@@ -426,20 +342,20 @@ where
426342
};
427343

428344
// Safety: the input and output are integer vectors
429-
let masked_index: Simd<<T::Mask as Sealed>::Unsigned, N> =
345+
let masked_index: Simd<<T::Mask as MaskElement>::Unsigned, N> =
430346
unsafe { core::intrinsics::simd::simd_cast(masked_index) };
431347

432348
// Safety: the input is an integer vectors
433-
let min_index: <T::Mask as Sealed>::Unsigned =
349+
let min_index: <T::Mask as MaskElement>::Unsigned =
434350
unsafe { core::intrinsics::simd::simd_reduce_min(masked_index) };
435351

436352
// Safety: the return value is the unsigned version of T
437353
let min_index: T::Mask = unsafe { core::mem::transmute_copy(&min_index) };
438354

439-
if <T::Mask as Sealed>::eq(min_index, <T::Mask as Sealed>::TRUE) {
355+
if <T::Mask as MaskElement>::eq(min_index, <T::Mask as MaskElement>::TRUE) {
440356
None
441357
} else {
442-
Some(<T::Mask as Sealed>::to_usize(min_index))
358+
Some(<T::Mask as MaskElement>::to_usize(min_index))
443359
}
444360
}
445361
}

crates/core_simd/src/vector.rs

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::simd::{
2-
LaneCount, Mask, MaskElement, SupportedLaneCount, Swizzle,
2+
LaneCount, Mask, SupportedLaneCount, Swizzle,
33
cmp::SimdPartialOrd,
44
num::SimdUint,
55
ptr::{SimdConstPtr, SimdMutPtr},
@@ -1074,8 +1074,32 @@ where
10741074
}
10751075
}
10761076

1077-
mod sealed {
1077+
pub(crate) mod sealed {
1078+
use super::*;
1079+
use crate::simd::SimdCast;
1080+
10781081
pub trait Sealed {}
1082+
1083+
/// These functions prevent other traits from bleeding into the parent bounds.
1084+
///
1085+
/// For example, `eq` could be provided by requiring `MaskElement: PartialEq`, but that would
1086+
/// prevent us from ever removing that bound, or from implementing `MaskElement` on
1087+
/// non-`PartialEq` types in the future.
1088+
pub trait MaskElement: SimdCast + Sealed {
1089+
fn valid<const N: usize>(value: Simd<Self, N>) -> bool
1090+
where
1091+
LaneCount<N>: SupportedLaneCount,
1092+
Self: SimdElement;
1093+
1094+
fn eq(self, other: Self) -> bool;
1095+
1096+
fn to_usize(self) -> usize;
1097+
1098+
type Unsigned: SimdElement;
1099+
1100+
const TRUE: Self;
1101+
const FALSE: Self;
1102+
}
10791103
}
10801104
use sealed::Sealed;
10811105

@@ -1089,7 +1113,7 @@ use sealed::Sealed;
10891113
/// even when no soundness guarantees are broken by allowing the user to try.
10901114
pub unsafe trait SimdElement: Sealed + Copy {
10911115
/// The mask element type corresponding to this element type.
1092-
type Mask: MaskElement;
1116+
type Mask: sealed::MaskElement;
10931117
}
10941118

10951119
impl Sealed for u8 {}
@@ -1133,6 +1157,51 @@ impl Sealed for i8 {}
11331157
unsafe impl SimdElement for i8 {
11341158
type Mask = i8;
11351159
}
1160+
macro_rules! impl_mask_element {
1161+
($ty:ty, $unsigned:ty) => {
1162+
impl sealed::MaskElement for $ty {
1163+
#[inline]
1164+
fn valid<const N: usize>(value: Simd<Self, N>) -> bool
1165+
where
1166+
LaneCount<N>: SupportedLaneCount,
1167+
Self: SimdElement,
1168+
{
1169+
unsafe {
1170+
use core::intrinsics::simd;
1171+
let falses: Simd<Self, N> = simd::simd_eq(value, Simd::splat(0 as _));
1172+
let trues: Simd<Self, N> = simd::simd_eq(value, Simd::splat(-1 as _));
1173+
let valid: Simd<Self, N> = simd::simd_or(falses, trues);
1174+
simd::simd_reduce_all(valid)
1175+
}
1176+
}
1177+
1178+
#[inline]
1179+
fn eq(self, other: Self) -> bool {
1180+
self == other
1181+
}
1182+
1183+
#[inline]
1184+
fn to_usize(self) -> usize {
1185+
self as usize
1186+
}
1187+
1188+
type Unsigned = $unsigned;
1189+
1190+
const TRUE: Self = -1;
1191+
const FALSE: Self = 0;
1192+
}
1193+
};
1194+
}
1195+
1196+
impl_mask_element! { i8, u8 }
1197+
1198+
impl_mask_element! { i16, u16 }
1199+
1200+
impl_mask_element! { i32, u32 }
1201+
1202+
impl_mask_element! { i64, u64 }
1203+
1204+
impl_mask_element! { isize, usize }
11361205

11371206
impl Sealed for i16 {}
11381207

crates/core_simd/tests/masks.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ macro_rules! test_mask_api {
113113

114114
#[test]
115115
fn cast() {
116-
fn cast_impl<T: core_simd::simd::MaskElement>()
116+
fn cast_impl<T: core_simd::simd::SimdElement>()
117117
where
118118
Mask<$type, 8>: Into<Mask<T, 8>>,
119119
{

0 commit comments

Comments
 (0)