Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion client/src/cbltest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .globals import CBLPyTestGlobal
from .logging import LogLevel, cbl_log_init, cbl_setLogLevel
from .requests import RequestFactory
from .version import available_api_version


class CBLPyTest:
Expand Down Expand Up @@ -81,6 +82,7 @@ async def create(
cbl_log_init(str(ret_val.request_factory.uuid), ret_val.config.logslurp_url)

ts_index = 0
await ret_val.resolve_api_version()
for ts in ret_val.test_servers:
await ts.new_session(
str(ret_val.request_factory.uuid),
Expand Down Expand Up @@ -112,7 +114,6 @@ def __init__(
index = 0
for ts in self.__config.test_servers:
ts_info = TestServerInfo(ts)

dataset_version = ts_info.dataset_version or dataset_version
self.__test_servers.append(
TestServer(self.__request_factory, index, ts_info.url, dataset_version)
Expand Down Expand Up @@ -145,6 +146,23 @@ def __init__(
)
)

async def resolve_api_version(self) -> None:
ts_index = 0
apiVersion = 0
for ts in self.test_servers:
root_info = await ts.get_info()
if apiVersion != 0 and root_info.version != apiVersion:
raise ValueError(
f"Test Server at index {ts_index} has API version "
f"{root_info.version} which does not match other test servers' "
f"API version {apiVersion}"
)

apiVersion = available_api_version(root_info.version)
ts_index += 1

self.__request_factory.version = apiVersion

async def close(self) -> None:
"""
Closes all the test servers and sync gateways
Expand Down
90 changes: 46 additions & 44 deletions client/src/cbltest/api/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,24 @@

from opentelemetry.trace import get_tracer

from cbltest.api.database_types import MaintenanceType
from cbltest.api.database_types import DocumentEntry, MaintenanceType
from cbltest.api.error import CblTestError
from cbltest.logging import cbl_error, cbl_trace
from cbltest.request_types import DatabaseUpdateEntry, DatabaseUpdateType
from cbltest.requests import RequestFactory, TestServerRequestType
from cbltest.v1.requests import (
DatabaseUpdateEntry,
DatabaseUpdateType,
DocumentEntry,
PostGetAllDocumentsRequestBody,
PostGetDocumentRequestBody,
PostPerformMaintenanceRequestBody,
PostRunQueryRequestBody,
PostSnapshotDocumentsRequestBody,
PostUpdateDatabaseRequestBody,
PostVerifyDocumentsRequestBody,
)
from cbltest.v1.responses import (
from cbltest.response_types import (
PostGetAllDocumentsEntry,
PostGetAllDocumentsResponse,
PostGetDocumentResponse,
PostRunQueryResponse,
PostSnapshotDocumentsResponse,
PostVerifyDocumentsResponse,
PostGetAllDocumentsResponseMethods,
PostGetDocumentResponseMethods,
PostRunQueryResponseMethods,
PostSnapshotDocumentsResponseMethods,
PostVerifyDocumentsResponseMethods,
ValueOrMissing,
)
from cbltest.version import VERSION

BASE_BLOB_URL = "https://media.githubusercontent.com/media/couchbaselabs/couchbase-lite-tests/refs/heads/main/dataset/server/blobs/"


class SnapshotUpdater:
def __init__(self, id: str):
Expand Down Expand Up @@ -87,6 +78,11 @@ def upsert_document(
"Incorrect new_properties format, must be a list of dictionaries each with properties to update"
)

if new_blobs is not None:
for keypath in new_blobs:
blob_name = new_blobs[keypath]
new_blobs[keypath] = BASE_BLOB_URL + blob_name

self._updates.append(
DatabaseUpdateEntry(
DatabaseUpdateType.UPDATE,
Expand All @@ -105,9 +101,6 @@ class DatabaseUpdater:
"""

def __init__(self, db_name: str, request_factory: RequestFactory, index: int):
assert request_factory.version == 1, (
"This version of the CBLTest API requires request API v1"
)
self._db_name = db_name
self._updates: list[DatabaseUpdateEntry] = []
self.__request_factory = request_factory
Expand All @@ -124,9 +117,10 @@ async def __aexit__(self, exc_type, exc, tb):
if self.__error is not None:
raise CblTestError(self.__error)

payload = PostUpdateDatabaseRequestBody(self._db_name, self._updates)
request = self.__request_factory.create_request(
TestServerRequestType.UPDATE_DB, payload
TestServerRequestType.UPDATE_DB,
database=self._db_name,
updates=self._updates,
)
resp = await self.__request_factory.send_request(self.__index, request)
if resp.error is not None:
Expand Down Expand Up @@ -182,6 +176,11 @@ def upsert_document(
self.__error = "Incorrect new_properties format, must be a list of dictionaries each with properties to update"
return

if new_blobs is not None:
for keypath in new_blobs:
blob_name = new_blobs[keypath]
new_blobs[keypath] = BASE_BLOB_URL + blob_name

self._updates.append(
DatabaseUpdateEntry(
DatabaseUpdateType.UPDATE,
Expand Down Expand Up @@ -269,7 +268,7 @@ def document(self) -> dict[str, Any] | None:
"""Gets the document body of the document with the faulty keypath, if applicable"""
return self.__response.document

def __init__(self, rest_response: PostVerifyDocumentsResponse) -> None:
def __init__(self, rest_response: PostVerifyDocumentsResponseMethods) -> None:
self.__response = rest_response


Expand Down Expand Up @@ -384,12 +383,13 @@ async def get_all_documents(
"cbl.collection.names": collections,
},
):
payload = PostGetAllDocumentsRequestBody(self.__name, *collections)
req = self.__request_factory.create_request(
TestServerRequestType.ALL_DOC_IDS, payload
TestServerRequestType.ALL_DOC_IDS,
database=self.__name,
collections=list(collections),
)
resp = await self.__request_factory.send_request(self.__index, req)
cast_resp = cast(PostGetAllDocumentsResponse, resp)
cast_resp = cast(PostGetAllDocumentsResponseMethods, resp)
ret_val: dict[str, list[AllDocumentsEntry]] = {}
for c in cast_resp.collection_keys:
ret_val[c] = list(
Expand All @@ -412,12 +412,13 @@ async def get_document(self, document: DocumentEntry) -> GetDocumentResult:
"cbl.document.id": document.id,
},
):
payload = PostGetDocumentRequestBody(self.__name, document)
req = self.__request_factory.create_request(
TestServerRequestType.GET_DOCUMENT, payload
TestServerRequestType.GET_DOCUMENT,
database=self.__name,
document=document,
)
resp = await self.__request_factory.send_request(self.__index, req)
cast_resp = cast(PostGetDocumentResponse, resp)
cast_resp = cast(PostGetDocumentResponseMethods, resp)
return GetDocumentResult(cast_resp.raw_body)

async def create_snapshot(self, documents: list[DocumentEntry]) -> str:
Expand All @@ -427,12 +428,13 @@ async def create_snapshot(self, documents: list[DocumentEntry]) -> str:
:param documents: A list of documents to include in the snapshot
"""
with self.__tracer.start_as_current_span("create_snapshot"):
payload = PostSnapshotDocumentsRequestBody(self.__name, documents)
req = self.__request_factory.create_request(
TestServerRequestType.SNAPSHOT_DOCS, payload
TestServerRequestType.SNAPSHOT_DOCS,
database=self.__name,
entries=documents,
)
resp = await self.__request_factory.send_request(self.__index, req)
return cast(PostSnapshotDocumentsResponse, resp).snapshot_id
return cast(PostSnapshotDocumentsResponseMethods, resp).snapshot_id

async def verify_documents(self, updater: SnapshotUpdater) -> VerifyResult:
"""
Expand All @@ -442,14 +444,14 @@ async def verify_documents(self, updater: SnapshotUpdater) -> VerifyResult:
:param updater: The id and expected updates
"""
with self.__tracer.start_as_current_span("verify_documents"):
payload = PostVerifyDocumentsRequestBody(
self.__name, updater._id, updater._updates
)
req = self.__request_factory.create_request(
TestServerRequestType.VERIFY_DOCS, payload
TestServerRequestType.VERIFY_DOCS,
database=self.__name,
snapshot=updater._id,
changes=updater._updates,
)
resp = await self.__request_factory.send_request(self.__index, req)
return VerifyResult(cast(PostVerifyDocumentsResponse, resp))
return VerifyResult(cast(PostVerifyDocumentsResponseMethods, resp))

async def perform_maintenance(self, type: MaintenanceType) -> None:
"""
Expand All @@ -458,9 +460,10 @@ async def perform_maintenance(self, type: MaintenanceType) -> None:
:param type: The type of maintenance to perform
"""
with self.__tracer.start_as_current_span("perform_maintenance"):
payload = PostPerformMaintenanceRequestBody(self.__name, str(type))
req = self.__request_factory.create_request(
TestServerRequestType.PERFORM_MAINTENANCE, payload
TestServerRequestType.PERFORM_MAINTENANCE,
db=self.__name,
op_type=str(type),
)
await self.__request_factory.send_request(self.__index, req)

Expand All @@ -471,9 +474,8 @@ async def run_query(self, query: str) -> list[dict]:
:param query: The SQL++ query to run
"""
with self.__tracer.start_as_current_span("run_query"):
payload = PostRunQueryRequestBody(self.__name, query)
req = self.__request_factory.create_request(
TestServerRequestType.RUN_QUERY, payload
TestServerRequestType.RUN_QUERY, database=self.__name, query=query
)
resp = await self.__request_factory.send_request(self.__index, req)
return cast(PostRunQueryResponse, resp).results
return cast(PostRunQueryResponseMethods, resp).results
23 changes: 8 additions & 15 deletions client/src/cbltest/api/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
from cbltest.api.database import Database
from cbltest.logging import cbl_error, cbl_trace
from cbltest.requests import TestServerRequestType
from cbltest.v1.requests import (
PostStartListenerRequestBody,
PostStopListenerRequestBody,
)
from cbltest.v1.responses import PostStartListenerResponse
from cbltest.response_types import PostStartListenerResponseMethods
from cbltest.version import VERSION


Expand Down Expand Up @@ -47,26 +43,23 @@ def __init__(
self.__tracer = get_tracer(__name__, VERSION)
self.__id: str = ""

assert database._request_factory.version == 1, (
"This version of the cbl test API requires request API v1"
)

async def start(self) -> None:
"""Start listening for incoming connections"""
with self.__tracer.start_as_current_span("start_listener"):
payload = PostStartListenerRequestBody(
self.database.name, self.collections, self.port, self.disable_tls
)
request = self.__request_factory.create_request(
TestServerRequestType.START_LISTENER, payload
TestServerRequestType.START_LISTENER,
db=self.database.name,
collections=self.collections,
port=self.port,
disable_tls=self.disable_tls,
)
resp = await self.__request_factory.send_request(self.__index, request)
if resp.error is not None:
cbl_error("Failed to start replicator (see trace log for details)")
cbl_trace(resp.error.message)
return

cast_resp = cast(PostStartListenerResponse, resp)
cast_resp = cast(PostStartListenerResponseMethods, resp)
self.port = cast_resp.port
self.__id = cast_resp.listener_id

Expand All @@ -75,7 +68,7 @@ async def stop(self) -> None:
with self.__tracer.start_as_current_span("stop_listener"):
request = self.__request_factory.create_request(
TestServerRequestType.STOP_LISTENER,
PostStopListenerRequestBody(self.__id),
id=self.__id,
)
resp = await self.__request_factory.send_request(self.__index, request)
if resp.error is not None:
Expand Down
37 changes: 13 additions & 24 deletions client/src/cbltest/api/multipeer_replicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,10 @@
from cbltest.api.x509_certificate import CertKeyPair, create_leaf_certificate
from cbltest.logging import cbl_error, cbl_trace
from cbltest.requests import TestServerRequestType
from cbltest.v1.requests import (
PostGetMultipeerReplicatorStatusRequestBody,
PostStartMultipeerReplicatorRequestBody,
PostStopMultipeerReplicatorRequestBody,
)
from cbltest.v1.responses import (
from cbltest.response_types import (
MultipeerReplicatorStatusEntry,
PostGetMultipeerReplicatorStatusResponse,
PostStartMultipeerReplicatorResponse,
PostGetMultipeerReplicatorStatusResponseMethods,
PostStartMultipeerReplicatorResponseMethods,
)
from cbltest.version import VERSION

Expand Down Expand Up @@ -74,9 +69,6 @@ def __init__(
authenticator: MultipeerReplicatorAuthenticator | None = None,
identity: CertKeyPair | None = None,
):
assert database._request_factory.version == 1, (
"This version of the cbl test API requires request API v1"
)
self.__index = database._index
self.__request_factory = database._request_factory
self.__peerGroupID = peerGroupID
Expand All @@ -97,16 +89,13 @@ async def start(self) -> None:
Starts the multipeer replicator
"""
with self.__tracer.start_as_current_span("start_multipeer_replicator"):
payload = PostStartMultipeerReplicatorRequestBody(
self.__peerGroupID,
self.__database.name,
self.__collections,
self.__identity,
authenticator=self.__authenticator,
)

req = self.__request_factory.create_request(
TestServerRequestType.START_MULTIPEER_REPLICATOR, payload
TestServerRequestType.START_MULTIPEER_REPLICATOR,
peerGroupID=self.__peerGroupID,
database=self.__database.name,
collections=self.__collections,
identity=self.__identity,
authenticator=self.__authenticator,
)
resp = await self.__request_factory.send_request(self.__index, req)
if resp.error is not None:
Expand All @@ -116,7 +105,7 @@ async def start(self) -> None:
cbl_trace(resp.error.message)
return None

cast_resp = cast(PostStartMultipeerReplicatorResponse, resp)
cast_resp = cast(PostStartMultipeerReplicatorResponseMethods, resp)
self.__id = cast_resp.replicator_id

async def stop(self) -> None:
Expand All @@ -130,7 +119,7 @@ async def stop(self) -> None:

req = self.__request_factory.create_request(
TestServerRequestType.STOP_MULTIPEER_REPLICATOR,
PostStopMultipeerReplicatorRequestBody(self.__id),
id=self.__id,
)
resp = await self.__request_factory.send_request(self.__index, req)
if resp.error is not None:
Expand All @@ -152,10 +141,10 @@ async def get_status(self) -> MultipeerReplicatorStatus:

req = self.__request_factory.create_request(
TestServerRequestType.MULTIPEER_REPLICATOR_STATUS,
PostGetMultipeerReplicatorStatusRequestBody(self.__id),
id=self.__id,
)
resp = await self.__request_factory.send_request(self.__index, req)
cast_resp = cast(PostGetMultipeerReplicatorStatusResponse, resp)
cast_resp = cast(PostGetMultipeerReplicatorStatusResponseMethods, resp)
return MultipeerReplicatorStatus(cast_resp.replicators)

async def wait_for_idle(
Expand Down
Loading
Loading