Skip to content

Commit 8bd2832

Browse files
committed
Improve prompt-template msg serialize and sample usage
1 parent 134e52e commit 8bd2832

File tree

8 files changed

+122
-16
lines changed

8 files changed

+122
-16
lines changed

python/samples/concepts/prompt_templates/azure_chat_gpt_api_handlebars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
chat_function = kernel.add_function(
4040
prompt_template_config=PromptTemplateConfig(
4141
template="""{{system_message}}{{#each chat_history}}
42-
{{#message role=role}}{{~content~}}{{/message}} {{/each}}""",
42+
{{message_to_prompt}} {{/each}}""",
4343
template_format="handlebars",
4444
allow_dangerously_set_content=True,
4545
),

python/samples/concepts/prompt_templates/azure_chat_gpt_api_jinja2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
chat_function = kernel.add_function(
4040
prompt_template_config=PromptTemplateConfig(
41-
template="""{{system_message}}{% for item in chat_history %}{{ message(item) }}{% endfor %}""",
41+
template="""{{system_message}}{% for item in chat_history %}{{ message_to_prompt(item) }}{% endfor %}""",
4242
template_format="jinja2",
4343
allow_dangerously_set_content=True,
4444
),

python/semantic_kernel/prompt_template/utils/handlebars_system_helpers.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66
from collections.abc import Callable
77
from enum import Enum
8+
from xml.etree.ElementTree import Element, tostring # nosec B405
89

910
logger: logging.Logger = logging.getLogger(__name__)
1011

@@ -28,21 +29,20 @@ def _message_to_prompt(this, *args, **kwargs):
2829
def _message(this, options, *args, **kwargs):
2930
from semantic_kernel.contents.const import CHAT_MESSAGE_CONTENT_TAG
3031

31-
# everything in kwargs, goes to <ROOT_KEY_MESSAGE kwargs_key="kwargs_value">
32-
# everything in options, goes in between <ROOT_KEY_MESSAGE>options</ROOT_KEY_MESSAGE>
33-
start = f"<{CHAT_MESSAGE_CONTENT_TAG}"
32+
# Everything in kwargs becomes an attribute, and the block output is treated as message text.
33+
message = Element(CHAT_MESSAGE_CONTENT_TAG)
3434
for key, value in kwargs.items():
3535
if isinstance(value, Enum):
3636
value = value.value
3737
if value is not None:
38-
start += f' {key}="{value}"'
39-
start += ">"
40-
end = f"</{CHAT_MESSAGE_CONTENT_TAG}>"
38+
message.set(key, str(value))
4139
try:
42-
content = options["fn"](this)
40+
content = str(options["fn"](this))
4341
except Exception: # pragma: no cover
4442
content = ""
45-
return f"{start}{content}{end}"
43+
if content:
44+
message.text = content
45+
return tostring(message, encoding="unicode", short_empty_elements=False)
4646

4747

4848
def _set(this, *args, **kwargs):

python/semantic_kernel/prompt_template/utils/jinja2_system_helpers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import re
55
from collections.abc import Callable
66
from enum import Enum
7+
from xml.etree.ElementTree import Element, tostring # nosec B405
78

89
logger: logging.Logger = logging.getLogger(__name__)
910

@@ -27,15 +28,14 @@ def _message_to_prompt(context):
2728
def _message(item):
2829
from semantic_kernel.contents.const import CHAT_MESSAGE_CONTENT_TAG
2930

30-
start = f"<{CHAT_MESSAGE_CONTENT_TAG}"
3131
role = item.role
32-
content = item.content
3332
if isinstance(role, Enum):
3433
role = role.value
35-
start += f' role="{role}"'
36-
start += ">"
37-
end = f"</{CHAT_MESSAGE_CONTENT_TAG}>"
38-
return f"{start}{content}{end}"
34+
message = Element(CHAT_MESSAGE_CONTENT_TAG)
35+
message.set("role", str(role))
36+
if item.content:
37+
message.text = item.content
38+
return tostring(message, encoding="unicode", short_empty_elements=False)
3939

4040

4141
# Wrap the _get function to safely handle calls without arguments

python/tests/unit/prompt_template/test_handlebars_prompt_template.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,26 @@ async def test_helpers_message(kernel: Kernel):
261261
assert "Assistant message" in rendered
262262

263263

264+
async def test_helpers_message_escapes_xml_metacharacters(kernel: Kernel):
265+
template = """
266+
{{#each chat_history}}
267+
{{#message role=role}}
268+
{{~content~}}
269+
{{/message}}
270+
{{/each}}
271+
"""
272+
target = create_handlebars_prompt_template(template, allow_dangerously_set_content=True)
273+
chat_history = ChatHistory()
274+
chat_history.add_user_message('What does a < b & "c" mean?')
275+
276+
rendered = await target.render(kernel, KernelArguments(chat_history=chat_history))
277+
278+
assert "&lt;" in rendered
279+
assert "&amp;" in rendered
280+
assert '"c"' in rendered
281+
assert ChatHistory.from_rendered_prompt(rendered) == chat_history
282+
283+
264284
async def test_helpers_message_to_prompt(kernel: Kernel):
265285
template = """{{#each chat_history}}{{message_to_prompt}} {{/each}}"""
266286
target = create_handlebars_prompt_template(template, allow_dangerously_set_content=True)

python/tests/unit/prompt_template/test_handlebars_prompt_template_e2e.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from semantic_kernel import Kernel
55
from semantic_kernel.contents.chat_history import ChatHistory
6+
from semantic_kernel.contents.utils.author_role import AuthorRole
67
from semantic_kernel.functions import kernel_function
78
from semantic_kernel.functions.kernel_arguments import KernelArguments
89
from semantic_kernel.prompt_template.handlebars_prompt_template import HandlebarsPromptTemplate
@@ -100,3 +101,39 @@ async def test_chat_history_round_trip(self, kernel: Kernel):
100101
)
101102
chat_history2 = ChatHistory.from_rendered_prompt(rendered)
102103
assert chat_history2 == chat_history
104+
105+
async def test_chat_history_round_trip_with_xml_metacharacters(self, kernel: Kernel):
106+
# Arrange
107+
template = """{{#each chat_history}}{{#message role=role}}{{~content~}}{{/message}} {{/each}}"""
108+
target = create_handlebars_prompt_template(template)
109+
chat_history = ChatHistory()
110+
chat_history.add_user_message("What does a < b mean in Python?")
111+
chat_history.add_assistant_message('Use "&" carefully in XML and HTML.')
112+
113+
rendered = await target.render(kernel, KernelArguments(chat_history=chat_history))
114+
115+
assert "&lt;" in rendered
116+
assert "&amp;" in rendered
117+
assert '"&amp;"' in rendered
118+
assert ChatHistory.from_rendered_prompt(rendered) == chat_history
119+
120+
async def test_message_helper_preserves_system_role_with_xml_metacharacters(self, kernel: Kernel):
121+
# Arrange
122+
template = (
123+
"""{{system_message}}{{#each chat_history}}{{#message role=role}}{{~content~}}{{/message}} {{/each}}"""
124+
)
125+
target = create_handlebars_prompt_template(template)
126+
system_message = "You are a helpful assistant."
127+
chat_history = ChatHistory()
128+
chat_history.add_user_message("What does a < b mean in Python?")
129+
130+
rendered = await target.render(
131+
kernel,
132+
KernelArguments(system_message=system_message, chat_history=chat_history),
133+
)
134+
135+
parsed = ChatHistory.from_rendered_prompt(rendered)
136+
assert parsed.messages[0].role == AuthorRole.SYSTEM
137+
assert parsed.messages[0].content == system_message
138+
assert parsed.messages[1].role == AuthorRole.USER
139+
assert parsed.messages[1].content == "What does a < b mean in Python?"

python/tests/unit/prompt_template/test_jinja2_prompt_template.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,20 @@ async def test_helpers_message(kernel: Kernel):
264264
assert "Assistant message" in rendered
265265

266266

267+
async def test_helpers_message_escapes_xml_metacharacters(kernel: Kernel):
268+
template = """{% for item in chat_history %}{{ message(item) }}{% endfor %}"""
269+
target = create_jinja2_prompt_template(template, allow_dangerously_set_content=True)
270+
chat_history = ChatHistory()
271+
chat_history.add_user_message('What does a < b & "c" mean?')
272+
273+
rendered = await target.render(kernel, KernelArguments(chat_history=chat_history))
274+
275+
assert "&lt;" in rendered
276+
assert "&amp;" in rendered
277+
assert '"c"' in rendered
278+
assert ChatHistory.from_rendered_prompt(rendered) == chat_history
279+
280+
267281
async def test_helpers_message_to_prompt(kernel: Kernel):
268282
template = """
269283
{% for chat in chat_history %}

python/tests/unit/prompt_template/test_jinja2_prompt_template_e2e.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
from semantic_kernel.contents.chat_history import ChatHistory
5+
from semantic_kernel.contents.utils.author_role import AuthorRole
56
from semantic_kernel.functions import kernel_function
67
from semantic_kernel.functions.kernel_arguments import KernelArguments
78
from semantic_kernel.kernel import Kernel
@@ -104,3 +105,37 @@ async def test_chat_history_round_trip(kernel: Kernel):
104105
)
105106
chat_history2 = ChatHistory.from_rendered_prompt(rendered)
106107
assert chat_history2 == chat_history
108+
109+
110+
async def test_chat_history_round_trip_with_xml_metacharacters(kernel: Kernel):
111+
template = """{% for item in chat_history %}{{ message(item) }}{% endfor %}"""
112+
target = create_jinja2_prompt_template(template)
113+
chat_history = ChatHistory()
114+
chat_history.add_user_message("What does a < b mean in Python?")
115+
chat_history.add_assistant_message('Use "&" carefully in XML and HTML.')
116+
117+
rendered = await target.render(kernel, KernelArguments(chat_history=chat_history))
118+
119+
assert "&lt;" in rendered
120+
assert "&amp;" in rendered
121+
assert '"&amp;"' in rendered
122+
assert ChatHistory.from_rendered_prompt(rendered) == chat_history
123+
124+
125+
async def test_message_helper_preserves_system_role_with_xml_metacharacters(kernel: Kernel):
126+
template = """{{system_message}}{% for item in chat_history %}{{ message(item) }}{% endfor %}"""
127+
target = create_jinja2_prompt_template(template)
128+
system_message = "You are a helpful assistant."
129+
chat_history = ChatHistory()
130+
chat_history.add_user_message("What does a < b mean in Python?")
131+
132+
rendered = await target.render(
133+
kernel,
134+
KernelArguments(system_message=system_message, chat_history=chat_history),
135+
)
136+
137+
parsed = ChatHistory.from_rendered_prompt(rendered)
138+
assert parsed.messages[0].role == AuthorRole.SYSTEM
139+
assert parsed.messages[0].content == system_message
140+
assert parsed.messages[1].role == AuthorRole.USER
141+
assert parsed.messages[1].content == "What does a < b mean in Python?"

0 commit comments

Comments
 (0)