diff --git a/tests/test_api.py b/tests/test_api.py index b949906..0923f23 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -8,13 +8,13 @@ class TestApi(unittest.TestCase): api: TidalApi def setUp(self): - config = Config() - auth = config.config["auth"] + config = Config.fromFile() + auth = config.auth token, user_id, country_code = ( - auth.get("token"), - auth.get("user_id"), - auth.get("country_code"), + auth.token, + auth.user_id, + auth.country_code ) assert token, "No token found in config file" diff --git a/tiddl/cfg.py b/tiddl/cfg.py deleted file mode 100644 index 99d6ef8..0000000 --- a/tiddl/cfg.py +++ /dev/null @@ -1,39 +0,0 @@ -from pydantic import BaseModel -from pathlib import Path - -from tiddl.models import TrackArg - - -CONFIG_PATH = Path.home() / "tiddl.json" -CONFIG_INDENT = 2 - - -class DownloadConfig(BaseModel): - quality: TrackArg = "high" - path: Path = Path.home() / "Music" / "Tiddl" - template: str = "{artist} - {title}" - - -class AuthConfig(BaseModel): - token: str = "" - refresh_token: str = "" - expires: int = 0 - user_id: str = "" - country_code: str = "" - - -class Config(BaseModel): - download: DownloadConfig = DownloadConfig() - auth: AuthConfig = AuthConfig() - - def save(self): - with open(CONFIG_PATH, "w") as f: - f.write(self.model_dump_json(indent=CONFIG_INDENT)) - - @classmethod - def fromFile(cls): - try: - with CONFIG_PATH.open() as f: - return Config.model_validate_json(f.read()) - except FileNotFoundError: - return Config() diff --git a/tiddl/cli/auth.py b/tiddl/cli/auth.py index fb56c31..51e374e 100644 --- a/tiddl/cli/auth.py +++ b/tiddl/cli/auth.py @@ -5,6 +5,7 @@ from click import style from time import sleep, time from tiddl.auth import getDeviceAuth, getToken, refreshToken, removeToken, AuthError +from tiddl.config import AuthConfig from .ctx import passContext, Context @@ -21,25 +22,17 @@ def AuthGroup(): def login(ctx: Context): """Add token to the config""" - access_token, refresh_token, expires = ( - ctx.obj.config.config["auth"].get("token"), - ctx.obj.config.config["auth"].get("refresh_token"), - ctx.obj.config.config["auth"].get("expires", 0), - ) + auth = ctx.obj.config.auth - if access_token: - if refresh_token and time() > expires: + if auth.token: + if auth.refresh_token and time() > auth.expires: click.echo(style("Refreshing token...", fg="yellow")) - token = refreshToken(refresh_token) + token = refreshToken(auth.refresh_token) - ctx.obj.config.update( - { - "auth": { - "expires": token.expires_in + int(time()), - "token": token.access_token, - } - } - ) + ctx.obj.config.auth.expires = token.expires_in + int(time()) + ctx.obj.config.auth.token = token.access_token + + ctx.obj.config.save() click.echo(style("Authenticated!", fg="green")) return @@ -70,18 +63,17 @@ def login(ctx: Context): ) break - ctx.obj.config.update( - { - "auth": { - "token": token.access_token, - "refresh_token": token.refresh_token, - "expires": token.expires_in + int(time()), - "user_id": str(token.user.userId), - "country_code": token.user.countryCode, - } - } + new_auth = AuthConfig( + token=token.access_token, + refresh_token=token.refresh_token, + expires=token.expires_in + int(time()), + user_id=str(token.user.userId), + country_code=token.user.countryCode, ) + ctx.obj.config.auth = new_auth + ctx.obj.config.save() + click.echo(style("\nAuthenticated!", fg="green")) break @@ -92,7 +84,7 @@ def login(ctx: Context): def logout(ctx: Context): """Remove token from config""" - access_token = ctx.obj.config.config["auth"].get("token") + access_token = ctx.obj.config.auth.token if not access_token: click.echo(style("Not logged in", fg="yellow")) @@ -100,16 +92,7 @@ def logout(ctx: Context): removeToken(access_token) - ctx.obj.config.update( - { - "auth": { - "country_code": "", - "expires": 0, - "refresh_token": "", - "token": "", - "user_id": "", - } - } - ) + ctx.obj.config.auth = AuthConfig() + ctx.obj.config.save() click.echo(style("Logged out!", fg="green")) diff --git a/tiddl/cli/ctx.py b/tiddl/cli/ctx.py index 25b444a..d641a0e 100644 --- a/tiddl/cli/ctx.py +++ b/tiddl/cli/ctx.py @@ -7,27 +7,21 @@ from tiddl.api import TidalApi from tiddl.config import Config from tiddl.utils import TidalResource + class ContextObj: api: TidalApi | None config: Config resources: list[TidalResource] - def __init__(self) -> None: - self.config = Config() + self.config = Config.fromFile() self.resources = [] self.api = None - config_auth = self.config.config["auth"] + auth = self.config.auth - token, user_id, country_code = ( - config_auth.get("token"), - config_auth.get("user_id"), - config_auth.get("country_code"), - ) - - if token and user_id and country_code: - self.api = TidalApi(token, user_id, country_code) + if auth.token and auth.user_id and auth.country_code: + self.api = TidalApi(auth.token, auth.user_id, auth.country_code) def getApi(self) -> TidalApi: if self.api is None: diff --git a/tiddl/cli/download/__init__.py b/tiddl/cli/download/__init__.py index 88d741d..787a6aa 100644 --- a/tiddl/cli/download/__init__.py +++ b/tiddl/cli/download/__init__.py @@ -71,11 +71,8 @@ def DownloadCommand(ctx: Context, quality: TrackArg, output: str): click.echo("No tracks found.") return - download_quality = ARG_TO_QUALITY[ - quality or ctx.obj.config.config["download"]["quality"] - ] - - template = output or ctx.obj.config.config["download"].get("template", "") + download_quality = ARG_TO_QUALITY[quality or ctx.obj.config.download.quality] + template = output or ctx.obj.config.download.template for track in track_collector.tracks: click.echo(f"Downloading {track.title}") diff --git a/tiddl/config.py b/tiddl/config.py index d00351c..99d6ef8 100644 --- a/tiddl/config.py +++ b/tiddl/config.py @@ -1,85 +1,39 @@ -import json - -from dataclasses import dataclass, field -from typing import TypedDict +from pydantic import BaseModel from pathlib import Path from tiddl.models import TrackArg CONFIG_PATH = Path.home() / "tiddl.json" -DOWNLOAD_PATH = Path.home() / "Music" / "Tiddl" -DEFAULT_QUALITY: TrackArg = "high" +CONFIG_INDENT = 2 -class DownloadConfig(TypedDict, total=False): - quality: TrackArg - path: str - template: str +class DownloadConfig(BaseModel): + quality: TrackArg = "high" + path: Path = Path.home() / "Music" / "Tiddl" + template: str = "{artist} - {title}" -class AuthConfig(TypedDict, total=False): - token: str - refresh_token: str - expires: int - user_id: str - country_code: str +class AuthConfig(BaseModel): + token: str = "" + refresh_token: str = "" + expires: int = 0 + user_id: str = "" + country_code: str = "" -class ConfigFile(TypedDict): - download: DownloadConfig - auth: AuthConfig - - -class ConfigUpdate(TypedDict, total=False): - download: DownloadConfig - auth: AuthConfig - - -DEFAULT_CONFIG: ConfigFile = { - "download": {"quality": DEFAULT_QUALITY, "path": str(DOWNLOAD_PATH), "template": "{artist} - {track}"}, - "auth": {"token": "", "refresh_token": "", "expires": 0, "country_code": "", "user_id": ""}, -} - - -@dataclass -class Config: - """Configuration class for loading and updating CLI configuration file.""" - - config: ConfigFile = field(default_factory=lambda: DEFAULT_CONFIG) - - def __post_init__(self): - """Merge loaded configuration with defaults after initialization.""" - - try: - with open(CONFIG_PATH, "r") as f: - loaded_config: ConfigFile = json.load(f) - - self.config = merge(loaded_config, self.config) - - except (FileNotFoundError, json.JSONDecodeError): - pass - - def update(self, new_config: ConfigUpdate): - """Update the configuration with the new values and save it to the file.""" - - self.config = merge(new_config, self.config) +class Config(BaseModel): + download: DownloadConfig = DownloadConfig() + auth: AuthConfig = AuthConfig() + def save(self): with open(CONFIG_PATH, "w") as f: - json.dump(self.config, f, indent=2) + f.write(self.model_dump_json(indent=CONFIG_INDENT)) - -def merge(source, destination): - """ - Recursively merge two dictionaries. - https://stackoverflow.com/a/20666342 - """ - - for key, value in source.items(): - if isinstance(value, dict): - node = destination.setdefault(key, {}) - merge(value, node) - else: - destination[key] = value - - return destination + @classmethod + def fromFile(cls): + try: + with CONFIG_PATH.open() as f: + return Config.model_validate_json(f.read()) + except FileNotFoundError: + return Config()