Select audio track for moof/mdat extraction

This commit is contained in:
Rafael Moraes
2026-02-25 00:08:36 -03:00
parent b0c3b4630d
commit bde49305c9
+137 -36
View File
@@ -140,6 +140,12 @@ def extract_song(input_path: str) -> SongInfo:
default_sample_duration = 1024
default_sample_size = 0
# Determine which track is the audio track
audio_track_id = (
_extract_audio_track_id(song_info.moov_data) if song_info.moov_data else 1
)
logger.debug(f"Audio track ID: {audio_track_id}")
# Parse moof/mdat pairs
moof_box = None
for box in boxes:
@@ -152,7 +158,13 @@ def extract_song(input_path: str) -> SongInfo:
# Parse moof for tfhd (sample description index, defaults) and trun (entries)
samples_from_pair = _parse_moof_mdat(
moof_data, mdat_data, default_sample_duration, default_sample_size
moof_data,
mdat_data,
default_sample_duration,
default_sample_size,
audio_track_id=audio_track_id,
moof_offset=moof_box["offset"],
mdat_data_offset=box["offset"] + box["header_size"],
)
song_info.samples.extend(samples_from_pair)
moof_box = None
@@ -166,18 +178,21 @@ def _parse_moof_mdat(
mdat_data: bytes,
default_sample_duration: int,
default_sample_size: int,
audio_track_id: int = 1,
moof_offset: int = 0,
mdat_data_offset: int = 0,
) -> List[SampleInfo]:
"""Parse a moof box and extract samples from corresponding mdat."""
samples = []
"""Parse a moof box and extract samples from corresponding mdat.
# Parse moof to find tfhd (track fragment header) and trun (track run)
tfhd_info = {
"desc_index": 0,
"default_duration": default_sample_duration,
"default_size": default_sample_size,
"flags": 0,
}
trun_entries = []
Handles multi-track fragmented MP4s by only extracting samples from
the traf matching the audio track ID.
Args:
audio_track_id: Track ID of the audio track to extract.
moof_offset: Absolute file offset of the moof box.
mdat_data_offset: Absolute file offset of the mdat content (after header).
"""
samples = []
# Simple box parsing inside moof
offset = 8 # Skip moof header
@@ -189,7 +204,18 @@ def _parse_moof_mdat(
break
if box_type == "traf":
# Parse inside traf
# Parse inside traf with per-traf state
tfhd_info = {
"track_id": 0,
"desc_index": 0,
"default_duration": default_sample_duration,
"default_size": default_sample_size,
"flags": 0,
"base_data_offset": None,
}
trun_entries = []
first_trun_data_offset = None
traf_offset = offset + 8
traf_end = offset + size
while traf_offset < traf_end - 8:
@@ -208,33 +234,52 @@ def _parse_moof_mdat(
moof_data[traf_offset + 8 : traf_offset + inner_size], tfhd_info
)
elif inner_type == "trun":
trun_entries = _parse_trun(
entries, data_off = _parse_trun(
moof_data[traf_offset + 8 : traf_offset + inner_size], tfhd_info
)
if first_trun_data_offset is None:
first_trun_data_offset = data_off
trun_entries.extend(entries)
traf_offset += inner_size
# Only process this traf if it matches the audio track
if tfhd_info["track_id"] != audio_track_id:
offset += size
continue
# Compute starting offset in mdat_data
base = tfhd_info.get("base_data_offset")
if base is None:
base = moof_offset # Default: first byte of containing moof
if first_trun_data_offset is not None:
mdat_idx = (base + first_trun_data_offset) - mdat_data_offset
else:
mdat_idx = 0
mdat_read_offset = max(0, mdat_idx)
desc_index = tfhd_info["desc_index"]
if desc_index > 0:
desc_index -= 1 # Convert to 0-indexed
for entry in trun_entries:
sample_size = entry.get("size", tfhd_info["default_size"])
sample_duration = entry.get("duration", tfhd_info["default_duration"])
if sample_size > 0 and mdat_read_offset + sample_size <= len(mdat_data):
sample = SampleInfo(
data=mdat_data[
mdat_read_offset : mdat_read_offset + sample_size
],
duration=sample_duration,
desc_index=desc_index,
)
samples.append(sample)
mdat_read_offset += sample_size
offset += size
# Extract samples from mdat using trun entries
mdat_offset = 0
desc_index = tfhd_info["desc_index"]
if desc_index > 0:
desc_index -= 1 # Convert to 0-indexed
for entry in trun_entries:
sample_size = entry.get("size", tfhd_info["default_size"])
sample_duration = entry.get("duration", tfhd_info["default_duration"])
if sample_size > 0 and mdat_offset + sample_size <= len(mdat_data):
sample = SampleInfo(
data=mdat_data[mdat_offset : mdat_offset + sample_size],
duration=sample_duration,
desc_index=desc_index,
)
samples.append(sample)
mdat_offset += sample_size
return samples
@@ -249,9 +294,13 @@ def _parse_tfhd(data: bytes, tfhd_info: dict):
tfhd_info["flags"] = flags
# After version+flags is track_id(4)
tfhd_info["track_id"] = struct.unpack(">I", data[4:8])[0]
offset = 4 + 4 # version+flags + track_id
if flags & 0x01 and offset + 8 <= len(data): # base_data_offset
tfhd_info["base_data_offset"] = struct.unpack(">Q", data[offset : offset + 8])[
0
]
offset += 8
if flags & 0x02 and offset + 4 <= len(data): # sample_description_index
tfhd_info["desc_index"] = struct.unpack(">I", data[offset : offset + 4])[0]
@@ -265,11 +314,17 @@ def _parse_tfhd(data: bytes, tfhd_info: dict):
tfhd_info["default_size"] = struct.unpack(">I", data[offset : offset + 4])[0]
def _parse_trun(data: bytes, tfhd_info: dict) -> List[dict]:
"""Parse track run box to get sample entries (FullBox: version + flags + content)."""
def _parse_trun(data: bytes, tfhd_info: dict) -> tuple[List[dict], Optional[int]]:
"""Parse track run box to get sample entries and data_offset.
Returns:
Tuple of (entries, data_offset). data_offset is the signed offset from
base_data_offset to the first sample's data, or None if not present.
"""
entries = []
data_offset_value = None
if len(data) < 8: # version(1) + flags(3) + sample_count(4)
return entries
return entries, data_offset_value
# FullBox: version(1) + flags(3)
version = data[0]
@@ -279,6 +334,7 @@ def _parse_trun(data: bytes, tfhd_info: dict) -> List[dict]:
# Start reading entries after header fields
offset = 8 # version+flags(4) + sample_count(4)
if flags & 0x01: # data_offset present
data_offset_value = struct.unpack(">i", data[offset : offset + 4])[0]
offset += 4
if flags & 0x04: # first_sample_flags present
offset += 4
@@ -297,7 +353,7 @@ def _parse_trun(data: bytes, tfhd_info: dict) -> List[dict]:
offset += 4
entries.append(entry)
return entries
return entries, data_offset_value
async def decrypt_samples(
@@ -1020,6 +1076,51 @@ def _extract_timestamps_from_box(data: bytes, box_type: bytes) -> tuple[int, int
return 0, 0
def _extract_audio_track_id(moov_data: bytes) -> int:
"""Extract the track ID of the audio track from the moov box.
Parses trak boxes in moov to find one with handler_type 'soun' (sound),
then returns its track_id from tkhd. Defaults to 1 if not found.
"""
offset = 8 # Skip moov box header
while offset < len(moov_data) - 8:
size = struct.unpack(">I", moov_data[offset : offset + 4])[0]
box_type = moov_data[offset + 4 : offset + 8]
if size < 8 or offset + size > len(moov_data):
break
if box_type == b"trak":
trak_data = moov_data[offset : offset + size]
# Check handler type in hdlr box
hdlr_idx = trak_data.find(b"hdlr")
if hdlr_idx > 0:
# hdlr FullBox: after 'hdlr' type comes version+flags(4) + pre_defined(4) + handler_type(4)
handler_offset = hdlr_idx + 4 + 4 + 4
if handler_offset + 4 <= len(trak_data):
handler_type = trak_data[handler_offset : handler_offset + 4]
if handler_type == b"soun":
# Found audio track, extract track_id from tkhd
tkhd_idx = trak_data.find(b"tkhd")
if tkhd_idx > 0:
version = trak_data[tkhd_idx + 4]
if version == 0:
# v0: ver+flags(4) + creation(4) + modification(4) + track_id(4)
tid_offset = tkhd_idx + 4 + 4 + 4 + 4
else:
# v1: ver+flags(4) + creation(8) + modification(8) + track_id(4)
tid_offset = tkhd_idx + 4 + 4 + 8 + 8
if tid_offset + 4 <= len(trak_data):
return struct.unpack(
">I", trak_data[tid_offset : tid_offset + 4]
)[0]
offset += size
return 1 # Default to track 1
async def decrypt_file(
wrapper_ip: str,
track_id: str,