Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/teradata_mcp_server/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Settings:

# Database configuration
logmech: str = "TD2"
logmech_is_explicit: bool = False # True when set via CLI arg or env var
auth_rate_limit_attempts: int = 5
auth_rate_limit_window: int = 60
pool_size: int = 5
Expand Down Expand Up @@ -57,6 +58,7 @@ def settings_from_env() -> Settings:
auth_mode=os.getenv("AUTH_MODE", "none").lower(),
auth_cache_ttl=int(os.getenv("AUTH_CACHE_TTL", "300")),
logmech=os.getenv("LOGMECH", "TD2"),
logmech_is_explicit=(os.getenv("LOGMECH") is not None),
auth_rate_limit_attempts=int(os.getenv("AUTH_RATE_LIMIT_ATTEMPTS", "5")),
auth_rate_limit_window=int(os.getenv("AUTH_RATE_LIMIT_WINDOW", "60")),
pool_size=int(os.getenv("TD_POOL_SIZE", "5")),
Expand Down
1 change: 1 addition & 0 deletions src/teradata_mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def parse_args_to_settings() -> Settings:
mcp_port=args.mcp_port if args.mcp_port is not None else env.mcp_port,
mcp_path=args.mcp_path if args.mcp_path is not None else env.mcp_path,
logmech=args.logmech if args.logmech is not None else env.logmech,
logmech_is_explicit=(args.logmech is not None) or env.logmech_is_explicit,
auth_mode=(args.auth_mode or env.auth_mode).lower(),
auth_cache_ttl=args.auth_cache_ttl if args.auth_cache_ttl is not None else env.auth_cache_ttl,
logging_level=(args.logging_level or env.logging_level).upper(),
Expand Down
46 changes: 41 additions & 5 deletions src/teradata_mcp_server/tools/td_connect.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from typing import TYPE_CHECKING, Optional
from urllib.parse import quote_plus, urlparse
from urllib.parse import parse_qs, quote_plus, urlencode, urlparse

from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
Expand Down Expand Up @@ -80,10 +80,33 @@ def __init__(self, settings: Optional["Settings"] = None):
self._base_host = parsed_url.hostname
self._base_port = parsed_url.port or 1025
self._base_db = parsed_url.path.lstrip("/")
self._default_basic_logmech = logmech

# Parse query parameters from the DATABASE_URI (e.g. LOGMECH, ENCRYPTDATA, SSLMODE)
uri_query_params = parse_qs(parsed_url.query, keep_blank_values=True)

# Extract LOGMECH from URI query params (lowest priority source)
uri_logmech_values = uri_query_params.pop("LOGMECH", [])
uri_logmech = uri_logmech_values[0] if uri_logmech_values else None

# Determine if logmech was explicitly set via CLI arg or env var
logmech_is_explicit = settings.logmech_is_explicit if settings is not None else os.getenv("LOGMECH") is not None

# Apply LOGMECH precedence: CLI/env (explicit) > URI query param > default "TD2"
if logmech_is_explicit:
self._default_basic_logmech = logmech
elif uri_logmech:
self._default_basic_logmech = uri_logmech
else:
self._default_basic_logmech = logmech # default "TD2"

# Store extra URI query params for inclusion in all reconstructed URLs
self._extra_uri_params: dict[str, str] = {k: v[0] for k, v in uri_query_params.items()}

# Build SQLAlchemy connection string for teradatasqlalchemy
sqlalchemy_url = f"teradatasql://{user}:{password}@{self._base_host}:{self._base_port}/{self._base_db}?LOGMECH={self._default_basic_logmech}"
main_query = self._build_query_string({"LOGMECH": self._default_basic_logmech})
sqlalchemy_url = (
f"teradatasql://{user}:{password}@{self._base_host}:{self._base_port}/{self._base_db}?{main_query}"
)

try:
self.engine = create_engine(
Expand All @@ -110,6 +133,17 @@ def close(self):
else:
logger.warning("SQLAlchemy engine is already disposed or was never created")

def _build_query_string(self, base_params: dict[str, str]) -> str:
"""Build a URL query string merging extra URI params with base_params.

base_params keys override any same-named extra URI params.
LOGDATA from extra params is excluded (it is connection-specific).
"""
merged = dict(self._extra_uri_params)
merged.pop("LOGDATA", None) # Never carry LOGDATA from the original URI
merged.update(base_params) # Explicit params win over URI extras
return urlencode(merged)

# ------------------------------------------------------------------
# Auth header parsing & validation (for AUTH_MODE=basic)
# ------------------------------------------------------------------
Expand Down Expand Up @@ -188,7 +222,8 @@ def _validate_basic_credentials(self, user: str, secret: str, logmech: str) -> s
try:
# For basic credential validation, just validate the credentials without specifying a database
# Let Teradata use the user's default database
sqlalchemy_url = f"teradatasql://{user}:{secret}@{self._base_host}:{self._base_port}?LOGMECH={logmech}"
basic_query = self._build_query_string({"LOGMECH": logmech})
sqlalchemy_url = f"teradatasql://{user}:{secret}@{self._base_host}:{self._base_port}?{basic_query}"
engine = create_engine(
sqlalchemy_url,
poolclass=NullPool,
Expand All @@ -209,7 +244,8 @@ def _validate_jwt_token(self, jwt_token: str) -> str | None:
"""
try:
# No username needed for JWT LOGMECH
sqlalchemy_url = f"teradatasql://@{self._base_host}:{self._base_port}/{self._base_db}?LOGMECH=JWT&LOGDATA=token={quote_plus(jwt_token)}"
jwt_query = self._build_query_string({"LOGMECH": "JWT", "LOGDATA": f"token={quote_plus(jwt_token)}"})
sqlalchemy_url = f"teradatasql://@{self._base_host}:{self._base_port}/{self._base_db}?{jwt_query}"
engine = create_engine(
sqlalchemy_url,
poolclass=NullPool,
Expand Down
Loading