LZ77 compression is now fully functional! (But still very slow)

This commit is contained in:
Campbell 2025-01-23 21:27:15 -05:00
parent d6d0af0623
commit e9a110bb1e
Signed by: NinjaCheetah
GPG Key ID: 39C2500E1778B156

View File

@ -7,10 +7,10 @@ import io
from dataclasses import dataclass as _dataclass from dataclasses import dataclass as _dataclass
LZ_MIN_DISTANCE = 0x01 # Minimum distance for each reference. _LZ_MIN_DISTANCE = 0x01 # Minimum distance for each reference.
LZ_MAX_DISTANCE = 0x1000 # Maximum distance for each reference. _LZ_MAX_DISTANCE = 0x1000 # Maximum distance for each reference.
LZ_MIN_LENGTH = 0x03 # Minimum length for each reference. _LZ_MIN_LENGTH = 0x03 # Minimum length for each reference.
LZ_MAX_LENGTH = 0x12 # Maximum length for each reference. _LZ_MAX_LENGTH = 0x12 # Maximum length for each reference.
@_dataclass @_dataclass
@ -23,10 +23,8 @@ class _LZNode:
def _compress_compare_bytes(byte1: bytes, offset1: int, byte2: bytes, offset2: int, abs_len_max: int) -> int: def _compress_compare_bytes(byte1: bytes, offset1: int, byte2: bytes, offset2: int, abs_len_max: int) -> int:
# Compare bytes up to the maximum length we can match. # Compare bytes up to the maximum length we can match.
num_matched = 0 num_matched = 0
mem1 = memoryview(byte1)
mem2 = memoryview(byte2)
while num_matched < abs_len_max: while num_matched < abs_len_max:
if mem1[offset1 + num_matched] != mem2[offset2 + num_matched]: if byte1[offset1 + num_matched] != byte2[offset2 + num_matched]:
break break
num_matched += 1 num_matched += 1
return num_matched return num_matched
@ -34,18 +32,18 @@ def _compress_compare_bytes(byte1: bytes, offset1: int, byte2: bytes, offset2: i
def _compress_search_matches(buffer: bytes, pos: int) -> (int, int): def _compress_search_matches(buffer: bytes, pos: int) -> (int, int):
bytes_left = len(buffer) - pos bytes_left = len(buffer) - pos
global LZ_MAX_DISTANCE, LZ_MAX_LENGTH, LZ_MIN_DISTANCE global _LZ_MAX_DISTANCE, _LZ_MAX_LENGTH, _LZ_MIN_DISTANCE
# Default to only looking back 4096 bytes, unless we've moved fewer than 4096 bytes, in which case we should # Default to only looking back 4096 bytes, unless we've moved fewer than 4096 bytes, in which case we should
# only look as far back as we've gone. # only look as far back as we've gone.
max_dist = min(LZ_MAX_DISTANCE, pos) max_dist = min(_LZ_MAX_DISTANCE, pos)
# Default to only matching up to 18 bytes, unless fewer than 18 bytes remain, in which case we can only match # Default to only matching up to 18 bytes, unless fewer than 18 bytes remain, in which case we can only match
# up to that many bytes. # up to that many bytes.
max_len = min(LZ_MAX_LENGTH, bytes_left) max_len = min(_LZ_MAX_LENGTH, bytes_left)
# Log the longest match we found and its offset. # Log the longest match we found and its offset.
biggest_match = 0 biggest_match = 0
biggest_match_pos = 0 biggest_match_pos = 0
# Search for matches. # Search for matches.
for i in range(LZ_MIN_DISTANCE, max_dist + 1): for i in range(_LZ_MIN_DISTANCE, max_dist + 1):
num_matched = _compress_compare_bytes(buffer, pos - i, buffer, pos, max_len) num_matched = _compress_compare_bytes(buffer, pos - i, buffer, pos, max_len)
if num_matched > biggest_match: if num_matched > biggest_match:
biggest_match = num_matched biggest_match = num_matched
@ -56,11 +54,11 @@ def _compress_search_matches(buffer: bytes, pos: int) -> (int, int):
def _compress_node_is_ref(node: _LZNode) -> bool: def _compress_node_is_ref(node: _LZNode) -> bool:
return node.len >= LZ_MIN_LENGTH return node.len >= _LZ_MIN_LENGTH
def _compress_get_node_cost(length: int) -> int: def _compress_get_node_cost(length: int) -> int:
if length >= LZ_MAX_LENGTH: if length >= _LZ_MIN_LENGTH:
num_bytes = 2 num_bytes = 2
else: else:
num_bytes = 1 num_bytes = 1
@ -81,31 +79,33 @@ def compress_lz77(data: bytes) -> bytes:
nodes = [_LZNode() for _ in range(len(data))] nodes = [_LZNode() for _ in range(len(data))]
# Iterate over the uncompressed data, starting from the end. # Iterate over the uncompressed data, starting from the end.
pos = len(data) pos = len(data)
global LZ_MAX_LENGTH, LZ_MIN_LENGTH, LZ_MIN_DISTANCE global _LZ_MAX_LENGTH, _LZ_MIN_LENGTH, _LZ_MIN_DISTANCE
iters = 0
while pos: while pos:
iters += 1
pos -= 1 pos -= 1
node = nodes[pos] node = nodes[pos]
# Limit the maximum search length when we're near the end of the file. # Limit the maximum search length when we're near the end of the file.
max_search_len = LZ_MAX_LENGTH max_search_len = _LZ_MAX_LENGTH
if max_search_len > (len(data) - pos): if max_search_len > (len(data) - pos):
max_search_len = len(data) - pos max_search_len = len(data) - pos
if max_search_len < LZ_MIN_DISTANCE: if max_search_len < _LZ_MIN_DISTANCE:
max_search_len = 1 max_search_len = 1
# Initialize as 1 for each, since that's all we could use if we weren't compressing. # Initialize as 1 for each, since that's all we could use if we weren't compressing.
length, dist = 1, 1 length, dist = 1, 1
if max_search_len >= LZ_MIN_LENGTH: if max_search_len >= _LZ_MIN_LENGTH:
length, dist = _compress_search_matches(data, pos) length, dist = _compress_search_matches(data, pos)
# Treat as direct bytes if it's too short to copy. # Treat as direct bytes if it's too short to copy.
if length == 0 or length < LZ_MIN_LENGTH: if length == 0 or length < _LZ_MIN_LENGTH:
length = 1 length = 1
# If the node goes to the end of the file, the weight is the cost of the node. # If the node goes to the end of the file, the weight is the cost of the node.
if pos + length == len(data): if (pos + length) == len(data):
node.len = length node.len = length
node.dist = dist node.dist = dist
node.weight = _compress_get_node_cost(length) node.weight = _compress_get_node_cost(length)
# Otherwise, search for possible matches and determine the one with the best cost. # Otherwise, search for possible matches and determine the one with the best cost.
else: else:
weight_best = 0xFFFFFFFF # This was originally UINT_MAX, but that isn't a thing here. weight_best = 0xFFFFFFFF # This was originally UINT_MAX, but that isn't a thing here so 32-bit it is!
len_best = 1 len_best = 1
while length: while length:
weight_next = nodes[pos + length].weight weight_next = nodes[pos + length].weight
@ -114,13 +114,11 @@ def compress_lz77(data: bytes) -> bytes:
len_best = length len_best = length
weight_best = weight weight_best = weight
length -= 1 length -= 1
if length != 0 and length < LZ_MIN_LENGTH: if length != 0 and length < _LZ_MIN_LENGTH:
length = 1 length = 1
node.len = len_best node.len = len_best
node.dist = dist node.dist = dist
node.weight = weight_best node.weight = weight_best
# Maximum size of the compressed file.
max_compressed_size = int(4 + len(data) + (len(data) + 7) / 8)
# Write the header data. # Write the header data.
with io.BytesIO() as buffer: with io.BytesIO() as buffer:
# Write the header data. # Write the header data.
@ -131,6 +129,7 @@ def compress_lz77(data: bytes) -> bytes:
while src_pos < len(data): while src_pos < len(data):
head = 0 head = 0
head_pos = buffer.tell() head_pos = buffer.tell()
buffer.write(b'\x00') # Reserve a byte for the chunk head.
i = 0 i = 0
while i < 8 and src_pos < len(data): while i < 8 and src_pos < len(data):
@ -139,14 +138,14 @@ def compress_lz77(data: bytes) -> bytes:
dist = current_node.dist dist = current_node.dist
# This is a reference node. # This is a reference node.
if _compress_node_is_ref(current_node): if _compress_node_is_ref(current_node):
encoded = ((dist - LZ_MIN_DISTANCE) | ((length - LZ_MAX_LENGTH) << 12)) & 0xFFFF # A uint16_t. encoded = (((length - _LZ_MIN_LENGTH) & 0xF) << 12) | ((dist - _LZ_MIN_DISTANCE) & 0xFFF)
buffer.write(((encoded >> 8) & 0xFF).to_bytes(1)) buffer.write(encoded.to_bytes(2))
buffer.write(((encoded >> 0) & 0xFF).to_bytes(1))
head = (head | (1 << (7 - i))) & 0xFF head = (head | (1 << (7 - i))) & 0xFF
# This is a direct copy node. # This is a direct copy node.
else: else:
buffer.write(data[src_pos:src_pos + 1]) buffer.write(data[src_pos:src_pos + 1])
src_pos += length src_pos += length
i += 1
pos = buffer.tell() pos = buffer.tell()
buffer.seek(head_pos) buffer.seek(head_pos)