|
| 1 | +import argparse |
1 | 2 | import os |
2 | 3 | from contextlib import closing |
3 | | -from dataclasses import dataclass |
| 4 | +from dataclasses import dataclass, field |
4 | 5 |
|
5 | 6 | import mariadb |
6 | 7 | from dotenv import load_dotenv |
|
16 | 17 | READ_ONLY_KEYWORD_NAMES = ", ".join(READ_ONLY_KEYWORDS) |
17 | 18 |
|
18 | 19 |
|
| 20 | +def get_arguments() -> dict: |
| 21 | + """Parse command-line arguments and return as a dictionary.""" |
| 22 | + parser = argparse.ArgumentParser(description="MariaDB Configuration") |
| 23 | + parser.add_argument("--host", help="MariaDB host") |
| 24 | + parser.add_argument("--port", type=int, help="MariaDB port") |
| 25 | + parser.add_argument("--user", help="MariaDB user") |
| 26 | + parser.add_argument("--password", help="MariaDB password") |
| 27 | + parser.add_argument("--database", help="MariaDB database") |
| 28 | + args = parser.parse_args() |
| 29 | + |
| 30 | + return {k: v for k, v in vars(args).items() if v is not None} |
| 31 | + |
| 32 | + |
19 | 33 | @dataclass |
20 | 34 | class DBconfig: |
21 | | - host: str = os.getenv("MARIADB_HOST", "localhost") |
22 | | - port: int = int(os.getenv("MARIADB_PORT", "3306")) |
23 | | - user: str = os.getenv("MARIADB_USER", "") |
24 | | - password: str = os.getenv("MARIADB_PASSWORD", "") |
25 | | - database: str = os.getenv("MARIADB_DATABASE", "") |
| 35 | + """Database configuration""" |
| 36 | + |
| 37 | + host: str = field(default_factory=lambda: os.getenv("MARIADB_HOST", "localhost")) |
| 38 | + port: int = field(default_factory=lambda: int(os.getenv("MARIADB_PORT", "3306"))) |
| 39 | + user: str = field(default_factory=lambda: os.getenv("MARIADB_USER", "")) |
| 40 | + password: str = field(default_factory=lambda: os.getenv("MARIADB_PASSWORD", "")) |
| 41 | + database: str = field(default_factory=lambda: os.getenv("MARIADB_DATABASE", "")) |
| 42 | + |
| 43 | + @classmethod |
| 44 | + def from_args(cls) -> "DBconfig": |
| 45 | + """Create a DBconfig instance from command-line arguments and environment variables.""" |
| 46 | + cli_args = get_arguments() |
| 47 | + return cls(**{**cls().__dict__, **cli_args}) |
26 | 48 |
|
27 | 49 |
|
28 | 50 | def get_connection(): |
29 | 51 | """Create a connection to the database connection""" |
30 | | - |
31 | | - config = DBconfig() |
| 52 | + config = DBconfig.from_args() |
32 | 53 |
|
33 | 54 | try: |
34 | 55 | conn = mariadb.connect( |
|
0 commit comments