diff --git a/gamdl/cli/config_file.py b/gamdl/cli/config_file.py index 51c0612..64d966c 100644 --- a/gamdl/cli/config_file.py +++ b/gamdl/cli/config_file.py @@ -1,25 +1,15 @@ import configparser import typing -from dataclasses import dataclass from pathlib import Path -from typing import get_type_hints import click import click.types as click_types -from dataclass_click.dataclass_click import _DelayedCall from .cli_config import CliConfig from .constants import EXCLUDED_CONFIG_FILE_PARAMS from .utils import Csv -@dataclass -class ParameterInfo: - name: str - default: typing.Any - type: typing.Any - - class ConfigFile: def __init__( self, @@ -28,34 +18,10 @@ class ConfigFile: ) -> None: self.config_path = config_path self.section_name = section_name - self.parameters = self._extract_parameters_from_cli_config() + self.click_context = click.get_current_context() self._read_config_file() - def _extract_parameters_from_cli_config(self) -> dict[str, ParameterInfo]: - parameters = {} - hints = get_type_hints(CliConfig, include_extras=True) - - for field_name, hint in hints.items(): - if hasattr(hint, "__metadata__"): - for metadata in hint.__metadata__: - if isinstance(metadata, _DelayedCall): - param_type = metadata.kwargs.get("type") - if param_type is None: - raise ValueError( - f"Parameter type for field '{field_name}' " - "could not be determined." - ) - - parameters[field_name] = ParameterInfo( - name=field_name, - default=metadata.kwargs.get("default"), - type=param_type, - ) - break - - return parameters - def _read_config_file(self) -> None: self.config = configparser.ConfigParser(interpolation=None) @@ -71,81 +37,81 @@ class ConfigFile: with open(self.config_path, "w", encoding="utf-8") as config_file: self.config.write(config_file) - def _serialize_param_default(self, param_info: ParameterInfo) -> str: - if param_info.default is None: + def _serialize_param_default(self, param: click.Parameter) -> str: + if param.default is None: return "null" - if isinstance(param_info.type, Csv): + if isinstance(param.type, Csv): return ",".join( item.value if hasattr(item, "value") else str(item) - for item in param_info.default + for item in param.default ) - if isinstance(param_info.type, click_types.FuncParamType): - return param_info.default.value + if isinstance(param.type, click_types.FuncParamType): + return param.default.value - if isinstance(param_info.type, click_types.BoolParamType): - return "true" if param_info.default else "false" + if isinstance(param.type, click_types.BoolParamType): + return "true" if param.default else "false" if isinstance( - param_info.type, + param.type, click_types.Choice | click_types.Path | click_types.StringParamType | click_types.IntParamType, ): - return str(param_info.default) + return str(param.default) raise NotImplementedError( - f"Serialization for parameter '{param_info.name}' of type " - f"'{type(param_info.type)}' is not implemented." + f"Serialization for parameter '{param.name}' of type " + f"'{type(param.type)}' is not implemented." ) def _add_param_default_to_config( self, - param_info: ParameterInfo, + param: click.Parameter, ) -> bool: - if self.config.has_option(self.section_name, param_info.name): + if self.config.has_option(self.section_name, param.name): return False - value = self._serialize_param_default(param_info) - self.config.set(self.section_name, param_info.name, value) + value = self._serialize_param_default(param) + self.config.set(self.section_name, param.name, value) return True def _parse_param_from_config( self, - param_info: ParameterInfo, + param: click.Parameter, ) -> typing.Any: - value = self.config[self.section_name].get(param_info.name) + value = self.config[self.section_name].get(param.name) if value is None: - return param_info.default + return param.default if value == "null": return None - if not isinstance(param_info.type, click_types.ParamType): + if not isinstance(param.type, click_types.ParamType): raise NotImplementedError( - f"Parsing for parameter '{param_info.name}' of type " - f"'{type(param_info.type)}' is not implemented." + f"Parsing for parameter '{param.name}' of type " + f"'{type(param.type)}' is not implemented." ) - return param_info.type.convert(value, None, None) + return param.type.convert(value, None, None) def add_params_default_to_config(self) -> None: has_changes = False - for param_info in self.parameters.values(): - if param_info.name in EXCLUDED_CONFIG_FILE_PARAMS: + for param in self.click_context.command.params: + if param.name in EXCLUDED_CONFIG_FILE_PARAMS: continue - has_changes = self._add_param_default_to_config(param_info) or has_changes + has_changes = self._add_param_default_to_config(param) or has_changes if has_changes: self._write_config_file() def cleanup_unknown_params(self) -> None: - param_names = {info.name for info in self.parameters.values()} + param_names = {info.name for info in self.click_context.command.params} has_changes = False for key in list(self.config[self.section_name].keys()): @@ -158,16 +124,16 @@ class ConfigFile: def update_params_from_config(self, config: CliConfig) -> CliConfig: updates = {} - click_context = click.get_current_context() - for param_info in self.parameters.values(): + + for param in self.click_context.command.params: if ( - click_context.get_parameter_source(param_info.name) + self.click_context.get_parameter_source(param.name) == click.core.ParameterSource.COMMANDLINE ): continue - if self.config.has_option(self.section_name, param_info.name): - updates[param_info.name] = self._parse_param_from_config(param_info) + if self.config.has_option(self.section_name, param.name): + updates[param.name] = self._parse_param_from_config(param) config_dict = config.__dict__.copy() config_dict.update(updates)