diff --git a/gamdl/downloader/amdecrypt.py b/gamdl/downloader/amdecrypt.py index 5ca032d..8d7b74c 100644 --- a/gamdl/downloader/amdecrypt.py +++ b/gamdl/downloader/amdecrypt.py @@ -140,6 +140,12 @@ def extract_song(input_path: str) -> SongInfo: default_sample_duration = 1024 default_sample_size = 0 + # Determine which track is the audio track + audio_track_id = ( + _extract_audio_track_id(song_info.moov_data) if song_info.moov_data else 1 + ) + logger.debug(f"Audio track ID: {audio_track_id}") + # Parse moof/mdat pairs moof_box = None for box in boxes: @@ -152,7 +158,13 @@ def extract_song(input_path: str) -> SongInfo: # Parse moof for tfhd (sample description index, defaults) and trun (entries) samples_from_pair = _parse_moof_mdat( - moof_data, mdat_data, default_sample_duration, default_sample_size + moof_data, + mdat_data, + default_sample_duration, + default_sample_size, + audio_track_id=audio_track_id, + moof_offset=moof_box["offset"], + mdat_data_offset=box["offset"] + box["header_size"], ) song_info.samples.extend(samples_from_pair) moof_box = None @@ -166,18 +178,21 @@ def _parse_moof_mdat( mdat_data: bytes, default_sample_duration: int, default_sample_size: int, + audio_track_id: int = 1, + moof_offset: int = 0, + mdat_data_offset: int = 0, ) -> List[SampleInfo]: - """Parse a moof box and extract samples from corresponding mdat.""" - samples = [] + """Parse a moof box and extract samples from corresponding mdat. - # Parse moof to find tfhd (track fragment header) and trun (track run) - tfhd_info = { - "desc_index": 0, - "default_duration": default_sample_duration, - "default_size": default_sample_size, - "flags": 0, - } - trun_entries = [] + Handles multi-track fragmented MP4s by only extracting samples from + the traf matching the audio track ID. + + Args: + audio_track_id: Track ID of the audio track to extract. + moof_offset: Absolute file offset of the moof box. + mdat_data_offset: Absolute file offset of the mdat content (after header). + """ + samples = [] # Simple box parsing inside moof offset = 8 # Skip moof header @@ -189,7 +204,18 @@ def _parse_moof_mdat( break if box_type == "traf": - # Parse inside traf + # Parse inside traf with per-traf state + tfhd_info = { + "track_id": 0, + "desc_index": 0, + "default_duration": default_sample_duration, + "default_size": default_sample_size, + "flags": 0, + "base_data_offset": None, + } + trun_entries = [] + first_trun_data_offset = None + traf_offset = offset + 8 traf_end = offset + size while traf_offset < traf_end - 8: @@ -208,33 +234,52 @@ def _parse_moof_mdat( moof_data[traf_offset + 8 : traf_offset + inner_size], tfhd_info ) elif inner_type == "trun": - trun_entries = _parse_trun( + entries, data_off = _parse_trun( moof_data[traf_offset + 8 : traf_offset + inner_size], tfhd_info ) + if first_trun_data_offset is None: + first_trun_data_offset = data_off + trun_entries.extend(entries) traf_offset += inner_size + # Only process this traf if it matches the audio track + if tfhd_info["track_id"] != audio_track_id: + offset += size + continue + + # Compute starting offset in mdat_data + base = tfhd_info.get("base_data_offset") + if base is None: + base = moof_offset # Default: first byte of containing moof + + if first_trun_data_offset is not None: + mdat_idx = (base + first_trun_data_offset) - mdat_data_offset + else: + mdat_idx = 0 + + mdat_read_offset = max(0, mdat_idx) + desc_index = tfhd_info["desc_index"] + if desc_index > 0: + desc_index -= 1 # Convert to 0-indexed + + for entry in trun_entries: + sample_size = entry.get("size", tfhd_info["default_size"]) + sample_duration = entry.get("duration", tfhd_info["default_duration"]) + + if sample_size > 0 and mdat_read_offset + sample_size <= len(mdat_data): + sample = SampleInfo( + data=mdat_data[ + mdat_read_offset : mdat_read_offset + sample_size + ], + duration=sample_duration, + desc_index=desc_index, + ) + samples.append(sample) + mdat_read_offset += sample_size + offset += size - # Extract samples from mdat using trun entries - mdat_offset = 0 - desc_index = tfhd_info["desc_index"] - if desc_index > 0: - desc_index -= 1 # Convert to 0-indexed - - for entry in trun_entries: - sample_size = entry.get("size", tfhd_info["default_size"]) - sample_duration = entry.get("duration", tfhd_info["default_duration"]) - - if sample_size > 0 and mdat_offset + sample_size <= len(mdat_data): - sample = SampleInfo( - data=mdat_data[mdat_offset : mdat_offset + sample_size], - duration=sample_duration, - desc_index=desc_index, - ) - samples.append(sample) - mdat_offset += sample_size - return samples @@ -249,9 +294,13 @@ def _parse_tfhd(data: bytes, tfhd_info: dict): tfhd_info["flags"] = flags # After version+flags is track_id(4) + tfhd_info["track_id"] = struct.unpack(">I", data[4:8])[0] offset = 4 + 4 # version+flags + track_id if flags & 0x01 and offset + 8 <= len(data): # base_data_offset + tfhd_info["base_data_offset"] = struct.unpack(">Q", data[offset : offset + 8])[ + 0 + ] offset += 8 if flags & 0x02 and offset + 4 <= len(data): # sample_description_index tfhd_info["desc_index"] = struct.unpack(">I", data[offset : offset + 4])[0] @@ -265,11 +314,17 @@ def _parse_tfhd(data: bytes, tfhd_info: dict): tfhd_info["default_size"] = struct.unpack(">I", data[offset : offset + 4])[0] -def _parse_trun(data: bytes, tfhd_info: dict) -> List[dict]: - """Parse track run box to get sample entries (FullBox: version + flags + content).""" +def _parse_trun(data: bytes, tfhd_info: dict) -> tuple[List[dict], Optional[int]]: + """Parse track run box to get sample entries and data_offset. + + Returns: + Tuple of (entries, data_offset). data_offset is the signed offset from + base_data_offset to the first sample's data, or None if not present. + """ entries = [] + data_offset_value = None if len(data) < 8: # version(1) + flags(3) + sample_count(4) - return entries + return entries, data_offset_value # FullBox: version(1) + flags(3) version = data[0] @@ -279,6 +334,7 @@ def _parse_trun(data: bytes, tfhd_info: dict) -> List[dict]: # Start reading entries after header fields offset = 8 # version+flags(4) + sample_count(4) if flags & 0x01: # data_offset present + data_offset_value = struct.unpack(">i", data[offset : offset + 4])[0] offset += 4 if flags & 0x04: # first_sample_flags present offset += 4 @@ -297,7 +353,7 @@ def _parse_trun(data: bytes, tfhd_info: dict) -> List[dict]: offset += 4 entries.append(entry) - return entries + return entries, data_offset_value async def decrypt_samples( @@ -1020,6 +1076,51 @@ def _extract_timestamps_from_box(data: bytes, box_type: bytes) -> tuple[int, int return 0, 0 +def _extract_audio_track_id(moov_data: bytes) -> int: + """Extract the track ID of the audio track from the moov box. + + Parses trak boxes in moov to find one with handler_type 'soun' (sound), + then returns its track_id from tkhd. Defaults to 1 if not found. + """ + offset = 8 # Skip moov box header + while offset < len(moov_data) - 8: + size = struct.unpack(">I", moov_data[offset : offset + 4])[0] + box_type = moov_data[offset + 4 : offset + 8] + + if size < 8 or offset + size > len(moov_data): + break + + if box_type == b"trak": + trak_data = moov_data[offset : offset + size] + + # Check handler type in hdlr box + hdlr_idx = trak_data.find(b"hdlr") + if hdlr_idx > 0: + # hdlr FullBox: after 'hdlr' type comes version+flags(4) + pre_defined(4) + handler_type(4) + handler_offset = hdlr_idx + 4 + 4 + 4 + if handler_offset + 4 <= len(trak_data): + handler_type = trak_data[handler_offset : handler_offset + 4] + if handler_type == b"soun": + # Found audio track, extract track_id from tkhd + tkhd_idx = trak_data.find(b"tkhd") + if tkhd_idx > 0: + version = trak_data[tkhd_idx + 4] + if version == 0: + # v0: ver+flags(4) + creation(4) + modification(4) + track_id(4) + tid_offset = tkhd_idx + 4 + 4 + 4 + 4 + else: + # v1: ver+flags(4) + creation(8) + modification(8) + track_id(4) + tid_offset = tkhd_idx + 4 + 4 + 8 + 8 + if tid_offset + 4 <= len(trak_data): + return struct.unpack( + ">I", trak_data[tid_offset : tid_offset + 4] + )[0] + + offset += size + + return 1 # Default to track 1 + + async def decrypt_file( wrapper_ip: str, track_id: str,