Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions client/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@ dependencies = [
"aiosignal==1.3.2",
"async-timeout==5.0.1",
"attrs==24.3.0",
"bcrypt==4.2.1",
"certifi==2024.12.14",
"cffi==1.17.1",
"charset-normalizer==3.4.1",
"colorama==0.4.6",
"couchbase==4.3.4",
"cryptography==44.0.0",
"Deprecated==1.2.15",
"exceptiongroup==1.2.2",
"executing==2.1.0",
Expand All @@ -48,15 +51,19 @@ dependencies = [
"opentelemetry-sdk==1.29.0",
"opentelemetry-semantic-conventions==0.50b0",
"packaging==24.2",
"paramiko==3.5.0",
"pluggy==1.5.0",
"propcache==0.2.1",
"protobuf==5.29.3",
"pycparser==2.22",
"PyNaCl==1.5.0",
"pytest==8.3.4",
"pytest-asyncio==0.25.2",
"requests==2.32.3",
"tomli==2.2.1",
"types-Deprecated==1.2.15.20241117",
"types-requests==2.32.0.20241016",
"typing_extensions==4.12.2",
"urllib3==2.3.0",
"websocket-client==1.8.0",
"wrapt==1.17.2",
Expand Down
15 changes: 15 additions & 0 deletions client/src/cbltest/api/cbltestclass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC

import pytest
from packaging.specifiers import SpecifierSet
from packaging.version import Version

from cbltest.api.testserver import TestServer
from cbltest.globals import CBLPyTestGlobal
Expand Down Expand Up @@ -47,3 +49,16 @@ async def skip_if_not_platform(
if variant not in allow_platforms:
self.__skipped = True
pytest.skip(f"{variant} is not in the platforms {allow_platforms}")

async def skip_if_cbl_not(self, server: TestServer, constraint: str):
"""
Skips the test if the CBL version does not match the specified comparison operation and value.

:param constraint: A string representing the comparison operation and version, e.g., ">= 3.3.0".
"""
version_str = (await server.get_info()).library_version.split("-")[0]
version = Version(version_str)
spec = SpecifierSet(constraint)
if version not in spec:
self.__skipped = True
pytest.skip(f"CBL {version_str} not {constraint}")
24 changes: 20 additions & 4 deletions client/src/cbltest/api/multipeer_replicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from opentelemetry.trace import get_tracer

from cbltest.api.database import Database
from cbltest.api.replicator import (
ReplicatorCollectionEntry,
)
from cbltest.api.multipeer_replicator_types import MultipeerReplicatorAuthenticator
from cbltest.api.replicator import ReplicatorCollectionEntry
from cbltest.api.x509_certificate import CertKeyPair, create_leaf_certificate
from cbltest.logging import cbl_error, cbl_trace
from cbltest.requests import TestServerRequestType
from cbltest.v1.requests import (
Expand Down Expand Up @@ -55,11 +55,19 @@ def collections(self) -> list[ReplicatorCollectionEntry]:
"""Gets the collections for the replicator"""
return self.__collections

@property
def identity(self) -> CertKeyPair:
"""Gets the identity used by the replicator"""
return self.__identity

def __init__(
self,
peerGroupID: str,
database: Database,
collections: list[ReplicatorCollectionEntry],
*,
authenticator: MultipeerReplicatorAuthenticator | None = None,
identity: CertKeyPair | None = None,
):
assert database._request_factory.version == 1, (
"This version of the cbl test API requires request API v1"
Expand All @@ -68,6 +76,10 @@ def __init__(
self.__request_factory = database._request_factory
self.__peerGroupID = peerGroupID
self.__database = database
self.__authenticator = authenticator
self.__identity = (
identity if identity is not None else create_leaf_certificate("anonymous")
)
assert len(collections) > 0, "At least one collection is required"
self.__collections = collections
self.__tracer = get_tracer(__name__, VERSION)
Expand All @@ -79,7 +91,11 @@ async def start(self) -> None:
"""
with self.__tracer.start_as_current_span("start_multipeer_replicator"):
payload = PostStartMultipeerReplicatorRequestBody(
self.__peerGroupID, self.__database.name, self.__collections
self.__peerGroupID,
self.__database.name,
self.__collections,
self.__identity,
authenticator=self.__authenticator,
)

req = self.__request_factory.create_request(
Expand Down
40 changes: 40 additions & 0 deletions client/src/cbltest/api/multipeer_replicator_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from abc import abstractmethod
from typing import Any

from cbltest.api.jsonserializable import JSONSerializable
from cbltest.api.x509_certificate import CertKeyPair


class MultipeerReplicatorAuthenticator(JSONSerializable):
"""
The base class for replicator authenticators
"""

@property
def name(self) -> str:
"""Gets the type of authenticator (required for all authenticators)"""
return self.__name

def __init__(self, name: str) -> None:
self.__name = name

@abstractmethod
def to_json(self) -> Any:
pass


class MultipeerReplicatorCAAuthenticator(MultipeerReplicatorAuthenticator):
"""
Represents an authenticator based on a CA certificate. Use the
:class:`cbltest.api.x509_certificate.X509Generator` if you need to generate a CA certificate.
"""

def __init__(self, ca_data: CertKeyPair) -> None:
super().__init__("CA-CERT")
self.__ca_data = ca_data

def to_json(self) -> dict[str, Any]:
return {
"type": self.name,
"params": {"certificate": self.__ca_data.pem_bytes().decode("utf-8")},
}
99 changes: 99 additions & 0 deletions client/src/cbltest/api/x509_certificate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from datetime import datetime, timedelta, timezone

from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.serialization import Encoding, NoEncryption, pkcs12
from cryptography.x509 import (
BasicConstraints,
Certificate,
CertificateBuilder,
ExtendedKeyUsage,
ExtendedKeyUsageOID,
Name,
NameAttribute,
NameOID,
random_serial_number,
)


class CertKeyPair:
"""
A class representing a certificate and its associated private key.
"""

def __init__(
self, certificate: Certificate, private_key: ec.EllipticCurvePrivateKey
):
self.certificate = certificate
self.private_key = private_key

def pfx_bytes(self) -> bytes:
"""
Returns the certificate and private key in PFX format.
"""
ret_val = pkcs12.serialize_key_and_certificates(
name=b"cbltest",
key=self.private_key,
cert=self.certificate,
cas=None,
encryption_algorithm=NoEncryption(),
)

return ret_val

def pem_bytes(self) -> bytes:
"""
Returns the certificate in PEM format.
"""
return self.certificate.public_bytes(encoding=Encoding.PEM)


def create_ca_certificate(CN: str) -> CertKeyPair:
private_key = ec.generate_private_key(ec.SECP256R1())
cn_attribute = Name([NameAttribute(NameOID.COMMON_NAME, CN)])
not_valid_before = datetime.now(timezone.utc)
not_valid_after = not_valid_before + timedelta(days=1)

ca_certificate: Certificate = (
CertificateBuilder()
.subject_name(cn_attribute)
.issuer_name(cn_attribute)
.public_key(private_key.public_key())
.serial_number(random_serial_number())
.not_valid_before(not_valid_before)
.not_valid_after(not_valid_after)
.add_extension(BasicConstraints(ca=True, path_length=None), critical=True)
.sign(private_key, hashes.SHA256())
)

return CertKeyPair(ca_certificate, private_key)


def create_leaf_certificate(
Comment thread
borrrden marked this conversation as resolved.
CN: str, *, issuer_data: CertKeyPair | None = None
) -> CertKeyPair:
private_key = ec.generate_private_key(ec.SECP256R1())
cn_attribute = Name([NameAttribute(NameOID.COMMON_NAME, CN)])
not_valid_before = datetime.now(timezone.utc)
not_valid_after = not_valid_before + timedelta(days=1)
issuer_name = issuer_data.certificate.subject if issuer_data else cn_attribute
signing_key = issuer_data.private_key if issuer_data else private_key

leaf_certificate = (
CertificateBuilder()
.subject_name(cn_attribute)
.issuer_name(issuer_name)
.public_key(private_key.public_key())
.serial_number(random_serial_number())
.not_valid_before(not_valid_before)
.not_valid_after(not_valid_after)
.add_extension(
ExtendedKeyUsage(
[ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH]
),
critical=False,
)
.sign(signing_key, hashes.SHA256())
)

return CertKeyPair(leaf_certificate, private_key)
30 changes: 28 additions & 2 deletions client/src/cbltest/v1/requests.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import base64
from enum import Enum
from typing import Any, cast
from uuid import UUID

from cbltest.api.database_types import DocumentEntry
from cbltest.api.jsonserializable import JSONSerializable
from cbltest.api.multipeer_replicator_types import MultipeerReplicatorAuthenticator
from cbltest.api.replicator_types import (
ReplicatorAuthenticator,
ReplicatorCollectionEntry,
ReplicatorType,
)
from cbltest.api.x509_certificate import CertKeyPair
from cbltest.assertions import _assert_not_null
from cbltest.logging import cbl_warning
from cbltest.requests import TestServerRequest, TestServerRequestBody
Expand Down Expand Up @@ -796,8 +799,17 @@ class PostStartMultipeerReplicatorRequestBody(TestServerRequestBody):
}
}
}
]
],
"identity": {
"encoding": "PKCS12",
"data": "string",
"password": "pass"
},
"authenticator": {
"type": "CA-CERT",
"certificate": "string"
}
}
"""

@property
Expand All @@ -820,19 +832,33 @@ def __init__(
peerGroupID: str,
database: str,
collections: list[ReplicatorCollectionEntry],
identity: CertKeyPair,
*,
authenticator: MultipeerReplicatorAuthenticator | None = None,
):
super().__init__(1)
self.__peerGroupID = peerGroupID
self.__database = database
self.__collections = collections
self.__identity = identity
self.__authenticator = authenticator

def to_json(self) -> Any:
return {
json = {
"peerGroupID": self.__peerGroupID,
"database": self.__database,
"collections": self.__collections,
"identity": {
"encoding": "PKCS12",
"data": base64.b64encode(self.__identity.pfx_bytes()).decode("utf-8"),
},
}

if self.__authenticator is not None:
json["authenticator"] = self.__authenticator.to_json()

return json


class PostStopMultipeerReplicatorRequestBody(TestServerRequestBody):
"""
Expand Down
2 changes: 1 addition & 1 deletion client/src/cbltest/version.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Final

# For hatchling to easily detect the version
__version__ = "1.2.0"
__version__ = "1.2.1"

# Typed version for outside use
VERSION: Final[str] = __version__
Expand Down
Loading