diff --git a/gamdl/interface/interface.py b/gamdl/interface/interface.py index fe29b95..097c0fc 100644 --- a/gamdl/interface/interface.py +++ b/gamdl/interface/interface.py @@ -86,6 +86,14 @@ class AppleMusicInterface: media.media_id, ) + async def _collect_generator( + self, generator_or_coroutine: AsyncGenerator[AppleMusicMedia, None] + ) -> list[AppleMusicMedia]: + results = [] + async for result in generator_or_coroutine: + results.append(result) + return results + async def _get_song_media( self, media_id: str, @@ -218,9 +226,16 @@ class AppleMusicInterface: for index, track in enumerate(tracks) ] - for task in tasks: - async for media in task: - yield media + if self.concurrency == 1: + for task in tasks: + async for media in task: + yield media + else: + collected_tasks = [self._collect_generator(task) for task in tasks] + batches = await safe_gather(*collected_tasks, limit=self.concurrency) + for batch in batches: + for media in batch: + yield media async def _get_playlist_media( self, @@ -282,9 +297,16 @@ class AppleMusicInterface: for index, track in enumerate(tracks) ] - for task in tasks: - async for media in task: - yield media + if self.concurrency == 1: + for task in tasks: + async for media in task: + yield media + else: + collected_tasks = [self._collect_generator(task) for task in tasks] + batches = await safe_gather(*collected_tasks, limit=self.concurrency) + for batch in batches: + for media in batch: + yield media async def _get_artist_media( self, @@ -380,9 +402,16 @@ class AppleMusicInterface: ) ) - for task in tasks: - async for media in task: - yield media + if self.concurrency == 1: + for task in tasks: + async for media in task: + yield media + else: + collected_tasks = [self._collect_generator(task) for task in tasks] + batches = await safe_gather(*collected_tasks, limit=self.concurrency) + for batch in batches: + for media in batch: + yield media async def get_media_from_url( self,