diff --git a/src/libWiiPy/archive/lz77.py b/src/libWiiPy/archive/lz77.py index 41cf0af..bc2e601 100644 --- a/src/libWiiPy/archive/lz77.py +++ b/src/libWiiPy/archive/lz77.py @@ -5,6 +5,7 @@ import io from dataclasses import dataclass as _dataclass +from typing import List as _List _LZ_MIN_DISTANCE = 0x01 # Minimum distance for each reference. @@ -20,8 +21,9 @@ class _LZNode: weight: int = 0 -def _compress_compare_bytes(buffer: bytes, offset1: int, offset2: int, abs_len_max: int) -> int: - # Compare bytes up to the maximum length we can match. +def _compress_compare_bytes(buffer: _List[int], offset1: int, offset2: int, abs_len_max: int) -> int: + # Compare bytes up to the maximum length we can match. Start by comparing the first 3 bytes, since that's the + # minimum match length and this allows for a more optimized early exit. num_matched = 0 while num_matched < abs_len_max: if buffer[offset1 + num_matched] != buffer[offset2 + num_matched]: @@ -30,9 +32,9 @@ def _compress_compare_bytes(buffer: bytes, offset1: int, offset2: int, abs_len_m return num_matched -def _compress_search_matches(buffer: bytes, pos: int) -> (int, int): +def _compress_search_matches_optimized(buffer: _List[int], pos: int) -> (int, int): bytes_left = len(buffer) - pos - global _LZ_MAX_DISTANCE, _LZ_MAX_LENGTH, _LZ_MIN_DISTANCE + global _LZ_MAX_DISTANCE, _LZ_MIN_LENGTH, _LZ_MAX_LENGTH, _LZ_MIN_DISTANCE # Default to only looking back 4096 bytes, unless we've moved fewer than 4096 bytes, in which case we should # only look as far back as we've gone. max_dist = min(_LZ_MAX_DISTANCE, pos) @@ -52,7 +54,7 @@ def _compress_search_matches(buffer: bytes, pos: int) -> (int, int): return biggest_match, biggest_match_pos -def _compress_search_matches_greedy(buffer: bytes, pos: int) -> (int, int): +def _compress_search_matches_greedy(buffer: _List[int], pos: int) -> (int, int): # Finds and returns the first valid match, rather that finding the best one. bytes_left = len(buffer) - pos global _LZ_MAX_DISTANCE, _LZ_MAX_LENGTH, _LZ_MIN_DISTANCE @@ -90,22 +92,23 @@ def _compress_lz77_optimized(data: bytes) -> bytes: # Iterate over the uncompressed data, starting from the end. pos = len(data) global _LZ_MAX_LENGTH, _LZ_MIN_LENGTH, _LZ_MIN_DISTANCE + data_list = list(data) while pos: pos -= 1 node = nodes[pos] # Limit the maximum search length when we're near the end of the file. - max_search_len = min(_LZ_MAX_LENGTH, len(data) - pos) + max_search_len = min(_LZ_MAX_LENGTH, len(data_list) - pos) if max_search_len < _LZ_MIN_DISTANCE: max_search_len = 1 # Initialize as 1 for each, since that's all we could use if we weren't compressing. length, dist = 1, 1 if max_search_len >= _LZ_MIN_LENGTH: - length, dist = _compress_search_matches(data, pos) + length, dist = _compress_search_matches_optimized(data_list, pos) # Treat as direct bytes if it's too short to copy. if length == 0 or length < _LZ_MIN_LENGTH: length = 1 # If the node goes to the end of the file, the weight is the cost of the node. - if (pos + length) == len(data): + if (pos + length) == len(data_list): node.len = length node.dist = dist node.weight = _compress_get_node_cost(length) @@ -173,6 +176,7 @@ def _compress_lz77_greedy(data: bytes) -> bytes: buffer.write(len(data).to_bytes(3, 'little')) src_pos = 0 + data_list = list(data) while src_pos < len(data): head = 0 head_pos = buffer.tell() @@ -180,7 +184,7 @@ def _compress_lz77_greedy(data: bytes) -> bytes: i = 0 while i < 8 and src_pos < len(data): - length, dist = _compress_search_matches_greedy(data, src_pos) + length, dist = _compress_search_matches_greedy(data_list, src_pos) # This is a reference node. if length >= _LZ_MIN_LENGTH: encoded = (((length - _LZ_MIN_LENGTH) & 0xF) << 12) | ((dist - _LZ_MIN_DISTANCE) & 0xFFF)