diff --git a/tiddl/cli/commands/auth.py b/tiddl/cli/commands/auth.py index 81dbcc4..1bb259e 100644 --- a/tiddl/cli/commands/auth.py +++ b/tiddl/cli/commands/auth.py @@ -6,6 +6,8 @@ from rich.console import Console from tiddl.cli.utils.auth.core import load_auth_data, save_auth_data, AuthData from tiddl.core.auth import AuthAPI, AuthClientError +from typing_extensions import Annotated + console = Console() auth_command = typer.Typer( @@ -13,6 +15,7 @@ auth_command = typer.Typer( ) +# TODO add context and load auth data from ctx @auth_command.command(help="Login with your Tidal account.") def login(): loaded_auth_data = load_auth_data() @@ -80,14 +83,30 @@ def logout(): @auth_command.command(help="Refreshes your token in app.") -def refresh(): +def refresh( + FORCE: Annotated[ + bool, + typer.Option( + "--force", "-f", help="Refresh token even when it is still valid." + ), + ] = False, + EARLY_EXPIRE_TIME: Annotated[ + int, + typer.Option( + "--early-expire", + "-e", + help="Time to expire the token earlier", + metavar="seconds", + ), + ] = 0, +): loaded_auth_data = load_auth_data() if loaded_auth_data.refresh_token is None: console.print("[bold red]Not logged in.") raise typer.Exit() - if time() < loaded_auth_data.expires_at: + if time() < (loaded_auth_data.expires_at - EARLY_EXPIRE_TIME) and not FORCE: expiry_time = datetime.fromtimestamp(loaded_auth_data.expires_at) remaining = expiry_time - datetime.now() hours, remainder = divmod(remaining.seconds, 3600) diff --git a/tiddl/cli/commands/download/__init__.py b/tiddl/cli/commands/download/__init__.py index 7afea45..4d60ae0 100644 --- a/tiddl/cli/commands/download/__init__.py +++ b/tiddl/cli/commands/download/__init__.py @@ -123,7 +123,7 @@ def download_callback( Download Tidal resources. """ - ctx.invoke(refresh) + ctx.invoke(refresh, EARLY_EXPIRE_TIME=600) log.debug(f"{ctx.params=}") diff --git a/tiddl/cli/ctx.py b/tiddl/cli/ctx.py index a036f9e..f400180 100644 --- a/tiddl/cli/ctx.py +++ b/tiddl/cli/ctx.py @@ -1,17 +1,20 @@ import typer +from time import time +from pathlib import Path from rich.console import Console -from pathlib import Path from tiddl.core.api import TidalClient, TidalAPI from tiddl.cli.config import APP_PATH -from tiddl.cli.utils.auth.core import load_auth_data +from tiddl.core.auth import AuthAPI +from tiddl.cli.utils.auth.core import load_auth_data, save_auth_data from tiddl.cli.utils.resource import TidalResource class ContextObject: console: Console resources: list[TidalResource] + auth_api: AuthAPI _api: TidalAPI | None api_omit_cache: bool debug_path: Path | None @@ -21,6 +24,7 @@ class ContextObject: ) -> None: self.console = console self.resources = [] + self.auth_api = AuthAPI() self._api = None self.api_omit_cache = api_omit_cache self.debug_path = debug_path @@ -36,11 +40,27 @@ class ContextObject: assert auth_data.user_id, "User ID is missing. Use `tiddl auth login`" assert auth_data.country_code, "Country Code is missing. Use `tiddl auth login`" + refresh_token = auth_data.refresh_token + assert refresh_token, "Refresh Token is missing. Use `tiddl auth login`" + + def on_token_expiry() -> str | None: + auth_response = self.auth_api.refresh_token(refresh_token) + auth_data.token = auth_response.access_token + auth_data.expires_at = auth_response.expires_in + int(time()) + + save_auth_data(auth_data=auth_data) + + if auth_response: + return auth_response.access_token + + return None + client = TidalClient( token=auth_data.token, cache_name=APP_PATH / "api_cache", omit_cache=self.api_omit_cache, debug_path=self.debug_path, + on_token_expiry=on_token_expiry, ) self._api = TidalAPI(client, auth_data.user_id, auth_data.country_code) diff --git a/tiddl/core/api/client.py b/tiddl/core/api/client.py index 40b5b19..8d7620a 100644 --- a/tiddl/core/api/client.py +++ b/tiddl/core/api/client.py @@ -1,7 +1,7 @@ import json from logging import getLogger from pathlib import Path -from typing import Any, Type, TypeVar +from typing import Any, Type, TypeVar, Callable, Optional from pydantic import BaseModel from time import sleep @@ -24,10 +24,13 @@ RETRY_DELAY = 2 log = getLogger(__name__) +# TODO add token expiry check +# maybe refactor to aiohttp.ClientSession class TidalClient: - token: str + _token: str debug_path: Path | None session: CachedSession + on_token_expiry: Optional[Callable[[], str | None]] def __init__( self, @@ -35,10 +38,10 @@ class TidalClient: cache_name: StrOrPath, omit_cache: bool = False, debug_path: Path | None = None, + on_token_expiry: Optional[Callable[[], str | None]] = None, ) -> None: - self.token = token + self.on_token_expiry = on_token_expiry self.debug_path = debug_path - self.session = CachedSession( cache_name=cache_name, always_revalidate=omit_cache ) @@ -46,6 +49,20 @@ class TidalClient: "Authorization": f"Bearer {token}", "Accept": "application/json", } + self._token = token + + @property + def token(self): + return self._token + + @token.setter + def token(self, token: str): + self._token = token + self.session.headers.update( + { + "Authorization": f"Bearer {token}", + } + ) def fetch( self, @@ -64,6 +81,20 @@ class TidalClient: f"{API_URL}/{endpoint}", params=params, expire_after=expire_after ) + if res.status_code == 401 and self.on_token_expiry: + token = self.on_token_expiry() + + if token: + self.token = token + + return self.fetch( + model=model, + endpoint=endpoint, + params=params, + expire_after=expire_after, + _attempt=MAX_RETRIES - 1, + ) + log.debug( f"{endpoint} {params} '{'HIT' if res.from_cache else 'MISS'}' [{res.status_code}]", )