mirror of
https://github.com/oskvr37/tiddl.git
synced 2026-06-13 04:05:08 +03:00
✨ Auth Token is now refreshed mid-request (#213)
This commit is contained in:
@@ -6,6 +6,8 @@ from rich.console import Console
|
||||
from tiddl.cli.utils.auth.core import load_auth_data, save_auth_data, AuthData
|
||||
from tiddl.core.auth import AuthAPI, AuthClientError
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
console = Console()
|
||||
|
||||
auth_command = typer.Typer(
|
||||
@@ -13,6 +15,7 @@ auth_command = typer.Typer(
|
||||
)
|
||||
|
||||
|
||||
# TODO add context and load auth data from ctx
|
||||
@auth_command.command(help="Login with your Tidal account.")
|
||||
def login():
|
||||
loaded_auth_data = load_auth_data()
|
||||
@@ -80,14 +83,30 @@ def logout():
|
||||
|
||||
|
||||
@auth_command.command(help="Refreshes your token in app.")
|
||||
def refresh():
|
||||
def refresh(
|
||||
FORCE: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--force", "-f", help="Refresh token even when it is still valid."
|
||||
),
|
||||
] = False,
|
||||
EARLY_EXPIRE_TIME: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
"--early-expire",
|
||||
"-e",
|
||||
help="Time to expire the token earlier",
|
||||
metavar="seconds",
|
||||
),
|
||||
] = 0,
|
||||
):
|
||||
loaded_auth_data = load_auth_data()
|
||||
|
||||
if loaded_auth_data.refresh_token is None:
|
||||
console.print("[bold red]Not logged in.")
|
||||
raise typer.Exit()
|
||||
|
||||
if time() < loaded_auth_data.expires_at:
|
||||
if time() < (loaded_auth_data.expires_at - EARLY_EXPIRE_TIME) and not FORCE:
|
||||
expiry_time = datetime.fromtimestamp(loaded_auth_data.expires_at)
|
||||
remaining = expiry_time - datetime.now()
|
||||
hours, remainder = divmod(remaining.seconds, 3600)
|
||||
|
||||
@@ -123,7 +123,7 @@ def download_callback(
|
||||
Download Tidal resources.
|
||||
"""
|
||||
|
||||
ctx.invoke(refresh)
|
||||
ctx.invoke(refresh, EARLY_EXPIRE_TIME=600)
|
||||
|
||||
log.debug(f"{ctx.params=}")
|
||||
|
||||
|
||||
+22
-2
@@ -1,17 +1,20 @@
|
||||
import typer
|
||||
from time import time
|
||||
from pathlib import Path
|
||||
|
||||
from rich.console import Console
|
||||
from pathlib import Path
|
||||
|
||||
from tiddl.core.api import TidalClient, TidalAPI
|
||||
from tiddl.cli.config import APP_PATH
|
||||
from tiddl.cli.utils.auth.core import load_auth_data
|
||||
from tiddl.core.auth import AuthAPI
|
||||
from tiddl.cli.utils.auth.core import load_auth_data, save_auth_data
|
||||
from tiddl.cli.utils.resource import TidalResource
|
||||
|
||||
|
||||
class ContextObject:
|
||||
console: Console
|
||||
resources: list[TidalResource]
|
||||
auth_api: AuthAPI
|
||||
_api: TidalAPI | None
|
||||
api_omit_cache: bool
|
||||
debug_path: Path | None
|
||||
@@ -21,6 +24,7 @@ class ContextObject:
|
||||
) -> None:
|
||||
self.console = console
|
||||
self.resources = []
|
||||
self.auth_api = AuthAPI()
|
||||
self._api = None
|
||||
self.api_omit_cache = api_omit_cache
|
||||
self.debug_path = debug_path
|
||||
@@ -36,11 +40,27 @@ class ContextObject:
|
||||
assert auth_data.user_id, "User ID is missing. Use `tiddl auth login`"
|
||||
assert auth_data.country_code, "Country Code is missing. Use `tiddl auth login`"
|
||||
|
||||
refresh_token = auth_data.refresh_token
|
||||
assert refresh_token, "Refresh Token is missing. Use `tiddl auth login`"
|
||||
|
||||
def on_token_expiry() -> str | None:
|
||||
auth_response = self.auth_api.refresh_token(refresh_token)
|
||||
auth_data.token = auth_response.access_token
|
||||
auth_data.expires_at = auth_response.expires_in + int(time())
|
||||
|
||||
save_auth_data(auth_data=auth_data)
|
||||
|
||||
if auth_response:
|
||||
return auth_response.access_token
|
||||
|
||||
return None
|
||||
|
||||
client = TidalClient(
|
||||
token=auth_data.token,
|
||||
cache_name=APP_PATH / "api_cache",
|
||||
omit_cache=self.api_omit_cache,
|
||||
debug_path=self.debug_path,
|
||||
on_token_expiry=on_token_expiry,
|
||||
)
|
||||
|
||||
self._api = TidalAPI(client, auth_data.user_id, auth_data.country_code)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Any, Type, TypeVar
|
||||
from typing import Any, Type, TypeVar, Callable, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from time import sleep
|
||||
@@ -24,10 +24,13 @@ RETRY_DELAY = 2
|
||||
log = getLogger(__name__)
|
||||
|
||||
|
||||
# TODO add token expiry check
|
||||
# maybe refactor to aiohttp.ClientSession
|
||||
class TidalClient:
|
||||
token: str
|
||||
_token: str
|
||||
debug_path: Path | None
|
||||
session: CachedSession
|
||||
on_token_expiry: Optional[Callable[[], str | None]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -35,10 +38,10 @@ class TidalClient:
|
||||
cache_name: StrOrPath,
|
||||
omit_cache: bool = False,
|
||||
debug_path: Path | None = None,
|
||||
on_token_expiry: Optional[Callable[[], str | None]] = None,
|
||||
) -> None:
|
||||
self.token = token
|
||||
self.on_token_expiry = on_token_expiry
|
||||
self.debug_path = debug_path
|
||||
|
||||
self.session = CachedSession(
|
||||
cache_name=cache_name, always_revalidate=omit_cache
|
||||
)
|
||||
@@ -46,6 +49,20 @@ class TidalClient:
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
self._token = token
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
return self._token
|
||||
|
||||
@token.setter
|
||||
def token(self, token: str):
|
||||
self._token = token
|
||||
self.session.headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {token}",
|
||||
}
|
||||
)
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
@@ -64,6 +81,20 @@ class TidalClient:
|
||||
f"{API_URL}/{endpoint}", params=params, expire_after=expire_after
|
||||
)
|
||||
|
||||
if res.status_code == 401 and self.on_token_expiry:
|
||||
token = self.on_token_expiry()
|
||||
|
||||
if token:
|
||||
self.token = token
|
||||
|
||||
return self.fetch(
|
||||
model=model,
|
||||
endpoint=endpoint,
|
||||
params=params,
|
||||
expire_after=expire_after,
|
||||
_attempt=MAX_RETRIES - 1,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"{endpoint} {params} '{'HIT' if res.from_cache else 'MISS'}' [{res.status_code}]",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user