mirror of
https://github.com/glomatico/gamdl.git
synced 2026-06-13 04:05:14 +03:00
Refactor amdecrypt for wrapper-v2 /decrypt/samples
This commit is contained in:
+132
-108
@@ -2,18 +2,18 @@
|
||||
This is a modified version of https://github.com/sn0wst0rm/st0rmMusicPlayer/blob/main/scripts/amdecrypt.py
|
||||
All the modifications made here were AI generated
|
||||
|
||||
FairPlay sample decryption talks to wrapper-v2 over HTTP POST /decrypt (JSON), not the
|
||||
legacy raw TCP port used by the original wrapper (e.g. 10020).
|
||||
FairPlay sample decryption talks to wrapper-v2 over HTTP POST /decrypt/sample
|
||||
(binary frame), not the legacy raw TCP port used by the original wrapper (e.g. 10020).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import struct
|
||||
import time
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from typing import BinaryIO, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
@@ -33,7 +33,7 @@ PREFETCH_KEY = "skd://itunes.apple.com/P000000000/s1/e1"
|
||||
# wrapper-v2 HTTP API base (no trailing slash). Override via decrypt_file(..., wrapper_ip=...).
|
||||
DEFAULT_WRAPPER_IP = "http://127.0.0.1:80"
|
||||
|
||||
# Max ciphertext blobs per POST /decrypt (same adam_id + uri). Increase for fewer
|
||||
# Max ciphertext blobs per POST /decrypt/sample (same adam_id + uri). Increase for fewer
|
||||
# round-trips; set to 1 if a given wrapper build mis-handles CBC between chunks.
|
||||
WRAPPER_DECRYPT_BATCH_SIZE = 128
|
||||
|
||||
@@ -63,6 +63,57 @@ def _wrapper_v2_base_url(wrapper_ip: str) -> str:
|
||||
return f"http://{s}:80"
|
||||
|
||||
|
||||
def _build_decrypt_sample_frame(adam_id: str, skd_uri: str, ciphertexts: List[bytes]) -> bytes:
|
||||
"""Build wrapper-v2 /decrypt/sample binary request frame."""
|
||||
adam_id_bytes = adam_id.encode("utf-8")
|
||||
skd_uri_bytes = skd_uri.encode("utf-8")
|
||||
if not adam_id_bytes:
|
||||
raise ValueError("wrapper-v2: adam_id must not be empty")
|
||||
if not skd_uri_bytes:
|
||||
raise ValueError("wrapper-v2: skd_uri must not be empty")
|
||||
if not ciphertexts:
|
||||
raise ValueError("wrapper-v2: ciphertext batch must not be empty")
|
||||
|
||||
frame = bytearray()
|
||||
frame += struct.pack(">III", len(adam_id_bytes), len(skd_uri_bytes), len(ciphertexts))
|
||||
for ciphertext in ciphertexts:
|
||||
frame += struct.pack(">I", len(ciphertext))
|
||||
frame += adam_id_bytes
|
||||
frame += skd_uri_bytes
|
||||
for ciphertext in ciphertexts:
|
||||
frame += ciphertext
|
||||
return bytes(frame)
|
||||
|
||||
|
||||
def _parse_decrypt_sample_frame(data: bytes, expected_count: int) -> List[bytes]:
|
||||
"""Parse wrapper-v2 /decrypt/sample binary response frame."""
|
||||
if len(data) < 4:
|
||||
raise IOError("wrapper-v2: POST /decrypt/sample returned a truncated response")
|
||||
(sample_count,) = struct.unpack_from(">I", data, 0)
|
||||
if sample_count != expected_count:
|
||||
raise IOError(
|
||||
f"wrapper-v2: expected {expected_count} samples in response, got {sample_count}"
|
||||
)
|
||||
|
||||
table_end = 4 + sample_count * 4
|
||||
if len(data) < table_end:
|
||||
raise IOError("wrapper-v2: POST /decrypt/sample returned a truncated length table")
|
||||
|
||||
lengths = [struct.unpack_from(">I", data, 4 + i * 4)[0] for i in range(sample_count)]
|
||||
offset = table_end
|
||||
out: List[bytes] = []
|
||||
for i, length in enumerate(lengths):
|
||||
end = offset + length
|
||||
if end > len(data):
|
||||
raise IOError(f"wrapper-v2: POST /decrypt/sample returned truncated sample {i}")
|
||||
out.append(data[offset:end])
|
||||
offset = end
|
||||
|
||||
if offset != len(data):
|
||||
raise IOError("wrapper-v2: POST /decrypt/sample returned trailing bytes")
|
||||
return out
|
||||
|
||||
|
||||
async def _post_decrypt_batch(
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
@@ -70,16 +121,19 @@ async def _post_decrypt_batch(
|
||||
skd_uri: str,
|
||||
ciphertexts: List[bytes],
|
||||
) -> List[bytes]:
|
||||
"""One POST /decrypt; ciphertexts and returned plaintexts are in order."""
|
||||
payload = {
|
||||
"adam_id": adam_id,
|
||||
"uri": skd_uri,
|
||||
"samples": [base64.standard_b64encode(c).decode("ascii") for c in ciphertexts],
|
||||
}
|
||||
r = await client.post(f"{base_url}/decrypt", json=payload)
|
||||
"""One POST /decrypt/sample; ciphertexts and returned plaintexts are in order."""
|
||||
frame = _build_decrypt_sample_frame(adam_id, skd_uri, ciphertexts)
|
||||
r = await client.post(
|
||||
f"{base_url}/decrypt/sample",
|
||||
content=frame,
|
||||
headers={
|
||||
"content-type": "application/octet-stream",
|
||||
"accept": "application/octet-stream",
|
||||
},
|
||||
)
|
||||
if r.status_code == 401:
|
||||
raise IOError(
|
||||
"wrapper-v2: POST /decrypt returned 401 — log in with POST /login "
|
||||
"wrapper-v2: POST /decrypt/sample returned 401 — log in with POST /login "
|
||||
"or restore a session on the daemon first"
|
||||
)
|
||||
if r.status_code == 503:
|
||||
@@ -94,23 +148,11 @@ async def _post_decrypt_batch(
|
||||
detail = (j.get("detail") or j.get("error") or str(j)) or ""
|
||||
except Exception:
|
||||
detail = (r.text or "")[:500]
|
||||
raise IOError(f"wrapper-v2: POST /decrypt failed HTTP {r.status_code}: {detail}")
|
||||
|
||||
body = r.json()
|
||||
if not isinstance(body, dict) or "samples" not in body:
|
||||
raise IOError("wrapper-v2: POST /decrypt: expected JSON object with 'samples' array")
|
||||
out_b64 = body["samples"]
|
||||
if not isinstance(out_b64, list) or len(out_b64) != len(ciphertexts):
|
||||
raise IOError(
|
||||
f"wrapper-v2: expected {len(ciphertexts)} samples in response, "
|
||||
f"got {len(out_b64) if isinstance(out_b64, list) else type(out_b64)}"
|
||||
f"wrapper-v2: POST /decrypt/sample failed HTTP {r.status_code}: {detail}"
|
||||
)
|
||||
out: List[bytes] = []
|
||||
for i, item in enumerate(out_b64):
|
||||
if not isinstance(item, str):
|
||||
raise IOError(f"wrapper-v2: samples[{i}] must be base64 string")
|
||||
out.append(base64.standard_b64decode(item.encode("ascii")))
|
||||
return out
|
||||
|
||||
return _parse_decrypt_sample_frame(r.content, len(ciphertexts))
|
||||
|
||||
|
||||
def _cbcs_ciphertext_for_sample(sample: SampleInfo) -> Optional[tuple[bytes, bytes]]:
|
||||
@@ -151,25 +193,45 @@ def _cbcs_ciphertext_for_sample(sample: SampleInfo) -> Optional[tuple[bytes, byt
|
||||
return (aligned, tail)
|
||||
|
||||
|
||||
def _append_reassembled_sample(decrypted_data: bytearray, sample: SampleInfo, plain: bytes, tail: bytes) -> None:
|
||||
"""Place decrypted bytes back into MP4 sample layout (clear / encrypted regions)."""
|
||||
def _reassemble_cbcs_sample(sample: SampleInfo, plain: bytes, tail: bytes) -> bytes:
|
||||
"""Place decrypted CBCS bytes back into the original MP4 sample layout."""
|
||||
full_dec = plain + tail
|
||||
data = sample.data
|
||||
if not sample.subsamples:
|
||||
decrypted_data.extend(full_dec)
|
||||
return
|
||||
if len(full_dec) != len(data):
|
||||
raise IOError(
|
||||
f"decrypted sample length mismatch: expected {len(data)}, got {len(full_dec)}"
|
||||
)
|
||||
return full_dec
|
||||
|
||||
encrypted_total = sum(enc_b for _, enc_b in sample.subsamples)
|
||||
if len(full_dec) != encrypted_total:
|
||||
raise IOError(
|
||||
"decrypted subsample length mismatch: "
|
||||
f"expected {encrypted_total}, got {len(full_dec)}"
|
||||
)
|
||||
|
||||
out = bytearray()
|
||||
dec_off = 0
|
||||
offset = 0
|
||||
for clear_b, enc_b in sample.subsamples:
|
||||
if clear_b:
|
||||
decrypted_data.extend(data[offset : offset + clear_b])
|
||||
out.extend(data[offset : offset + clear_b])
|
||||
offset += clear_b
|
||||
if enc_b:
|
||||
decrypted_data.extend(full_dec[dec_off : dec_off + enc_b])
|
||||
out.extend(full_dec[dec_off : dec_off + enc_b])
|
||||
dec_off += enc_b
|
||||
offset += enc_b
|
||||
if offset < len(data):
|
||||
decrypted_data.extend(data[offset:])
|
||||
out.extend(data[offset:])
|
||||
if len(out) != len(data):
|
||||
raise IOError(f"reassembled sample length mismatch: expected {len(data)}, got {len(out)}")
|
||||
return bytes(out)
|
||||
|
||||
|
||||
def _append_reassembled_sample(decrypted_data: bytearray, sample: SampleInfo, plain: bytes, tail: bytes) -> None:
|
||||
"""Append one decrypted CBCS sample to the output stream."""
|
||||
decrypted_data.extend(_reassemble_cbcs_sample(sample, plain, tail))
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -651,7 +713,8 @@ async def decrypt_samples(
|
||||
progress_callback=None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Send samples to wrapper-v2 (HTTP POST /decrypt) for CBCS decryption and return decrypted bytes.
|
||||
Send samples to wrapper-v2 (HTTP POST /decrypt/sample) for CBCS decryption and
|
||||
return decrypted bytes.
|
||||
|
||||
Ciphertext is sent in batches of up to :data:`WRAPPER_DECRYPT_BATCH_SIZE` MP4 samples
|
||||
per request (same ``adam_id`` and ``uri``). Literal or tail-only samples are applied
|
||||
@@ -769,6 +832,7 @@ def write_decrypted_m4a(
|
||||
# master timescale to ensure 100% duration consistency.
|
||||
orig_hdlr = None
|
||||
timescale = 44100 # Default fallback
|
||||
preferred_desc_index = _preferred_sample_description_index(song_info.samples)
|
||||
|
||||
if original_path:
|
||||
with open(original_path, "rb") as f:
|
||||
@@ -779,7 +843,7 @@ def write_decrypted_m4a(
|
||||
orig_data = None
|
||||
|
||||
if orig_data:
|
||||
stsd_content = _extract_stsd_content(orig_data)
|
||||
stsd_content = _extract_stsd_content(orig_data, preferred_desc_index)
|
||||
# Extract the REAL sample rate from the codec configuration
|
||||
timescale = _extract_sample_rate_from_stsd(stsd_content) or _extract_timescale(orig_data)
|
||||
|
||||
@@ -832,6 +896,14 @@ def write_decrypted_m4a(
|
||||
logger.debug(f"Wrote decrypted file to {output_path}")
|
||||
|
||||
|
||||
def _preferred_sample_description_index(samples: List[SampleInfo]) -> int:
|
||||
"""Return the 0-based sample description index to keep in flattened output."""
|
||||
counts = Counter(sample.desc_index for sample in samples if sample.data)
|
||||
if not counts:
|
||||
return 0
|
||||
return counts.most_common(1)[0][0]
|
||||
|
||||
|
||||
def _write_box(f, box_type: bytes, content: bytes):
|
||||
"""Write a simple MP4 box."""
|
||||
size = len(content) + 8
|
||||
@@ -1146,8 +1218,8 @@ def _write_mdat(f, data: bytes):
|
||||
f.write(data)
|
||||
|
||||
|
||||
def _extract_stsd_content(data: bytes) -> Optional[bytes]:
|
||||
"""Extract full stsd box content from moov box (supports any codec)."""
|
||||
def _extract_stsd_content(data: bytes, preferred_desc_index: Optional[int] = None) -> Optional[bytes]:
|
||||
"""Extract cleaned stsd box content from moov box (supports any codec)."""
|
||||
# Find stsd box in the data
|
||||
idx = data.find(b"stsd")
|
||||
if idx < 4:
|
||||
@@ -1162,10 +1234,10 @@ def _extract_stsd_content(data: bytes) -> Optional[bytes]:
|
||||
raw_content = data[idx + 4 : idx - 4 + size]
|
||||
|
||||
# Clean the stsd content to remove encryption metadata
|
||||
return _clean_stsd_content(raw_content)
|
||||
return _clean_stsd_content(raw_content, preferred_desc_index)
|
||||
|
||||
|
||||
def _clean_stsd_content(stsd_content: bytes) -> bytes:
|
||||
def _clean_stsd_content(stsd_content: bytes, preferred_desc_index: Optional[int] = None) -> bytes:
|
||||
"""
|
||||
Clean stsd content by removing encryption metadata.
|
||||
|
||||
@@ -1182,7 +1254,7 @@ def _clean_stsd_content(stsd_content: bytes) -> bytes:
|
||||
version_flags = stsd_content[:4]
|
||||
entry_count = struct.unpack(">I", stsd_content[4:8])[0]
|
||||
|
||||
# Parse and clean each sample entry
|
||||
# Parse and clean each sample entry.
|
||||
cleaned_entries = []
|
||||
offset = 8
|
||||
|
||||
@@ -1210,6 +1282,15 @@ def _clean_stsd_content(stsd_content: bytes) -> bytes:
|
||||
|
||||
offset += entry_size
|
||||
|
||||
# The writer emits one chunk and one stsc entry with sample_description_index=1.
|
||||
# Keep only the dominant source description so that the flattened MP4's sample
|
||||
# table and stsd agree. iTunes is stricter about this than many players.
|
||||
if preferred_desc_index is not None and cleaned_entries:
|
||||
if 0 <= preferred_desc_index < len(cleaned_entries):
|
||||
cleaned_entries = [cleaned_entries[preferred_desc_index]]
|
||||
else:
|
||||
cleaned_entries = [cleaned_entries[0]]
|
||||
|
||||
# Rebuild stsd content
|
||||
result = version_flags + struct.pack(">I", len(cleaned_entries))
|
||||
for entry in cleaned_entries:
|
||||
@@ -1928,74 +2009,17 @@ def decrypt_samples_hex(
|
||||
if len(iv) < 16:
|
||||
iv = iv + b"\x00" * (16 - len(iv))
|
||||
|
||||
if sample.subsamples:
|
||||
# For CBCS subsamples: concatenate all encrypted regions into one,
|
||||
# decrypt as one CBC stream (to maintain cipher state), then split back.
|
||||
# This avoids losing bytes if encrypted_bytes values aren't 16-byte aligned.
|
||||
parts = _cbcs_ciphertext_for_sample(sample)
|
||||
if parts is None:
|
||||
decrypted.extend(sample.data)
|
||||
continue
|
||||
|
||||
# Collect all encrypted byte ranges and encrypt content
|
||||
encrypted_concat = bytearray()
|
||||
subsample_sizes = (
|
||||
[]
|
||||
) # Track size of each encrypted region for reassembly
|
||||
offset = 0
|
||||
for clear_bytes, encrypted_bytes in sample.subsamples:
|
||||
offset += clear_bytes
|
||||
if encrypted_bytes > 0:
|
||||
encrypted_concat.extend(
|
||||
sample.data[offset : offset + encrypted_bytes]
|
||||
)
|
||||
subsample_sizes.append(encrypted_bytes)
|
||||
offset += encrypted_bytes
|
||||
|
||||
# Decrypt concatenated regions as one CBC stream
|
||||
total_enc_len = len(encrypted_concat)
|
||||
decrypted_concat = bytearray()
|
||||
if total_enc_len > 0:
|
||||
cbc_len = total_enc_len & ~0xF
|
||||
if cbc_len > 0:
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv=iv)
|
||||
decrypted_concat.extend(
|
||||
cipher.decrypt(bytes(encrypted_concat[:cbc_len]))
|
||||
)
|
||||
# Any trailing unaligned bytes (shouldn't happen if file is well-formed)
|
||||
if cbc_len < total_enc_len:
|
||||
decrypted_concat.extend(encrypted_concat[cbc_len:])
|
||||
|
||||
# Reassemble with original clear/encrypted pattern
|
||||
plaintext = bytearray()
|
||||
dec_offset = 0
|
||||
offset = 0
|
||||
for clear_bytes, encrypted_bytes in sample.subsamples:
|
||||
plaintext.extend(sample.data[offset : offset + clear_bytes])
|
||||
offset += clear_bytes
|
||||
if encrypted_bytes > 0:
|
||||
plaintext.extend(
|
||||
decrypted_concat[dec_offset : dec_offset + encrypted_bytes]
|
||||
)
|
||||
dec_offset += encrypted_bytes
|
||||
offset += encrypted_bytes
|
||||
plaintext.extend(sample.data[offset:])
|
||||
decrypted.extend(plaintext)
|
||||
else:
|
||||
# Full subsample: for well-formed files, the entire sample should be
|
||||
# a multiple of 16 bytes. Only truncate if misaligned (unexpected).
|
||||
sample_len = len(sample.data)
|
||||
if sample_len % 16 == 0:
|
||||
# Data is properly 16-byte aligned, decrypt as-is
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv=iv)
|
||||
decrypted.extend(cipher.decrypt(sample.data))
|
||||
else:
|
||||
# Data is not aligned (unexpected case) - truncate carefully
|
||||
truncated_len = sample_len & ~0xF
|
||||
if truncated_len > 0:
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv=iv)
|
||||
decrypted.extend(cipher.decrypt(sample.data[:truncated_len]))
|
||||
# Keep unaligned tail bytes as clear (unencrypted)
|
||||
decrypted.extend(sample.data[truncated_len:])
|
||||
else:
|
||||
# Less than 16 bytes - cannot decrypt with CBC, keep as-is
|
||||
decrypted.extend(sample.data)
|
||||
aligned, tail = parts
|
||||
plain = b""
|
||||
if aligned:
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv=iv)
|
||||
plain = cipher.decrypt(aligned)
|
||||
decrypted.extend(_reassemble_cbcs_sample(sample, plain, tail))
|
||||
|
||||
logger.debug(
|
||||
f"Decrypted {len(samples)} samples ({len(decrypted)} bytes) with hex keys"
|
||||
|
||||
Reference in New Issue
Block a user