-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Expand file tree
/
Copy pathrerank_server_demo.py
More file actions
121 lines (97 loc) · 4.08 KB
/
rerank_server_demo.py
File metadata and controls
121 lines (97 loc) · 4.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gc
import torch
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel, Field
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from typing import List
# 使能 CUDA 调试
# import os
# os.environ['CUDA_LAUNCH_BLOCKING']='1'
# --- 1. 定义API的请求和响应数据结构 ---
# 请求体结构保持不变
class RerankRequest(BaseModel):
query: str
documents: List[str]
# --- 修改开始:定义测试用的响应结构,字段名为 "score" ---
# DocumentInfo 结构保持不变
class DocumentInfo(BaseModel):
text: str
# 将原来的 GoRankResult 修改为 TestRankResult
# 核心改动:将 "relevance_score" 字段重命名为 "score"
class TestRankResult(BaseModel):
index: int
document: DocumentInfo
score: float # <--- 【关键修改点】字段名已从 relevance_score 改为 score
# 最终响应体结构,其 "results" 列表包含的是 TestRankResult
class TestFinalResponse(BaseModel):
results: List[TestRankResult]
# --- 修改结束 ---
# --- 2. 加载模型 (在服务启动时执行一次) ---
print("正在加载模型,请稍候...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用的设备: {device}")
try:
# 请确保这里的路径是正确的
model_path = '/data1/home/lwx/work/Download/rerank_model_weight'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.to(device)
model.eval()
print("模型加载成功!")
except Exception as e:
print(f"模型加载失败: {e}")
# 在测试环境中,如果模型加载失败,可以考虑退出以避免运行一个无效的服务
exit()
# --- 3. 创建FastAPI应用 ---
app = FastAPI(
title="Reranker API (Test Version)",
description="一个返回 'score' 字段以测试Go客户端兼容性的API服务",
version="1.0.2"
)
# --- 4. 定义API端点 ---
# --- 修改开始:将 response_model 指向新的测试用响应结构 ---
@app.post("/rerank", response_model=TestFinalResponse) # <--- 【关键修改点】response_model 改为 TestFinalResponse
def rerank_endpoint(request: RerankRequest):
# --- 修改结束 ---
pairs = [[request.query, doc] for doc in request.documents]
with torch.no_grad():
inputs = outputs = logits = None
try:
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=1024).to(device)
outputs = model(**inputs, return_dict=True)
logits = outputs.logits.view(-1, ).float()
scores = torch.sigmoid(logits)
finally:
# 释放 GPU 资源占用
del inputs, outputs, logits
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif hasattr(torch, "mps") and torch.mps.is_available():
torch.mps.empty_cache()
# --- 修改开始:按照测试用的结构来构建结果 ---
results = []
for i, (text, score_val) in enumerate(zip(request.documents, scores)):
# 1. 创建嵌套的 document 对象
doc_info = DocumentInfo(text=text)
# 2. 创建 TestRankResult 对象
# 注意字段名:index, document, score
test_result = TestRankResult(
index=i,
document=doc_info,
score=score_val.item() # <--- 【关键修改点】赋值给 "score" 字段
)
results.append(test_result)
# 3. 排序 (key 也要相应修改为 score)
sorted_results = sorted(results, key=lambda x: x.score, reverse=True)
# --- 修改结束 ---
# 返回一个字典,FastAPI 会根据 response_model (TestFinalResponse) 来验证和序列化它
# 最终生成的 JSON 会是 {"results": [{"index": ..., "document": ..., "score": ...}]}
return {"results": sorted_results}
@app.get("/")
def read_root():
return {"status": "Reranker API (Test Version) is running"}
# --- 5. 启动服务 ---
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)