Skip to content

Commit

Permalink
Better Database Connection Parameter Management (#182)
Browse files Browse the repository at this point in the history
- Remove database connection parameters from `dbos-config.yaml` by
default (though they can still be there)
- Instead, read and store database connection parameters in a hidden
local file, `.dbos/db_connection`, which is **not committed to version
control**.
- If no connection parameters are given, use local connection defaults.
- Clearly print where database connection parameters are being read from
at DBOS startup.
  • Loading branch information
kraftp authored Jan 22, 2025
1 parent 04bc3a1 commit dc5c4c6
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 44 deletions.
58 changes: 48 additions & 10 deletions dbos/_db_wizard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import os
import time
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, TypedDict

import docker # type: ignore
import typer
Expand All @@ -15,8 +17,18 @@
from ._error import DBOSInitializationError
from ._logger import dbos_logger

DB_CONNECTION_PATH = os.path.join(".dbos", "db_connection")

def db_connect(config: "ConfigFile", config_file_path: str) -> "ConfigFile":

class DatabaseConnection(TypedDict):
hostname: Optional[str]
port: Optional[int]
username: Optional[str]
password: Optional[str]
local_suffix: Optional[bool]


def db_wizard(config: "ConfigFile", config_file_path: str) -> "ConfigFile":
# 1. Check the connectivity to the database. Return if successful. If cannot connect, continue to the following steps.
db_connection_error = _check_db_connectivity(config)
if db_connection_error is None:
Expand Down Expand Up @@ -82,17 +94,20 @@ def db_connect(config: "ConfigFile", config_file_path: str) -> "ConfigFile":
f"Could not connect to the database. Exception: {db_connection_error}"
)

# 6. Save the config to the config file and return the updated config.
# TODO: make the config file prettier
with open(config_file_path, "w") as file:
file.write(yaml.dump(config))

# 6. Save the config to the database connection file
updated_connection = DatabaseConnection(
hostname=config["database"]["hostname"],
port=config["database"]["port"],
username=config["database"]["username"],
password=config["database"]["password"],
local_suffix=config["database"]["local_suffix"],
)
save_db_connection(updated_connection)
return config


def _start_docker_postgres(config: "ConfigFile") -> bool:
print("Starting a Postgres Docker container...")
config["database"]["password"] = "dbos"
client = docker.from_env()
pg_data = "/var/lib/postgresql/data"
container_name = "dbos-db"
Expand Down Expand Up @@ -122,7 +137,7 @@ def _start_docker_postgres(config: "ConfigFile") -> bool:
continue
print("[green]Postgres Docker container started successfully![/green]")
break
except Exception as e:
except:
attempts -= 1
time.sleep(1)

Expand Down Expand Up @@ -151,7 +166,7 @@ def _check_db_connectivity(config: "ConfigFile") -> Optional[Exception]:
host=config["database"]["hostname"],
port=config["database"]["port"],
database="postgres",
query={"connect_timeout": "2"},
query={"connect_timeout": "1"},
)
postgres_db_engine = create_engine(postgres_db_url)
try:
Expand All @@ -168,3 +183,26 @@ def _check_db_connectivity(config: "ConfigFile") -> Optional[Exception]:
postgres_db_engine.dispose()

return None


def load_db_connection() -> DatabaseConnection:
try:
with open(DB_CONNECTION_PATH, "r") as f:
data = json.load(f)
return DatabaseConnection(
hostname=data.get("hostname", None),
port=data.get("port", None),
username=data.get("username", None),
password=data.get("password", None),
local_suffix=data.get("local_suffix", None),
)
except:
return DatabaseConnection(
hostname=None, port=None, username=None, password=None, local_suffix=None
)


def save_db_connection(connection: DatabaseConnection) -> None:
os.makedirs(".dbos", exist_ok=True)
with open(DB_CONNECTION_PATH, "w") as f:
json.dump(connection, f)
59 changes: 52 additions & 7 deletions dbos/_dbos_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@

import yaml
from jsonschema import ValidationError, validate
from rich import print
from sqlalchemy import URL

from ._db_wizard import db_connect
from ._db_wizard import db_wizard, load_db_connection
from ._error import DBOSInitializationError
from ._logger import config_logger, dbos_logger, init_logger

DBOS_CONFIG_PATH = "dbos-config.yaml"


class RuntimeConfig(TypedDict, total=False):
start: List[str]
Expand All @@ -23,7 +26,7 @@ class DatabaseConfig(TypedDict, total=False):
hostname: str
port: int
username: str
password: Optional[str]
password: str
connectionTimeoutMillis: Optional[int]
app_db_name: str
sys_db_name: Optional[str]
Expand Down Expand Up @@ -93,7 +96,7 @@ def replace_func(match: re.Match[str]) -> str:
return re.sub(regex, replace_func, content)


def get_dbos_database_url(config_file_path: str = "dbos-config.yaml") -> str:
def get_dbos_database_url(config_file_path: str = DBOS_CONFIG_PATH) -> str:
"""
Retrieve application database URL from configuration `.yaml` file.
Expand All @@ -119,7 +122,9 @@ def get_dbos_database_url(config_file_path: str = "dbos-config.yaml") -> str:
return db_url.render_as_string(hide_password=False)


def load_config(config_file_path: str = "dbos-config.yaml") -> ConfigFile:
def load_config(
config_file_path: str = DBOS_CONFIG_PATH, *, use_db_wizard: bool = True
) -> ConfigFile:
"""
Load the DBOS `ConfigFile` from the specified path (typically `dbos-config.yaml`).
Expand Down Expand Up @@ -151,6 +156,9 @@ def load_config(config_file_path: str = "dbos-config.yaml") -> ConfigFile:
except ValidationError as e:
raise DBOSInitializationError(f"Validation error: {e}")

if "database" not in data:
data["database"] = {}

if "name" not in data:
raise DBOSInitializationError(
f"dbos-config.yaml must specify an application name"
Expand All @@ -169,8 +177,6 @@ def load_config(config_file_path: str = "dbos-config.yaml") -> ConfigFile:
if "runtimeConfig" not in data or "start" not in data["runtimeConfig"]:
raise DBOSInitializationError(f"dbos-config.yaml must specify a start command")

data = cast(ConfigFile, data)

if not _is_valid_app_name(data["name"]):
raise DBOSInitializationError(
f'Invalid app name {data["name"]}. App names must be between 3 and 30 characters long and contain only lowercase letters, numbers, dashes, and underscores.'
Expand All @@ -179,10 +185,49 @@ def load_config(config_file_path: str = "dbos-config.yaml") -> ConfigFile:
if "app_db_name" not in data["database"]:
data["database"]["app_db_name"] = _app_name_to_db_name(data["name"])

# Load the DB connection file. Use its values for missing fields from dbos-config.yaml. Use defaults otherwise.
data = cast(ConfigFile, data)
db_connection = load_db_connection()
if data["database"].get("hostname"):
print(
"[bold blue]Loading database connection parameters from dbos-config.yaml[/bold blue]"
)
elif db_connection.get("hostname"):
print(
"[bold blue]Loading database connection parameters from .dbos/db_connection[/bold blue]"
)
else:
print(
"[bold blue]Using default database connection parameters (localhost)[/bold blue]"
)

data["database"]["hostname"] = (
data["database"].get("hostname") or db_connection.get("hostname") or "localhost"
)
data["database"]["port"] = (
data["database"].get("port") or db_connection.get("port") or 5432
)
data["database"]["username"] = (
data["database"].get("username") or db_connection.get("username") or "postgres"
)
data["database"]["password"] = (
data["database"].get("password")
or db_connection.get("password")
or os.environ.get("PGPASSWORD")
or "dbos"
)
data["database"]["local_suffix"] = (
data["database"].get("local_suffix")
or db_connection.get("local_suffix")
or False
)

# Configure the DBOS logger
config_logger(data)

# Check the connectivity to the database and make sure it's properly configured
data = db_connect(data, config_file_path)
if use_db_wizard:
data = db_wizard(data, config_file_path)

if "local_suffix" in data["database"] and data["database"]["local_suffix"]:
data["database"]["app_db_name"] = f"{data['database']['app_db_name']}_local"
Expand Down
4 changes: 0 additions & 4 deletions dbos/_templates/hello/dbos-config.yaml.dbos
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ runtimeConfig:
start:
- "fastapi run ${package_name}/main.py"
database:
hostname: localhost
port: 5432
username: postgres
password: ${PGPASSWORD}
migrate:
- ${migration_command}
telemetry:
Expand Down
13 changes: 2 additions & 11 deletions dbos/dbos-config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,7 @@
"type": "array",
"description": "Specify a list of user DB rollback commands to run"
}
},
"required": [
"hostname",
"port",
"username",
"password"
]
}
},
"telemetry": {
"type": "object",
Expand Down Expand Up @@ -181,9 +175,6 @@
"type": "string",
"deprecated": true
}
},
"required": [
"database"
]
}
}

85 changes: 73 additions & 12 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@
original_open = __builtins__["open"]


def generate_mock_open(filename, mock_data):
def generate_mock_open(filenames, mock_files):
if not isinstance(filenames, list):
filenames = [filenames]
if not isinstance(mock_files, list):
mock_files = [mock_files]

def conditional_mock_open(*args, **kwargs):
if args[0] == filename:
m = mock_open(read_data=mock_data)
return m()
else:
return original_open(*args, **kwargs)
for filename, mock_file in zip(filenames, mock_files):
if args[0] == filename:
m = mock_open(read_data=mock_file)
return m()
return original_open(*args, **kwargs)

return conditional_mock_open

Expand Down Expand Up @@ -95,23 +100,79 @@ def test_valid_config_without_appdbname(mocker):
assert configFile["database"]["app_db_name"] == "some_app"


def test_config_missing_params(mocker):
def test_config_load_defaults(mocker):
mock_config = """
name: "some-app"
language: "python"
runtimeConfig:
start:
- "python3 main.py"
"""
mocker.patch(
"builtins.open", side_effect=generate_mock_open(mock_filename, mock_config)
)

configFile = load_config(mock_filename)
assert configFile["name"] == "some-app"
assert configFile["language"] == "python"
assert configFile["database"]["hostname"] == "localhost"
assert configFile["database"]["port"] == 5432
assert configFile["database"]["username"] == "postgres"
assert configFile["database"]["password"] == os.environ.get("PGPASSWORD", "dbos")


def test_config_load_db_connection(mocker):
mock_config = """
name: "some-app"
language: "python"
runtimeConfig:
start:
- "python3 main.py"
"""
mock_db_connection = """
{"hostname": "example.com", "port": 2345, "username": "example", "password": "password", "local_suffix": true}
"""
mocker.patch(
"builtins.open",
side_effect=generate_mock_open(
[mock_filename, ".dbos/db_connection"], [mock_config, mock_db_connection]
),
)

configFile = load_config(mock_filename, use_db_wizard=False)
assert configFile["name"] == "some-app"
assert configFile["language"] == "python"
assert configFile["database"]["hostname"] == "example.com"
assert configFile["database"]["port"] == 2345
assert configFile["database"]["username"] == "example"
assert configFile["database"]["password"] == "password"
assert configFile["database"]["local_suffix"] == True
assert configFile["database"]["app_db_name"] == "some_app_local"


def test_config_mixed_params(mocker):
mock_config = """
name: "some-app"
language: "python"
runtimeConfig:
start:
- "python3 main.py"
database:
port: 1234
username: 'some user'
password: abc123
connectionTimeoutMillis: 3000
"""
mocker.patch(
"builtins.open", side_effect=generate_mock_open(mock_filename, mock_config)
)

with pytest.raises(DBOSInitializationError) as exc_info:
load_config(mock_filename)

assert "'hostname' is a required property" in str(exc_info.value)
configFile = load_config(mock_filename, use_db_wizard=False)
assert configFile["name"] == "some-app"
assert configFile["language"] == "python"
assert configFile["database"]["hostname"] == "localhost"
assert configFile["database"]["port"] == 1234
assert configFile["database"]["username"] == "some user"
assert configFile["database"]["password"] == "abc123"


def test_config_extra_params(mocker):
Expand Down

0 comments on commit dc5c4c6

Please sign in to comment.