Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions client/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@ dependencies = [
"aiosignal==1.3.2",
"async-timeout==5.0.1",
"attrs==24.3.0",
"bcrypt==4.2.1",
"certifi==2024.12.14",
"cffi==1.17.1",
"charset-normalizer==3.4.1",
"colorama==0.4.6",
"couchbase==4.3.4",
"cryptography==44.0.0",
"Deprecated==1.2.15",
"exceptiongroup==1.2.2",
"executing==2.1.0",
Expand All @@ -48,15 +51,19 @@ dependencies = [
"opentelemetry-sdk==1.29.0",
"opentelemetry-semantic-conventions==0.50b0",
"packaging==24.2",
"paramiko==3.5.0",
"pluggy==1.5.0",
"propcache==0.2.1",
"protobuf==5.29.3",
"pycparser==2.22",
"PyNaCl==1.5.0",
"pytest==8.3.4",
"pytest-asyncio==0.25.2",
"requests==2.32.3",
"tomli==2.2.1",
"types-Deprecated==1.2.15.20241117",
"types-requests==2.32.0.20241016",
"typing_extensions==4.12.2",
"urllib3==2.3.0",
"websocket-client==1.8.0",
"wrapt==1.17.2",
Expand Down
15 changes: 15 additions & 0 deletions client/src/cbltest/api/cbltestclass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC

import pytest
from packaging.specifiers import SpecifierSet
from packaging.version import Version

from cbltest.api.testserver import TestServer
from cbltest.globals import CBLPyTestGlobal
Expand Down Expand Up @@ -47,3 +49,16 @@ async def skip_if_not_platform(
if variant not in allow_platforms:
self.__skipped = True
pytest.skip(f"{variant} is not in the platforms {allow_platforms}")

async def skip_if_cbl_not(self, server: TestServer, constraint: str):
"""
Skips the test if the CBL version does not match the specified comparison operation and value.

:param constraint: A string representing the comparison operation and version, e.g., ">= 3.3.0".
"""
version_str = (await server.get_info()).library_version.split("-")[0]
version = Version(version_str)
spec = SpecifierSet(constraint)
if version not in spec:
self.__skipped = True
pytest.skip(f"CBL {version_str} not {constraint}")
24 changes: 20 additions & 4 deletions client/src/cbltest/api/multipeer_replicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from opentelemetry.trace import get_tracer

from cbltest.api.database import Database
from cbltest.api.replicator import (
ReplicatorCollectionEntry,
)
from cbltest.api.multipeer_replicator_types import MultipeerReplicatorAuthenticator
from cbltest.api.replicator import ReplicatorCollectionEntry
from cbltest.api.x509_certificate import CertKeyPair, create_leaf_certificate
from cbltest.logging import cbl_error, cbl_trace
from cbltest.requests import TestServerRequestType
from cbltest.v1.requests import (
Expand Down Expand Up @@ -55,11 +55,19 @@ def collections(self) -> list[ReplicatorCollectionEntry]:
"""Gets the collections for the replicator"""
return self.__collections

@property
def identity(self) -> CertKeyPair:
"""Gets the identity used by the replicator"""
return self.__identity

def __init__(
self,
peerGroupID: str,
database: Database,
collections: list[ReplicatorCollectionEntry],
*,
authenticator: MultipeerReplicatorAuthenticator | None = None,
identity: CertKeyPair | None = None,
):
assert database._request_factory.version == 1, (
"This version of the cbl test API requires request API v1"
Expand All @@ -68,6 +76,10 @@ def __init__(
self.__request_factory = database._request_factory
self.__peerGroupID = peerGroupID
self.__database = database
self.__authenticator = authenticator
self.__identity = (
identity if identity is not None else create_leaf_certificate("anonymous")
)
assert len(collections) > 0, "At least one collection is required"
self.__collections = collections
self.__tracer = get_tracer(__name__, VERSION)
Expand All @@ -79,7 +91,11 @@ async def start(self) -> None:
"""
with self.__tracer.start_as_current_span("start_multipeer_replicator"):
payload = PostStartMultipeerReplicatorRequestBody(
self.__peerGroupID, self.__database.name, self.__collections
self.__peerGroupID,
self.__database.name,
self.__collections,
self.__identity,
authenticator=self.__authenticator,
)

req = self.__request_factory.create_request(
Expand Down
40 changes: 40 additions & 0 deletions client/src/cbltest/api/multipeer_replicator_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from abc import abstractmethod
from typing import Any

from cbltest.api.jsonserializable import JSONSerializable
from cbltest.api.x509_certificate import CertKeyPair


class MultipeerReplicatorAuthenticator(JSONSerializable):
"""
The base class for replicator authenticators
"""

@property
def name(self) -> str:
"""Gets the type of authenticator (required for all authenticators)"""
return self.__name

def __init__(self, name: str) -> None:
self.__name = name

@abstractmethod
def to_json(self) -> Any:
pass


class MultipeerReplicatorCAAuthenticator(MultipeerReplicatorAuthenticator):
"""
Represents an authenticator based on a CA certificate. Use the
:class:`cbltest.api.x509_certificate.X509Generator` if you need to generate a CA certificate.
"""

def __init__(self, ca_data: CertKeyPair) -> None:
super().__init__("CA-CERT")
self.__ca_data = ca_data

def to_json(self) -> dict[str, Any]:
return {
"type": self.name,
"params": {"certificate": self.__ca_data.pem_bytes().decode("utf-8")},
}
88 changes: 88 additions & 0 deletions client/src/cbltest/api/x509_certificate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from datetime import datetime, timedelta, timezone

from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from cryptography.hazmat.primitives.serialization import Encoding, NoEncryption, pkcs12
from cryptography.x509 import (
BasicConstraints,
Certificate,
CertificateBuilder,
Name,
NameAttribute,
NameOID,
random_serial_number,
)


class CertKeyPair:
"""
A class representing a certificate and its associated private key.
"""

def __init__(self, certificate: Certificate, private_key: Ed25519PrivateKey):
self.certificate = certificate
self.private_key = private_key

def pfx_bytes(self) -> bytes:
"""
Returns the certificate and private key in PFX format.
"""
ret_val = pkcs12.serialize_key_and_certificates(
name=b"cbltest",
key=self.private_key,
cert=self.certificate,
cas=None,
encryption_algorithm=NoEncryption(),
)

return ret_val

def pem_bytes(self) -> bytes:
"""
Returns the certificate in PEM format.
"""
return self.certificate.public_bytes(encoding=Encoding.PEM)


def create_ca_certificate(CN: str) -> CertKeyPair:
private_key = Ed25519PrivateKey.generate()
cn_attribute = Name([NameAttribute(NameOID.COMMON_NAME, CN)])
not_valid_before = datetime.now(timezone.utc)
not_valid_after = not_valid_before + timedelta(days=1)

ca_certificate: Certificate = (
CertificateBuilder()
.subject_name(cn_attribute)
.issuer_name(cn_attribute)
.public_key(private_key.public_key())
.serial_number(random_serial_number())
.not_valid_before(not_valid_before)
.not_valid_after(not_valid_after)
.add_extension(BasicConstraints(ca=True, path_length=None), critical=True)
.sign(private_key, None)
)

return CertKeyPair(ca_certificate, private_key)


def create_leaf_certificate(
Comment thread
borrrden marked this conversation as resolved.
CN: str, *, issuer_data: CertKeyPair | None = None
) -> CertKeyPair:
private_key = Ed25519PrivateKey.generate()
cn_attribute = Name([NameAttribute(NameOID.COMMON_NAME, CN)])
not_valid_before = datetime.now(timezone.utc)
not_valid_after = not_valid_before + timedelta(days=1)
issuer_name = issuer_data.certificate.subject if issuer_data else cn_attribute
signing_key = issuer_data.private_key if issuer_data else private_key

leaf_certificate = (
CertificateBuilder()
.subject_name(cn_attribute)
.issuer_name(issuer_name)
.public_key(private_key.public_key())
.serial_number(random_serial_number())
.not_valid_before(not_valid_before)
.not_valid_after(not_valid_after)
.sign(signing_key, None)
)

return CertKeyPair(leaf_certificate, private_key)
30 changes: 28 additions & 2 deletions client/src/cbltest/v1/requests.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import base64
from enum import Enum
from typing import Any, cast
from uuid import UUID

from cbltest.api.database_types import DocumentEntry
from cbltest.api.jsonserializable import JSONSerializable
from cbltest.api.multipeer_replicator_types import MultipeerReplicatorAuthenticator
from cbltest.api.replicator_types import (
ReplicatorAuthenticator,
ReplicatorCollectionEntry,
ReplicatorType,
)
from cbltest.api.x509_certificate import CertKeyPair
from cbltest.assertions import _assert_not_null
from cbltest.logging import cbl_warning
from cbltest.requests import TestServerRequest, TestServerRequestBody
Expand Down Expand Up @@ -796,8 +799,17 @@ class PostStartMultipeerReplicatorRequestBody(TestServerRequestBody):
}
}
}
]
],
"identity": {
"encoding": "PKCS12",
"data": "string",
"password": "pass"
},
"authenticator": {
"type": "CA-CERT",
"certificate": "string"
}
}
"""

@property
Expand All @@ -820,19 +832,33 @@ def __init__(
peerGroupID: str,
database: str,
collections: list[ReplicatorCollectionEntry],
identity: CertKeyPair,
*,
authenticator: MultipeerReplicatorAuthenticator | None = None,
):
super().__init__(1)
self.__peerGroupID = peerGroupID
self.__database = database
self.__collections = collections
self.__identity = identity
self.__authenticator = authenticator

def to_json(self) -> Any:
return {
json = {
"peerGroupID": self.__peerGroupID,
"database": self.__database,
"collections": self.__collections,
"identity": {
"encoding": "PKCS12",
"data": base64.b64encode(self.__identity.pfx_bytes()).decode("utf-8"),
},
}

if self.__authenticator is not None:
json["authenticator"] = self.__authenticator.to_json()

return json


class PostStopMultipeerReplicatorRequestBody(TestServerRequestBody):
"""
Expand Down
2 changes: 1 addition & 1 deletion client/src/cbltest/version.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Final

# For hatchling to easily detect the version
__version__ = "1.2.0"
__version__ = "1.2.1"

# Typed version for outside use
VERSION: Final[str] = __version__
Expand Down
Loading