Make decrypt_samples async and use asyncio streams

This commit is contained in:
Rafael Moraes
2026-02-24 23:09:32 -03:00
parent fd30ab861b
commit b0c3b4630d
+18 -32
View File
@@ -5,11 +5,8 @@ This is a modified version of https://github.com/sn0wst0rm/st0rmMusicPlayer/blob
import asyncio
import io
import logging
import os
import socket
import struct
from dataclasses import dataclass, field
from pathlib import Path
from typing import BinaryIO, List, Optional
logger = logging.getLogger(__name__)
@@ -303,7 +300,7 @@ def _parse_trun(data: bytes, tfhd_info: dict) -> List[dict]:
return entries
def decrypt_samples(
async def decrypt_samples(
wrapper_ip: str,
track_id: str,
fairplay_key: str,
@@ -331,16 +328,7 @@ def decrypt_samples(
host, port = wrapper_ip.split(":")
port = int(port)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(120.0)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # Disable Nagle
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 262144) # 256KB send buffer
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 262144) # 256KB recv buffer
sock.connect((host, port))
# Use buffered I/O like Go's bufio
sock_writer = sock.makefile("wb", buffering=65536)
sock_reader = sock.makefile("rb", buffering=65536)
reader, writer = await asyncio.open_connection(host, port)
try:
decrypted_data = bytearray()
@@ -357,8 +345,8 @@ def decrypt_samples(
if last_desc_index != sample.desc_index:
if last_desc_index != 255:
# Send key switch signal
sock_writer.write(struct.pack("<I", 0))
sock_writer.flush()
writer.write(struct.pack("<I", 0))
await writer.drain()
# Send new key info
key_uri = keys[min(sample.desc_index, len(keys) - 1)]
@@ -367,13 +355,13 @@ def decrypt_samples(
id_bytes = b"0"
else:
id_bytes = track_id.encode("utf-8")
sock_writer.write(struct.pack("B", len(id_bytes)))
sock_writer.write(id_bytes)
writer.write(struct.pack("B", len(id_bytes)))
writer.write(id_bytes)
key_bytes = key_uri.encode("utf-8")
sock_writer.write(struct.pack("B", len(key_bytes)))
sock_writer.write(key_bytes)
sock_writer.flush()
writer.write(struct.pack("B", len(key_bytes)))
writer.write(key_bytes)
await writer.drain()
last_desc_index = sample.desc_index
@@ -383,12 +371,12 @@ def decrypt_samples(
if truncated_len > 0:
# Send size and data
sock_writer.write(struct.pack("<I", truncated_len))
sock_writer.write(sample.data[:truncated_len])
sock_writer.flush()
writer.write(struct.pack("<I", truncated_len))
writer.write(sample.data[:truncated_len])
await writer.drain()
# Read decrypted data
decrypted_sample = sock_reader.read(truncated_len)
decrypted_sample = await reader.readexactly(truncated_len)
if len(decrypted_sample) != truncated_len:
raise IOError(
f"Short read: got {len(decrypted_sample)}, expected {truncated_len}"
@@ -412,16 +400,15 @@ def decrypt_samples(
last_progress_time = now
# Send close signal
sock_writer.write(bytes([0, 0, 0, 0, 0]))
sock_writer.flush()
writer.write(bytes([0, 0, 0, 0, 0]))
await writer.drain()
logger.debug(f"Decrypted {len(samples)} samples ({len(decrypted_data)} bytes)")
return bytes(decrypted_data)
finally:
sock_writer.close()
sock_reader.close()
sock.close()
writer.close()
await writer.wait_closed()
def write_decrypted_m4a(
@@ -1063,8 +1050,7 @@ async def decrypt_file(
song_info = await asyncio.to_thread(extract_song, input_path)
# Decrypt samples via wrapper
decrypted_data = await asyncio.to_thread(
decrypt_samples,
decrypted_data = await decrypt_samples(
wrapper_ip,
track_id,
fairplay_key,