Skip to content

Commit 2a2a621

Browse files
authored
feat: enhance DoubaoTextToSpeech with download timeouts and retry logic (#16)
1 parent 0a2dcee commit 2a2a621

File tree

2 files changed

+182
-20
lines changed

2 files changed

+182
-20
lines changed

epub2speech/tts/doubao_provider.py

Lines changed: 85 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,42 @@ def __init__(
2727
poll_interval: float = 2.0,
2828
submit_timeout: float = 1800.0,
2929
poll_timeout: float = 30.0,
30+
download_connect_timeout: float = 20.0,
31+
download_read_timeout: float = 300.0,
32+
download_max_retries: int = 3,
33+
download_retry_delay: float = 2.0,
3034
):
3135
if not access_token:
3236
raise ValueError("access_token is required and cannot be None or empty")
3337
if not base_url:
3438
raise ValueError("base_url is required and cannot be None or empty")
39+
if max_retries < 1:
40+
raise ValueError("max_retries must be at least 1")
41+
if poll_interval <= 0:
42+
raise ValueError("poll_interval must be > 0")
43+
if submit_timeout <= 0:
44+
raise ValueError("submit_timeout must be > 0")
45+
if poll_timeout <= 0:
46+
raise ValueError("poll_timeout must be > 0")
47+
if download_connect_timeout <= 0:
48+
raise ValueError("download_connect_timeout must be > 0")
49+
if download_read_timeout <= 0:
50+
raise ValueError("download_read_timeout must be > 0")
51+
if download_max_retries < 1:
52+
raise ValueError("download_max_retries must be at least 1")
53+
if download_retry_delay < 0:
54+
raise ValueError("download_retry_delay must be >= 0")
3555

3656
self.access_token = access_token
3757
self.base_url = base_url.rstrip("/")
3858
self.max_retries = max_retries
3959
self.poll_interval = poll_interval
4060
self.submit_timeout = submit_timeout
4161
self.poll_timeout = poll_timeout
62+
self.download_connect_timeout = download_connect_timeout
63+
self.download_read_timeout = download_read_timeout
64+
self.download_max_retries = download_max_retries
65+
self.download_retry_delay = download_retry_delay
4266

4367
self._setup()
4468

@@ -188,25 +212,66 @@ def _poll_tts_result(self, task_id: str) -> str:
188212

189213
def _download_audio(self, audio_url: str, output_path: Path) -> None:
190214
"""Download audio file from URL to output_path."""
191-
try:
192-
response = requests.get(audio_url, timeout=300.0, stream=True)
193-
response.raise_for_status()
215+
temp_output_path = output_path.with_suffix(f"{output_path.suffix}.part")
216+
last_timeout_error: requests.exceptions.Timeout | None = None
217+
last_network_error: requests.exceptions.RequestException | None = None
194218

195-
# Write to file in chunks
196-
with open(output_path, "wb") as f:
197-
for chunk in response.iter_content(chunk_size=8192):
198-
if chunk:
199-
f.write(chunk)
219+
for attempt in range(1, self.download_max_retries + 1):
220+
try:
221+
with requests.get(
222+
audio_url,
223+
timeout=(self.download_connect_timeout, self.download_read_timeout),
224+
stream=True,
225+
) as response:
226+
response.raise_for_status()
200227

201-
except requests.exceptions.Timeout as e:
202-
raise TimeoutError("Download timeout after 300 seconds") from e
203-
except requests.exceptions.HTTPError as e:
204-
# HTTPError always has a response attribute
205-
resp = e.response
206-
if resp is not None:
207-
raise RuntimeError(f"Download failed with HTTP Error {resp.status_code}: {resp.text}") from e
208-
raise RuntimeError(f"Download failed with HTTP Error: {e}") from e
209-
except requests.exceptions.RequestException as e:
210-
raise ConnectionError(f"Download failed: {e}") from e
211-
except OSError as e:
212-
raise RuntimeError(f"Failed to write audio file to {output_path}: {e}") from e
228+
with open(temp_output_path, "wb") as f:
229+
for chunk in response.iter_content(chunk_size=8192):
230+
if chunk:
231+
f.write(chunk)
232+
233+
temp_output_path.replace(output_path)
234+
return
235+
236+
except requests.exceptions.Timeout as e:
237+
temp_output_path.unlink(missing_ok=True)
238+
last_timeout_error = e
239+
if attempt < self.download_max_retries:
240+
time.sleep(self.download_retry_delay * attempt)
241+
continue
242+
243+
except requests.exceptions.HTTPError as e:
244+
temp_output_path.unlink(missing_ok=True)
245+
resp = e.response
246+
status_code = resp.status_code if resp is not None else None
247+
retryable_status_codes = {408, 425, 429, 500, 502, 503, 504}
248+
if status_code in retryable_status_codes and attempt < self.download_max_retries:
249+
time.sleep(self.download_retry_delay * attempt)
250+
continue
251+
252+
if resp is not None:
253+
raise RuntimeError(f"Download failed with HTTP Error {resp.status_code}: {resp.text}") from e
254+
raise RuntimeError(f"Download failed with HTTP Error: {e}") from e
255+
256+
except requests.exceptions.RequestException as e:
257+
temp_output_path.unlink(missing_ok=True)
258+
last_network_error = e
259+
if attempt < self.download_max_retries:
260+
time.sleep(self.download_retry_delay * attempt)
261+
continue
262+
raise ConnectionError(f"Download failed after {self.download_max_retries} attempts: {e}") from e
263+
264+
except OSError as e:
265+
temp_output_path.unlink(missing_ok=True)
266+
raise RuntimeError(f"Failed to write audio file to {output_path}: {e}") from e
267+
268+
if last_timeout_error is not None:
269+
raise TimeoutError(
270+
f"Download timeout after {self.download_max_retries} attempts "
271+
f"(connect_timeout={self.download_connect_timeout}s, read_timeout={self.download_read_timeout}s)"
272+
) from last_timeout_error
273+
274+
if last_network_error is not None:
275+
raise ConnectionError(f"Download failed after {self.download_max_retries} attempts") from last_network_error
276+
277+
raise RuntimeError("Download failed due to an unknown error")

tests/test_doubao_provider.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#!/usr/bin/env python3
2+
import tempfile
3+
import unittest
4+
from pathlib import Path
5+
from unittest.mock import MagicMock, patch
6+
7+
import requests
8+
9+
from epub2speech.tts.doubao_provider import DoubaoTextToSpeech
10+
11+
12+
class TestDoubaoProviderDownload(unittest.TestCase):
13+
def _create_provider(self, *, retries: int = 3) -> DoubaoTextToSpeech:
14+
return DoubaoTextToSpeech(
15+
access_token="token",
16+
base_url="https://example.com/api/v1/tts",
17+
download_connect_timeout=1.0,
18+
download_read_timeout=2.0,
19+
download_max_retries=retries,
20+
download_retry_delay=0.0,
21+
)
22+
23+
@staticmethod
24+
def _as_response_context(response: MagicMock) -> MagicMock:
25+
context_manager = MagicMock()
26+
context_manager.__enter__.return_value = response
27+
context_manager.__exit__.return_value = None
28+
return context_manager
29+
30+
@patch("epub2speech.tts.doubao_provider.time.sleep", return_value=None)
31+
@patch("epub2speech.tts.doubao_provider.requests.get")
32+
def test_download_audio_retries_after_timeout_then_succeeds(self, mock_get: MagicMock, _: MagicMock) -> None:
33+
provider = self._create_provider(retries=3)
34+
success_response = MagicMock()
35+
success_response.raise_for_status.return_value = None
36+
success_response.iter_content.return_value = [b"chunk1", b"chunk2"]
37+
mock_get.side_effect = [
38+
requests.exceptions.Timeout("tls handshake timeout"),
39+
self._as_response_context(success_response),
40+
]
41+
42+
with tempfile.TemporaryDirectory() as tmp_dir:
43+
output_path = Path(tmp_dir) / "output.wav"
44+
provider._download_audio("https://audio.example.com/file", output_path)
45+
46+
self.assertEqual(output_path.read_bytes(), b"chunk1chunk2")
47+
self.assertFalse(output_path.with_suffix(".wav.part").exists())
48+
self.assertEqual(mock_get.call_count, 2)
49+
for call in mock_get.call_args_list:
50+
self.assertEqual(call.kwargs["timeout"], (1.0, 2.0))
51+
self.assertTrue(call.kwargs["stream"])
52+
53+
@patch("epub2speech.tts.doubao_provider.time.sleep", return_value=None)
54+
@patch("epub2speech.tts.doubao_provider.requests.get")
55+
def test_download_audio_raises_timeout_after_retries_exhausted(self, mock_get: MagicMock, _: MagicMock) -> None:
56+
provider = self._create_provider(retries=2)
57+
mock_get.side_effect = requests.exceptions.Timeout("tls handshake timeout")
58+
59+
with tempfile.TemporaryDirectory() as tmp_dir:
60+
output_path = Path(tmp_dir) / "output.wav"
61+
with self.assertRaises(TimeoutError) as err:
62+
provider._download_audio("https://audio.example.com/file", output_path)
63+
64+
self.assertIn("Download timeout after 2 attempts", str(err.exception))
65+
self.assertEqual(mock_get.call_count, 2)
66+
self.assertFalse(output_path.exists())
67+
self.assertFalse(output_path.with_suffix(".wav.part").exists())
68+
69+
@patch("epub2speech.tts.doubao_provider.time.sleep", return_value=None)
70+
@patch("epub2speech.tts.doubao_provider.requests.get")
71+
def test_download_audio_retries_on_retryable_http_status(self, mock_get: MagicMock, _: MagicMock) -> None:
72+
provider = self._create_provider(retries=2)
73+
http_error = requests.exceptions.HTTPError("service unavailable")
74+
http_error.response = MagicMock(status_code=503, text="busy")
75+
76+
fail_response = MagicMock()
77+
fail_response.raise_for_status.side_effect = http_error
78+
79+
success_response = MagicMock()
80+
success_response.raise_for_status.return_value = None
81+
success_response.iter_content.return_value = [b"ok"]
82+
83+
mock_get.side_effect = [
84+
self._as_response_context(fail_response),
85+
self._as_response_context(success_response),
86+
]
87+
88+
with tempfile.TemporaryDirectory() as tmp_dir:
89+
output_path = Path(tmp_dir) / "output.wav"
90+
provider._download_audio("https://audio.example.com/file", output_path)
91+
92+
self.assertEqual(output_path.read_bytes(), b"ok")
93+
self.assertEqual(mock_get.call_count, 2)
94+
95+
96+
if __name__ == "__main__":
97+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)