Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ _build/*
.DS_Store
.coverage
.edsl_cache
.edsl_objects/
.env
.ipynb_checkpoints/
.mypy_cache
Expand Down
36 changes: 35 additions & 1 deletion edsl/inference_services/services/anthropic_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
from typing import Any, Optional, List, TYPE_CHECKING
from anthropic import AsyncAnthropic

Expand All @@ -23,6 +24,8 @@ class AnthropicService(InferenceServiceABC):
input_token_name = "input_tokens"
output_token_name = "output_tokens"
available_models_url = "https://docs.anthropic.com/en/docs/about-claude/models"
_temperature_deprecation_date = 20260205
_temperature_deprecation_version = (4, 6)

@classmethod
def get_model_info(cls, api_key: Optional[str] = None):
Expand All @@ -36,6 +39,37 @@ def get_model_info(cls, api_key: Optional[str] = None):
response.raise_for_status()
return response.json()["data"]

@classmethod
def _requires_temperature_one(cls, model_name: str) -> bool:
"""Return whether Anthropic only accepts temperature=1.0 for this model."""
model_name = model_name.lower()
date_match = re.search(r"(?<!\d)(\d{8})(?!\d)", model_name)
if date_match:
return int(date_match.group(1)) > cls._temperature_deprecation_date

version_match = re.search(
r"claude-(?P<family>opus|sonnet|haiku)-(?P<major>\d+)-(?P<minor>\d+)",
model_name,
)
if not version_match:
return False

version = (
int(version_match.group("major")),
int(version_match.group("minor")),
)
if version > cls._temperature_deprecation_version:
return True

family = version_match.group("family")
return version == cls._temperature_deprecation_version and family != "opus"
Comment thread
greptile-apps[bot] marked this conversation as resolved.

@classmethod
def _api_temperature(cls, model_name: str, temperature: float) -> float:
if cls._requires_temperature_one(model_name):
return 1.0
return temperature

@classmethod
def create_model(
cls, model_name: str = "claude-3-opus-20240229", model_class_name=None
Expand Down Expand Up @@ -152,7 +186,7 @@ async def async_execute_model_call(
create_kwargs = dict(
model=model_name,
max_tokens=self.max_tokens,
temperature=self.temperature,
temperature=cls._api_temperature(model_name, self.temperature),
system=system_prompt, # note that the Anthropic API uses "system" parameter rather than put it in the message
messages=messages,
)
Expand Down
12 changes: 11 additions & 1 deletion edsl/jobs/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,17 @@ def _remote_results(
"Remote execution completed but results could not be retrieved."
)
else:
return None, None
if self.run_config.parameters.disable_remote_inference:
return None, None

from .exceptions import JobsRunError

raise JobsRunError(
"Remote execution was requested, but remote inference is not "
"available. Check EXPECTED_PARROT_URL, EXPECTED_PARROT_API_KEY, "
"and the remote inference setting. To run locally, pass "
"disable_remote_inference=True or offload_execution=False."
)
Comment thread
greptile-apps[bot] marked this conversation as resolved.

def _prepare_to_run(self) -> None:
"""Prepare the job to run and ensure keys are in place for a remote job."""
Expand Down
8 changes: 4 additions & 4 deletions edsl/object_store/CLAUDE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# object_store Module

A git-like content-addressable storage (CAS) system for versioning EDSL domain objects. Provides local persistence via `platformdirs` (default `~/Library/Application Support/edsl/objects/` on macOS), branching, commit history, diffing, and push/pull sync to a remote HTTP service.
A git-like content-addressable storage (CAS) system for versioning EDSL domain objects. Provides local persistence in the current working directory (`.edsl_objects/`), branching, commit history, diffing, and push/pull sync to a remote HTTP service.

## Architecture

Expand Down Expand Up @@ -155,7 +155,7 @@ AgentList.store.branches(uuid) # list branches
AgentList.store.pull(uuid) # pull from remote and return loaded object
```

All methods accept `root=` to override the default platformdirs location.
All methods accept `root=` to override the default current-working-directory location.

### Accessor State Management

Expand Down Expand Up @@ -262,14 +262,14 @@ Both SQLite and PostgreSQL implementations share identical schemas. The PostgreS

## Configuration and Defaults

- **Default root**: `Path(platformdirs.user_data_dir("edsl")) / "objects"` (e.g. `~/Library/Application Support/edsl/objects/` on macOS)
- **Default root**: `Path.cwd() / ".edsl_objects"`
- **Default remote URL**: Read from `CONFIG.get("EDSL_CAS_URL")` (via `edsl.config`)
- **Auth token**: Read from `EXPECTED_PARROT_API_KEY` environment variable
- **Default branch**: `"main"`
- **Min UUID prefix length**: 4 characters

## Dependencies

- **Required**: `platformdirs` (for default storage path), `dotenv` (for env loading in store_accessor)
- **Required**: `dotenv` (for env loading in store_accessor)
- **Optional**: `google-cloud-storage` for GCSBackend (`pip install edsl[gcp]`), `psycopg2` for PostgreSQLMetadataIndex
- **Stdlib only**: HttpBackend uses `urllib.request` (no `requests` dependency)
9 changes: 5 additions & 4 deletions edsl/object_store/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from uuid import uuid4
from typing import Callable, Optional, Type

import platformdirs

from .cas_repository import CASRepository
from .exceptions import AmbiguousUUIDError
from .fs_backend import FileSystemBackend
Expand Down Expand Up @@ -219,15 +217,18 @@ class ObjectStore:
[]
"""

DEFAULT_ROOT = Path(platformdirs.user_data_dir("edsl")) / "objects"
@staticmethod
def default_root() -> Path:
"""Return the default local object store root for the current process."""
return Path.cwd() / ".edsl_objects"

def __init__(
self,
root: Optional[str | Path] = None,
backend_factory=None,
metadata_index=None,
):
self.root = Path(root) if root else self.DEFAULT_ROOT
self.root = Path(root) if root else self.default_root()
if backend_factory is None:
self.root.mkdir(parents=True, exist_ok=True)

Expand Down
4 changes: 2 additions & 2 deletions edsl/runner/cas_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class RunnerCASIntegration:
job_id: The runner job ID.
survey: The Survey for this job (used to build the JSONL preamble).
service: The JobService to register the callback on.
root: CAS directory root. Defaults to ObjectStore.DEFAULT_ROOT / uuid.
root: CAS directory root. Defaults to ObjectStore.default_root() / uuid.
batch_size: Number of completed interviews to accumulate before
flushing a CAS commit. Default 1 = commit per interview.
uuid: Optional UUID for the CAS object. Auto-generated if omitted.
Expand All @@ -58,7 +58,7 @@ def __init__(
self._batch_size = batch_size
self._pending_ids: list[str] = [] # interview IDs awaiting flush

cas_root = root or (ObjectStore.DEFAULT_ROOT / self._uuid)
cas_root = root or (ObjectStore.default_root() / self._uuid)
self._backend = FileSystemBackend(cas_root)
self._writer = StreamingCASWriter(self._backend, branch="main")

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ optional = true

[tool.poetry.dependencies.typer]
extras = [ "all",]
version = "^0.9.0"
version = ">=0.12,<1"

[tool.tomlsort.overrides."tool.poetry.dependencies"]
table_keys = false
Expand Down
74 changes: 74 additions & 0 deletions tests/inference_services/test_anthropic_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import asyncio

from edsl.inference_services.services.anthropic_service import AnthropicService


def test_requires_temperature_one_for_models_after_opus_46():
assert AnthropicService._requires_temperature_one(
"claude-sonnet-4-6-20260217"
)
assert AnthropicService._requires_temperature_one("claude-opus-4-7-20260416")
assert AnthropicService._requires_temperature_one("claude-opus-4-7")
assert not AnthropicService._requires_temperature_one("claude-opus-4-6-20260205")
assert not AnthropicService._requires_temperature_one("claude-opus-4-5-20251124")
assert not AnthropicService._requires_temperature_one("claude-3-5-sonnet-20241022")
Comment thread
greptile-apps[bot] marked this conversation as resolved.


def test_anthropic_request_uses_temperature_one_for_affected_models(monkeypatch):
captured_kwargs = {}

class DummyResponse:
def model_dump(self):
return {"content": [{"type": "text", "text": "ok"}]}

class DummyMessages:
async def create(self, **kwargs):
captured_kwargs.update(kwargs)
return DummyResponse()

class DummyAnthropicClient:
def __init__(self, api_key):
self.messages = DummyMessages()

monkeypatch.setattr(
"edsl.inference_services.services.anthropic_service.AsyncAnthropic",
DummyAnthropicClient,
)

model_class = AnthropicService.create_model("claude-sonnet-4-6-20260217")
model = model_class(temperature=0.2, skip_api_key_check=True)

asyncio.run(model.async_execute_model_call("hello"))

assert captured_kwargs["temperature"] == 1.0
assert model.temperature == 0.2
assert model.parameters["temperature"] == 0.2


def test_anthropic_request_preserves_temperature_for_legacy_models(monkeypatch):
captured_kwargs = {}

class DummyResponse:
def model_dump(self):
return {"content": [{"type": "text", "text": "ok"}]}

class DummyMessages:
async def create(self, **kwargs):
captured_kwargs.update(kwargs)
return DummyResponse()

class DummyAnthropicClient:
def __init__(self, api_key):
self.messages = DummyMessages()

monkeypatch.setattr(
"edsl.inference_services.services.anthropic_service.AsyncAnthropic",
DummyAnthropicClient,
)

model_class = AnthropicService.create_model("claude-opus-4-5-20251124")
model = model_class(temperature=0.2, skip_api_key_check=True)

asyncio.run(model.async_execute_model_call("hello"))

assert captured_kwargs["temperature"] == 0.2
23 changes: 23 additions & 0 deletions tests/jobs/test_Jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from edsl.caching import Cache

from edsl.agents import AgentList
from edsl.jobs.exceptions import JobsRunError
from edsl.language_models import ModelList, Model, LanguageModel


Expand Down Expand Up @@ -55,6 +56,28 @@ def test_jobs_simple_stuf(valid_job):
assert Jobs.from_dict(empty_job.to_dict()).to_dict() == empty_job.to_dict()


def test_offload_execution_does_not_fallback_when_remote_unavailable(
valid_job, monkeypatch
):
monkeypatch.setattr(
"edsl.jobs.remote_inference.JobsRemoteInferenceHandler.use_remote_inference",
lambda self, disable_remote_inference: False,
)

with pytest.raises(JobsRunError, match="Remote execution was requested"):
valid_job.run(cache=False, offload_execution=True)


def test_disable_remote_inference_allows_local_execution(valid_job, monkeypatch):
monkeypatch.setattr(
"edsl.jobs.remote_inference.JobsRemoteInferenceHandler.use_remote_inference",
lambda self, disable_remote_inference: False,
)

results = valid_job.run(cache=False, disable_remote_inference=True)
assert results is not None


def test_jobs_by_agents():
q = QuestionMultipleChoice(
question_text="How are you?",
Expand Down
Loading