Skip to content
Open
63 changes: 60 additions & 3 deletions awswrangler/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,65 @@ def _build_cluster_args(**pars: Any) -> dict[str, Any]: # noqa: PLR0912,PLR0915
args["Applications"] = [{"Name": x} for x in pars["applications"]]

# Bootstraps
if pars["bootstraps_paths"]:
args["BootstrapActions"] = [{"Name": x, "ScriptBootstrapAction": {"Path": x}} for x in pars["bootstraps_paths"]]
# if pars["bootstraps_paths"]:
# args["BootstrapActions"] = [{"Name": x, "ScriptBootstrapAction": {"Path": x}} for x in pars["bootstraps_paths"]]

bootstraps = pars.get("bootstraps")
bootstraps_paths = pars.get("bootstraps_paths")

# Backward compatibility
if bootstraps is None and bootstraps_paths:
bootstraps = bootstraps_paths

if bootstraps:
bootstrap_actions = []

for item in bootstraps:
# Case 1: Simple string path
if isinstance(item, str):
bootstrap_actions.append(
{
"Name": "bootstrap",
"ScriptBootstrapAction": {
"Path": item,
},
}
)

# Case 2: Dictionary bootstrap action
elif isinstance(item, dict):
# Already in EMR expected format
if "ScriptBootstrapAction" in item:
bootstrap_actions.append(item)
continue

# New simplified format
name = item.get("name", "bootstrap")
path = item.get("path")

if path is None:
raise ValueError("Bootstrap dict must include a 'path' key.")

args_list = item.get("args", [])

bootstrap_actions.append(
{
"Name": name,
"ScriptBootstrapAction": {
"Path": path,
"Args": args_list,
},
}
)

# ✅ THIS WAS MISSING
if bootstrap_actions:
args["BootstrapActions"] = bootstrap_actions

else:
raise TypeError("Each bootstrap must be a string or a dict.")

args["BootstrapActions"] = bootstrap_actions
Comment on lines +352 to +355
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already Raised error, the line #335 should never be executed.


# Debugging and Steps
if (pars["debugging"] is True) or (pars["steps"] is not None):
Expand Down Expand Up @@ -470,6 +527,7 @@ def create_cluster( # noqa: PLR0913
consistent_view_retry_count: int = 5,
consistent_view_table_name: str = "EmrFSMetadata",
bootstraps_paths: list[str] | None = None,
bootstraps: list[str | dict[str, Any]] | None = None,
debugging: bool = True,
applications: list[str] | None = None,
visible_to_all_users: bool = True,
Expand Down Expand Up @@ -747,7 +805,6 @@ def create_cluster( # noqa: PLR0913
args: dict[str, Any] = _build_cluster_args(**locals())
client_emr = _utils.client(service_name="emr", session=boto3_session)
response = client_emr.run_job_flow(**args)
_logger.debug("response: \n%s", pprint.pformat(response))
return response["JobFlowId"]


Expand Down
52 changes: 52 additions & 0 deletions tests/unit/test_emr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import time
from unittest.mock import MagicMock, patch

import pytest

Expand Down Expand Up @@ -185,3 +186,54 @@ def test_docker(bucket, cloudformation_outputs, emr_security_configuration):
)
def test_get_emr_integer_version(version, result):
assert wr.emr._get_emr_classification_lib(version) == result


def test_create_cluster_bootstrap_with_args():
fake_emr_client = MagicMock()
fake_emr_client.run_job_flow.return_value = {"JobFlowId": "j-123"}

with patch("awswrangler.sts.get_account_id", return_value="123456789012"):
with patch("awswrangler._utils.get_region_from_session", return_value="us-east-1"):
with patch("awswrangler._utils.client", return_value=fake_emr_client):
wr.emr.create_cluster(
cluster_name="test",
subnet_id="subnet-12345678",
bootstraps=[
{
"name": "cw agent",
"path": "s3://bucket/install.sh",
"args": ["--target-account", "121213"],
}
],
)

args = fake_emr_client.run_job_flow.call_args[1]

bootstrap = args["BootstrapActions"][0]

assert bootstrap["Name"] == "cw agent"
assert bootstrap["ScriptBootstrapAction"]["Path"] == "s3://bucket/install.sh"
assert bootstrap["ScriptBootstrapAction"]["Args"] == [
"--target-account",
"121213",
]


def test_create_cluster_bootstrap_paths_still_work():
fake_emr_client = MagicMock()
fake_emr_client.run_job_flow.return_value = {"JobFlowId": "j-123"}

with patch("awswrangler.sts.get_account_id", return_value="123456789012"):
with patch("awswrangler._utils.get_region_from_session", return_value="us-east-1"):
with patch("awswrangler._utils.client", return_value=fake_emr_client):
wr.emr.create_cluster(
cluster_name="test",
subnet_id="subnet-12345678",
bootstraps_paths=["s3://bucket/old.sh"],
)

args = fake_emr_client.run_job_flow.call_args[1]

bootstrap = args["BootstrapActions"][0]

assert bootstrap["ScriptBootstrapAction"]["Path"] == "s3://bucket/old.sh"