Skip to content

Commit f9e1009

Browse files
committed
added test for feou
1 parent 994155b commit f9e1009

2 files changed

Lines changed: 144 additions & 0 deletions

File tree

273 KB
Binary file not shown.

tests/voice/test_18_feou.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import datetime
2+
import json
3+
import os
4+
5+
import pytest
6+
from _utils import get_client
7+
from _utils import send_audio_file
8+
from pydantic import Field
9+
10+
from speechmatics.voice import AdditionalVocabEntry
11+
from speechmatics.voice import AgentServerMessageType
12+
from speechmatics.voice._models import BaseModel
13+
from speechmatics.voice._presets import VoiceAgentConfigPreset
14+
15+
# Skip for CI testing
16+
pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping smart turn tests in CI")
17+
18+
19+
# Constants
20+
API_KEY = os.getenv("SPEECHMATICS_API_KEY")
21+
SHOW_LOG = os.getenv("SPEECHMATICS_SHOW_LOG", "0").lower() in ["1", "true"]
22+
23+
24+
class TranscriptionSpeaker(BaseModel):
25+
text: str
26+
speaker_id: int = "S1"
27+
28+
29+
class TranscriptionTest(BaseModel):
30+
id: str
31+
path: str
32+
sample_rate: int
33+
language: str
34+
35+
36+
class TranscriptionTests(BaseModel):
37+
samples: list[TranscriptionTest]
38+
39+
40+
SAMPLES: TranscriptionTests = TranscriptionTests.from_dict(
41+
{
42+
"samples": [
43+
{
44+
"id": "08",
45+
"path": "./assets/audio_08_16kHz.wav",
46+
"sample_rate": 16000,
47+
"language": "en",
48+
},
49+
]
50+
}
51+
)
52+
53+
54+
@pytest.mark.asyncio
55+
@pytest.mark.parametrize("sample", SAMPLES.samples, ids=lambda s: f"{s.id}:{s.path}")
56+
async def test_feou_payloads(sample: TranscriptionTest):
57+
"""Test transcription and segments being emitted"""
58+
59+
# API key
60+
api_key = os.getenv("SPEECHMATICS_API_KEY")
61+
if not api_key:
62+
pytest.skip("Valid API key required for test")
63+
64+
# Start time
65+
start_time = datetime.datetime.now()
66+
67+
# Results
68+
eot_count: int = 0
69+
segment_transcribed: list[str] = []
70+
71+
# Client
72+
client = await get_client(
73+
api_key=api_key,
74+
connect=False,
75+
config=VoiceAgentConfigPreset.ADAPTIVE(),
76+
)
77+
78+
# SOT detected
79+
def sot_detected(message):
80+
nonlocal eot_count
81+
eot_count += 1
82+
print("✅ START_OF_TURN: {turn_id}".format(**message))
83+
84+
# Finalized segment
85+
def add_segments(message):
86+
segments = message["segments"]
87+
for s in segments:
88+
segment_transcribed.append(s["text"])
89+
print('🚀 ADD_SEGMENT: {speaker_id} @ "{text}"'.format(**s))
90+
91+
# EOT detected
92+
def eot_detected(message):
93+
nonlocal eot_count
94+
eot_count += 1
95+
print("🏁 END_OF_TURN: {turn_id}\n".format(**message))
96+
97+
# Callback for each message
98+
def log_message(message):
99+
ts = (datetime.datetime.now() - start_time).total_seconds()
100+
log = json.dumps({"ts": round(ts, 3), "payload": message})
101+
if SHOW_LOG:
102+
print(log)
103+
104+
# # Add listeners
105+
# for message_type in AgentServerMessageType:
106+
# if message_type not in [AgentServerMessageType.AUDIO_ADDED]:
107+
# client.on(message_type, log_message)
108+
109+
# Custom listeners
110+
client.on(AgentServerMessageType.START_OF_TURN, sot_detected)
111+
client.on(AgentServerMessageType.END_OF_TURN, eot_detected)
112+
client.on(AgentServerMessageType.ADD_SEGMENT, add_segments)
113+
114+
# HEADER
115+
if SHOW_LOG:
116+
print()
117+
print()
118+
print("---")
119+
120+
# Connect
121+
try:
122+
await client.connect()
123+
except Exception:
124+
pytest.skip("Failed to connect to server")
125+
126+
# Check we are connected
127+
assert client._is_connected
128+
129+
# Individual payloads
130+
await send_audio_file(client, sample.path)
131+
132+
# FOOTER
133+
if SHOW_LOG:
134+
print("---")
135+
print()
136+
print()
137+
138+
# Close session
139+
await client.disconnect()
140+
assert not client._is_connected
141+
142+
# Debug count
143+
print(f"EOT count: {eot_count}")
144+
print(f"Segment transcribed: {len(segment_transcribed)}")

0 commit comments

Comments
 (0)