Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion projects/fal_client/src/fal_client/_headers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
from __future__ import annotations

from typing import Literal, Union, get_args
from typing import Literal, Union, get_args, Optional, Any

from httpx import Headers

try:
from fal.ref import get_current_app
except ImportError:

def get_current_app() -> Optional[Any]:
return None


def _current_fal_app_request() -> Optional[Any]:
"""Get the current request if we are running in a fal app."""
if (app := get_current_app()) is not None and app.current_request is not None:
return app.current_request
return None


MIN_REQUEST_TIMEOUT_SECONDS = 1 # Minimum allowed request timeout in seconds

Expand Down Expand Up @@ -52,3 +69,15 @@ def add_priority_header(priority: Priority, headers: dict[str, str]) -> None:
f"Priority must be one of {valid_priorities}, got '{priority}'"
)
headers[QUEUE_PRIORITY_HEADER] = priority


def add_fal_app_context_headers(headers: dict[str, str]) -> None:
if request := _current_fal_app_request():
if cdn_token := request.headers.get("x-fal-cdn-token"):
headers["x-fal-cdn-token"] = cdn_token


def handle_response_headers(response_headers: Headers) -> None:
if request := _current_fal_app_request():
if cdn_token := response_headers.get("x-fal-cdn-token"):
request.headers["x-fal-cdn-token"] = cdn_token
17 changes: 16 additions & 1 deletion projects/fal_client/src/fal_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
add_priority_header,
add_timeout_header,
add_hint_header,
add_fal_app_context_headers,
handle_response_headers,
REQUEST_TIMEOUT_TYPE_HEADER,
REQUEST_TIMEOUT_HEADER,
)
Expand Down Expand Up @@ -1600,6 +1602,8 @@ async def run(
if start_timeout is not None:
add_timeout_header(start_timeout, _headers)

add_fal_app_context_headers(_headers)

response = await _async_maybe_retry_request(
self._client,
"POST",
Expand All @@ -1608,8 +1612,9 @@ async def run(
timeout=timeout,
headers=_headers,
)

_raise_for_status(response)
handle_response_headers(response.headers)

return response.json()

async def submit(
Expand Down Expand Up @@ -1653,6 +1658,8 @@ async def submit(
if start_timeout is not None:
add_timeout_header(start_timeout, _headers)

add_fal_app_context_headers(_headers)

response = await _async_maybe_retry_request(
self._client,
"POST",
Expand All @@ -1662,6 +1669,7 @@ async def submit(
headers=_headers,
)
_raise_for_status(response)
handle_response_headers(response.headers)

data = response.json()
return AsyncRequestHandle(
Expand Down Expand Up @@ -2096,6 +2104,8 @@ def run(
if start_timeout is not None:
add_timeout_header(start_timeout, _headers)

add_fal_app_context_headers(_headers)

response = _maybe_retry_request(
self._client,
"POST",
Expand All @@ -2105,6 +2115,8 @@ def run(
headers=_headers,
)
_raise_for_status(response)
handle_response_headers(response.headers)

return response.json()

def submit(
Expand Down Expand Up @@ -2145,6 +2157,8 @@ def submit(
if start_timeout is not None:
add_timeout_header(start_timeout, _headers)

add_fal_app_context_headers(_headers)

response = _maybe_retry_request(
self._client,
"POST",
Expand All @@ -2154,6 +2168,7 @@ def submit(
headers=_headers,
)
_raise_for_status(response)
handle_response_headers(response.headers)

data = response.json()
return SyncRequestHandle(
Expand Down
Loading