Skip to content

Commit

Permalink
Fix unnecessary auto-vectorization in residual loop of SAD intrinsics (
Browse files Browse the repository at this point in the history
…#2897)

There is no way to force the compiler to not auto-vectorize a loop without changing the global compilation flags. However, it is possible with inline assembly which was stabilized in Rust 1.59.0.

Before, the compiler aggressively vectorized a loop that was supposed to be scalar since it was only a residual loop of an already vectorized function. With this PR, the correct assembly code is generated: https://godbolt.org/z/e6KTfqnxG
  • Loading branch information
redzic authored Mar 1, 2022
1 parent 5d38546 commit f1a80d8
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 34 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/rav1e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
matrix:
conf:
- beta-build
- 1.56.0-tests
- 1.59.0-tests
- aom-tests
- dav1d-tests
- no-asm-tests
Expand All @@ -59,8 +59,8 @@ jobs:
include:
- conf: beta-build
toolchain: beta
- conf: 1.56.0-tests
toolchain: 1.56.0
- conf: 1.59.0-tests
toolchain: 1.59.0
- conf: aom-tests
toolchain: stable
- conf: dav1d-tests
Expand Down Expand Up @@ -125,7 +125,7 @@ jobs:
echo "$NASM_SHA256 nasm_${NASM_VERSION}_amd64.deb" >> CHECKSUMS
- name: Add aom
if: >
matrix.conf == '1.56.0-tests' || matrix.conf == 'aom-tests' ||
matrix.conf == '1.59.0-tests' || matrix.conf == 'aom-tests' ||
matrix.conf == 'grcov-coveralls'
env:
LINK: https://mirror.cogentco.com/debian/pool/main/a/aom
Expand All @@ -141,7 +141,7 @@ jobs:
echo "$AOM_LIB3_SHA256 libaom3_${AOM3_VERSION}_amd64.deb" >> CHECKSUMS
- name: Add dav1d
if: >
matrix.conf == '1.56.0-tests' || matrix.conf == 'dav1d-tests' ||
matrix.conf == '1.59.0-tests' || matrix.conf == 'dav1d-tests' ||
matrix.conf == 'grcov-coveralls' || matrix.conf == 'fuzz' || matrix.conf == 'no-asm-tests'
env:
LINK: https://mirror.cogentco.com/debian/pool/main/d/dav1d
Expand Down Expand Up @@ -216,8 +216,8 @@ jobs:
- name: Start sccache server
run: |
sccache --start-server
- name: Run 1.56.0 tests
if: matrix.toolchain == '1.56.0' && matrix.conf == '1.56.0-tests'
- name: Run 1.59.0 tests
if: matrix.toolchain == '1.59.0' && matrix.conf == '1.59.0-tests'
run: |
cargo test --workspace --verbose \
--features=decode_test,decode_test_dav1d,quick_test,capi
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "rav1e"
version = "0.5.0"
authors = ["Thomas Daede <tdaede@xiph.org>"]
edition = "2021"
rust-version = "1.56.0"
rust-version = "1.59.0"
build = "build.rs"
include = [
"/Cargo.toml",
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ For the foreseeable future, a weekly pre-release of rav1e will be [published](ht

### Toolchain: Rust

rav1e currently requires Rust 1.56.0 or later to build.
rav1e currently requires Rust 1.59.0 or later to build.

### Dependency: NASM
Some `x86_64`-specific optimizations require [NASM](https://nasm.us/) `2.14.02` or newer and are enabled by default.
Expand Down
2 changes: 1 addition & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ fn build_asm_files() {
fn rustc_version_check() {
// This should match the version in the CI
// Make sure to updated README.md when this changes.
const REQUIRED_VERSION: &str = "1.56.0";
const REQUIRED_VERSION: &str = "1.59.0";
if version().unwrap() < Version::parse(REQUIRED_VERSION).unwrap() {
eprintln!("rav1e requires rustc >= {}.", REQUIRED_VERSION);
exit(1);
Expand Down
49 changes: 25 additions & 24 deletions src/asm/x86/sad_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,31 @@ use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

use std::arch::asm;
use std::hint::unreachable_unchecked;
use std::mem;

/// SAFETY: src and dst must be the same length and less than 16 elements
#[inline(always)]
unsafe fn sad_below16_8bpc(src: &[u8], dst: &[u8]) -> i64 {
// we have a separate function for this so that the autovectorizer
// does not unroll the loop too much

unsafe fn sad_scalar(src: &[u8], dst: &[u8]) -> i64 {
if src.len() != dst.len() {
unreachable_unchecked()
}
if src.len() >= 16 {
unreachable_unchecked()
}
if dst.len() >= 16 {
unreachable_unchecked()

let mut sum = 0;

for i in 0..src.len() {
// We use inline assembly here to force the compiler to not auto-vectorize the loop,
// since it is already vectorized manually.
asm!(
"add {sum}, {x}",
sum = out(reg) sum,
x = in(reg) (*src.get_unchecked(i) as i64 - *dst.get_unchecked(i) as i64).abs(),
options(nostack)
);
}

src
.iter()
.zip(dst.iter())
.map(|(&p1, &p2)| (p1 as i16 - p2 as i16).abs() as i64)
.sum::<i64>()
sum
}

/// SAFETY: src and dst must be the same length and less than 32 elements
Expand All @@ -51,8 +52,8 @@ unsafe fn sad_below32_8bpc_sse2(src: &[u8], dst: &[u8]) -> i64 {
}

if src.len() >= 16 {
let src_u8x16 = _mm_loadu_si128(src.as_ptr() as *const _);
let dst_u8x16 = _mm_loadu_si128(dst.as_ptr() as *const _);
let src_u8x16 = _mm_load_si128(src.as_ptr() as *const _);
let dst_u8x16 = _mm_load_si128(dst.as_ptr() as *const _);
let result = _mm_sad_epu8(src_u8x16, dst_u8x16);
let mut sum = mem::transmute::<_, [i64; 2]>(result).iter().sum::<i64>();

Expand All @@ -62,12 +63,12 @@ unsafe fn sad_below32_8bpc_sse2(src: &[u8], dst: &[u8]) -> i64 {
if remaining != 0 {
let src_extra = src.get_unchecked(16..);
let dst_extra = dst.get_unchecked(16..);
sum += sad_below16_8bpc(src_extra, dst_extra);
sum += sad_scalar(src_extra, dst_extra);
}

sum
} else {
sad_below16_8bpc(src, dst)
sad_scalar(src, dst)
}
}

Expand All @@ -90,8 +91,8 @@ unsafe fn sad_8bpc_avx2(src: &[u8], dst: &[u8]) -> i64 {
let main_sum = src_chunks
.zip(dst_chunks)
.map(|(src_chunk, dst_chunk)| {
let src = _mm256_loadu_si256(src_chunk.as_ptr() as *const _);
let dst = _mm256_loadu_si256(dst_chunk.as_ptr() as *const _);
let src = _mm256_load_si256(src_chunk.as_ptr() as *const _);
let dst = _mm256_load_si256(dst_chunk.as_ptr() as *const _);

_mm256_sad_epu8(src, dst)
})
Expand Down Expand Up @@ -123,13 +124,13 @@ unsafe fn sad_8bpc_sse2(src: &[u8], dst: &[u8]) -> i64 {
let (src_rem, dst_rem) = (src_chunks.remainder(), dst_chunks.remainder());

if src_chunks.len() == 0 {
sad_below16_8bpc(src_rem, dst_rem)
sad_scalar(src_rem, dst_rem)
} else {
let main_sum = src_chunks
.zip(dst_chunks)
.map(|(src_chunk, dst_chunk)| {
let src = _mm_loadu_si128(src_chunk.as_ptr() as *const _);
let dst = _mm_loadu_si128(dst_chunk.as_ptr() as *const _);
let src = _mm_load_si128(src_chunk.as_ptr() as *const _);
let dst = _mm_load_si128(dst_chunk.as_ptr() as *const _);

_mm_sad_epu8(src, dst)
})
Expand All @@ -140,7 +141,7 @@ unsafe fn sad_8bpc_sse2(src: &[u8], dst: &[u8]) -> i64 {
mem::transmute::<_, [i64; 2]>(main_sum).iter().sum::<i64>();

if !src_rem.is_empty() {
main_sum += sad_below16_8bpc(src_rem, dst_rem);
main_sum += sad_scalar(src_rem, dst_rem);
}

main_sum
Expand Down

0 comments on commit f1a80d8

Please sign in to comment.