Add sample encryption parsing and hex-key decryption

This commit is contained in:
Rafael Moraes
2026-02-25 14:08:09 -03:00
parent 9f86c7436d
commit b3b5e6d1b2
+488 -3
View File
@@ -7,11 +7,17 @@ import asyncio
import io
import logging
import struct
import time
from dataclasses import dataclass, field
from typing import BinaryIO, List, Optional
from Crypto.Cipher import AES
logger = logging.getLogger(__name__)
# Default decryption key for songs without per-sample keys (legacy AAC)
DEFAULT_SONG_DECRYPTION_KEY = b"2\xb8\xad\xe1v\x9e&\xb1\xff\xb8\x98cRy?\xc6"
# Pre-fetch key used for first sample description
PREFETCH_KEY = "skd://itunes.apple.com/P000000000/s1/e1"
@@ -26,6 +32,20 @@ class SampleInfo:
data: bytes
duration: int
desc_index: int
iv: bytes = b"" # Per-sample IV from senc (empty if constant IV)
subsamples: List[tuple] = field(
default_factory=list
) # [(clear_bytes, encrypted_bytes), ...]
@dataclass
class EncryptionInfo:
"""Encryption scheme info extracted from sinf/schm + sinf/schi/tenc."""
scheme_type: str = "cbcs" # 'cenc' or 'cbcs'
per_sample_iv_size: int = 0 # 0, 8, or 16
constant_iv: bytes = b"" # Constant IV (when per_sample_iv_size == 0)
kid: bytes = b"" # Default Key ID (16 bytes)
@dataclass
@@ -35,6 +55,7 @@ class SongInfo:
samples: List[SampleInfo] = field(default_factory=list)
moov_data: bytes = b""
ftyp_data: bytes = b""
encryption_info: Optional[EncryptionInfo] = None
def read_box_header(f: BinaryIO) -> tuple[int, str, int]:
@@ -147,6 +168,10 @@ def extract_song(input_path: str) -> SongInfo:
)
logger.debug(f"Audio track ID: {audio_track_id}")
# Extract encryption scheme info from moov (sinf/schm + sinf/schi/tenc)
if song_info.moov_data:
song_info.encryption_info = _extract_encryption_info(song_info.moov_data)
# Parse moof/mdat pairs
moof_box = None
for box in boxes:
@@ -158,6 +183,11 @@ def extract_song(input_path: str) -> SongInfo:
mdat_data = box["data"][box["header_size"] :] # Skip mdat header
# Parse moof for tfhd (sample description index, defaults) and trun (entries)
_iv_size = (
song_info.encryption_info.per_sample_iv_size
if song_info.encryption_info
else 0
)
samples_from_pair = _parse_moof_mdat(
moof_data,
mdat_data,
@@ -166,6 +196,7 @@ def extract_song(input_path: str) -> SongInfo:
audio_track_id=audio_track_id,
moof_offset=moof_box["offset"],
mdat_data_offset=box["offset"] + box["header_size"],
per_sample_iv_size=_iv_size,
)
song_info.samples.extend(samples_from_pair)
moof_box = None
@@ -182,6 +213,7 @@ def _parse_moof_mdat(
audio_track_id: int = 1,
moof_offset: int = 0,
mdat_data_offset: int = 0,
per_sample_iv_size: int = 0,
) -> List[SampleInfo]:
"""Parse a moof box and extract samples from corresponding mdat.
@@ -192,6 +224,7 @@ def _parse_moof_mdat(
audio_track_id: Track ID of the audio track to extract.
moof_offset: Absolute file offset of the moof box.
mdat_data_offset: Absolute file offset of the mdat content (after header).
per_sample_iv_size: IV size per sample from tenc (0, 8, or 16).
"""
samples = []
@@ -216,6 +249,7 @@ def _parse_moof_mdat(
}
trun_entries = []
first_trun_data_offset = None
senc_entries = [] # Per-sample encryption info from senc box
traf_offset = offset + 8
traf_end = offset + size
@@ -241,6 +275,11 @@ def _parse_moof_mdat(
if first_trun_data_offset is None:
first_trun_data_offset = data_off
trun_entries.extend(entries)
elif inner_type == "senc":
senc_entries = _parse_senc(
moof_data[traf_offset + 8 : traf_offset + inner_size],
per_sample_iv_size,
)
traf_offset += inner_size
@@ -264,17 +303,26 @@ def _parse_moof_mdat(
if desc_index > 0:
desc_index -= 1 # Convert to 0-indexed
for entry in trun_entries:
for i, entry in enumerate(trun_entries):
sample_size = entry.get("size", tfhd_info["default_size"])
sample_duration = entry.get("duration", tfhd_info["default_duration"])
if sample_size > 0 and mdat_read_offset + sample_size <= len(mdat_data):
# Attach per-sample encryption info from senc if available
sample_iv = b""
sample_subsamples = []
if i < len(senc_entries):
sample_iv = senc_entries[i]["iv"]
sample_subsamples = senc_entries[i]["subsamples"]
sample = SampleInfo(
data=mdat_data[
mdat_read_offset : mdat_read_offset + sample_size
],
duration=sample_duration,
desc_index=desc_index,
iv=sample_iv,
subsamples=sample_subsamples,
)
samples.append(sample)
mdat_read_offset += sample_size
@@ -357,6 +405,56 @@ def _parse_trun(data: bytes, tfhd_info: dict) -> tuple[List[dict], Optional[int]
return entries, data_offset_value
def _parse_senc(data: bytes, per_sample_iv_size: int) -> List[dict]:
"""Parse Sample Encryption Box (senc) content (after box header).
Returns a list of dicts, one per sample:
{"iv": bytes, "subsamples": [(clear_bytes, encrypted_bytes), ...]}
The data starts after the 8-byte box header (size+type) but includes
the FullBox header (version 1 byte + flags 3 bytes).
per_sample_iv_size can be 0 (cbcs with constant IV from tenc) — in that case
there are 0 IV bytes per sample but subsample info may still be present.
"""
if len(data) < 8:
return []
version = data[0]
flags = struct.unpack(">I", b"\x00" + data[1:4])[0]
sample_count = struct.unpack(">I", data[4:8])[0]
entries: List[dict] = []
offset = 8
for _ in range(sample_count):
# Read per-sample IV (0 bytes when per_sample_iv_size == 0)
iv = b""
if per_sample_iv_size > 0:
if offset + per_sample_iv_size > len(data):
break
iv = data[offset : offset + per_sample_iv_size]
offset += per_sample_iv_size
subsamples = []
if flags & 0x02:
# Subsample encryption info present
if offset + 2 > len(data):
break
subsample_count = struct.unpack(">H", data[offset : offset + 2])[0]
offset += 2
for _ in range(subsample_count):
if offset + 6 > len(data):
break
clear_bytes = struct.unpack(">H", data[offset : offset + 2])[0]
encrypted_bytes = struct.unpack(">I", data[offset + 2 : offset + 6])[0]
subsamples.append((clear_bytes, encrypted_bytes))
offset += 6
entries.append({"iv": iv, "subsamples": subsamples})
return entries
async def decrypt_samples(
wrapper_ip: str,
track_id: str,
@@ -380,8 +478,6 @@ async def decrypt_samples(
Args:
progress_callback: Optional callback(current_sample, total_samples, bytes_processed) for progress tracking
"""
import time
host, port = wrapper_ip.split(":")
port = int(port)
@@ -1210,6 +1306,188 @@ def _write_udta(f):
_fixup_box_size(f, udta_start, b"udta")
def _extract_encryption_info(moov_data: bytes) -> Optional[EncryptionInfo]:
"""Extract encryption scheme info from the audio track's sinf box.
Walks moov → trak (audio) → mdia → minf → stbl → stsd → sample_entry → sinf,
then reads sinf/schm for scheme_type and sinf/schi/tenc for IV size, constant IV,
and default KID.
Returns EncryptionInfo or None if no sinf is found.
"""
trak_data = _find_audio_trak(moov_data)
if trak_data is None:
return None
# Navigate trak → mdia → minf → stbl → stsd
mdia = _find_child_box(trak_data, b"mdia")
if mdia is None:
return None
minf = _find_child_box(mdia, b"minf")
if minf is None:
return None
stbl = _find_child_box(minf, b"stbl")
if stbl is None:
return None
stsd = _find_child_box(stbl, b"stsd")
if stsd is None:
return None
# stsd is a FullBox: 4 (size) + 4 (type) + 4 (version+flags) + 4 (entry_count)
# Then the first sample entry immediately follows
if len(stsd) < 16:
return None
entry_offset = 16 # past header+version+flags+entry_count
if entry_offset + 8 > len(stsd):
return None
entry_size = struct.unpack(">I", stsd[entry_offset : entry_offset + 4])[0]
entry_data = stsd[entry_offset : entry_offset + entry_size]
# Find sinf inside this sample entry
# Audio sample entries have a 36-byte fixed header:
# size(4) + type(4) + reserved(6) + data_ref_index(2) + audio_data(20)
# Child boxes (including sinf) start at offset 36
sinf = _find_child_box(entry_data, b"sinf", skip_header=36)
if sinf is None:
return None
info = EncryptionInfo()
# Parse schm (Scheme Type Box) inside sinf
schm = _find_child_box(sinf, b"schm")
if schm and len(schm) >= 20:
# schm: 4(size) + 4(type) + 4(ver+flags) + 4(scheme_type) + 4(scheme_version)
info.scheme_type = schm[12:16].decode("ascii", errors="replace")
logger.debug(f"Encryption scheme: {info.scheme_type}")
# Parse tenc (Track Encryption Box) inside sinf/schi
schi = _find_child_box(sinf, b"schi")
if schi:
tenc = _find_child_box(schi, b"tenc")
if tenc and len(tenc) >= 32:
# tenc FullBox layout (offsets include 8-byte box header):
# [0:4] size
# [4:8] type "tenc"
# [8] version
# [9:12] flags (3 bytes)
# [12] reserved
# [13] reserved (v0) / crypt_byte_block|skip_byte_block (v1)
# [14] default_isProtected
# [15] default_Per_Sample_IV_Size
# [16:32] default_KID (16 bytes)
# if per_sample_iv_size==0:
# [32] default_constant_IV_size
# [33..] default_constant_IV
tenc_version = tenc[8]
per_sample_iv_size = tenc[15]
kid = tenc[16:32]
info.per_sample_iv_size = per_sample_iv_size
info.kid = kid
logger.debug(
f"tenc: per_sample_iv_size={per_sample_iv_size}, " f"kid={kid.hex()}"
)
# If per_sample_iv_size is 0, a constant IV follows the KID
if per_sample_iv_size == 0 and len(tenc) > 32:
constant_iv_size = tenc[32]
if len(tenc) >= 33 + constant_iv_size:
info.constant_iv = tenc[33 : 33 + constant_iv_size]
logger.debug(f"Constant IV: {info.constant_iv.hex()}")
return info
def _extract_encryption_info_per_stsd(moov_data: bytes) -> Optional[dict]:
"""Extract encryption scheme info for each stsd entry (sample description).
Returns a dict mapping desc_index (0-based) → EncryptionInfo, or None if no
encryption found. This handles cases where different sample descriptions have
different encryption parameters (e.g., different IVs or key schemes).
"""
trak_data = _find_audio_trak(moov_data)
if trak_data is None:
return None
# Navigate trak → mdia → minf → stbl → stsd
mdia = _find_child_box(trak_data, b"mdia")
if mdia is None:
return None
minf = _find_child_box(mdia, b"minf")
if minf is None:
return None
stbl = _find_child_box(minf, b"stbl")
if stbl is None:
return None
stsd = _find_child_box(stbl, b"stsd")
if stsd is None:
return None
if len(stsd) < 16:
return None
entry_count = struct.unpack(">I", stsd[12:16])[0]
if entry_count == 0:
return None
encryption_info_per_desc = {}
entry_offset = 16 # past header+version+flags+entry_count
for desc_idx in range(entry_count):
if entry_offset + 8 > len(stsd):
break
entry_size = struct.unpack(">I", stsd[entry_offset : entry_offset + 4])[0]
if entry_size < 8 or entry_offset + entry_size > len(stsd):
break
entry_data = stsd[entry_offset : entry_offset + entry_size]
# Find sinf inside this sample entry
# Audio sample entries have 36-byte fixed header before child boxes
sinf = _find_child_box(entry_data, b"sinf", skip_header=36)
if sinf is not None:
# Extract encryption info for this stsd entry
info = EncryptionInfo()
# Parse schm
schm = _find_child_box(sinf, b"schm")
if schm and len(schm) >= 20:
info.scheme_type = schm[12:16].decode("ascii", errors="replace")
logger.debug(
f"Encryption scheme for desc_index {desc_idx}: {info.scheme_type}"
)
# Parse tenc
schi = _find_child_box(sinf, b"schi")
if schi:
tenc = _find_child_box(schi, b"tenc")
if tenc and len(tenc) >= 32:
per_sample_iv_size = tenc[15]
kid = tenc[16:32]
info.per_sample_iv_size = per_sample_iv_size
info.kid = kid
logger.debug(
f"tenc (desc {desc_idx}): per_sample_iv_size={per_sample_iv_size}"
)
# If per_sample_iv_size is 0, extract constant IV
if per_sample_iv_size == 0 and len(tenc) > 32:
constant_iv_size = tenc[32]
if len(tenc) >= 33 + constant_iv_size:
info.constant_iv = tenc[33 : 33 + constant_iv_size]
logger.debug(
f"Constant IV (desc {desc_idx}): {info.constant_iv.hex()}"
)
encryption_info_per_desc[desc_idx] = info
entry_offset += entry_size
return encryption_info_per_desc if encryption_info_per_desc else None
def _extract_audio_track_id(moov_data: bytes) -> int:
"""Extract the track ID of the audio track from the moov box.
@@ -1302,3 +1580,210 @@ async def decrypt_file(
decrypted_data,
input_path, # Pass original path for codec info extraction
)
def decrypt_samples_hex(
samples: List[SampleInfo],
keys: dict,
encryption_info: EncryptionInfo,
encryption_info_per_desc: Optional[dict] = None,
) -> bytes:
"""Decrypt samples using hex AES keys (no wrapper needed).
Supports both CENC (AES-128-CTR) and CBCS (AES-128-CBC) schemes.
Args:
samples: List of SampleInfo with data, desc_index, iv, subsamples.
keys: Mapping of desc_index (int) → AES key (16 bytes, raw).
encryption_info: EncryptionInfo with scheme_type, constant_iv, etc.
encryption_info_per_desc: Optional dict mapping desc_index → EncryptionInfo
(used when different stsd entries have different params).
Returns:
Concatenated decrypted sample data.
"""
is_cenc = encryption_info.scheme_type == "cenc"
decrypted = bytearray()
for sample in samples:
key = keys.get(sample.desc_index)
if key is None:
# No key for this desc_index — keep data as-is (shouldn't happen)
decrypted.extend(sample.data)
continue
# Get encryption info for this sample's desc_index (if per-description info exists)
if encryption_info_per_desc and sample.desc_index in encryption_info_per_desc:
enc_info = encryption_info_per_desc[sample.desc_index]
else:
enc_info = encryption_info
if is_cenc:
# AES-128-CTR: per-sample IV from senc, zero-padded to 16 bytes
iv = sample.iv
if len(iv) < 16:
iv = iv + b"\x00" * (16 - len(iv))
cipher = AES.new(key, AES.MODE_CTR, nonce=b"", initial_value=iv)
if sample.subsamples:
plaintext = bytearray()
offset = 0
for clear_bytes, encrypted_bytes in sample.subsamples:
plaintext.extend(sample.data[offset : offset + clear_bytes])
offset += clear_bytes
plaintext.extend(
cipher.decrypt(sample.data[offset : offset + encrypted_bytes])
)
offset += encrypted_bytes
plaintext.extend(sample.data[offset:])
decrypted.extend(plaintext)
else:
decrypted.extend(cipher.decrypt(sample.data))
else:
# CBCS (AES-128-CBC): constant IV or per-sample IV
iv = sample.iv if sample.iv else enc_info.constant_iv
if len(iv) < 16:
iv = iv + b"\x00" * (16 - len(iv))
if sample.subsamples:
# For CBCS subsamples: concatenate all encrypted regions into one,
# decrypt as one CBC stream (to maintain cipher state), then split back.
# This avoids losing bytes if encrypted_bytes values aren't 16-byte aligned.
# Collect all encrypted byte ranges and encrypt content
encrypted_concat = bytearray()
subsample_sizes = (
[]
) # Track size of each encrypted region for reassembly
offset = 0
for clear_bytes, encrypted_bytes in sample.subsamples:
offset += clear_bytes
if encrypted_bytes > 0:
encrypted_concat.extend(
sample.data[offset : offset + encrypted_bytes]
)
subsample_sizes.append(encrypted_bytes)
offset += encrypted_bytes
# Decrypt concatenated regions as one CBC stream
total_enc_len = len(encrypted_concat)
decrypted_concat = bytearray()
if total_enc_len > 0:
cbc_len = total_enc_len & ~0xF
if cbc_len > 0:
cipher = AES.new(key, AES.MODE_CBC, iv=iv)
decrypted_concat.extend(
cipher.decrypt(bytes(encrypted_concat[:cbc_len]))
)
# Any trailing unaligned bytes (shouldn't happen if file is well-formed)
if cbc_len < total_enc_len:
decrypted_concat.extend(encrypted_concat[cbc_len:])
# Reassemble with original clear/encrypted pattern
plaintext = bytearray()
dec_offset = 0
offset = 0
for clear_bytes, encrypted_bytes in sample.subsamples:
plaintext.extend(sample.data[offset : offset + clear_bytes])
offset += clear_bytes
if encrypted_bytes > 0:
plaintext.extend(
decrypted_concat[dec_offset : dec_offset + encrypted_bytes]
)
dec_offset += encrypted_bytes
offset += encrypted_bytes
plaintext.extend(sample.data[offset:])
decrypted.extend(plaintext)
else:
# Full subsample: for well-formed files, the entire sample should be
# a multiple of 16 bytes. Only truncate if misaligned (unexpected).
sample_len = len(sample.data)
if sample_len % 16 == 0:
# Data is properly 16-byte aligned, decrypt as-is
cipher = AES.new(key, AES.MODE_CBC, iv=iv)
decrypted.extend(cipher.decrypt(sample.data))
else:
# Data is not aligned (unexpected case) - truncate carefully
truncated_len = sample_len & ~0xF
if truncated_len > 0:
cipher = AES.new(key, AES.MODE_CBC, iv=iv)
decrypted.extend(cipher.decrypt(sample.data[:truncated_len]))
# Keep unaligned tail bytes as clear (unencrypted)
decrypted.extend(sample.data[truncated_len:])
else:
# Less than 16 bytes - cannot decrypt with CBC, keep as-is
decrypted.extend(sample.data)
logger.debug(
f"Decrypted {len(samples)} samples ({len(decrypted)} bytes) with hex keys"
)
return bytes(decrypted)
async def decrypt_file_hex(
input_path: str,
output_path: str,
decryption_key: str,
legacy: bool = False,
) -> None:
"""Decrypt an encrypted MP4 file using a hex AES key (no wrapper/mp4decrypt).
This replaces the mp4decrypt + remux pipeline with pure-Python decryption:
1. Extract samples and encryption info from MP4
2. Decrypt samples using AES (CTR for cenc / CBC for cbcs)
3. Write clean decrypted M4A output
Args:
input_path: Path to encrypted MP4 file.
output_path: Path for decrypted output file.
decryption_key: Hex-encoded 128-bit AES key (32 hex chars).
legacy: If True, treat as legacy AAC (cenc, single key).
"""
logger.debug(f"Hex-key decrypt: {input_path} -> {output_path}")
# Extract samples (run in thread to not block)
song_info = await asyncio.to_thread(extract_song, input_path)
# Build key mapping: desc_index → raw AES key bytes
track_key = bytes.fromhex(decryption_key)
if legacy:
# Legacy AAC (cenc): single key for all samples (all desc_index 0)
keys = {0: track_key}
else:
# Non-legacy (cbcs): two sample descriptions
# desc_index 0 → DEFAULT_SONG_DECRYPTION_KEY (prefetch samples)
# desc_index 1 → track key (from Widevine CDM)
keys = {0: DEFAULT_SONG_DECRYPTION_KEY, 1: track_key}
# Use encryption info from the file (fall back to sensible defaults)
enc_info = song_info.encryption_info or EncryptionInfo(
scheme_type="cenc" if legacy else "cbcs"
)
# Try to extract per-description encryption info (for non-legacy files)
# This handles cases where desc_index 0 and 1 have different encryption parameters
enc_info_per_desc = None
if song_info.moov_data and not legacy:
enc_info_per_desc = await asyncio.to_thread(
_extract_encryption_info_per_stsd, song_info.moov_data
)
if enc_info_per_desc:
logger.debug(
f"Found per-description encryption info: {list(enc_info_per_desc.keys())}"
)
# Decrypt
decrypted_data = decrypt_samples_hex(
song_info.samples, keys, enc_info, enc_info_per_desc
)
# Write output (preserves original metadata boxes)
await asyncio.to_thread(
write_decrypted_m4a,
output_path,
song_info,
decrypted_data,
input_path,
)