Refactor amdecrypt for wrapper-v2 /decrypt/samples

This commit is contained in:
Rafael Moraes
2026-05-18 14:52:21 -03:00
parent 86bbb94274
commit 4cdad09372
+132 -108
View File
@@ -2,18 +2,18 @@
This is a modified version of https://github.com/sn0wst0rm/st0rmMusicPlayer/blob/main/scripts/amdecrypt.py
All the modifications made here were AI generated
FairPlay sample decryption talks to wrapper-v2 over HTTP POST /decrypt (JSON), not the
legacy raw TCP port used by the original wrapper (e.g. 10020).
FairPlay sample decryption talks to wrapper-v2 over HTTP POST /decrypt/sample
(binary frame), not the legacy raw TCP port used by the original wrapper (e.g. 10020).
"""
from __future__ import annotations
import asyncio
import base64
import io
import logging
import struct
import time
from collections import Counter
from dataclasses import dataclass, field
from typing import BinaryIO, List, Optional
from urllib.parse import urlparse
@@ -33,7 +33,7 @@ PREFETCH_KEY = "skd://itunes.apple.com/P000000000/s1/e1"
# wrapper-v2 HTTP API base (no trailing slash). Override via decrypt_file(..., wrapper_ip=...).
DEFAULT_WRAPPER_IP = "http://127.0.0.1:80"
# Max ciphertext blobs per POST /decrypt (same adam_id + uri). Increase for fewer
# Max ciphertext blobs per POST /decrypt/sample (same adam_id + uri). Increase for fewer
# round-trips; set to 1 if a given wrapper build mis-handles CBC between chunks.
WRAPPER_DECRYPT_BATCH_SIZE = 128
@@ -63,6 +63,57 @@ def _wrapper_v2_base_url(wrapper_ip: str) -> str:
return f"http://{s}:80"
def _build_decrypt_sample_frame(adam_id: str, skd_uri: str, ciphertexts: List[bytes]) -> bytes:
"""Build wrapper-v2 /decrypt/sample binary request frame."""
adam_id_bytes = adam_id.encode("utf-8")
skd_uri_bytes = skd_uri.encode("utf-8")
if not adam_id_bytes:
raise ValueError("wrapper-v2: adam_id must not be empty")
if not skd_uri_bytes:
raise ValueError("wrapper-v2: skd_uri must not be empty")
if not ciphertexts:
raise ValueError("wrapper-v2: ciphertext batch must not be empty")
frame = bytearray()
frame += struct.pack(">III", len(adam_id_bytes), len(skd_uri_bytes), len(ciphertexts))
for ciphertext in ciphertexts:
frame += struct.pack(">I", len(ciphertext))
frame += adam_id_bytes
frame += skd_uri_bytes
for ciphertext in ciphertexts:
frame += ciphertext
return bytes(frame)
def _parse_decrypt_sample_frame(data: bytes, expected_count: int) -> List[bytes]:
"""Parse wrapper-v2 /decrypt/sample binary response frame."""
if len(data) < 4:
raise IOError("wrapper-v2: POST /decrypt/sample returned a truncated response")
(sample_count,) = struct.unpack_from(">I", data, 0)
if sample_count != expected_count:
raise IOError(
f"wrapper-v2: expected {expected_count} samples in response, got {sample_count}"
)
table_end = 4 + sample_count * 4
if len(data) < table_end:
raise IOError("wrapper-v2: POST /decrypt/sample returned a truncated length table")
lengths = [struct.unpack_from(">I", data, 4 + i * 4)[0] for i in range(sample_count)]
offset = table_end
out: List[bytes] = []
for i, length in enumerate(lengths):
end = offset + length
if end > len(data):
raise IOError(f"wrapper-v2: POST /decrypt/sample returned truncated sample {i}")
out.append(data[offset:end])
offset = end
if offset != len(data):
raise IOError("wrapper-v2: POST /decrypt/sample returned trailing bytes")
return out
async def _post_decrypt_batch(
client: httpx.AsyncClient,
base_url: str,
@@ -70,16 +121,19 @@ async def _post_decrypt_batch(
skd_uri: str,
ciphertexts: List[bytes],
) -> List[bytes]:
"""One POST /decrypt; ciphertexts and returned plaintexts are in order."""
payload = {
"adam_id": adam_id,
"uri": skd_uri,
"samples": [base64.standard_b64encode(c).decode("ascii") for c in ciphertexts],
}
r = await client.post(f"{base_url}/decrypt", json=payload)
"""One POST /decrypt/sample; ciphertexts and returned plaintexts are in order."""
frame = _build_decrypt_sample_frame(adam_id, skd_uri, ciphertexts)
r = await client.post(
f"{base_url}/decrypt/sample",
content=frame,
headers={
"content-type": "application/octet-stream",
"accept": "application/octet-stream",
},
)
if r.status_code == 401:
raise IOError(
"wrapper-v2: POST /decrypt returned 401 — log in with POST /login "
"wrapper-v2: POST /decrypt/sample returned 401 — log in with POST /login "
"or restore a session on the daemon first"
)
if r.status_code == 503:
@@ -94,23 +148,11 @@ async def _post_decrypt_batch(
detail = (j.get("detail") or j.get("error") or str(j)) or ""
except Exception:
detail = (r.text or "")[:500]
raise IOError(f"wrapper-v2: POST /decrypt failed HTTP {r.status_code}: {detail}")
body = r.json()
if not isinstance(body, dict) or "samples" not in body:
raise IOError("wrapper-v2: POST /decrypt: expected JSON object with 'samples' array")
out_b64 = body["samples"]
if not isinstance(out_b64, list) or len(out_b64) != len(ciphertexts):
raise IOError(
f"wrapper-v2: expected {len(ciphertexts)} samples in response, "
f"got {len(out_b64) if isinstance(out_b64, list) else type(out_b64)}"
f"wrapper-v2: POST /decrypt/sample failed HTTP {r.status_code}: {detail}"
)
out: List[bytes] = []
for i, item in enumerate(out_b64):
if not isinstance(item, str):
raise IOError(f"wrapper-v2: samples[{i}] must be base64 string")
out.append(base64.standard_b64decode(item.encode("ascii")))
return out
return _parse_decrypt_sample_frame(r.content, len(ciphertexts))
def _cbcs_ciphertext_for_sample(sample: SampleInfo) -> Optional[tuple[bytes, bytes]]:
@@ -151,25 +193,45 @@ def _cbcs_ciphertext_for_sample(sample: SampleInfo) -> Optional[tuple[bytes, byt
return (aligned, tail)
def _append_reassembled_sample(decrypted_data: bytearray, sample: SampleInfo, plain: bytes, tail: bytes) -> None:
"""Place decrypted bytes back into MP4 sample layout (clear / encrypted regions)."""
def _reassemble_cbcs_sample(sample: SampleInfo, plain: bytes, tail: bytes) -> bytes:
"""Place decrypted CBCS bytes back into the original MP4 sample layout."""
full_dec = plain + tail
data = sample.data
if not sample.subsamples:
decrypted_data.extend(full_dec)
return
if len(full_dec) != len(data):
raise IOError(
f"decrypted sample length mismatch: expected {len(data)}, got {len(full_dec)}"
)
return full_dec
encrypted_total = sum(enc_b for _, enc_b in sample.subsamples)
if len(full_dec) != encrypted_total:
raise IOError(
"decrypted subsample length mismatch: "
f"expected {encrypted_total}, got {len(full_dec)}"
)
out = bytearray()
dec_off = 0
offset = 0
for clear_b, enc_b in sample.subsamples:
if clear_b:
decrypted_data.extend(data[offset : offset + clear_b])
out.extend(data[offset : offset + clear_b])
offset += clear_b
if enc_b:
decrypted_data.extend(full_dec[dec_off : dec_off + enc_b])
out.extend(full_dec[dec_off : dec_off + enc_b])
dec_off += enc_b
offset += enc_b
if offset < len(data):
decrypted_data.extend(data[offset:])
out.extend(data[offset:])
if len(out) != len(data):
raise IOError(f"reassembled sample length mismatch: expected {len(data)}, got {len(out)}")
return bytes(out)
def _append_reassembled_sample(decrypted_data: bytearray, sample: SampleInfo, plain: bytes, tail: bytes) -> None:
"""Append one decrypted CBCS sample to the output stream."""
decrypted_data.extend(_reassemble_cbcs_sample(sample, plain, tail))
@dataclass
@@ -651,7 +713,8 @@ async def decrypt_samples(
progress_callback=None,
) -> bytes:
"""
Send samples to wrapper-v2 (HTTP POST /decrypt) for CBCS decryption and return decrypted bytes.
Send samples to wrapper-v2 (HTTP POST /decrypt/sample) for CBCS decryption and
return decrypted bytes.
Ciphertext is sent in batches of up to :data:`WRAPPER_DECRYPT_BATCH_SIZE` MP4 samples
per request (same ``adam_id`` and ``uri``). Literal or tail-only samples are applied
@@ -769,6 +832,7 @@ def write_decrypted_m4a(
# master timescale to ensure 100% duration consistency.
orig_hdlr = None
timescale = 44100 # Default fallback
preferred_desc_index = _preferred_sample_description_index(song_info.samples)
if original_path:
with open(original_path, "rb") as f:
@@ -779,7 +843,7 @@ def write_decrypted_m4a(
orig_data = None
if orig_data:
stsd_content = _extract_stsd_content(orig_data)
stsd_content = _extract_stsd_content(orig_data, preferred_desc_index)
# Extract the REAL sample rate from the codec configuration
timescale = _extract_sample_rate_from_stsd(stsd_content) or _extract_timescale(orig_data)
@@ -832,6 +896,14 @@ def write_decrypted_m4a(
logger.debug(f"Wrote decrypted file to {output_path}")
def _preferred_sample_description_index(samples: List[SampleInfo]) -> int:
"""Return the 0-based sample description index to keep in flattened output."""
counts = Counter(sample.desc_index for sample in samples if sample.data)
if not counts:
return 0
return counts.most_common(1)[0][0]
def _write_box(f, box_type: bytes, content: bytes):
"""Write a simple MP4 box."""
size = len(content) + 8
@@ -1146,8 +1218,8 @@ def _write_mdat(f, data: bytes):
f.write(data)
def _extract_stsd_content(data: bytes) -> Optional[bytes]:
"""Extract full stsd box content from moov box (supports any codec)."""
def _extract_stsd_content(data: bytes, preferred_desc_index: Optional[int] = None) -> Optional[bytes]:
"""Extract cleaned stsd box content from moov box (supports any codec)."""
# Find stsd box in the data
idx = data.find(b"stsd")
if idx < 4:
@@ -1162,10 +1234,10 @@ def _extract_stsd_content(data: bytes) -> Optional[bytes]:
raw_content = data[idx + 4 : idx - 4 + size]
# Clean the stsd content to remove encryption metadata
return _clean_stsd_content(raw_content)
return _clean_stsd_content(raw_content, preferred_desc_index)
def _clean_stsd_content(stsd_content: bytes) -> bytes:
def _clean_stsd_content(stsd_content: bytes, preferred_desc_index: Optional[int] = None) -> bytes:
"""
Clean stsd content by removing encryption metadata.
@@ -1182,7 +1254,7 @@ def _clean_stsd_content(stsd_content: bytes) -> bytes:
version_flags = stsd_content[:4]
entry_count = struct.unpack(">I", stsd_content[4:8])[0]
# Parse and clean each sample entry
# Parse and clean each sample entry.
cleaned_entries = []
offset = 8
@@ -1210,6 +1282,15 @@ def _clean_stsd_content(stsd_content: bytes) -> bytes:
offset += entry_size
# The writer emits one chunk and one stsc entry with sample_description_index=1.
# Keep only the dominant source description so that the flattened MP4's sample
# table and stsd agree. iTunes is stricter about this than many players.
if preferred_desc_index is not None and cleaned_entries:
if 0 <= preferred_desc_index < len(cleaned_entries):
cleaned_entries = [cleaned_entries[preferred_desc_index]]
else:
cleaned_entries = [cleaned_entries[0]]
# Rebuild stsd content
result = version_flags + struct.pack(">I", len(cleaned_entries))
for entry in cleaned_entries:
@@ -1928,74 +2009,17 @@ def decrypt_samples_hex(
if len(iv) < 16:
iv = iv + b"\x00" * (16 - len(iv))
if sample.subsamples:
# For CBCS subsamples: concatenate all encrypted regions into one,
# decrypt as one CBC stream (to maintain cipher state), then split back.
# This avoids losing bytes if encrypted_bytes values aren't 16-byte aligned.
parts = _cbcs_ciphertext_for_sample(sample)
if parts is None:
decrypted.extend(sample.data)
continue
# Collect all encrypted byte ranges and encrypt content
encrypted_concat = bytearray()
subsample_sizes = (
[]
) # Track size of each encrypted region for reassembly
offset = 0
for clear_bytes, encrypted_bytes in sample.subsamples:
offset += clear_bytes
if encrypted_bytes > 0:
encrypted_concat.extend(
sample.data[offset : offset + encrypted_bytes]
)
subsample_sizes.append(encrypted_bytes)
offset += encrypted_bytes
# Decrypt concatenated regions as one CBC stream
total_enc_len = len(encrypted_concat)
decrypted_concat = bytearray()
if total_enc_len > 0:
cbc_len = total_enc_len & ~0xF
if cbc_len > 0:
cipher = AES.new(key, AES.MODE_CBC, iv=iv)
decrypted_concat.extend(
cipher.decrypt(bytes(encrypted_concat[:cbc_len]))
)
# Any trailing unaligned bytes (shouldn't happen if file is well-formed)
if cbc_len < total_enc_len:
decrypted_concat.extend(encrypted_concat[cbc_len:])
# Reassemble with original clear/encrypted pattern
plaintext = bytearray()
dec_offset = 0
offset = 0
for clear_bytes, encrypted_bytes in sample.subsamples:
plaintext.extend(sample.data[offset : offset + clear_bytes])
offset += clear_bytes
if encrypted_bytes > 0:
plaintext.extend(
decrypted_concat[dec_offset : dec_offset + encrypted_bytes]
)
dec_offset += encrypted_bytes
offset += encrypted_bytes
plaintext.extend(sample.data[offset:])
decrypted.extend(plaintext)
else:
# Full subsample: for well-formed files, the entire sample should be
# a multiple of 16 bytes. Only truncate if misaligned (unexpected).
sample_len = len(sample.data)
if sample_len % 16 == 0:
# Data is properly 16-byte aligned, decrypt as-is
cipher = AES.new(key, AES.MODE_CBC, iv=iv)
decrypted.extend(cipher.decrypt(sample.data))
else:
# Data is not aligned (unexpected case) - truncate carefully
truncated_len = sample_len & ~0xF
if truncated_len > 0:
cipher = AES.new(key, AES.MODE_CBC, iv=iv)
decrypted.extend(cipher.decrypt(sample.data[:truncated_len]))
# Keep unaligned tail bytes as clear (unencrypted)
decrypted.extend(sample.data[truncated_len:])
else:
# Less than 16 bytes - cannot decrypt with CBC, keep as-is
decrypted.extend(sample.data)
aligned, tail = parts
plain = b""
if aligned:
cipher = AES.new(key, AES.MODE_CBC, iv=iv)
plain = cipher.decrypt(aligned)
decrypted.extend(_reassemble_cbcs_sample(sample, plain, tail))
logger.debug(
f"Decrypted {len(samples)} samples ({len(decrypted)} bytes) with hex keys"