Skip to content

Commit 17656c3

Browse files
authored
Add support for executemany() (#810)
* Add support for bulk execution with named parameters in executemany() * Add validation for parameter rows in executemany() and add tests
1 parent 28ee021 commit 17656c3

2 files changed

Lines changed: 132 additions & 1 deletion

File tree

src/crate/client/cursor.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,42 @@ def _replace(match: "re.Match[str]") -> str:
7070
return converted_sql, new_params
7171

7272

73+
def _convert_named_bulk_params(
74+
sql: str, seq_of_dicts: t.Sequence[t.Dict[str, t.Any]]
75+
) -> t.Tuple[str, t.List[t.List[t.Any]]]:
76+
"""Convert pyformat SQL and a sequence of dicts to positional bulk args.
77+
78+
Uses the first row to determine the SQL template and position map, then
79+
builds a positional argument list for every row.
80+
81+
Raises ``ProgrammingError`` if a placeholder name is absent from any row.
82+
Extra keys in each row are silently ignored (consistent with
83+
``_convert_named_to_positional``).
84+
"""
85+
first = seq_of_dicts[0]
86+
converted_sql, _ = _convert_named_to_positional(sql, first)
87+
positions = {k: i + 1 for i, k in enumerate(first)}
88+
n = len(positions)
89+
90+
bulk_args: t.List[t.List[t.Any]] = []
91+
for row in seq_of_dicts:
92+
if not isinstance(row, dict):
93+
raise ProgrammingError(
94+
"executemany() requires all parameter rows to be dicts "
95+
"when the SQL uses pyformat (%(name)s) placeholders"
96+
)
97+
positional: t.List[t.Any] = [None] * n
98+
for name, pos in positions.items():
99+
if name not in row:
100+
raise ProgrammingError(
101+
f"Named parameter '{name}' not found in the parameters dict"
102+
)
103+
positional[pos - 1] = row[name]
104+
bulk_args.append(positional)
105+
106+
return converted_sql, bulk_args
107+
108+
73109
class Cursor:
74110
"""
75111
not thread-safe by intention
@@ -118,7 +154,16 @@ def executemany(self, sql, seq_of_parameters):
118154
"""
119155
row_counts = []
120156
durations = []
121-
self.execute(sql, bulk_parameters=seq_of_parameters)
157+
bulk_parameters = seq_of_parameters
158+
if (
159+
bulk_parameters
160+
and isinstance(bulk_parameters[0], dict)
161+
and _NAMED_PARAM_RE.search(sql)
162+
):
163+
sql, bulk_parameters = _convert_named_bulk_params(
164+
sql, bulk_parameters
165+
)
166+
self.execute(sql, bulk_parameters=bulk_parameters)
122167

123168
for result in self._result.get("results", []):
124169
if result.get("rowcount") > -1:

tests/client/test_cursor.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,92 @@ def test_cursor_executemany(mocked_connection):
125125
assert response["results"] == result
126126

127127

128+
def test_executemany_with_named_params(mocked_connection):
129+
"""
130+
Verify that executemany() translates pyformat %(name)s placeholders to
131+
positional $N markers and converts each dict row to a positional list.
132+
133+
"""
134+
response = {
135+
"col_types": [],
136+
"cols": [],
137+
"duration": 123,
138+
"results": [{"rowcount": 1}, {"rowcount": 1}],
139+
}
140+
with mock.patch.object(
141+
mocked_connection.client, "sql", return_value=response
142+
):
143+
cursor = mocked_connection.cursor()
144+
cursor.executemany(
145+
"INSERT INTO characters (name, age) VALUES (%(name)s, %(age)s)",
146+
[
147+
{"name": "Arthur", "age": 42},
148+
{"name": "Bill", "age": 35},
149+
],
150+
)
151+
sql, _params, bulk_args = mocked_connection.client.sql.call_args[0]
152+
assert sql == "INSERT INTO characters (name, age) VALUES ($1, $2)"
153+
assert bulk_args == [["Arthur", 42], ["Bill", 35]]
154+
155+
156+
def test_executemany_with_named_params_missing_key(mocked_connection):
157+
"""
158+
Verify that executemany() raises ProgrammingError when a row is missing a
159+
key that appears as a placeholder in the SQL.
160+
"""
161+
cursor = mocked_connection.cursor()
162+
with pytest.raises(
163+
ProgrammingError, match="Named parameter 'age' not found"
164+
):
165+
cursor.executemany(
166+
"INSERT INTO characters (name, age) VALUES (%(name)s, %(age)s)",
167+
[
168+
{"name": "Arthur", "age": 42},
169+
{"name": "Bill"}, # missing 'age'
170+
],
171+
)
172+
mocked_connection.client.sql.assert_not_called()
173+
174+
175+
def test_executemany_with_named_params_repeated(mocked_connection):
176+
"""
177+
Verify that a placeholder name used multiple times in the SQL maps to the
178+
same $N position in every occurrence, and the value appears only once in
179+
each row's positional list.
180+
"""
181+
response = {
182+
"col_types": [],
183+
"cols": [],
184+
"duration": 123,
185+
"results": [{"rowcount": 1}, {"rowcount": 1}],
186+
}
187+
with mock.patch.object(
188+
mocked_connection.client, "sql", return_value=response
189+
):
190+
cursor = mocked_connection.cursor()
191+
cursor.executemany(
192+
"INSERT INTO t (a, b) VALUES (%(x)s, %(x)s)",
193+
[{"x": 1}, {"x": 2}],
194+
)
195+
sql, _params, bulk_args = mocked_connection.client.sql.call_args[0]
196+
assert sql == "INSERT INTO t (a, b) VALUES ($1, $1)"
197+
assert bulk_args == [[1], [2]]
198+
199+
200+
def test_executemany_with_mixed_param_types(mocked_connection):
201+
"""
202+
Verify that executemany() raises a clear ProgrammingError when the
203+
parameter sequence mixes dicts and non-dicts while the SQL uses pyformat.
204+
"""
205+
cursor = mocked_connection.cursor()
206+
with pytest.raises(ProgrammingError, match="requires all parameter rows"):
207+
cursor.executemany(
208+
"INSERT INTO characters (name) VALUES (%(name)s)",
209+
[{"name": "Arthur"}, ["Trillian"]], # second row is a list
210+
)
211+
mocked_connection.client.sql.assert_not_called()
212+
213+
128214
def test_create_with_timezone_as_datetime_object(mocked_connection):
129215
"""
130216
The cursor can return timezone-aware `datetime` objects when requested.

0 commit comments

Comments
 (0)