Skip to content

Commit 459a8c0

Browse files
committed
support other endpoints
1 parent 926bc21 commit 459a8c0

1 file changed

Lines changed: 43 additions & 8 deletions

File tree

tests/voice/test_17_eou_feou.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,34 +88,42 @@ class TranscriptionTests(BaseModel):
8888
)
8989

9090
# VAD delay
91-
VAD_DELAY_S: list[float] = [0.15, 0.18, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
91+
VAD_DELAY_S: list[float] = [0.18]
92+
93+
# Endpoints
94+
ENDPOINTS: list[str] = [
95+
"wss://eu.rt.speechmatics.com/v2",
96+
"wss://us.rt.speechmatics.com/v2",
97+
]
9298

9399
# Margin
94100
MARGIN_S = 0.5
95101
CER_THRESHOLD = 0.95
96102

97103

98104
@pytest.mark.asyncio
105+
@pytest.mark.parametrize("endpoint", ENDPOINTS)
99106
@pytest.mark.parametrize("sample", SAMPLES.samples, ids=lambda s: f"{s.id}:{s.path}")
100-
async def test_turn_fixed_eou(sample: TranscriptionTest):
107+
async def test_turn_fixed_eou(endpoint: str, sample: TranscriptionTest):
101108
"""Test transcription and prediction using FIXED without FEOU"""
102109

103110
# Config
104111
config = VoiceAgentConfigPreset.FIXED()
105112

106113
# Dump config
107114
if SHOW_LOG:
108-
print(f"\nTest `{sample.path}` with preset FIXED\n")
115+
print(f"\nTest `{sample.path}` with preset FIXED -> {endpoint}\n")
109116
print(config.to_json(exclude_defaults=True, exclude_none=True, exclude_unset=True, indent=2))
110117

111118
# Run test
112-
await run_test(sample, config)
119+
await run_test(endpoint, sample, config)
113120

114121

115122
@pytest.mark.asyncio
123+
@pytest.mark.parametrize("endpoint", ENDPOINTS)
116124
@pytest.mark.parametrize("sample", SAMPLES.samples, ids=lambda s: f"{s.id}")
117125
@pytest.mark.parametrize("vad_delay", VAD_DELAY_S)
118-
async def test_turn_adaptive_feou(sample: TranscriptionTest, vad_delay: float):
126+
async def test_turn_adaptive_feou(endpoint: str, sample: TranscriptionTest, vad_delay: float):
119127
"""Test transcription and prediction using ADAPTIVE with FEOU"""
120128

121129
# Config
@@ -127,18 +135,19 @@ async def test_turn_adaptive_feou(sample: TranscriptionTest, vad_delay: float):
127135

128136
# Dump config
129137
if SHOW_LOG:
130-
print(f"\nTest `{sample.path}` with preset ADAPTIVE with VAD delay of {vad_delay}s\n")
138+
print(f"\nTest `{sample.path}` with preset ADAPTIVE with VAD delay of {vad_delay}s -> {endpoint}\n")
131139
print(config.to_json(exclude_defaults=True, exclude_none=True, exclude_unset=True, indent=2))
132140

133141
# Run test
134-
await run_test(sample, config)
142+
await run_test(endpoint, sample, config)
135143

136144

137-
async def run_test(sample: TranscriptionTest, config: VoiceAgentConfig):
145+
async def run_test(endpoint: str, sample: TranscriptionTest, config: VoiceAgentConfig):
138146
"""Run a test with the given sample and config."""
139147

140148
# Client
141149
client = await get_client(
150+
url=endpoint,
142151
api_key=API_KEY,
143152
connect=False,
144153
config=config,
@@ -147,6 +156,8 @@ async def run_test(sample: TranscriptionTest, config: VoiceAgentConfig):
147156
# Results
148157
eot_count: int = 0
149158
segments_received: list[dict] = []
159+
partials_received: set[str] = set()
160+
finals_received: set[str] = set()
150161

151162
# Start time
152163
start_time = datetime.datetime.now()
@@ -162,6 +173,25 @@ def eot_detected(message):
162173
nonlocal eot_count
163174
eot_count += 1
164175

176+
# Extract words
177+
def extract_words(message) -> list[str]:
178+
return [
179+
alt.get("content", None)
180+
for result in message.get("results", [])
181+
if result.get("type") == "word"
182+
for alt in result.get("alternatives", [])
183+
]
184+
185+
# Partials
186+
def rx_partial(message):
187+
words = extract_words(message)
188+
partials_received.update(w.lower() for w in words if w)
189+
190+
# Finals
191+
def rx_finals(message):
192+
words = extract_words(message)
193+
finals_received.update(w.lower() for w in words if w)
194+
165195
# Callback for each message
166196
def log_message(message):
167197
ts = (datetime.datetime.now() - start_time).total_seconds()
@@ -178,6 +208,8 @@ def log_message(message):
178208
# Custom listeners
179209
client.on(AgentServerMessageType.END_OF_TURN, eot_detected)
180210
client.on(AgentServerMessageType.ADD_SEGMENT, add_segments)
211+
client.on(AgentServerMessageType.ADD_PARTIAL_TRANSCRIPT, rx_partial)
212+
client.on(AgentServerMessageType.ADD_TRANSCRIPT, rx_partial)
181213

182214
# HEADER
183215
if SHOW_LOG:
@@ -206,6 +238,9 @@ def log_message(message):
206238
print()
207239
print("--- AUDIO END ---")
208240
print()
241+
print(f"\nPartial words = {json.dumps(sorted(partials_received), indent=2)}\n")
242+
print(f"\nFinal words = {json.dumps(sorted(finals_received), indent=2)}\n")
243+
print()
209244

210245
# Check segment count
211246
expected_count = len(sample.segments)

0 commit comments

Comments
 (0)