diff --git a/src/libWiiPy/archive/lz77.py b/src/libWiiPy/archive/lz77.py index 55c4629..70161e8 100644 --- a/src/libWiiPy/archive/lz77.py +++ b/src/libWiiPy/archive/lz77.py @@ -7,6 +7,12 @@ import io from dataclasses import dataclass as _dataclass +LZ_MIN_DISTANCE = 0x01 # Minimum distance for each reference. +LZ_MAX_DISTANCE = 0x1000 # Maximum distance for each reference. +LZ_MIN_LENGTH = 0x03 # Minimum length for each reference. +LZ_MAX_LENGTH = 0x12 # Maximum length for each reference. + + @_dataclass class _LZNode: dist: int = 0 @@ -17,34 +23,29 @@ class _LZNode: def _compress_compare_bytes(byte1: bytes, offset1: int, byte2: bytes, offset2: int, abs_len_max: int) -> int: # Compare bytes up to the maximum length we can match. num_matched = 0 - while abs_len_max > 0: - if byte1[offset1] != byte2[offset2]: + mem1 = memoryview(byte1) + mem2 = memoryview(byte2) + while num_matched < abs_len_max: + if mem1[offset1 + num_matched] != mem2[offset2 + num_matched]: break - offset1 += 1 - offset2 += 1 - abs_len_max -= 1 num_matched += 1 return num_matched def _compress_search_matches(buffer: bytes, pos: int) -> (int, int): bytes_left = len(buffer) - pos + global LZ_MAX_DISTANCE, LZ_MAX_LENGTH, LZ_MIN_DISTANCE # Default to only looking back 4096 bytes, unless we've moved fewer than 4096 bytes, in which case we should # only look as far back as we've gone. - max_dist = 0x1000 - if max_dist > pos: - max_dist = pos + max_dist = min(LZ_MAX_DISTANCE, pos) + # Default to only matching up to 18 bytes, unless fewer than 18 bytes remain, in which case we can only match + # up to that many bytes. + max_len = min(LZ_MAX_LENGTH, bytes_left) # Log the longest match we found and its offset. biggest_match = 0 biggest_match_pos = 0 - # Default to only matching up to 18 bytes, unless fewer than 18 bytes remain, in which case we can only match - # up to that many bytes. - max_len = 0x12 - if max_len > bytes_left: - max_len = bytes_left - min_dist = 0x01 # Search for matches. - for i in range(min_dist, max_dist + 1): + for i in range(LZ_MIN_DISTANCE, max_dist + 1): num_matched = _compress_compare_bytes(buffer, pos - i, buffer, pos, max_len) if num_matched > biggest_match: biggest_match = num_matched @@ -55,11 +56,11 @@ def _compress_search_matches(buffer: bytes, pos: int) -> (int, int): def _compress_node_is_ref(node: _LZNode) -> bool: - return node.len >= 0x03 + return node.len >= LZ_MIN_LENGTH def _compress_get_node_cost(length: int) -> int: - if length >= 0x03: + if length >= LZ_MAX_LENGTH: num_bytes = 2 else: num_bytes = 1 @@ -80,21 +81,22 @@ def compress_lz77(data: bytes) -> bytes: nodes = [_LZNode() for _ in range(len(data))] # Iterate over the uncompressed data, starting from the end. pos = len(data) + global LZ_MAX_LENGTH, LZ_MIN_LENGTH, LZ_MIN_DISTANCE while pos: pos -= 1 node = nodes[pos] # Limit the maximum search length when we're near the end of the file. - max_search_len = 0x12 + max_search_len = LZ_MAX_LENGTH if max_search_len > (len(data) - pos): max_search_len = len(data) - pos - if max_search_len < 0x03: + if max_search_len < LZ_MIN_DISTANCE: max_search_len = 1 # Initialize as 1 for each, since that's all we could use if we weren't compressing. length, dist = 1, 1 - if max_search_len >= 0x03: + if max_search_len >= LZ_MIN_LENGTH: length, dist = _compress_search_matches(data, pos) # Treat as direct bytes if it's too short to copy. - if length == 0 or length < 0x03: + if length == 0 or length < LZ_MIN_LENGTH: length = 1 # If the node goes to the end of the file, the weight is the cost of the node. if pos + length == len(data): @@ -112,7 +114,7 @@ def compress_lz77(data: bytes) -> bytes: len_best = length weight_best = weight length -= 1 - if length != 0 and length < 0x03: + if length != 0 and length < LZ_MIN_LENGTH: length = 1 node.len = len_best node.dist = dist @@ -123,9 +125,8 @@ def compress_lz77(data: bytes) -> bytes: with io.BytesIO() as buffer: # Write the header data. buffer.write(b'LZ77\x10') # The LZ type on the Wii is *always* 0x10. - buffer.write(int.to_bytes(len(data), 3, 'little')) + buffer.write(len(data).to_bytes(3, 'little')) - current_node = nodes[0] src_pos = 0 while src_pos < len(data): head = 0 @@ -133,25 +134,23 @@ def compress_lz77(data: bytes) -> bytes: i = 0 while i < 8 and src_pos < len(data): + current_node = nodes[src_pos] length = current_node.len dist = current_node.dist # This is a reference node. if _compress_node_is_ref(current_node): - encoded = ((dist - 0x01) | ((length - 0x03) << 12)) & 0xFF # This is a uint16_t. - buffer.write(int.to_bytes((encoded >> 8) & 0xFF)) - buffer.write(int.to_bytes((encoded >> 0) & 0xFF)) - head |= 1 << (7 - i) - head = head & 0xF + encoded = ((dist - LZ_MIN_DISTANCE) | ((length - LZ_MAX_LENGTH) << 12)) & 0xFFFF # A uint16_t. + buffer.write(((encoded >> 8) & 0xFF).to_bytes(1)) + buffer.write(((encoded >> 0) & 0xFF).to_bytes(1)) + head = (head | (1 << (7 - i))) & 0xFF # This is a direct copy node. else: - buffer.write(data[src_pos:]) - + buffer.write(data[src_pos:src_pos + 1]) src_pos += length - current_node = nodes[src_pos] pos = buffer.tell() buffer.seek(head_pos) - buffer.write(head) + buffer.write(head.to_bytes(1)) buffer.seek(pos) buffer.seek(0)