Refactor config parameter handling to use Click params

This commit is contained in:
Rafael Moraes
2026-01-16 22:49:45 -03:00
parent 073f70afa7
commit 7e132c27de
+33 -67
View File
@@ -1,25 +1,15 @@
import configparser
import typing
from dataclasses import dataclass
from pathlib import Path
from typing import get_type_hints
import click
import click.types as click_types
from dataclass_click.dataclass_click import _DelayedCall
from .cli_config import CliConfig
from .constants import EXCLUDED_CONFIG_FILE_PARAMS
from .utils import Csv
@dataclass
class ParameterInfo:
name: str
default: typing.Any
type: typing.Any
class ConfigFile:
def __init__(
self,
@@ -28,34 +18,10 @@ class ConfigFile:
) -> None:
self.config_path = config_path
self.section_name = section_name
self.parameters = self._extract_parameters_from_cli_config()
self.click_context = click.get_current_context()
self._read_config_file()
def _extract_parameters_from_cli_config(self) -> dict[str, ParameterInfo]:
parameters = {}
hints = get_type_hints(CliConfig, include_extras=True)
for field_name, hint in hints.items():
if hasattr(hint, "__metadata__"):
for metadata in hint.__metadata__:
if isinstance(metadata, _DelayedCall):
param_type = metadata.kwargs.get("type")
if param_type is None:
raise ValueError(
f"Parameter type for field '{field_name}' "
"could not be determined."
)
parameters[field_name] = ParameterInfo(
name=field_name,
default=metadata.kwargs.get("default"),
type=param_type,
)
break
return parameters
def _read_config_file(self) -> None:
self.config = configparser.ConfigParser(interpolation=None)
@@ -71,81 +37,81 @@ class ConfigFile:
with open(self.config_path, "w", encoding="utf-8") as config_file:
self.config.write(config_file)
def _serialize_param_default(self, param_info: ParameterInfo) -> str:
if param_info.default is None:
def _serialize_param_default(self, param: click.Parameter) -> str:
if param.default is None:
return "null"
if isinstance(param_info.type, Csv):
if isinstance(param.type, Csv):
return ",".join(
item.value if hasattr(item, "value") else str(item)
for item in param_info.default
for item in param.default
)
if isinstance(param_info.type, click_types.FuncParamType):
return param_info.default.value
if isinstance(param.type, click_types.FuncParamType):
return param.default.value
if isinstance(param_info.type, click_types.BoolParamType):
return "true" if param_info.default else "false"
if isinstance(param.type, click_types.BoolParamType):
return "true" if param.default else "false"
if isinstance(
param_info.type,
param.type,
click_types.Choice
| click_types.Path
| click_types.StringParamType
| click_types.IntParamType,
):
return str(param_info.default)
return str(param.default)
raise NotImplementedError(
f"Serialization for parameter '{param_info.name}' of type "
f"'{type(param_info.type)}' is not implemented."
f"Serialization for parameter '{param.name}' of type "
f"'{type(param.type)}' is not implemented."
)
def _add_param_default_to_config(
self,
param_info: ParameterInfo,
param: click.Parameter,
) -> bool:
if self.config.has_option(self.section_name, param_info.name):
if self.config.has_option(self.section_name, param.name):
return False
value = self._serialize_param_default(param_info)
self.config.set(self.section_name, param_info.name, value)
value = self._serialize_param_default(param)
self.config.set(self.section_name, param.name, value)
return True
def _parse_param_from_config(
self,
param_info: ParameterInfo,
param: click.Parameter,
) -> typing.Any:
value = self.config[self.section_name].get(param_info.name)
value = self.config[self.section_name].get(param.name)
if value is None:
return param_info.default
return param.default
if value == "null":
return None
if not isinstance(param_info.type, click_types.ParamType):
if not isinstance(param.type, click_types.ParamType):
raise NotImplementedError(
f"Parsing for parameter '{param_info.name}' of type "
f"'{type(param_info.type)}' is not implemented."
f"Parsing for parameter '{param.name}' of type "
f"'{type(param.type)}' is not implemented."
)
return param_info.type.convert(value, None, None)
return param.type.convert(value, None, None)
def add_params_default_to_config(self) -> None:
has_changes = False
for param_info in self.parameters.values():
if param_info.name in EXCLUDED_CONFIG_FILE_PARAMS:
for param in self.click_context.command.params:
if param.name in EXCLUDED_CONFIG_FILE_PARAMS:
continue
has_changes = self._add_param_default_to_config(param_info) or has_changes
has_changes = self._add_param_default_to_config(param) or has_changes
if has_changes:
self._write_config_file()
def cleanup_unknown_params(self) -> None:
param_names = {info.name for info in self.parameters.values()}
param_names = {info.name for info in self.click_context.command.params}
has_changes = False
for key in list(self.config[self.section_name].keys()):
@@ -158,16 +124,16 @@ class ConfigFile:
def update_params_from_config(self, config: CliConfig) -> CliConfig:
updates = {}
click_context = click.get_current_context()
for param_info in self.parameters.values():
for param in self.click_context.command.params:
if (
click_context.get_parameter_source(param_info.name)
self.click_context.get_parameter_source(param.name)
== click.core.ParameterSource.COMMANDLINE
):
continue
if self.config.has_option(self.section_name, param_info.name):
updates[param_info.name] = self._parse_param_from_config(param_info)
if self.config.has_option(self.section_name, param.name):
updates[param.name] = self._parse_param_from_config(param)
config_dict = config.__dict__.copy()
config_dict.update(updates)