Skip to content

Commit 8ea2d91

Browse files
fix browsing context
1 parent acb22f1 commit 8ea2d91

File tree

1 file changed

+268
-7
lines changed

1 file changed

+268
-7
lines changed

py/selenium/webdriver/common/bidi/browsing_context.py

Lines changed: 268 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ class SetViewportParameters:
253253

254254
context: Any | None = None
255255
viewport: Any | None = None
256-
user_contexts: list[Any | None] | None = field(default_factory=list)
256+
device_pixel_ratio: Any | None = None
257257

258258

259259
@dataclass
@@ -329,6 +329,9 @@ class BrowsingContext:
329329

330330
def __init__(self, driver) -> None:
331331
self._driver = driver
332+
self._event_handlers: dict = {}
333+
self._handler_id_counter = 0
334+
self._registered_ws_callbacks: set = set()
332335

333336
def activate(self, context: Any | None = None):
334337
"""Execute browsingContext.activate."""
@@ -344,16 +347,24 @@ def capture_screenshot(
344347
context: Any | None = None,
345348
format: Any | None = None,
346349
clip: Any | None = None,
350+
origin: Any | None = None,
347351
):
348352
"""Execute browsingContext.captureScreenshot."""
349353
params = {
350354
"context": context,
351355
"format": format,
352356
"clip": clip,
357+
"origin": origin,
353358
}
354359
params = {k: v for k, v in params.items() if v is not None}
355360
cmd = command_builder("browsingContext.captureScreenshot", params)
356-
return self._driver.execute(cmd)
361+
result = self._driver.execute(cmd)
362+
# Return the base64 string directly from the data field
363+
return (
364+
result.get("data")
365+
if isinstance(result, dict) and "data" in result
366+
else result
367+
)
357368

358369
def close(self, context: Any | None = None, prompt_unload: Any | None = None):
359370
"""Execute browsingContext.close."""
@@ -392,7 +403,9 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None):
392403
}
393404
params = {k: v for k, v in params.items() if v is not None}
394405
cmd = command_builder("browsingContext.getTree", params)
395-
return self._driver.execute(cmd)
406+
result = self._driver.execute(cmd)
407+
# Convert raw context dicts to Info objects
408+
return self._convert_contexts_to_info_list(result.get("contexts", []))
396409

397410
def handle_user_prompt(
398411
self,
@@ -416,17 +429,21 @@ def locate_nodes(
416429
locator: Any | None = None,
417430
serialization_options: Any | None = None,
418431
start_nodes: list[Any] = None,
432+
max_node_count: Any | None = None,
419433
):
420434
"""Execute browsingContext.locateNodes."""
421435
params = {
422436
"context": context,
423437
"locator": locator,
424438
"serializationOptions": serialization_options,
425439
"startNodes": start_nodes,
440+
"maxNodeCount": max_node_count,
426441
}
427442
params = {k: v for k, v in params.items() if v is not None}
428443
cmd = command_builder("browsingContext.locateNodes", params)
429-
return self._driver.execute(cmd)
444+
result = self._driver.execute(cmd)
445+
# Return the nodes list directly
446+
return result.get("nodes", []) if isinstance(result, dict) else result
430447

431448
def navigate(
432449
self,
@@ -462,7 +479,13 @@ def print(
462479
}
463480
params = {k: v for k, v in params.items() if v is not None}
464481
cmd = command_builder("browsingContext.print", params)
465-
return self._driver.execute(cmd)
482+
result = self._driver.execute(cmd)
483+
# Return the base64 string directly from the data field
484+
return (
485+
result.get("data")
486+
if isinstance(result, dict) and "data" in result
487+
else result
488+
)
466489

467490
def reload(
468491
self,
@@ -484,13 +507,13 @@ def set_viewport(
484507
self,
485508
context: Any | None = None,
486509
viewport: Any | None = None,
487-
user_contexts: list[Any] = None,
510+
device_pixel_ratio: Any | None = None,
488511
):
489512
"""Execute browsingContext.setViewport."""
490513
params = {
491514
"context": context,
492515
"viewport": viewport,
493-
"userContexts": user_contexts,
516+
"devicePixelRatio": device_pixel_ratio,
494517
}
495518
params = {k: v for k, v in params.items() if v is not None}
496519
cmd = command_builder("browsingContext.setViewport", params)
@@ -506,6 +529,244 @@ def traverse_history(self, context: Any | None = None, delta: Any | None = None)
506529
cmd = command_builder("browsingContext.traverseHistory", params)
507530
return self._driver.execute(cmd)
508531

532+
def add_event_handler(self, event_name: str, callback, contexts: list[str] = None):
533+
"""Register an event handler for the specified event.
534+
535+
Args:
536+
event_name: The name of the event (e.g., 'context_created', 'navigation_started')
537+
callback: The callable to invoke when the event occurs
538+
contexts: Optional list of context IDs to filter events
539+
540+
Returns:
541+
A callback_id that can be used to remove the handler later
542+
"""
543+
if event_name not in self._event_handlers:
544+
self._event_handlers[event_name] = {}
545+
546+
callback_id = self._handler_id_counter
547+
self._handler_id_counter += 1
548+
549+
self._event_handlers[event_name][callback_id] = {
550+
"callback": callback,
551+
"contexts": contexts,
552+
}
553+
554+
# If this is the first handler for this event type, subscribe at BiDi level
555+
if len(self._event_handlers[event_name]) == 1:
556+
self._subscribe_to_event(event_name)
557+
558+
return callback_id
559+
560+
def remove_event_handler(self, event_name: str, callback_id: int):
561+
"""Remove an event handler.
562+
563+
Args:
564+
event_name: The name of the event
565+
callback_id: The callback_id returned from add_event_handler
566+
"""
567+
if event_name in self._event_handlers:
568+
if callback_id in self._event_handlers[event_name]:
569+
del self._event_handlers[event_name][callback_id]
570+
571+
# Clean up empty event entries and unsubscribe if no more handlers
572+
if not self._event_handlers[event_name]:
573+
del self._event_handlers[event_name]
574+
self._unsubscribe_from_event(event_name)
575+
576+
def _subscribe_to_event(self, event_name: str):
577+
"""Subscribe to a BiDi event and register dispatcher with WebSocket connection."""
578+
# Map Python event names to BiDi event names
579+
bidi_event_map = {
580+
"context_created": "browsingContext.contextCreated",
581+
"context_destroyed": "browsingContext.contextDestroyed",
582+
"navigation_started": "browsingContext.navigationStarted",
583+
"navigation_committed": "browsingContext.navigationCommitted",
584+
"navigation_failed": "browsingContext.navigationFailed",
585+
"navigation_aborted": "browsingContext.navigationAborted",
586+
"dom_content_loaded": "browsingContext.domContentLoaded",
587+
"load": "browsingContext.load",
588+
"fragment_navigated": "browsingContext.fragmentNavigated",
589+
"history_updated": "browsingContext.historyUpdated",
590+
"user_prompt_opened": "browsingContext.userPromptOpened",
591+
"user_prompt_closed": "browsingContext.userPromptClosed",
592+
"download_will_begin": "browsingContext.downloadWillBegin",
593+
"download_end": "browsingContext.downloadEnd",
594+
}
595+
596+
bidi_event_name = bidi_event_map.get(event_name)
597+
if not bidi_event_name:
598+
return
599+
600+
# The _driver is actually a WebSocketConnection object
601+
ws_conn = self._driver
602+
603+
# Subscribe to the event via BiDi protocol
604+
try:
605+
from .session import Session
606+
607+
# Create subscription command and execute
608+
subscription_cmd = Session(ws_conn).subscribe(events=[bidi_event_name])
609+
ws_conn.execute(subscription_cmd)
610+
except Exception:
611+
# If subscription fails, continue - some events may not support subscription
612+
pass
613+
614+
# Register dispatcher callback with WebSocket connection
615+
if bidi_event_name not in self._registered_ws_callbacks:
616+
# Create a dispatcher that routes events to all registered handlers
617+
# Use default parameter to capture event_name value, not reference
618+
def event_dispatcher(event_data, _event_name=event_name):
619+
self._on_event(_event_name, event_data)
620+
621+
# Register with WebSocket connection
622+
if bidi_event_name not in ws_conn.callbacks:
623+
ws_conn.callbacks[bidi_event_name] = []
624+
ws_conn.callbacks[bidi_event_name].append(event_dispatcher)
625+
self._registered_ws_callbacks.add(bidi_event_name)
626+
627+
def _unsubscribe_from_event(self, event_name: str):
628+
"""Unsubscribe from a BiDi event."""
629+
# Map Python event names to BiDi event names
630+
bidi_event_map = {
631+
"context_created": "browsingContext.contextCreated",
632+
"context_destroyed": "browsingContext.contextDestroyed",
633+
"navigation_started": "browsingContext.navigationStarted",
634+
"navigation_committed": "browsingContext.navigationCommitted",
635+
"navigation_failed": "browsingContext.navigationFailed",
636+
"navigation_aborted": "browsingContext.navigationAborted",
637+
"dom_content_loaded": "browsingContext.domContentLoaded",
638+
"load": "browsingContext.load",
639+
"fragment_navigated": "browsingContext.fragmentNavigated",
640+
"history_updated": "browsingContext.historyUpdated",
641+
"user_prompt_opened": "browsingContext.userPromptOpened",
642+
"user_prompt_closed": "browsingContext.userPromptClosed",
643+
"download_will_begin": "browsingContext.downloadWillBegin",
644+
"download_end": "browsingContext.downloadEnd",
645+
}
646+
647+
bidi_event_name = bidi_event_map.get(event_name)
648+
if not bidi_event_name:
649+
return
650+
651+
# Remove dispatcher callback from WebSocket connection to prevent events from being delivered
652+
if bidi_event_name in self._registered_ws_callbacks:
653+
if bidi_event_name in self._driver.callbacks:
654+
# Remove all dispatchers for this event
655+
self._driver.callbacks[bidi_event_name] = []
656+
self._registered_ws_callbacks.discard(bidi_event_name)
657+
658+
def _on_event(self, event_name: str, event_data: dict):
659+
"""Internal callback invoked when BiDi events arrive."""
660+
# Dispatch to registered handlers
661+
if event_name not in self._event_handlers:
662+
return
663+
664+
handlers_list = list(self._event_handlers[event_name].items())
665+
if not handlers_list:
666+
return
667+
668+
for callback_id, handler_info in handlers_list:
669+
callback = handler_info.get("callback")
670+
if not callback:
671+
continue
672+
673+
contexts = handler_info.get("contexts")
674+
675+
# Convert event data to typed object
676+
event_obj = self._convert_event_data(event_name, event_data)
677+
678+
# Check if this event should be dispatched to this handler
679+
if contexts:
680+
# Filter by context if specified
681+
if hasattr(event_obj, "context") and event_obj.context in contexts:
682+
callback(event_obj)
683+
else:
684+
callback(event_obj)
685+
686+
def _convert_event_data(self, event_name: str, event_data: dict):
687+
"""Convert raw BiDi event data to typed objects."""
688+
689+
# Create a simple object for event data that supports dot notation
690+
class EventObject:
691+
def __init__(self, data):
692+
for key, value in data.items():
693+
# Convert snake_case to python attributes
694+
setattr(self, key, value)
695+
696+
if event_name == "context_created":
697+
# Convert nested context data to Info objects
698+
info_data = event_data.copy()
699+
if "children" in info_data and info_data["children"]:
700+
info_data["children"] = [
701+
self._dict_to_info(child) for child in info_data["children"]
702+
]
703+
return self._dict_to_info(info_data)
704+
elif event_name == "context_destroyed":
705+
return self._dict_to_info(event_data)
706+
elif event_name == "user_prompt_opened":
707+
return self._create_event_object(event_data)
708+
elif event_name == "user_prompt_closed":
709+
return self._create_event_object(event_data)
710+
elif event_name == "download_will_begin":
711+
return self._create_event_object(event_data)
712+
elif event_name == "download_end":
713+
return self._create_event_object(event_data)
714+
else:
715+
return self._create_event_object(event_data)
716+
717+
def _dict_to_info(self, data: dict) -> Info:
718+
"""Convert a dictionary to an Info object."""
719+
if not isinstance(data, dict):
720+
return data
721+
722+
children = data.get("children")
723+
if children:
724+
children = [
725+
self._dict_to_info(child) if isinstance(child, dict) else child
726+
for child in children
727+
]
728+
729+
return Info(
730+
children=children,
731+
client_window=data.get("clientWindow"),
732+
context=data.get("context"),
733+
original_opener=data.get("originalOpener"),
734+
url=data.get("url"),
735+
user_context=data.get("userContext"),
736+
parent=data.get("parent"),
737+
)
738+
739+
def _create_event_object(self, data: dict):
740+
"""Create a simple event object that supports dot notation access with snake_case attributes."""
741+
import re
742+
743+
# Convert camelCase keys to snake_case
744+
def camel_to_snake(name):
745+
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
746+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
747+
748+
snake_case_data = {}
749+
for key, value in data.items():
750+
# Convert nested dicts recursively
751+
if isinstance(value, dict):
752+
value = self._create_event_object(value).__dict__
753+
snake_case_data[camel_to_snake(key)] = value
754+
755+
class EventObject:
756+
def __init__(self, d):
757+
self.__dict__.update(d)
758+
759+
return EventObject(snake_case_data)
760+
761+
def _convert_contexts_to_info_list(self, contexts_data: list) -> list:
762+
"""Convert a list of context dicts to Info objects."""
763+
if not contexts_data:
764+
return []
765+
return [
766+
self._dict_to_info(ctx) if isinstance(ctx, dict) else ctx
767+
for ctx in contexts_data
768+
]
769+
509770
def context_created(
510771
self,
511772
children: Any | None = None,

0 commit comments

Comments
 (0)