Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def retrieve_parallel(self, prompt, corpus, model_A, model_B):
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(self.retrieve, prompt, corpus, model) for model in model_names]
results = [future.result() for future in futures]
return results[0], results[1], model_names[0], model_names[1]
return results[0][0], results[1][0], model_names[0], model_names[1], results[0][1], results[1][1]

@spaces.GPU(duration=120)
def retrieve(self, query, corpus, model_name, topk=1):
Expand All @@ -226,9 +226,9 @@ def retrieve(self, query, corpus, model_name, topk=1):
index = self.load_bm25_index(model_name, corpus)
docs = index.search([query], topk=topk)
if corpus == "stackexchange":
return [[query, corpus_format.format(text=docs[0][0]["text"])]]
return [[query, corpus_format.format(text=docs[0][0]["text"])]], None
else:
return [[query, corpus_format.format(title=docs[0][0]["title"], text=docs[0][0]["text"])]]
return [[query, corpus_format.format(title=docs[0][0]["title"], text=docs[0][0]["text"])]], None

model = self.load_model(model_name)
kwargs = {} if self.use_gcp_index else {"convert_to_tensor": True}
Expand All @@ -255,8 +255,8 @@ def retrieve(self, query, corpus, model_name, topk=1):
else:
index = self.load_local_index(model_name, corpus)
docs, scores = index.search_knn(query_embed, topk=topk)
docs = [[query, corpus_format.format(title=docs[0].get("title", ""), text=docs[0][0]["text"])]]
return docs
docs = [[query, corpus_format.format(title=docs[0][0].get("title", ""), text=docs[0][0]["text"])]]
return docs, query_embed

def clustering_parallel(self, prompt, model_A, model_B, ncluster=1, ndim="3D", dim_method="PCA", clustering_method="KMeans"):
if model_A == "" and model_B == "":
Expand Down Expand Up @@ -512,3 +512,4 @@ def get_model_description_md(self, task_type="retrieval"):
model_description_md += "\n"
ct += 1
return model_description_md

107 changes: 85 additions & 22 deletions ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import time
import os
import uuid

import torch
import gradio as gr
import atexit
import shutil
Comment thread
SaitejaUtpala marked this conversation as resolved.
Outdated


from log_utils import build_logger, store_data_in_hub

Expand All @@ -21,6 +24,20 @@
clustering_logger = build_logger("gradio_clustering", "gradio_clustering.log")
sts_logger = build_logger("gradio_sts", "gradio_sts.log")


def save_tensor(tensor, file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
torch.save(tensor, file_path)


def cleanup_dirs(*dir_paths):
for dir_path in dir_paths:
try:
if os.path.exists(dir_path) and os.path.isdir(dir_path):
shutil.rmtree(dir_path)
except Exception as e:
print(f"Error deleting directory {dir_path}: {e}")

def get_ip(request: gr.Request):
if request:
if "cf-connecting-ip" in request.headers:
Expand All @@ -34,9 +51,12 @@ def get_ip(request: gr.Request):
def clear_history(): return None, "", None
def clear_history_sts(): return None, "", "", "", None
def clear_history_clustering(): return None, "", 1, None
def clear_history_side_by_side(): return None, None, "", None, None
def clear_history_side_by_side_anon():
return None, None, "", None, None, gr.Markdown("", visible=False), gr.Markdown("", visible=False)
def clear_history_side_by_side(state0, state1):
cleanup_dirs( os.path.dirname(state0.query_embed_file_path), os.path.dirname(state1.query_embed_file_path))
Comment thread
SaitejaUtpala marked this conversation as resolved.
Outdated
return None, None, "", None, None, gr.DownloadButton(visible=False), gr.DownloadButton(visible=False)
def clear_history_side_by_side_anon(state0, state1):
cleanup_dirs( os.path.dirname(state0.query_embed_file_path), os.path.dirname(state1.query_embed_file_path))
Comment thread
SaitejaUtpala marked this conversation as resolved.
Outdated
return None, None, "", None, None, gr.Markdown("", visible=False), gr.Markdown("", visible=False), gr.DownloadButton(visible=False), gr.DownloadButton(visible=False)
def clear_history_side_by_side_anon_sts():
return None, None, "", "", "", None, None, gr.Markdown("", visible=False), gr.Markdown("", visible=False)
def clear_history_side_by_side_anon_clustering():
Expand Down Expand Up @@ -74,9 +94,23 @@ def vote_last_response(vote_type, state0, state1, model_selector0, model_selecto

if vote_type == "share": return



return_state = ("Press 🎲 New Round to start over 👇 (Note: Your vote shapes the leaderboard, please vote RESPONSIBLY!)",) + disable_btns(4)
if model_selector0 == "":
return ("Press 🎲 New Round to start over 👇 (Note: Your vote shapes the leaderboard, please vote RESPONSIBLY!)",) + disable_btns(4) + (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
return ("Press 🎲 New Round to start over 👇 (Note: Your vote shapes the leaderboard, please vote RESPONSIBLY!)",) + disable_btns(4) + (gr.Markdown(state0.model_name, visible=True), gr.Markdown(state1.model_name, visible=True))
return_state = return_state + (gr.Markdown(f"### Model A: {state0.model_name}", visible=True), gr.Markdown(f"### Model B: {state1.model_name}", visible=True))
Comment thread
SaitejaUtpala marked this conversation as resolved.
Outdated
else:
return_state = return_state + (gr.Markdown(state0.model_name, visible=True), gr.Markdown(state1.model_name, visible=True))

if os.path.exists(state0.query_embed_file_path) and os.path.exists(state1.query_embed_file_path):
download_a_btn = gr.update(
label = "📥 Download embedding for model A", value=state0.query_embed_file_path, visible=True
)
download_b_btn = gr.update(
label = "📥 Download embedding for model B", value=state1.query_embed_file_path, visible=True
)
return_state = return_state + (download_a_btn, download_b_btn )
Comment thread
SaitejaUtpala marked this conversation as resolved.
Outdated
return return_state

def vote_last_response_sts(vote_type, state0, state1, model_selector0, model_selector1, request: gr.Request):
if vote_type != "share":
Expand Down Expand Up @@ -192,25 +226,42 @@ def __init__(self, model_name):
self.prompt = ""
self.corpus = ""
self.output = ""
self.query_embed_file_path = ""

def dict(self, prefix: str = None):
if prefix is None:
return {"conv_id": self.conv_id, "model_name": self.model_name, "prompt": self.prompt, "output": self.output, "corpus": self.corpus}
else:
return {f"{prefix}_conv_id": self.conv_id, f"{prefix}_model_name": self.model_name, f"{prefix}_prompt": self.prompt, f"{prefix}_output": self.output, f"{prefix}_corpus": self.corpus}




def retrieve_side_by_side(gen_func, state0, state1, text, corpus, model_name0, model_name1, request: gr.Request):
if not text: raise gr.Warning("Query cannot be empty.")
state0, state1 = RetrievalState(model_name0), RetrievalState(model_name1)
ip = get_ip(request)
retrieval_logger.info(f"Retrieval. ip: {ip}")
start_tstamp = time.time()
retrieved_txt0, retrieved_txt1, model_name0, model_name1 = gen_func(text, corpus, model_name0, model_name1)
retrieved_txt0, retrieved_txt1, model_name0, model_name1, query_embed0, query_embed1 = gen_func(text, corpus, model_name0, model_name1)
state0.prompt, state1.prompt = text, text
state0.corpus, state1.corpus = corpus, corpus
state0.output, state1.output = retrieved_txt0, retrieved_txt1
state0.model_name, state1.model_name = model_name0, model_name1

if query_embed0 is not None and query_embed1 is not None:
state0.query_embed_file_path = f"{state0.conv_id}_retrieve_side_by_side/model0_query_embedding.pth"
state1.query_embed_file_path = f"{state1.conv_id}_retrieve_side_by_side/model1_query_embedding.pth"

# Save tensors
save_tensor(query_embed0, state0.query_embed_file_path)
save_tensor(query_embed1, state1.query_embed_file_path)

# Register cleanup function
atexit.register(cleanup_dirs, os.path.dirname(state0.query_embed_file_path), os.path.dirname(state1.query_embed_file_path))



yield state0, state1, retrieved_txt0, retrieved_txt1, \
gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)

Expand Down Expand Up @@ -330,6 +381,13 @@ def build_side_by_side_ui_anon(models):
value="👎 Both are bad", visible=False, interactive=False
)

download_a_btn = gr.DownloadButton(
value="📥 Download embedding for model A", visible=False
)
download_b_btn = gr.DownloadButton(
value="📥 Download embedding for model B", visible=False
)

with gr.Row():
textbox = gr.Textbox(
label="Query",
Expand Down Expand Up @@ -418,8 +476,8 @@ def build_side_by_side_ui_anon(models):

clear_btn.click(
clear_history_side_by_side_anon,
inputs=None,
outputs=[state0, state1, textbox, chatbot_left, chatbot_right, model_selector_left, model_selector_right],
inputs=[state0, state1],
outputs=[state0, state1, textbox, chatbot_left, chatbot_right, model_selector_left, model_selector_right, download_a_btn, download_b_btn],
api_name="clear_btn_anon"
).then(
disable_buttons_side_by_side,
Expand All @@ -436,22 +494,22 @@ def build_side_by_side_ui_anon(models):
leftvote_btn.click(
partial(vote_last_response, "leftvote"),
inputs=[state0, state1, dummy_left_model, dummy_right_model],
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right]
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right, download_a_btn, download_b_btn]
)
rightvote_btn.click(
partial(vote_last_response, "rightvote"),
inputs=[state0, state1, dummy_left_model, dummy_right_model],
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right]
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right, download_a_btn, download_b_btn]
)
tie_btn.click(
partial(vote_last_response, "tievote"),
inputs=[state0, state1, dummy_left_model, dummy_right_model],
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right]
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right, download_a_btn, download_b_btn]
)
bothbad_btn.click(
partial(vote_last_response, "bothbadvote"),
inputs=[state0, state1, dummy_left_model, dummy_right_model],
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right]
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right, download_a_btn, download_b_btn]
)

share_js = """
Expand Down Expand Up @@ -548,6 +606,14 @@ def build_side_by_side_ui_named(models):
value="👎 Both are bad", visible=False, interactive=False
)

download_a_btn = gr.DownloadButton(
value="📥 Download embedding for model A", visible=False
)
download_b_btn = gr.DownloadButton(
value="📥 Download embedding for model B", visible=False
)


with gr.Row():
textbox = gr.Textbox(
label="Query",
Expand Down Expand Up @@ -636,8 +702,8 @@ def build_side_by_side_ui_named(models):

clear_btn.click(
clear_history_side_by_side,
inputs=None,
outputs=[state0, state1, textbox, chatbot_left, chatbot_right],
inputs=[state0, state1],
outputs=[state0, state1, textbox, chatbot_left, chatbot_right, download_a_btn, download_b_btn],
api_name="clear_btn_side_by_side"
).then(
disable_buttons_side_by_side,
Expand All @@ -652,22 +718,22 @@ def build_side_by_side_ui_named(models):
leftvote_btn.click(
partial(vote_last_response, "leftvote"),
inputs=[state0, state1, model_selector_left, model_selector_right],
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right]
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right, download_a_btn, download_b_btn]
)
rightvote_btn.click(
partial(vote_last_response, "rightvote"),
inputs=[state0, state1, model_selector_left, model_selector_right],
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right]
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right, download_a_btn, download_b_btn]
)
tie_btn.click(
partial(vote_last_response, "tievote"),
inputs=[state0, state1, model_selector_left, model_selector_right],
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right]
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right, download_a_btn, download_b_btn]
)
bothbad_btn.click(
partial(vote_last_response, "bothbadvote"),
inputs=[state0, state1, model_selector_left, model_selector_right],
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right]
outputs=[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, model_selector_left, model_selector_right, download_a_btn, download_b_btn]
)

share_js = """
Expand Down Expand Up @@ -2197,6 +2263,3 @@ def build_single_model_ui_sts(models):
inputs=None,
outputs=[send_btn, draw_btn, textbox0, textbox1, textbox2],
)