Fix all mypy errors

This commit is contained in:
Campbell 2025-05-27 13:46:21 -04:00
parent 79ab33c18a
commit 2ca2ff1f44
Signed by: NinjaCheetah
GPG Key ID: 39C2500E1778B156
15 changed files with 77 additions and 75 deletions

View File

@ -23,7 +23,8 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"pycryptodome", "pycryptodome",
"requests" "requests",
"types-requests"
] ]
keywords = ["Wii", "wii"] keywords = ["Wii", "wii"]

View File

@ -1,6 +1,7 @@
build build
pycryptodome pycryptodome
requests requests
types-requests
sphinx sphinx
sphinx-book-theme sphinx-book-theme
myst-parser myst-parser

View File

@ -8,10 +8,11 @@
# See <link pending> for details about the ASH compression format. # See <link pending> for details about the ASH compression format.
import io import io
from dataclasses import dataclass as _dataclass from dataclasses import dataclass
from typing import List
@_dataclass @dataclass
class _ASHBitReader: class _ASHBitReader:
""" """
An _ASHBitReader class used to parse individual words in an ASH file. Private class used by the ASH module. An _ASHBitReader class used to parse individual words in an ASH file. Private class used by the ASH module.
@ -93,7 +94,7 @@ def _ash_bit_reader_read_bits(bit_reader: _ASHBitReader, num_bits: int):
return bits return bits
def _ash_read_tree(bit_reader: _ASHBitReader, width: int, left_tree: [int], right_tree: [int]): def _ash_read_tree(bit_reader: _ASHBitReader, width: int, left_tree: List[int], right_tree: List[int]):
# Read either the symbol or distance tree from the ASH file, and return the root of that tree. # Read either the symbol or distance tree from the ASH file, and return the root of that tree.
work = [0] * (2 * (1 << width)) work = [0] * (2 * (1 << width))
work_pos = 0 work_pos = 0

View File

@ -5,7 +5,7 @@
import io import io
from dataclasses import dataclass as _dataclass from dataclasses import dataclass as _dataclass
from typing import List as _List from typing import List, Tuple
_LZ_MIN_DISTANCE = 0x01 # Minimum distance for each reference. _LZ_MIN_DISTANCE = 0x01 # Minimum distance for each reference.
@ -21,7 +21,7 @@ class _LZNode:
weight: int = 0 weight: int = 0
def _compress_compare_bytes(buffer: _List[int], offset1: int, offset2: int, abs_len_max: int) -> int: def _compress_compare_bytes(buffer: List[int], offset1: int, offset2: int, abs_len_max: int) -> int:
# Compare bytes up to the maximum length we can match. Start by comparing the first 3 bytes, since that's the # Compare bytes up to the maximum length we can match. Start by comparing the first 3 bytes, since that's the
# minimum match length and this allows for a more optimized early exit. # minimum match length and this allows for a more optimized early exit.
num_matched = 0 num_matched = 0
@ -32,7 +32,7 @@ def _compress_compare_bytes(buffer: _List[int], offset1: int, offset2: int, abs_
return num_matched return num_matched
def _compress_search_matches_optimized(buffer: _List[int], pos: int) -> (int, int): def _compress_search_matches_optimized(buffer: List[int], pos: int) -> Tuple[int, int]:
bytes_left = len(buffer) - pos bytes_left = len(buffer) - pos
global _LZ_MAX_DISTANCE, _LZ_MIN_LENGTH, _LZ_MAX_LENGTH, _LZ_MIN_DISTANCE global _LZ_MAX_DISTANCE, _LZ_MIN_LENGTH, _LZ_MAX_LENGTH, _LZ_MIN_DISTANCE
# Default to only looking back 4096 bytes, unless we've moved fewer than 4096 bytes, in which case we should # Default to only looking back 4096 bytes, unless we've moved fewer than 4096 bytes, in which case we should
@ -54,7 +54,7 @@ def _compress_search_matches_optimized(buffer: _List[int], pos: int) -> (int, in
return biggest_match, biggest_match_pos return biggest_match, biggest_match_pos
def _compress_search_matches_greedy(buffer: _List[int], pos: int) -> (int, int): def _compress_search_matches_greedy(buffer: List[int], pos: int) -> Tuple[int, int]:
# Finds and returns the first valid match, rather that finding the best one. # Finds and returns the first valid match, rather that finding the best one.
bytes_left = len(buffer) - pos bytes_left = len(buffer) - pos
global _LZ_MAX_DISTANCE, _LZ_MAX_LENGTH, _LZ_MIN_DISTANCE global _LZ_MAX_DISTANCE, _LZ_MAX_LENGTH, _LZ_MIN_DISTANCE

View File

@ -68,16 +68,16 @@ class U8Archive:
self.root_node: _U8Node = _U8Node(0, 0, 0, 0) self.root_node: _U8Node = _U8Node(0, 0, 0, 0)
self.imet_header: IMETHeader = IMETHeader() self.imet_header: IMETHeader = IMETHeader()
def load(self, u8_data: bytes) -> None: def load(self, u8: bytes) -> None:
""" """
Loads raw U8 data into a new U8 object. This allows for extracting the file and updating its contents. Loads raw U8 data into a new U8 object. This allows for extracting the file and updating its contents.
Parameters Parameters
---------- ----------
u8_data : bytes u8 : bytes
The data for the U8 file to load. The data for the U8 file to load.
""" """
with io.BytesIO(u8_data) as u8_data: with io.BytesIO(u8) as u8_data:
# Read the first 4 bytes of the file to ensure that it's a U8 archive. # Read the first 4 bytes of the file to ensure that it's a U8 archive.
u8_data.seek(0x0) u8_data.seek(0x0)
self.u8_magic = u8_data.read(4) self.u8_magic = u8_data.read(4)
@ -126,7 +126,7 @@ class U8Archive:
# Seek back before the root node so that it gets read with all the rest. # Seek back before the root node so that it gets read with all the rest.
u8_data.seek(u8_data.tell() - 12) u8_data.seek(u8_data.tell() - 12)
# Iterate over the number of nodes that the root node lists. # Iterate over the number of nodes that the root node lists.
for node in range(root_node_size): for _ in range(root_node_size):
node_type = int.from_bytes(u8_data.read(1)) node_type = int.from_bytes(u8_data.read(1))
node_name_offset = int.from_bytes(u8_data.read(3)) node_name_offset = int.from_bytes(u8_data.read(3))
node_data_offset = int.from_bytes(u8_data.read(4)) node_data_offset = int.from_bytes(u8_data.read(4))
@ -160,7 +160,7 @@ class U8Archive:
# This is 0 because the header size DOES NOT include the initial 32 bytes describing the file. # This is 0 because the header size DOES NOT include the initial 32 bytes describing the file.
header_size = 0 header_size = 0
# Add 12 bytes for each node, since that's how many bytes each one is made up of. # Add 12 bytes for each node, since that's how many bytes each one is made up of.
for node in range(len(self.u8_node_list)): for _ in range(len(self.u8_node_list)):
header_size += 12 header_size += 12
# Add the number of bytes used for each file/folder name in the string table. # Add the number of bytes used for each file/folder name in the string table.
for file_name in self.file_name_list: for file_name in self.file_name_list:
@ -170,13 +170,13 @@ class U8Archive:
# Adjust all nodes to place file data in the same order as the nodes. Why isn't it already like this? # Adjust all nodes to place file data in the same order as the nodes. Why isn't it already like this?
current_data_offset = data_offset current_data_offset = data_offset
current_name_offset = 0 current_name_offset = 0
for node in range(len(self.u8_node_list)): for idx in range(len(self.u8_node_list)):
if self.u8_node_list[node].type == 0: if self.u8_node_list[idx].type == 0:
self.u8_node_list[node].data_offset = _align_value(current_data_offset, 32) self.u8_node_list[idx].data_offset = _align_value(current_data_offset, 32)
current_data_offset += _align_value(self.u8_node_list[node].size, 32) current_data_offset += _align_value(self.u8_node_list[idx].size, 32)
# Calculate the name offsets, including the extra 1 for the NULL byte at the end of each name. # Calculate the name offsets, including the extra 1 for the NULL byte at the end of each name.
self.u8_node_list[node].name_offset = current_name_offset self.u8_node_list[idx].name_offset = current_name_offset
current_name_offset += len(self.file_name_list[node]) + 1 current_name_offset += len(self.file_name_list[idx]) + 1
# Begin joining all the U8 archive data into bytes. # Begin joining all the U8 archive data into bytes.
u8_data = b'' u8_data = b''
# Magic number. # Magic number.
@ -300,7 +300,7 @@ def _pack_u8_dir(u8_archive: U8Archive, current_path, node_count, parent_node):
return u8_archive, node_count return u8_archive, node_count
def pack_u8(input_path, generate_imet=False, imet_titles:List[str]=None) -> bytes: def pack_u8(input_path, generate_imet=False, imet_titles:List[str] | None = None) -> bytes:
""" """
Packs the provided file or folder into a new U8 archive, and returns the raw file data for it. Packs the provided file or folder into a new U8 archive, and returns the raw file data for it.
@ -513,13 +513,15 @@ class IMETHeader:
raise ValueError(f"The specified language is not valid!") raise ValueError(f"The specified language is not valid!")
return self.channel_names[target_languages] return self.channel_names[target_languages]
# If multiple channel names were requested. # If multiple channel names were requested.
else: elif type(target_languages) == List:
channel_names = [] channel_names = []
for lang in target_languages: for lang in target_languages:
if lang not in self.LocalizedTitles: if lang not in self.LocalizedTitles:
raise ValueError(f"The specified language at index {target_languages.index(lang)} is not valid!") raise ValueError(f"The specified language at index {target_languages.index(lang)} is not valid!")
channel_names.append(self.channel_names[lang]) channel_names.append(self.channel_names[lang])
return channel_names return channel_names
else:
raise TypeError("Target languages must be type int or List[int]!")
def set_channel_names(self, channel_names: Tuple[int, str] | List[Tuple[int, str]]) -> None: def set_channel_names(self, channel_names: Tuple[int, str] | List[Tuple[int, str]]) -> None:
""" """
@ -544,7 +546,7 @@ class IMETHeader:
f"42 characters!") f"42 characters!")
self.channel_names[channel_names[0]] = channel_names[1] self.channel_names[channel_names[0]] = channel_names[1]
# If a list of channel names was provided. # If a list of channel names was provided.
else: elif type(channel_names) == list:
for name in channel_names: for name in channel_names:
if name[0] not in self.LocalizedTitles: if name[0] not in self.LocalizedTitles:
raise ValueError(f"The target language \"{name[0]}\" for the name at index {channel_names.index(name)} " raise ValueError(f"The target language \"{name[0]}\" for the name at index {channel_names.index(name)} "
@ -553,3 +555,5 @@ class IMETHeader:
raise ValueError(f"The channel name \"{name[1]}\" at index {channel_names.index(name)} is too long! " raise ValueError(f"The channel name \"{name[1]}\" at index {channel_names.index(name)} is too long! "
f"Channel names cannot exceed 42 characters!") f"Channel names cannot exceed 42 characters!")
self.channel_names[name[0]] = name[1] self.channel_names[name[0]] = name[1]
else:
raise TypeError("Channel names must be type Tuple[int, str] or List[Tuple[int, str]]!")

View File

@ -7,7 +7,7 @@ import os
import pathlib import pathlib
import shutil import shutil
from dataclasses import dataclass as _dataclass from dataclasses import dataclass as _dataclass
from typing import List from typing import Callable, List
from ..title.ticket import Ticket from ..title.ticket import Ticket
from ..title.title import Title from ..title.title import Title
from ..title.tmd import TMD from ..title.tmd import TMD
@ -32,7 +32,7 @@ class EmuNAND:
emunand_root : pathlib.Path emunand_root : pathlib.Path
The path to the EmuNAND root directory. The path to the EmuNAND root directory.
""" """
def __init__(self, emunand_root: str | pathlib.Path, callback: callable = None): def __init__(self, emunand_root: str | pathlib.Path, callback: Callable | None = None):
self.emunand_root = pathlib.Path(emunand_root) self.emunand_root = pathlib.Path(emunand_root)
self.log = callback if callback is not None else None self.log = callback if callback is not None else None

View File

@ -4,6 +4,7 @@
# See https://wiibrew.org/wiki//title/00000001/00000002/data/setting.txt for information about setting.txt. # See https://wiibrew.org/wiki//title/00000001/00000002/data/setting.txt for information about setting.txt.
import io import io
from typing import List
from ..shared import _pad_bytes from ..shared import _pad_bytes
@ -53,16 +54,16 @@ class SettingTxt:
""" """
with io.BytesIO(setting_txt) as setting_data: with io.BytesIO(setting_txt) as setting_data:
global _key # I still don't actually know what *kind* of encryption this is. global _key # I still don't actually know what *kind* of encryption this is.
setting_txt_dec: [int] = [] setting_txt_dec: List[int] = []
for i in range(0, 256): for i in range(0, 256):
setting_txt_dec.append(int.from_bytes(setting_data.read(1)) ^ (_key & 0xff)) setting_txt_dec.append(int.from_bytes(setting_data.read(1)) ^ (_key & 0xff))
_key = (_key << 1) | (_key >> 31) _key = (_key << 1) | (_key >> 31)
setting_txt_dec = bytes(setting_txt_dec) setting_txt_bytes = bytes(setting_txt_dec)
try: try:
setting_str = setting_txt_dec.decode('utf-8') setting_str = setting_txt_bytes.decode('utf-8')
except UnicodeDecodeError: except UnicodeDecodeError:
last_newline_pos = setting_txt_dec.rfind(b'\n') # This makes sure we don't try to decode any garbage data. last_newline_pos = setting_txt_bytes.rfind(b'\n') # This makes sure we don't try to decode any garbage data.
setting_str = setting_txt_dec[:last_newline_pos + 1].decode('utf-8') setting_str = setting_txt_bytes[:last_newline_pos + 1].decode('utf-8')
self.load_decrypted(setting_str) self.load_decrypted(setting_str)
def load_decrypted(self, setting_txt: str) -> None: def load_decrypted(self, setting_txt: str) -> None:
@ -104,13 +105,13 @@ class SettingTxt:
setting_txt_dec = setting_str.encode() setting_txt_dec = setting_str.encode()
global _key global _key
# This could probably be made more efficient somehow. # This could probably be made more efficient somehow.
setting_txt_enc: [int] = [] setting_txt_enc: List[int] = []
with io.BytesIO(setting_txt_dec) as setting_data: with io.BytesIO(setting_txt_dec) as setting_data:
for i in range(0, len(setting_txt_dec)): for i in range(0, len(setting_txt_dec)):
setting_txt_enc.append(int.from_bytes(setting_data.read(1)) ^ (_key & 0xff)) setting_txt_enc.append(int.from_bytes(setting_data.read(1)) ^ (_key & 0xff))
_key = (_key << 1) | (_key >> 31) _key = (_key << 1) | (_key >> 31)
setting_txt_enc = _pad_bytes(bytes(setting_txt_enc), 256) setting_txt_bytes = _pad_bytes(bytes(setting_txt_enc), 256)
return setting_txt_enc return setting_txt_bytes
def dump_decrypted(self) -> str: def dump_decrypted(self) -> str:
""" """

View File

@ -91,11 +91,8 @@ class UidSys:
The UID assigned to the new Title ID. The UID assigned to the new Title ID.
""" """
if type(title_id) is bytes: if type(title_id) is bytes:
# This catches the format b'0000000100000002'
if len(title_id) == 16:
title_id_converted = title_id.encode()
# This catches the format b'\x00\x00\x00\x01\x00\x00\x00\x02' # This catches the format b'\x00\x00\x00\x01\x00\x00\x00\x02'
elif len(title_id) == 8: if len(title_id) == 8:
title_id_converted = binascii.hexlify(title_id).decode() title_id_converted = binascii.hexlify(title_id).decode()
# If it isn't one of those lengths, it cannot possibly be valid, so reject it. # If it isn't one of those lengths, it cannot possibly be valid, so reject it.
else: else:

0
src/libWiiPy/py.typed Normal file
View File

View File

@ -61,10 +61,10 @@ class Certificate:
The exponent of this certificate's public key. Combined with the modulus to get the full key. The exponent of this certificate's public key. Combined with the modulus to get the full key.
""" """
def __init__(self): def __init__(self):
self.type: CertificateType | None = None self.type: CertificateType = CertificateType.RSA_4096
self.signature: bytes = b'' self.signature: bytes = b''
self.issuer: str = "" self.issuer: str = ""
self.pub_key_type: CertificateKeyType | None = None self.pub_key_type: CertificateKeyType = CertificateKeyType.RSA_4096
self.child_name: str = "" self.child_name: str = ""
self.pub_key_id: int = 0 self.pub_key_id: int = 0
self.pub_key_modulus: int = 0 self.pub_key_modulus: int = 0

View File

@ -66,16 +66,16 @@ class ContentRegion:
start_offset += 64 - (content.content_size % 64) start_offset += 64 - (content.content_size % 64)
self.content_start_offsets.append(start_offset) self.content_start_offsets.append(start_offset)
# Build a list of all the encrypted content data. # Build a list of all the encrypted content data.
for content in range(self.num_contents): for idx in range(self.num_contents):
# Seek to the start of the content based on the list of offsets. # Seek to the start of the content based on the list of offsets.
content_region_data.seek(self.content_start_offsets[content]) content_region_data.seek(self.content_start_offsets[idx])
# Calculate the number of bytes we need to read by adding bytes up the nearest multiple of 16 if needed. # Calculate the number of bytes we need to read by adding bytes up the nearest multiple of 16 if needed.
bytes_to_read = self.content_records[content].content_size content_size = self.content_records[idx].content_size
if (bytes_to_read % 16) != 0: if (content_size % 16) != 0:
bytes_to_read += 16 - (bytes_to_read % 16) content_size += 16 - (content_size % 16)
# Read the file based on the size of the content in the associated record, then append that data to # Read the file based on the size of the content in the associated record, then append that data to
# the list of content. # the list of content.
content_enc = content_region_data.read(bytes_to_read) content_enc = content_region_data.read(content_size)
self.content_list.append(content_enc) self.content_list.append(content_enc)
def dump(self) -> tuple[bytes, int]: def dump(self) -> tuple[bytes, int]:
@ -336,8 +336,8 @@ class ContentRegion:
enc_content = encrypt_content(dec_content, title_key, index) enc_content = encrypt_content(dec_content, title_key, index)
self.add_enc_content(enc_content, cid, index, content_type, content_size, content_hash) self.add_enc_content(enc_content, cid, index, content_type, content_size, content_hash)
def set_enc_content(self, enc_content: bytes, index: int, content_size: int, content_hash: bytes, cid: int = None, def set_enc_content(self, enc_content: bytes, index: int, content_size: int, content_hash: bytes,
content_type: int = None) -> None: cid: int | None = None, content_type: int | None = None) -> None:
""" """
Sets the content at the provided content index to the provided new encrypted content. The provided hash and Sets the content at the provided content index to the provided new encrypted content. The provided hash and
content size are set in the corresponding content record. A new Content ID or content type can also be content size are set in the corresponding content record. A new Content ID or content type can also be
@ -373,8 +373,8 @@ class ContentRegion:
self.content_list.append(b'') self.content_list.append(b'')
self.content_list[index] = enc_content self.content_list[index] = enc_content
def set_content(self, dec_content: bytes, index: int, title_key: bytes, cid: int = None, def set_content(self, dec_content: bytes, index: int, title_key: bytes, cid: int | None = None,
content_type: int = None) -> None: content_type: int | None = None) -> None:
""" """
Sets the content at the provided content index to the provided new decrypted content. The hash and content size Sets the content at the provided content index to the provided new decrypted content. The hash and content size
of this content will be generated and then set in the corresponding content record. A new Content ID or content of this content will be generated and then set in the corresponding content record. A new Content ID or content

View File

@ -33,8 +33,8 @@ class DownloadCallback(Protocol):
... ...
def download_title(title_id: str, title_version: int = None, wiiu_endpoint: bool = False, def download_title(title_id: str, title_version: int | None = None, wiiu_endpoint: bool = False,
endpoint_override: str = None, progress: DownloadCallback = lambda done, total: None) -> Title: endpoint_override: str | None = None, progress: DownloadCallback = lambda done, total: None) -> Title:
""" """
Download an entire title and all of its contents, then load the downloaded components into a Title object for Download an entire title and all of its contents, then load the downloaded components into a Title object for
further use. This method is NOT recommended for general use, as it has extremely limited verbosity. It is instead further use. This method is NOT recommended for general use, as it has extremely limited verbosity. It is instead
@ -81,8 +81,8 @@ def download_title(title_id: str, title_version: int = None, wiiu_endpoint: bool
return title return title
def download_tmd(title_id: str, title_version: int = None, wiiu_endpoint: bool = False, def download_tmd(title_id: str, title_version: int | None = None, wiiu_endpoint: bool = False,
endpoint_override: str = None, progress: DownloadCallback = lambda done, total: None) -> bytes: endpoint_override: str | None = None, progress: DownloadCallback = lambda done, total: None) -> bytes:
""" """
Downloads the TMD of the Title specified in the object. Will download the latest version by default, or another Downloads the TMD of the Title specified in the object. Will download the latest version by default, or another
version if it was manually specified in the object. version if it was manually specified in the object.
@ -151,7 +151,7 @@ def download_tmd(title_id: str, title_version: int = None, wiiu_endpoint: bool =
return tmd return tmd
def download_ticket(title_id: str, wiiu_endpoint: bool = False, endpoint_override: str = None, def download_ticket(title_id: str, wiiu_endpoint: bool = False, endpoint_override: str | None = None,
progress: DownloadCallback = lambda done, total: None) -> bytes: progress: DownloadCallback = lambda done, total: None) -> bytes:
""" """
Downloads the Ticket of the Title specified in the object. This will only work if the Title ID specified is for Downloads the Ticket of the Title specified in the object. This will only work if the Title ID specified is for
@ -215,7 +215,7 @@ def download_ticket(title_id: str, wiiu_endpoint: bool = False, endpoint_overrid
return ticket return ticket
def download_cert_chain(wiiu_endpoint: bool = False, endpoint_override: str = None) -> bytes: def download_cert_chain(wiiu_endpoint: bool = False, endpoint_override: str | None = None) -> bytes:
""" """
Downloads the signing certificate chain used by all WADs. This uses System Menu 4.3U as the source. Downloads the signing certificate chain used by all WADs. This uses System Menu 4.3U as the source.
@ -266,8 +266,8 @@ def download_cert_chain(wiiu_endpoint: bool = False, endpoint_override: str = No
return cert_chain return cert_chain
def download_content(title_id: str, content_id: int, wiiu_endpoint: bool = False, def download_content(title_id: str, content_id: int, wiiu_endpoint: bool = False, endpoint_override: str | None = None,
endpoint_override: str = None, progress: DownloadCallback = lambda done, total: None) -> bytes: progress: DownloadCallback = lambda done, total: None) -> bytes:
""" """
Downloads a specified content for the title specified in the object. Downloads a specified content for the title specified in the object.
@ -330,7 +330,7 @@ def download_content(title_id: str, content_id: int, wiiu_endpoint: bool = False
return content return content
def download_contents(title_id: str, tmd: TMD, wiiu_endpoint: bool = False, endpoint_override: str = None, def download_contents(title_id: str, tmd: TMD, wiiu_endpoint: bool = False, endpoint_override: str | None = None,
progress: DownloadCallback = lambda done, total: None) -> List[bytes]: progress: DownloadCallback = lambda done, total: None) -> List[bytes]:
""" """
Downloads all the contents for the title specified in the object. This requires a TMD to already be available Downloads all the contents for the title specified in the object. This requires a TMD to already be available

View File

@ -178,7 +178,7 @@ class Title:
self.tmd.set_title_version(title_version) self.tmd.set_title_version(title_version)
self.ticket.set_title_version(title_version) self.ticket.set_title_version(title_version)
def get_content_by_index(self, index: id, skip_hash=False) -> bytes: def get_content_by_index(self, index: int, skip_hash=False) -> bytes:
""" """
Gets an individual content from the content region based on the provided index, in decrypted form. Gets an individual content from the content region based on the provided index, in decrypted form.
@ -321,8 +321,8 @@ class Title:
# Update the TMD to match. # Update the TMD to match.
self.tmd.content_records = self.content.content_records self.tmd.content_records = self.content.content_records
def set_enc_content(self, enc_content: bytes, index: int, content_size: int, content_hash: bytes, cid: int = None, def set_enc_content(self, enc_content: bytes, index: int, content_size: int, content_hash: bytes,
content_type: int = None) -> None: cid: int | None = None, content_type: int | None = None) -> None:
""" """
Sets the content at the provided index to the provided new encrypted content. The provided hash and content size Sets the content at the provided index to the provided new encrypted content. The provided hash and content size
are set in the corresponding content record. A new Content ID or content type can also be specified, but if it are set in the corresponding content record. A new Content ID or content type can also be specified, but if it
@ -350,7 +350,8 @@ class Title:
# Update the TMD to match. # Update the TMD to match.
self.tmd.content_records = self.content.content_records self.tmd.content_records = self.content.content_records
def set_content(self, dec_content: bytes, index: int, cid: int = None, content_type: int = None) -> None: def set_content(self, dec_content: bytes, index: int, cid: int | None = None,
content_type: int | None = None) -> None:
""" """
Sets the content at the provided index to the provided new decrypted content. The hash and content size of this Sets the content at the provided index to the provided new decrypted content. The hash and content size of this
content will be generated and then set in the corresponding content record. A new Content ID or content type can content will be generated and then set in the corresponding content record. A new Content ID or content type can

View File

@ -12,7 +12,7 @@ from typing import List
from enum import IntEnum as _IntEnum from enum import IntEnum as _IntEnum
from ..types import _ContentRecord from ..types import _ContentRecord
from ..shared import _bitmask from ..shared import _bitmask
from .util import title_ver_dec_to_standard, title_ver_standard_to_dec from .util import title_ver_standard_to_dec
class TMD: class TMD:
@ -36,7 +36,7 @@ class TMD:
""" """
def __init__(self): def __init__(self):
self.blob_header: bytes = b'' self.blob_header: bytes = b''
self.signature_type: int = 0 self.signature_type: bytes = b''
self.signature: bytes = b'' self.signature: bytes = b''
self.signature_issuer: str = "" # Follows the format "Root-CA%08x-CP%08x" self.signature_issuer: str = "" # Follows the format "Root-CA%08x-CP%08x"
self.tmd_version: int = 0 # This seems to always be 0 no matter what? self.tmd_version: int = 0 # This seems to always be 0 no matter what?
@ -55,7 +55,6 @@ class TMD:
self.reserved2: bytes = b'' # Other "Reserved" data from WiiBrew. self.reserved2: bytes = b'' # Other "Reserved" data from WiiBrew.
self.access_rights: int = 0 self.access_rights: int = 0
self.title_version: int = 0 # The version of the associated title. self.title_version: int = 0 # The version of the associated title.
self.title_version_converted: int = 0 # The title version in vX.X format.
self.num_contents: int = 0 # The number of contents contained in the associated title. self.num_contents: int = 0 # The number of contents contained in the associated title.
self.boot_index: int = 0 # The content index that contains the bootable executable. self.boot_index: int = 0 # The content index that contains the bootable executable.
self.minor_version: int = 0 # Minor version (unused typically). self.minor_version: int = 0 # Minor version (unused typically).
@ -137,8 +136,6 @@ class TMD:
# Version number straight from the TMD. # Version number straight from the TMD.
tmd_data.seek(0x1DC) tmd_data.seek(0x1DC)
self.title_version = int.from_bytes(tmd_data.read(2)) self.title_version = int.from_bytes(tmd_data.read(2))
# Calculate the converted version number via util module.
self.title_version_converted = title_ver_dec_to_standard(self.title_version, self.title_id, bool(self.vwii))
# The number of contents listed in the TMD. # The number of contents listed in the TMD.
tmd_data.seek(0x1DE) tmd_data.seek(0x1DE)
self.num_contents = int.from_bytes(tmd_data.read(2)) self.num_contents = int.from_bytes(tmd_data.read(2))
@ -305,6 +302,8 @@ class TMD:
return "None" return "None"
case 4: case 4:
return "KOR" return "KOR"
case _:
raise ValueError(f"Title contains unknown region \"{self.region}\".")
def get_title_type(self) -> str: def get_title_type(self) -> str:
""" """
@ -500,7 +499,7 @@ class TMD:
Parameters Parameters
---------- ----------
new_version : str, int new_version : str, int
The new version of the title. See description for valid formats. The new version of the title.
""" """
if type(new_version) is str: if type(new_version) is str:
# Validate string input is in the correct format, then validate that the version isn't higher than v255.0. # Validate string input is in the correct format, then validate that the version isn't higher than v255.0.
@ -510,8 +509,7 @@ class TMD:
raise ValueError("Title version is not valid! String version must be entered in format \"X.X\".") raise ValueError("Title version is not valid! String version must be entered in format \"X.X\".")
if int(version_str_split[0]) > 255 or int(version_str_split[1]) > 255: if int(version_str_split[0]) > 255 or int(version_str_split[1]) > 255:
raise ValueError("Title version is not valid! String version number cannot exceed v255.255.") raise ValueError("Title version is not valid! String version number cannot exceed v255.255.")
self.title_version_converted = new_version version_converted: int = title_ver_standard_to_dec(new_version, self.title_id)
version_converted = title_ver_standard_to_dec(new_version, self.title_id)
self.title_version = version_converted self.title_version = version_converted
elif type(new_version) is int: elif type(new_version) is int:
# Validate that the version isn't higher than v65280. If the check passes, set that as the title version, # Validate that the version isn't higher than v65280. If the check passes, set that as the title version,
@ -519,7 +517,5 @@ class TMD:
if new_version > 65535: if new_version > 65535:
raise ValueError("Title version is not valid! Integer version number cannot exceed v65535.") raise ValueError("Title version is not valid! Integer version number cannot exceed v65535.")
self.title_version = new_version self.title_version = new_version
version_converted = title_ver_dec_to_standard(new_version, self.title_id, bool(self.vwii))
self.title_version_converted = version_converted
else: else:
raise TypeError("Title version type is not valid! Type must be either integer or string.") raise TypeError("Title version type is not valid! Type must be either integer or string.")

View File

@ -49,17 +49,17 @@ class WAD:
self.wad_content_data: bytes = b'' self.wad_content_data: bytes = b''
self.wad_meta_data: bytes = b'' self.wad_meta_data: bytes = b''
def load(self, wad_data: bytes) -> None: def load(self, wad: bytes) -> None:
""" """
Loads raw WAD data and sets all attributes of the WAD object. This allows for manipulating an already Loads raw WAD data and sets all attributes of the WAD object. This allows for manipulating an already
existing WAD file. existing WAD file.
Parameters Parameters
---------- ----------
wad_data : bytes wad : bytes
The data for the WAD file to load. The data for the WAD file to load.
""" """
with io.BytesIO(wad_data) as wad_data: with io.BytesIO(wad) as wad_data:
# Read the first 8 bytes of the file to ensure that it's a WAD. Has two possible valid values for the two # Read the first 8 bytes of the file to ensure that it's a WAD. Has two possible valid values for the two
# different types of WADs that might be encountered. # different types of WADs that might be encountered.
wad_data.seek(0x0) wad_data.seek(0x0)
@ -311,7 +311,7 @@ class WAD:
# Calculate the size of the new Ticket data. # Calculate the size of the new Ticket data.
self.wad_tik_size = len(tik_data) self.wad_tik_size = len(tik_data)
def set_content_data(self, content_data, size: int = None) -> None: def set_content_data(self, content_data, size: int | None = None) -> None:
""" """
Sets the content data of the WAD. Also calculates the new size. Sets the content data of the WAD. Also calculates the new size.