Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions udata/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
url_for,
)
from flask_restx import Api, Resource
from flask_restx.marshalling import marshal
from flask_restx.reqparse import RequestParser
from flask_restx.utils import merge, unpack
from flask_storage import UnauthorizedFileType

from udata import tracking
Expand Down Expand Up @@ -42,6 +44,51 @@ def __init__(self, app=None, **kwargs):
super(UDataApi, self).__init__(app, **kwargs)
self.authorizations = {"apikey": {"type": "apiKey", "in": "header", "name": HEADER_API_KEY}}

def marshal_with(self, fields, as_list=False, code=200, description=None, **kwargs):
"""Override marshal_with to support API versioning.

If the model has a _version_resolver attribute, the resolver is called at
request time to pick the right model for the requested API version.
Swagger documentation always uses the latest (static) model.
"""
resolver = getattr(fields, "_version_resolver", None)
if not resolver:
return self.default_namespace.marshal_with(fields, as_list, code, description, **kwargs)

def decorator(func):
# Register Swagger docs with the latest (static) model — same as parent
doc = {
"responses": {
str(code): (
(description, [fields], kwargs)
if as_list
else (description, fields, kwargs)
)
},
"__mask__": kwargs.get("mask", True),
}
func.__apidoc__ = merge(getattr(func, "__apidoc__", {}), doc)

@wraps(func)
def wrapper(*args, **kw):
resolved = resolver()
resp = func(*args, **kw)
mask = getattr(resolved, "__mask__", None)
mask_header = current_app.config.get("RESTX_MASK_HEADER", "X-Fields")
mask = request.headers.get(mask_header) or mask
if isinstance(resp, tuple):
data, resp_code, headers = unpack(resp)
return (
marshal(data, resolved, mask=mask, ordered=self.ordered),
resp_code,
headers,
)
return marshal(resp, resolved, mask=mask, ordered=self.ordered)

return wrapper

return decorator

def secure(self, func):
"""Enforce authentication on a given method/verb
and optionally check a given permission
Expand Down Expand Up @@ -221,6 +268,19 @@ def extract_name_from_path(path):
return safe_unicode(name)


@apiv1_blueprint.after_request
@apiv2_blueprint.after_request
def add_version_header(response):
from udata.api.versioning import VERSION_HEADER, get_request_version

try:
version = get_request_version()
response.headers[VERSION_HEADER] = str(version)
except Exception:
pass
return response


@apiv1_blueprint.after_request
@apiv2_blueprint.after_request
def collect_stats(response):
Expand Down Expand Up @@ -336,6 +396,23 @@ def marshal_page_with(func):
pass


ns_versions = api.namespace("versions", "API versioning information")


@ns_versions.route("/", endpoint="api_versions")
class APIVersionsAPI(API):
@api.doc("list_api_versions")
def get(self):
"""List all API version changes"""
from udata.api.versioning import LATEST_API_VERSION, OLDEST_API_VERSION, VERSION_CHANGES

return {
"latest": str(LATEST_API_VERSION),
"oldest": str(OLDEST_API_VERSION),
"changes": sorted(VERSION_CHANGES, key=lambda c: c["date"], reverse=True),
}


def init_app(app):
# Load all core APIs
import udata.core.access_type.api # noqa
Expand Down
144 changes: 144 additions & 0 deletions udata/api/versioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from __future__ import annotations

from flask import g, request
from packaging.version import Version

VERSION_HEADER = "X-API-Version"
LATEST_API_VERSION = Version("16.3.0")
OLDEST_API_VERSION = Version("1.0.0")

# Registry of all version changes, populated at import time via change classes
VERSION_CHANGES: list[dict] = []


def get_request_version() -> Version:
"""Return the API version requested via header. Cached on flask.g."""
if hasattr(g, "_api_version"):
return g._api_version

header = request.headers.get(VERSION_HEADER)
if not header:
version = OLDEST_API_VERSION
else:
try:
version = Version(header)
except Exception:
from udata.api import api

api.abort(400, f"Invalid {VERSION_HEADER} header. Expected a version like '16.3.0'.")

g._api_version = version
return version


class VersionChange:
"""Base class for API version changes."""

def __init__(self, version: str, description: str | None = None):
self.version_str = version
self.version = Version(version)
self.description = description

def auto_description(self) -> str:
raise NotImplementedError

def register(self, model_name: str, field_name: str | None = None):
VERSION_CHANGES.append(
{
"version": self.version_str,
"model": model_name,
"field": field_name,
"description": self.description or self.auto_description(),
}
)


class ChangeAttribute(VersionChange):
"""Before this version, some field attributes were different.

Usage:
datasets = field(
ListField(...),
href=lambda reuse: url_for(...),
before=[
ChangeAttribute("16.3.0", href=None,
description="datasets returned inline instead of href"),
],
)
"""

def __init__(self, version: str, description: str | None = None, **attrs):
super().__init__(version, description)
self.attrs = attrs

def auto_description(self) -> str:
changes = ", ".join(f"{k}={v!r}" for k, v in self.attrs.items())
return f"Field attributes changed: {changes}"

def apply(self, info: dict) -> dict:
modified = {**info}
modified.update(self.attrs)
return modified


class RenameField(VersionChange):
"""Before this version, this field had a different name.

Usage:
new_name = field(
StringField(),
before=[
RenameField("16.3.0", old_name="old_name"),
],
)
"""

def __init__(self, version: str, old_name: str, description: str | None = None):
super().__init__(version, description)
self.old_name = old_name

def auto_description(self) -> str:
return f"Field renamed from '{self.old_name}'"


class RemoveField(VersionChange):
"""At this version, this field was removed. Before this version it existed.

Usage:
legacy_field = field(
StringField(),
before=[
RemoveField("16.3.0",
description="legacy_field has been removed"),
],
)
"""

def auto_description(self) -> str:
return "Field removed"


class ChangeModelAttribute(VersionChange):
"""Before this version, model-level attributes (masks, etc.) were different.

Usage:
@generate_fields(
before=[
ChangeModelAttribute("16.3.0",
page_mask="*,datasets{id,title,uri,page}"),
],
)
"""

def __init__(self, version: str, description: str | None = None, **attrs):
super().__init__(version, description)
self.attrs = attrs

def auto_description(self) -> str:
changes = ", ".join(f"{k}={v!r}" for k, v in self.attrs.items())
return f"Model attributes changed: {changes}"

def apply(self, model_kwargs: dict) -> dict:
modified = {**model_kwargs}
modified.update(self.attrs)
return modified
Loading