Skip to content

Commit e6501a8

Browse files
committed
Fix endianness correction
1 parent e75a8d8 commit e6501a8

File tree

3 files changed

+54
-30
lines changed

3 files changed

+54
-30
lines changed

crates/core_simd/src/masks.rs

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,36 @@ use crate::simd::{LaneCount, Select, Simd, SimdCast, SimdElement, SupportedLaneC
66
use core::cmp::Ordering;
77
use core::{fmt, mem};
88

9+
pub(crate) trait FixEndianness {
10+
fn fix_endianness(self, elements: usize) -> Self;
11+
}
12+
13+
macro_rules! impl_fix_endianness {
14+
{ $($int:ty),* } => {
15+
$(
16+
impl FixEndianness for $int {
17+
#[inline(always)]
18+
fn fix_endianness(self, elements: usize) -> Self {
19+
if cfg!(target_endian = "big") {
20+
let rev = <$int>::reverse_bits(self);
21+
let bitsize = size_of::<$int>() * 8;
22+
if elements < bitsize {
23+
// Shift things back to the right
24+
rev >> (bitsize - elements)
25+
} else {
26+
rev
27+
}
28+
} else {
29+
self
30+
}
31+
}
32+
}
33+
)*
34+
}
35+
}
36+
37+
impl_fix_endianness! { u8, u16, u32, u64 }
38+
939
mod sealed {
1040
use super::*;
1141

@@ -283,7 +313,9 @@ where
283313
#[must_use = "method returns a new integer and does not mutate the original value"]
284314
pub fn to_bitmask(self) -> u64 {
285315
#[inline]
286-
unsafe fn to_bitmask_impl<T, U, const M: usize, const N: usize>(mask: Mask<T, N>) -> U
316+
unsafe fn to_bitmask_impl<T, U: FixEndianness, const M: usize, const N: usize>(
317+
mask: Mask<T, N>,
318+
) -> U
287319
where
288320
T: MaskElement,
289321
LaneCount<M>: SupportedLaneCount,
@@ -292,11 +324,14 @@ where
292324
let resized = mask.resize::<M>(false);
293325

294326
// Safety: `resized` is an integer vector with length M, which must match T
295-
unsafe { core::intrinsics::simd::simd_bitmask(resized.0) }
327+
let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized.0) };
328+
329+
// LLVM assumes bit order should match endianness
330+
bitmask.fix_endianness(N)
296331
}
297332

298333
// TODO modify simd_bitmask to zero-extend output, making this unnecessary
299-
let bitmask = if N <= 8 {
334+
if N <= 8 {
300335
// Safety: bitmask matches length
301336
unsafe { to_bitmask_impl::<T, u8, 8, N>(self) as u64 }
302337
} else if N <= 16 {
@@ -308,13 +343,6 @@ where
308343
} else {
309344
// Safety: bitmask matches length
310345
unsafe { to_bitmask_impl::<T, u64, 64, N>(self) }
311-
};
312-
313-
// LLVM assumes bit order should match endianness
314-
if cfg!(target_endian = "big") {
315-
bitmask.reverse_bits() >> (64 - N.min(64))
316-
} else {
317-
bitmask
318346
}
319347
}
320348

crates/core_simd/src/select.rs

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use crate::simd::{LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount};
1+
use crate::simd::{
2+
FixEndianness, LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount,
3+
};
24

35
/// Choose elements from two vectors using a mask.
46
///
@@ -82,21 +84,8 @@ where
8284
assert!(N <= 64, "number of elements can't be greater than 64");
8385
}
8486

85-
// LLVM assumes bit order should match endianness
86-
let bitmask = if cfg!(target_endian = "big") {
87-
let rev = self.reverse_bits();
88-
if N < 64 {
89-
// Shift things back to the right
90-
rev >> (64 - N)
91-
} else {
92-
rev
93-
}
94-
} else {
95-
self
96-
};
97-
9887
#[inline]
99-
unsafe fn select_impl<T, U, const M: usize, const N: usize>(
88+
unsafe fn select_impl<T, U: FixEndianness, const M: usize, const N: usize>(
10089
bitmask: U,
10190
true_values: Simd<T, N>,
10291
false_values: Simd<T, N>,
@@ -110,6 +99,9 @@ where
11099
let true_values = true_values.resize::<M>(default);
111100
let false_values = false_values.resize::<M>(default);
112101

102+
// LLVM assumes bit order should match endianness
103+
let bitmask = bitmask.fix_endianness(N);
104+
113105
// Safety: the caller guarantees that the size of U matches M
114106
let selected = unsafe {
115107
core::intrinsics::simd::simd_select_bitmask(bitmask, true_values, false_values)
@@ -120,15 +112,19 @@ where
120112

121113
// TODO modify simd_bitmask_select to truncate input, making this unnecessary
122114
if N <= 8 {
115+
let bitmask = self as u8;
123116
// Safety: bitmask matches length
124-
unsafe { select_impl::<T, u8, 8, N>(bitmask as u8, true_values, false_values) }
117+
unsafe { select_impl::<T, u8, 8, N>(bitmask, true_values, false_values) }
125118
} else if N <= 16 {
119+
let bitmask = self as u16;
126120
// Safety: bitmask matches length
127-
unsafe { select_impl::<T, u16, 16, N>(bitmask as u16, true_values, false_values) }
121+
unsafe { select_impl::<T, u16, 16, N>(bitmask, true_values, false_values) }
128122
} else if N <= 32 {
123+
let bitmask = self as u32;
129124
// Safety: bitmask matches length
130-
unsafe { select_impl::<T, u32, 32, N>(bitmask as u32, true_values, false_values) }
125+
unsafe { select_impl::<T, u32, 32, N>(bitmask, true_values, false_values) }
131126
} else {
127+
let bitmask = self;
132128
// Safety: bitmask matches length
133129
unsafe { select_impl::<T, u64, 64, N>(bitmask, true_values, false_values) }
134130
}

crates/core_simd/src/swizzle_dyn.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ unsafe fn armv7_neon_swizzle_u8x16(bytes: Simd<u8, 16>, idxs: Simd<u8, 16>) -> S
139139
#[inline]
140140
#[allow(clippy::let_and_return)]
141141
unsafe fn avx2_pshufb(bytes: Simd<u8, 32>, idxs: Simd<u8, 32>) -> Simd<u8, 32> {
142-
use crate::simd::{cmp::SimdPartialOrd, Select};
142+
use crate::simd::{Select, cmp::SimdPartialOrd};
143143
#[cfg(target_arch = "x86")]
144144
use core::arch::x86;
145145
#[cfg(target_arch = "x86_64")]
@@ -200,7 +200,7 @@ fn zeroing_idxs<const N: usize>(idxs: Simd<u8, N>) -> Simd<u8, N>
200200
where
201201
LaneCount<N>: SupportedLaneCount,
202202
{
203-
use crate::simd::{cmp::SimdPartialOrd, Select};
203+
use crate::simd::{Select, cmp::SimdPartialOrd};
204204
idxs.simd_lt(Simd::splat(N as u8))
205205
.select(idxs, Simd::splat(u8::MAX))
206206
}

0 commit comments

Comments
 (0)