Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework decompression algorithm to operate over 4-byte chunks #45

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 81 additions & 27 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,46 +250,100 @@ impl<'a> Decompressor<'a> {
Self { symbols, lengths }
}

#[inline]
unsafe fn nonescape(&self, code: u8, decoded: &mut[u8], in_pos: &mut usize, out_pos: &mut usize) {
debug_assert!(code != ESCAPE_CODE);

// SAFETY: code is in range 0..255
// The symbol and length tables are both of length 256, so this is safe.
let symbol = unsafe { *self.symbols.get_unchecked(code as usize) };
let length = unsafe { *self.lengths.get_unchecked(code as usize) };

// SAFETY: out_pos is always 8 bytes or more from the end of decoded buffer
unsafe {
let write_addr = decoded.as_mut_ptr().byte_add(*out_pos) as *mut u64;
// Perform 8 byte unaligned write.
write_addr.write_unaligned(symbol.as_u64());
}
*in_pos += 1;
*out_pos += length as usize;
}

#[inline]
unsafe fn escape(&self, compressed: &[u8], decoded: &mut[u8], in_pos: &mut usize, out_pos: &mut usize) {
// SAFETY: out_pos is always 8 bytes or more from the end of decoded buffer
// SAFETY: ESCAPE_CODE can not be the last byte of the compressed stream
let write_addr = decoded.as_mut_ptr().byte_add(*out_pos);
std::ptr::write(write_addr, *compressed.get_unchecked(*in_pos + 1));

*in_pos += 2;
*out_pos += 1;
}

/// Decompress a byte slice that was previously returned by a compressor using
/// the same symbol table.
pub fn decompress(&self, compressed: &[u8]) -> Vec<u8> {
let mut decoded: Vec<u8> = Vec::with_capacity(size_of::<Symbol>() * (compressed.len() + 1));
let ptr = decoded.as_mut_ptr();

let mut in_pos = 0;
let mut out_pos = 0;

while in_pos + 4 <= compressed.len() {
// out_pos can grow at most 32 bytes per iteration, and we start at 0
debug_assert!(out_pos <= decoded.capacity() - 4 * size_of::<Symbol>());

while in_pos < compressed.len() {
// out_pos can grow at most 8 bytes per iteration, and we start at 0
debug_assert!(out_pos <= decoded.capacity() - size_of::<Symbol>());
// SAFETY: in_pos is always in range 0..compressed.len()
let code = unsafe { *compressed.get_unchecked(in_pos) };
if code == ESCAPE_CODE {
// Advance by one, do raw write.
in_pos += 1;
// SAFETY: out_pos is always 8 bytes or more from the end of decoded buffer
// SAFETY: ESCAPE_CODE can not be the last byte of the compressed stream
let mut next_block: u32 = 0;
// SAFETY: in_pos is always in range 0..(compress.len() - 4)
unsafe {
std::ptr::copy_nonoverlapping(compressed.as_ptr().byte_add(in_pos), &mut next_block as *mut u32 as *mut u8, size_of::<u32>());
};
let escape_mask: u32 = (next_block & 0x80808080) & ((((!next_block) & 0x7F7F7F7F) + 0x7F7F7F7F) ^ 0x80808080);
debug_assert!(escape_mask & 0x7F7F7F7F == 0);
if escape_mask == 0 {
// fast path no escape codes
// SAFETY: TODO
unsafe {
let write_addr = ptr.byte_add(out_pos);
std::ptr::write(write_addr, *compressed.get_unchecked(in_pos));
}
out_pos += 1;
in_pos += 1;
self.nonescape(*compressed.get_unchecked(in_pos), &mut decoded, &mut in_pos, &mut out_pos);
self.nonescape(*compressed.get_unchecked(in_pos), &mut decoded, &mut in_pos, &mut out_pos);
self.nonescape(*compressed.get_unchecked(in_pos), &mut decoded, &mut in_pos, &mut out_pos);
self.nonescape(*compressed.get_unchecked(in_pos), &mut decoded, &mut in_pos, &mut out_pos);
};
} else {
// SAFETY: code is in range 0..255
// The symbol and length tables are both of length 256, so this is safe.
let symbol = unsafe { *self.symbols.get_unchecked(code as usize) };
let length = unsafe { *self.lengths.get_unchecked(code as usize) };
// SAFETY: out_pos is always 8 bytes or more from the end of decoded buffer
unsafe {
let write_addr = ptr.byte_add(out_pos) as *mut u64;
// Perform 8 byte unaligned write.
write_addr.write_unaligned(symbol.as_u64());
// index of first escape (0..4)
let first_escape = escape_mask.trailing_zeros() >> 3;
debug_assert!(first_escape <= 3);
for _ in 0..first_escape {
unsafe {
let code = *compressed.get_unchecked(in_pos);
self.nonescape(code, &mut decoded, &mut in_pos, &mut out_pos);
};
}
in_pos += 1;
out_pos += length as usize;
unsafe { self.escape(compressed, &mut decoded, &mut in_pos, &mut out_pos); };
}
}
// handle up to 3 final bytes if they exist
// SAFETY: TODO
unsafe {
if in_pos + 2 <= compressed.len() {
let code = *compressed.get_unchecked(in_pos);
if code == ESCAPE_CODE {
self.escape(compressed, &mut decoded, &mut in_pos, &mut out_pos);
} else {
self.nonescape(code, &mut decoded, &mut in_pos, &mut out_pos);
let code = *compressed.get_unchecked(in_pos);
if code == ESCAPE_CODE {
self.escape(compressed, &mut decoded, &mut in_pos, &mut out_pos);
} else {
self.nonescape(code, &mut decoded, &mut in_pos, &mut out_pos);
}
}
}
if in_pos < compressed.len() {
debug_assert!(in_pos + 1 == compressed.len());
// last code cannot be an escape
self.nonescape(*compressed.get_unchecked(in_pos), &mut decoded, &mut in_pos, &mut out_pos);
}
};

assert!(
in_pos >= compressed.len(),
Expand Down
Loading