0%

异步请求助力模型线上批量推理

针对模型线上推理服务,一个请求一个请求的推理,其服务的吞吐量会及其低下,也做不到资源的充分利用。可以考虑批量预测的方式,提升服务的并发能力, 模型推理时间与批处理数量batch并不呈现线性关系,直到模型占用全部显存,到达batch瓶颈。

要实现上面的目标,需要以下几个模块

  • 前端服务:用于接收请求、返回结果。可以是Http、PRC等各种协议。是一个独立进程。
  • 推理Worker:负责模型的初始化、批量推理数据构建、推理计算。是一个独立进程。
  • 任务队列:前端服务收到请求之后把计算任务送入任务队列;推理Worker监听该队列,每次取出一个小批量由模型推理
  • 结果队列:推理服务推理完成后将结果送入结果队列;前端服务监听该队列,获得推理结果
  • 结果分发:在将任务送入任务队列前需要生成任务的唯一标识,从结果队列取回结果后根据标识获取到任务对应的结果

其中两个任务队列的实现方式很多,可以通过一些成熟的中间件例如Kafka、Redis等,但为了避免外部依赖,这次我选择使用Python原生的多进程队列。结果队列监听和分发通过前端服务进程的一个子线程来完成。

以加载清华开源的glm模型为例,实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import time
from transformers.generation.logits_process import LogitsProcessor
from sanic import Sanic
from sanic.response import json
import asyncio
import logging
import multiprocessing as mp
import threading
import uuid
from queue import Empty
from transformers import AutoTokenizer, AutoModel
from cachetools import TTLCache
import torch
from enum import Enum

app = Sanic('test')


class BaseInferLightWorker:

def __init__(self, data_queue: mp.Queue, result_queue: mp.Queue,
model_args: dict,
batch_size=16, max_delay=0.1,
ready_event=None,
max_length: int = 2048, num_beams=1,
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs
) -> None:
self.data_queue = data_queue
self.result_queue = result_queue
self.batch_size = batch_size
self.max_delay = max_delay
self.logger = logging.getLogger('InferLight-Worker')
self.logger.setLevel(logging.DEBUG)
self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
self.load_model(model_args)

# 由于模型载入时间较长
# 加载完成后使用一个event来通知主进程
if ready_event:
ready_event.set()

def run(self):
self.logger.info('Worker started!')
while True:
data, task_ids = [], []
since = time.time()
for i in range(self.batch_size):
try:
# 从数据队列获取数据
d = self.data_queue.get(block=True, timeout=self.max_delay)
task_ids.append(d[0])
data.append(d[1])
self.logger.info('get one new task')
except Empty:
pass
if time.time() - since >= self.max_delay:
break
if len(data) > 0:
start = time.perf_counter()
batch = self.build_batch(data)
results = self.inference(batch)
end = time.perf_counter()
time_elapsed = (end - start) * 1000
self.logger.info(f'inference succeeded. batch size: {len(data)}, time elapsed: {time_elapsed:.3f} ms')
# 将结果写入结果队列
for (task_id, result) in zip(task_ids, results):
self.result_queue.put((task_id, result))

def build_batch(self, requests):
raise NotImplementedError

def inference(self, batch):
raise NotImplementedError

def load_model(self, model_args):
raise NotImplementedError

@classmethod
def start(cls, data_queue: mp.Queue, result_queue: mp.Queue, model_args: dict, batch_size=16, max_delay=0.1,
ready_event=None):
w = cls(data_queue, result_queue, model_args, batch_size, max_delay, ready_event)
w.run()


class InferStatus(Enum):
SUCCEED = 0
TIMEOUT = 1


class InferResponse:

def __init__(self, status: InferStatus, result) -> None:
self.status = status
self.result = result

def succeed(self):
return self.status == InferStatus.SUCCEED


class LightWrapper:

def __init__(self, worker_class, model_args: dict,
batch_size=16, max_delay=0.1) -> None:
# setup logger
self.logger = logging.getLogger('InferLight-Wrapper')
self.logger.setLevel(logging.INFO)

# 用一个5秒自动超时的缓存来保存结果
self.result_cache = TTLCache(maxsize=10000, ttl=5)

self.mp = mp.get_context('spawn')
self.result_queue = self.mp.Queue()
self.data_queue = self.mp.Queue()

# 启动推理Worker
self.logger.info('Starting worker...')
worker_ready_event = self.mp.Event()
self._worker_p = self.mp.Process(target=worker_class.start, args=(
self.data_queue, self.result_queue, model_args, batch_size, max_delay, worker_ready_event
), daemon=True)
self._worker_p.start()

# 最长等待30秒
is_ready = worker_ready_event.wait(timeout=30)
if is_ready:
self.logger.info('Worker started!')
else:
self.logger.error('Failed to start worker!')

# 启动收集结果的线程
self.back_thread = threading.Thread(
target=self._collect_result, name="thread_collect_result")
self.back_thread.daemon = True
self.back_thread.start()

def _collect_result(self):
# 在线程中不断读取结果队列
# 以task_id为key将结果写入到结果缓存中
self.logger.info('Result collecting thread started!')
while True:
try:
msg = self.result_queue.get(block=True, timeout=0.01)
except Empty:
msg = None
if msg is not None:
(task_id, result) = msg
self.result_cache[task_id] = result

async def get_result(self, task_id):
# 非阻塞地获取任务的结果
while task_id not in self.result_cache:
await asyncio.sleep(0.01)
return self.result_cache[task_id]

async def predict(self, input, timeout=6) -> InferResponse:
# generate unique task_id
task_id = str(uuid.uuid4())

# send input to worker process
self.data_queue.put((task_id, input))
try:
# 这里设置了最大等待时间
result = await asyncio.wait_for(self.get_result(task_id), timeout=timeout)
except asyncio.TimeoutError:
return InferResponse(InferStatus.TIMEOUT, None)

return InferResponse(InferStatus.SUCCEED, result)


class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 20005] = 5e4
return scores


class MyWorker(BaseInferLightWorker):

def load_model(self, model_args):

self.tokenizer = AutoTokenizer.from_pretrained(model_args['model'], trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_args['model'], trust_remote_code=True).half().cuda()
self.device = torch.device('cuda')
return

def build_batch(self, requests):

prompts = []
for query, user_history in requests:
if not user_history:
prompt = query
else:
prompt = ""
for i, (old_query, response) in enumerate(user_history):
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
prompt += "[Round {}]\n问:{}\n答:".format(len(user_history), query)
prompts.append(prompt)
input_ids = self.tokenizer(prompts, return_tensors="pt", padding=True)
input_ids = input_ids.to(self.device)
return [input_ids, requests]

@torch.no_grad()
def inference(self, input_ids):
input_ids, requests = input_ids
outputs = self.model.generate(**input_ids, **self.gen_kwargs)
result = []

for output, input_id, (query, user_history) in zip(outputs, input_ids["input_ids"], requests):
out = output.tolist()[len(input_id):]
response = self.tokenizer.decode(out)
response = response.strip()
response = response.replace("[[训练时间]]", "2023年")
user_history = user_history + [(query, response)]
result.append([response, user_history])

return result


@app.post('/batch_predict')
async def batched_predict(request):
# global history
history = request.app.ctx.history
now = time.time()
openid = request.json['openid']
if openid not in history:
user_history = []
else:
user_history = history[openid]["history"]
user_last_time = history[openid]["user_last_time"]

if now - user_last_time > 600:
# 时间间隔超过10分钟,则重置历史
user_history = []

while len(user_history) > 5 or sum([len(x) + len(y) for x, y in user_history]) > 1024:
user_history = user_history[1:]
dummy_input = [request.json['text'], user_history]
# print(request.app.wrapped_model)
response = await request.app.ctx.wrapped_model.predict(dummy_input, timeout=20)

if not response.succeed():
return json({'output': None, 'status': 'failed'})

history[openid] = {"history": response.result[1], "user_last_time": now}
request.app.ctx.history = history
return json({'output': response.result[0]})


config = {
'model': "/sdk/pre_models/chatglm-6b-int4",
'use_cuda': True
}


@app.listener('before_server_start')
async def init(app, loop):
history = {}

wrapped_model = LightWrapper(MyWorker, config, batch_size=10, max_delay=0.05)
app.ctx.wrapped_model = wrapped_model
app.ctx.history = history


if __name__ == '__main__':
app.run(port=5008)

未使用批量预测的情况下,模型占用显存6g,并发大约在20。替换成批量预测之后,并发可以达到150,提升效果显著。

参考:https://zhuanlan.zhihu.com/p/382306521