From 387861bb2f5b29ae1f09a4b8762432da1cae1f65 Mon Sep 17 00:00:00 2001 From: Rafael Moraes <50295204+glomatico@users.noreply.github.com> Date: Sun, 24 May 2026 14:21:21 -0300 Subject: [PATCH] Support file-backed samples and streaming decrypt --- gamdl/downloader/amdecrypt.py | 447 +++++++++++++++++++++++----------- 1 file changed, 309 insertions(+), 138 deletions(-) diff --git a/gamdl/downloader/amdecrypt.py b/gamdl/downloader/amdecrypt.py index ac67bb3..52f6620 100644 --- a/gamdl/downloader/amdecrypt.py +++ b/gamdl/downloader/amdecrypt.py @@ -128,6 +128,40 @@ def _sample_size(sample: SampleInfo) -> int: return sample.size or len(sample.data) +def _sample_data(sample: SampleInfo) -> bytes: + """Return sample payload bytes, loading file-backed samples on demand.""" + if sample.data: + return sample.data + if sample.data_path and sample.size: + with open(sample.data_path, "rb") as f: + f.seek(sample.data_offset) + data = f.read(sample.size) + if len(data) != sample.size: + raise IOError( + f"unexpected EOF while reading sample at {sample.data_offset} " + f"from {sample.data_path}" + ) + return data + return sample.data + + +def _with_sample_data(sample: SampleInfo, data: bytes) -> SampleInfo: + """Return a copy of a sample with materialized payload bytes.""" + return SampleInfo( + data=data, + duration=sample.duration, + desc_index=sample.desc_index, + iv=sample.iv, + subsamples=sample.subsamples, + composition_time_offset=sample.composition_time_offset, + sample_flags=sample.sample_flags, + is_sync=sample.is_sync, + size=sample.size or len(data), + data_path=sample.data_path, + data_offset=sample.data_offset, + ) + + def _decrypt_cbcs_sample_with_key( sample: SampleInfo, key: bytes, enc_info: EncryptionInfo ) -> bytes: @@ -262,6 +296,8 @@ class SampleInfo: sample_flags: int = 0 is_sync: bool = True size: int = 0 + data_path: Optional[str] = None + data_offset: int = 0 @dataclass @@ -361,7 +397,11 @@ def find_box(data: bytes, box_path: List[str]) -> Optional[bytes]: return f.read() -def extract_song(input_path: str, handler_type: bytes = b"soun") -> SongInfo: +def extract_song( + input_path: str, + handler_type: bytes = b"soun", + file_backed_samples: bool = False, +) -> SongInfo: """ Extract media samples and metadata from encrypted MP4 file. @@ -370,38 +410,78 @@ def extract_song(input_path: str, handler_type: bytes = b"soun") -> SongInfo: - Individual audio samples from mdat boxes - Sample durations and description indices from moof boxes """ - with open(input_path, "rb") as f: - raw_data = f.read() - song_info = SongInfo(handler_type=handler_type) # First pass: collect all top-level boxes boxes = [] - offset = 0 - while offset < len(raw_data) - 8: - size = struct.unpack(">I", raw_data[offset : offset + 4])[0] - box_type = raw_data[offset + 4 : offset + 8].decode("ascii", errors="replace") + if file_backed_samples: + file_size = os.path.getsize(input_path) + with open(input_path, "rb") as f: + offset = 0 + while offset + 8 <= file_size: + f.seek(offset) + header = f.read(8) + if len(header) < 8: + break + size = struct.unpack(">I", header[:4])[0] + box_type = header[4:8].decode("ascii", errors="replace") + header_size = 8 + if size == 0: + size = file_size - offset + elif size == 1: + ext_size = f.read(8) + if len(ext_size) < 8: + break + size = struct.unpack(">Q", ext_size)[0] + header_size = 16 + if size < header_size or offset + size > file_size: + break - header_size = 8 - if size == 0: - break - if size == 1: - # Extended size - if offset + 16 > len(raw_data): + data = b"" + if box_type in ("ftyp", "moov", "moof"): + f.seek(offset) + data = f.read(size) + boxes.append( + { + "offset": offset, + "size": size, + "type": box_type, + "header_size": header_size, + "data": data, + } + ) + offset += size + else: + with open(input_path, "rb") as f: + raw_data = f.read() + + offset = 0 + while offset < len(raw_data) - 8: + size = struct.unpack(">I", raw_data[offset : offset + 4])[0] + box_type = raw_data[offset + 4 : offset + 8].decode( + "ascii", errors="replace" + ) + + header_size = 8 + if size == 0: break - size = struct.unpack(">Q", raw_data[offset + 8 : offset + 16])[0] - header_size = 16 + if size == 1: + # Extended size + if offset + 16 > len(raw_data): + break + size = struct.unpack(">Q", raw_data[offset + 8 : offset + 16])[0] + header_size = 16 - boxes.append( - { - "offset": offset, - "size": size, - "type": box_type, - "header_size": header_size, - "data": raw_data[offset : offset + size], - } - ) - offset += size + boxes.append( + { + "offset": offset, + "size": size, + "type": box_type, + "header_size": header_size, + "data": raw_data[offset : offset + size], + } + ) + offset += size # Extract ftyp and moov for box in boxes: @@ -457,7 +537,12 @@ def extract_song(input_path: str, handler_type: bytes = b"soun") -> SongInfo: elif box["type"] == "mdat" and moof_box is not None: # Parse this moof/mdat pair moof_data = moof_box["data"] - mdat_data = box["data"][box["header_size"] :] # Skip mdat header + if file_backed_samples: + mdat_data = b"" + mdat_data_size = box["size"] - box["header_size"] + else: + mdat_data = box["data"][box["header_size"] :] # Skip mdat header + mdat_data_size = len(mdat_data) # Parse moof for tfhd (sample description index, defaults) and trun (entries) _iv_size = ( @@ -475,6 +560,8 @@ def extract_song(input_path: str, handler_type: bytes = b"soun") -> SongInfo: moof_offset=moof_box["offset"], mdat_data_offset=box["offset"] + box["header_size"], per_sample_iv_size=_iv_size, + mdat_data_size=mdat_data_size, + mdat_source_path=input_path if file_backed_samples else None, ) song_info.samples.extend(samples_from_pair) moof_box = None @@ -507,6 +594,8 @@ def _parse_moof_mdat( moof_offset: int = 0, mdat_data_offset: int = 0, per_sample_iv_size: int = 0, + mdat_data_size: Optional[int] = None, + mdat_source_path: Optional[str] = None, ) -> List[SampleInfo]: """Parse a moof box and extract samples from corresponding mdat. @@ -520,6 +609,7 @@ def _parse_moof_mdat( per_sample_iv_size: IV size per sample from tenc (0, 8, or 16). """ samples = [] + available_mdat_bytes = len(mdat_data) if mdat_data_size is None else mdat_data_size # Simple box parsing inside moof offset = 8 # Skip moof header @@ -624,8 +714,9 @@ def _parse_moof_mdat( "sample_flags", tfhd_info["default_sample_flags"] ) - if sample_size > 0 and mdat_read_offset + sample_size <= len( - mdat_data + if ( + sample_size > 0 + and mdat_read_offset + sample_size <= available_mdat_bytes ): sample_iv = b"" sample_subsamples: List[tuple] = [] @@ -634,11 +725,17 @@ def _parse_moof_mdat( sample_subsamples = senc_entries[sample_index_in_traf][ "subsamples" ] + if mdat_source_path: + sample_data = b"" + sample_data_offset = mdat_data_offset + mdat_read_offset + else: + sample_data = mdat_data[ + mdat_read_offset : mdat_read_offset + sample_size + ] + sample_data_offset = 0 sample = SampleInfo( - data=mdat_data[ - mdat_read_offset : mdat_read_offset + sample_size - ], + data=sample_data, duration=sample_duration, desc_index=desc_index, iv=sample_iv, @@ -649,6 +746,8 @@ def _parse_moof_mdat( sample_flags=sample_flags, is_sync=not bool(sample_flags & 0x10000), size=sample_size, + data_path=mdat_source_path, + data_offset=sample_data_offset, ) samples.append(sample) mdat_read_offset += sample_size @@ -893,6 +992,7 @@ async def decrypt_samples( *, use_single_content_key: bool = False, progress_callback=None, + decrypted_data_path: Optional[str] = None, ) -> bytes: """ Send track-key samples to wrapper-v2 (HTTP POST /decrypt) for CBCS @@ -916,6 +1016,8 @@ async def decrypt_samples( """ keys = [fairplay_key] if use_single_content_key else [PREFETCH_KEY, fairplay_key] decrypted_data = bytearray() + decrypted_output = open(decrypted_data_path, "wb") if decrypted_data_path else None + decrypted_bytes = 0 last_desc_index: int = 255 total_samples = len(samples) bytes_processed = 0 @@ -927,6 +1029,14 @@ async def decrypt_samples( # Pending (sample, aligned_cbc, tail) for one SKD segment, flushed in batches. crypto_batch: List[tuple] = [] + def emit(data: bytes) -> None: + nonlocal decrypted_bytes + if decrypted_output: + decrypted_output.write(data) + else: + decrypted_data.extend(data) + decrypted_bytes += len(data) + async def flush_crypto_batch() -> None: if not crypto_batch: return @@ -939,75 +1049,96 @@ async def decrypt_samples( if len(plains) != len(chunks): raise IOError("wrapper-v2: plaintext batch count mismatch") for s, plain, tail in zip(sources, plains, tails): - _append_reassembled_sample(decrypted_data, s, plain, tail) + emit(_reassemble_cbcs_sample(s, plain, tail)) crypto_batch.clear() - for i, sample in enumerate(samples): - if last_desc_index != sample.desc_index: - await flush_crypto_batch() - if use_single_content_key: - segment_adam = track_id - segment_uri = fairplay_key - else: - key_uri = keys[min(sample.desc_index, len(keys) - 1)] - segment_adam = "0" if key_uri == PREFETCH_KEY else track_id - segment_uri = key_uri - last_desc_index = sample.desc_index + try: + for i, original_sample in enumerate(samples): + sample = ( + _with_sample_data(original_sample, _sample_data(original_sample)) + if not original_sample.data and original_sample.data_path + else original_sample + ) + if last_desc_index != sample.desc_index: + await flush_crypto_batch() + if use_single_content_key: + segment_adam = track_id + segment_uri = fairplay_key + else: + key_uri = keys[min(sample.desc_index, len(keys) - 1)] + segment_adam = "0" if key_uri == PREFETCH_KEY else track_id + segment_uri = key_uri + last_desc_index = sample.desc_index + + if not use_single_content_key and segment_adam == "0": + await flush_crypto_batch() + enc_info = ( + encryption_info_per_desc.get(sample.desc_index) + if encryption_info_per_desc + and sample.desc_index in encryption_info_per_desc + else encryption_info + ) + emit( + _decrypt_cbcs_sample_with_key( + sample, DEFAULT_SONG_DECRYPTION_KEY, enc_info + ) + ) + bytes_processed += _sample_size(sample) + now = time.time() + if progress_callback and ( + i % 50 == 0 + or now - last_progress_time > 0.5 + or i == total_samples - 1 + ): + elapsed = now - start_time + speed = bytes_processed / elapsed if elapsed > 0 else 0 + progress_callback(i + 1, total_samples, bytes_processed, speed) + last_progress_time = now + continue - if not use_single_content_key and segment_adam == "0": - await flush_crypto_batch() enc_info = ( encryption_info_per_desc.get(sample.desc_index) if encryption_info_per_desc and sample.desc_index in encryption_info_per_desc else encryption_info ) - decrypted_data.extend( - _decrypt_cbcs_sample_with_key( - sample, DEFAULT_SONG_DECRYPTION_KEY, enc_info + if enc_info.crypt_byte_block and enc_info.skip_byte_block: + raise IOError( + "wrapper-v2 pattern CBCS decrypt is not supported by gamdl's " + "batch decrypt path; use hex-key decrypt for this track" ) - ) - bytes_processed += len(sample.data) + + parts = _cbcs_ciphertext_for_sample(sample) + if parts is None: + await flush_crypto_batch() + emit(sample.data) + else: + aligned, tail = parts + if len(aligned) == 0: + await flush_crypto_batch() + emit(_reassemble_cbcs_sample(sample, b"", tail)) + else: + crypto_batch.append((sample, aligned, tail)) + if len(crypto_batch) >= WRAPPER_DECRYPT_BATCH_SIZE: + await flush_crypto_batch() + + bytes_processed += _sample_size(sample) + now = time.time() if progress_callback and ( - i % 50 == 0 - or now - last_progress_time > 0.5 - or i == total_samples - 1 + i % 50 == 0 or now - last_progress_time > 0.5 or i == total_samples - 1 ): elapsed = now - start_time speed = bytes_processed / elapsed if elapsed > 0 else 0 progress_callback(i + 1, total_samples, bytes_processed, speed) last_progress_time = now - continue - parts = _cbcs_ciphertext_for_sample(sample) - if parts is None: - await flush_crypto_batch() - decrypted_data.extend(sample.data) - else: - aligned, tail = parts - if len(aligned) == 0: - await flush_crypto_batch() - _append_reassembled_sample(decrypted_data, sample, b"", tail) - else: - crypto_batch.append((sample, aligned, tail)) - if len(crypto_batch) >= WRAPPER_DECRYPT_BATCH_SIZE: - await flush_crypto_batch() + await flush_crypto_batch() + finally: + if decrypted_output: + decrypted_output.close() - bytes_processed += len(sample.data) - - now = time.time() - if progress_callback and ( - i % 50 == 0 or now - last_progress_time > 0.5 or i == total_samples - 1 - ): - elapsed = now - start_time - speed = bytes_processed / elapsed if elapsed > 0 else 0 - progress_callback(i + 1, total_samples, bytes_processed, speed) - last_progress_time = now - - await flush_crypto_batch() - - logger.debug(f"Decrypted {len(samples)} samples ({len(decrypted_data)} bytes)") + logger.debug(f"Decrypted {len(samples)} samples ({decrypted_bytes} bytes)") return bytes(decrypted_data) @@ -1043,11 +1174,11 @@ def write_decrypted_m4a( timescale = 44100 # Default fallback preferred_desc_index = _preferred_sample_description_index(song_info.samples) - if original_path: + if song_info.moov_data: + orig_data = song_info.ftyp_data + song_info.moov_data + elif original_path: with open(original_path, "rb") as f: orig_data = f.read() - elif song_info.moov_data: - orig_data = song_info.ftyp_data + song_info.moov_data else: orig_data = None @@ -1133,11 +1264,11 @@ def write_decrypted_mp4_track( timescale = 44100 if track_info.handler_type == b"soun" else 90000 preferred_desc_index = _preferred_sample_description_index(track_info.samples) - if original_path: + if track_info.moov_data: + orig_data = track_info.ftyp_data + track_info.moov_data + elif original_path: with open(original_path, "rb") as f: orig_data = f.read() - elif track_info.moov_data: - orig_data = track_info.ftyp_data + track_info.moov_data else: orig_data = None @@ -1228,11 +1359,11 @@ def _build_decrypted_track_moov( timescale = 44100 if track_info.handler_type == b"soun" else 90000 preferred_desc_index = _preferred_sample_description_index(track_info.samples) - if original_path: + if track_info.moov_data: + orig_data = track_info.ftyp_data + track_info.moov_data + elif original_path: with open(original_path, "rb") as f: orig_data = f.read() - elif track_info.moov_data: - orig_data = track_info.ftyp_data + track_info.moov_data else: orig_data = None @@ -1301,6 +1432,11 @@ def _decrypted_track_payload_source(track: DecryptedTrack): return (None, 0, len(track.data), track.data) +def _sample_payload_bytes(samples: List[SampleInfo]) -> bytes: + """Materialize only the payload bytes for the given samples.""" + return b"".join(_sample_data(sample) for sample in samples) + + def mux_decrypted_media_direct( decrypted_media: DecryptedMedia, output_path: str, @@ -1310,17 +1446,11 @@ def mux_decrypted_media_direct( if decrypted_media.video is None: raise ValueError("direct AV mux requires a video track") - video_moov = _build_decrypted_track_moov( - decrypted_media.video.track_info, - decrypted_media.video.input_path, - ) - audio_moov = _build_decrypted_track_moov( - decrypted_media.audio.track_info, - decrypted_media.audio.input_path, - ) + video_moov = _build_decrypted_track_moov(decrypted_media.video.track_info) + audio_moov = _build_decrypted_track_moov(decrypted_media.audio.track_info) extra_track_files = [ ( - _build_decrypted_track_moov(caption.track_info, caption.input_path), + _build_decrypted_track_moov(caption.track_info), _decrypted_track_payload_source(caption), ) for caption in decrypted_media.captions @@ -1497,7 +1627,9 @@ async def _decrypt_track_hex( ``True`` (web AAC, muxed MV audio): every sample description uses ``decryption_key``. """ - track_info = await asyncio.to_thread(extract_song, input_path, handler_type) + track_info = await asyncio.to_thread( + extract_song, input_path, handler_type, file_backed + ) track_key = bytes.fromhex(decryption_key) if use_single_content_key: @@ -1581,21 +1713,27 @@ async def decrypt_file_hex( video_key = decryption_key_video or decryption_key_audio video_task = asyncio.create_task( - _decrypt_track_hex(input_video_path, video_key, b"vide", file_backed=True) + _decrypt_track_hex( + input_video_path, + video_key, + b"vide", + use_cenc=use_cenc, + file_backed=True, + ) ) caption_tracks = [ track for track in await asyncio.gather( - asyncio.to_thread(extract_song, input_video_path, b"clcp"), - asyncio.to_thread(extract_song, input_video_path, b"text"), - asyncio.to_thread(extract_song, input_video_path, b"sbtl"), - asyncio.to_thread(extract_song, input_video_path, b"subt"), + asyncio.to_thread(extract_song, input_video_path, b"clcp", True), + asyncio.to_thread(extract_song, input_video_path, b"text", True), + asyncio.to_thread(extract_song, input_video_path, b"sbtl", True), + asyncio.to_thread(extract_song, input_video_path, b"subt", True), ) if track.samples ] captions = [] for caption_track in caption_tracks: - caption_data = b"".join(sample.data for sample in caption_track.samples) + caption_data = _sample_payload_bytes(caption_track.samples) if caption_track.encryption_info: caption_key = bytes.fromhex(video_key) caption_enc_info_per_desc = await asyncio.to_thread( @@ -3106,10 +3244,13 @@ async def _decrypt_track_wrapper( handler_type: bytes = b"soun", *, use_single_content_key: bool = False, + file_backed: bool = False, progress_callback=None, ) -> DecryptedTrack: """Decrypt one track through wrapper-v2 (CBCS via FairPlay SKD).""" - song_info = await asyncio.to_thread(extract_song, input_path, handler_type) + song_info = await asyncio.to_thread( + extract_song, input_path, handler_type, file_backed + ) enc_info = song_info.encryption_info or EncryptionInfo(scheme_type="cbcs") enc_info_per_desc = None if song_info.moov_data: @@ -3119,16 +3260,39 @@ async def _decrypt_track_wrapper( handler_type, ) - decrypted_data = await decrypt_samples( - wrapper_api, - track_id, - fairplay_key, - song_info.samples, - enc_info, - enc_info_per_desc, - use_single_content_key=use_single_content_key, - progress_callback=progress_callback, - ) + temp_path = None + if file_backed: + temp_file = tempfile.NamedTemporaryFile( + prefix="gamdl_decrypted_", suffix=".bin", delete=False + ) + temp_path = temp_file.name + temp_file.close() + try: + decrypted_data = await decrypt_samples( + wrapper_api, + track_id, + fairplay_key, + song_info.samples, + enc_info, + enc_info_per_desc, + use_single_content_key=use_single_content_key, + progress_callback=progress_callback, + decrypted_data_path=temp_path, + ) + except Exception: + if temp_path: + try: + os.remove(temp_path) + except FileNotFoundError: + pass + raise + if temp_path: + return DecryptedTrack( + input_path, + song_info, + data_path=temp_path, + data_size=os.path.getsize(temp_path), + ) return DecryptedTrack(input_path, song_info, decrypted_data) @@ -3173,16 +3337,17 @@ async def decrypt_wrapper( input_video_path, b"vide", use_single_content_key=use_single_content_key, + file_backed=True, progress_callback=progress_callback, ) ) caption_tracks = [ track for track in await asyncio.gather( - asyncio.to_thread(extract_song, input_video_path, b"clcp"), - asyncio.to_thread(extract_song, input_video_path, b"text"), - asyncio.to_thread(extract_song, input_video_path, b"sbtl"), - asyncio.to_thread(extract_song, input_video_path, b"subt"), + asyncio.to_thread(extract_song, input_video_path, b"clcp", True), + asyncio.to_thread(extract_song, input_video_path, b"text", True), + asyncio.to_thread(extract_song, input_video_path, b"sbtl", True), + asyncio.to_thread(extract_song, input_video_path, b"subt", True), ) if track.samples ] @@ -3190,7 +3355,7 @@ async def decrypt_wrapper( DecryptedTrack( input_video_path, caption_track, - b"".join(sample.data for sample in caption_track.samples), + _sample_payload_bytes(caption_track.samples), ) for caption_track in caption_tracks ] @@ -3225,14 +3390,13 @@ def decrypt_samples_hex( Returns: Concatenated decrypted sample data. """ - is_cenc = encryption_info.scheme_type == "cenc" decrypted = bytearray() for sample in samples: key = keys.get(sample.desc_index) if key is None: # No key for this desc_index — keep data as-is (shouldn't happen) - decrypted.extend(sample.data) + decrypted.extend(_sample_data(sample)) continue # Get encryption info for this sample's desc_index (if per-description info exists) @@ -3241,8 +3405,13 @@ def decrypt_samples_hex( else: enc_info = encryption_info + if not sample.data and sample.data_path: + sample = _with_sample_data(sample, _sample_data(sample)) + + is_cenc = enc_info.scheme_type == "cenc" if is_cenc: # AES-128-CTR: per-sample IV from senc, zero-padded to 16 bytes + data = sample.data iv = sample.iv if len(iv) < 16: iv = iv + b"\x00" * (16 - len(iv)) @@ -3252,16 +3421,16 @@ def decrypt_samples_hex( plaintext = bytearray() offset = 0 for clear_bytes, encrypted_bytes in sample.subsamples: - plaintext.extend(sample.data[offset : offset + clear_bytes]) + plaintext.extend(data[offset : offset + clear_bytes]) offset += clear_bytes plaintext.extend( - cipher.decrypt(sample.data[offset : offset + encrypted_bytes]) + cipher.decrypt(data[offset : offset + encrypted_bytes]) ) offset += encrypted_bytes - plaintext.extend(sample.data[offset:]) + plaintext.extend(data[offset:]) decrypted.extend(plaintext) else: - decrypted.extend(cipher.decrypt(sample.data)) + decrypted.extend(cipher.decrypt(data)) else: # CBCS (AES-128-CBC): constant IV or per-sample IV @@ -3294,12 +3463,16 @@ def _decrypt_sample_hex( sample: SampleInfo, key: Optional[bytes], encryption_info: EncryptionInfo, - is_cenc: bool, ) -> bytes: """Decrypt one sample with a raw AES key.""" - if key is None: - return sample.data + data = _sample_data(sample) + if data is not sample.data: + sample = _with_sample_data(sample, data) + if key is None: + return data + + is_cenc = encryption_info.scheme_type == "cenc" if is_cenc: iv = sample.iv if len(iv) < 16: @@ -3307,18 +3480,18 @@ def _decrypt_sample_hex( cipher = AES.new(key, AES.MODE_CTR, nonce=b"", initial_value=iv) if not sample.subsamples: - return cipher.decrypt(sample.data) + return cipher.decrypt(data) plaintext = bytearray() offset = 0 for clear_bytes, encrypted_bytes in sample.subsamples: - plaintext.extend(sample.data[offset : offset + clear_bytes]) + plaintext.extend(data[offset : offset + clear_bytes]) offset += clear_bytes plaintext.extend( - cipher.decrypt(sample.data[offset : offset + encrypted_bytes]) + cipher.decrypt(data[offset : offset + encrypted_bytes]) ) offset += encrypted_bytes - plaintext.extend(sample.data[offset:]) + plaintext.extend(data[offset:]) return bytes(plaintext) if encryption_info.crypt_byte_block and encryption_info.skip_byte_block: @@ -3349,7 +3522,6 @@ def decrypt_samples_hex_to_file( release_sample_data: bool = False, ) -> int: """Decrypt samples to a raw payload file without building one large bytes object.""" - is_cenc = encryption_info.scheme_type == "cenc" bytes_written = 0 with open(output_path, "wb") as f: for sample in samples: @@ -3363,7 +3535,6 @@ def decrypt_samples_hex_to_file( sample, keys.get(sample.desc_index), enc_info, - is_cenc, ) f.write(decrypted_sample) sample.size = len(decrypted_sample)