Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
- id: add-trailing-comma
args: [--py36-plus]
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.2
rev: v3.16.0
hooks:
- id: pyupgrade
args: [--py37-plus]
Expand All @@ -33,6 +33,6 @@ repos:
- id: nbqa-isort
args: ["--float-to-top"]
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
rev: 7.1.0
hooks:
- id: flake8
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ install:
@$(MAKE) install-kernel

ensure-poetry:
@if [ "$(shell which poetry)" = "" ]; then \
@if [ -z "$$(which poetry)" ]; then \
echo "Did you activate the outer conda environment? Run: conda activate llm-math-education"; \
exit 1; \
else \
echo "Found existing Poetry installation at $(shell which poetry)."; \
echo "Found existing Poetry installation at $$(which poetry)"; \
fi
@poetry install

Expand Down
65 changes: 40 additions & 25 deletions data/app_data/question_sample.csv

Large diffs are not rendered by default.

34 changes: 11 additions & 23 deletions src/llm_math_education/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, intro_prompt_dict: dict):
self.intro_prompt_dict = intro_prompt_dict
self.pretty_name_to_id_map = {
t[1]["pretty_name"] if "pretty_name" in t[1] else f"Prompt {i}": t[0]
for i, t in enumerate(self.intro_prompt_dict)
for i, t in enumerate(self.intro_prompt_dict.items())
}

def get_intro_prompt_pretty_names(self):
Expand All @@ -28,45 +28,38 @@ def get_intro_prompt_pretty_names(self):
return pretty_name_list

def get_intro_prompt_message_lists(self) -> list[dict[str, str]]:
message_lists = []
for prompt_info in self.intro_prompt_dict.values():
message_lists.append(prompt_info["messages"])
return message_lists
return [prompt_info["messages"] for prompt_info in self.intro_prompt_dict.values()]

def get_default_intro_prompt(self) -> dict[str]:
return self.intro_prompt_dict[next(iter(self.intro_prompt_dict.keys()))]

@staticmethod
def convert_conversation_to_string(messages):
conversation_string = ""
for message in messages:
conversation_string += message["role"].upper() + ":\n"
conversation_string += message["content"] + "\n"
return conversation_string

@staticmethod
def convert_string_to_conversation(conversation_string: str) -> list[dict[str, str]]:
"""Given a string representing a conversation, convert into the expected messages list format.

Follows a pretty basic convention, defined in this implementation.

Args:
conversation_string (str): String representing a conversation.

Returns:
list[dict[str, str]]: List of messages, each with a "role" and "content".
"""
messages = []
message = {
"content": "",
}
message = {"content": ""}
for line in conversation_string.split("\n"):
possible_role = line[:-1].lower()
if possible_role in VALID_ROLES:
if "role" in message:
message["content"] = message["content"].strip()
messages.append(message)
message = {
"content": "",
}
message = {"content": ""}
message["role"] = possible_role
else:
message["content"] += line + "\n"
Expand Down Expand Up @@ -127,12 +120,9 @@ def build_query(
if previous_messages is None:
previous_messages = self.stored_messages
if len(previous_messages) == 0:
# this is a new query
messages = [message.copy() for message in self.intro_messages]
self.stored_messages.extend(messages)
else:
# not a new query,
# so include the previous messages as context
messages = [message.copy() for message in previous_messages]
if user_query is not None:
user_message = {
Expand All @@ -157,18 +147,17 @@ def build_query(
self.recent_slot_fill_dict.append(slot_fill_dict)
assert len(slot_fill_dict) == len(expected_slots), "Unexpected fill provided."
if "user_query" in slot_fill_dict and user_query is not None:
# special case: fill user_query slots with the current user_query
slot_fill_dict["user_query"] = user_query
should_remove_user_query_message = True
try:
message["content"] = message["content"].format(**slot_fill_dict)
except KeyError:
raise KeyError(f"Failed to fill {expected_slots} with {slot_fill_dict}.")
except KeyError as e:
raise KeyError(f"Failed to fill {expected_slots} with {slot_fill_dict}. Missing key: {e}")
except ValueError as e:
raise ValueError(f"Formatting error: {e}")
else:
self.recent_slot_fill_dict.append({})
if query_for_retrieval_context == "" and message["role"] == "user":
# use as retrieval context the most recent user message
# TODO rethink this, providing a more flexible way to specify the retrieval context
query_for_retrieval_context = message["content"]
self.recent_slot_fill_dict = self.recent_slot_fill_dict[::-1]
if should_remove_user_query_message:
Expand All @@ -182,11 +171,10 @@ def compute_stored_token_counts(self) -> int:
total_token_count = sum(token_counts)
return total_token_count

@staticmethod
def identify_slots(prompt_string: str) -> list[str]:
"""Uses a regex to identify missing slots in a prompt_string.

More advanced slot formatting is not supported.

Args:
prompt_string (str): The prompt itself, with format-style slots to fill e.g. "This is a prompt with a slot: {slot_to_fill}"

Expand Down
2 changes: 1 addition & 1 deletion src/streamlit_app/auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def validate_openai_api_key(openai_api_key: str) -> bool:
# temporarily override api key used
openai.api_key = openai_api_key
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0613",
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Hi"}],
request_timeout=10,
max_tokens=1,
Expand Down
2 changes: 1 addition & 1 deletion src/usage_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
import openai

completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0613",
model="gpt-4o-mini",
messages=messages,
)
assistant_message = completion["choices"][0]["message"]
Expand Down
2 changes: 1 addition & 1 deletion src/🤖_Math_QA.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def process_user_query(user_query: str):
s += "#### " + key + ":\n\n" + value.replace("\n", "\n\n") + "\n\n\n"
st.markdown(s)
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0613",
model="gpt-4o-mini",
messages=messages,
temperature=st.session_state.temperature,
request_timeout=20,
Expand Down