1. triton server
安装triton服务
x86
拉去镜像
1
docker pull nvcr.io/nvidia/tritonserver:21.07-py3
运行
1
docker run -it --rm -p8000:8000 -p8001:8001 -p8002:8002 -v /home/sunshine/infer/:/models nvcr.io/nvidia/tritonserver:21.07-py3 tritonserver --model-repository=/models --strict-model-config=false
jetson
nvidia暂未发布支持jetson的triton容器,所以只能源码安装。
参考地址:https://github.com/triton-inference-server/server/blob/main/docs/jetson.md
按照jetpack版本下载相应的release版本包并解压至: /opt/tritonserver
下载地址:https://github.com/triton-inference-server/server/releases
备注:如下解压到其他目录,启动会报错动态库文件找不到,当然也可以把所需的动态库文件添加到搜索空间,但解压到/opt/tritonserver最为简单
安装依赖包
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18apt-get update && \
apt-get install -y --no-install-recommends \
software-properties-common \
autoconf \
automake \
build-essential \
cmake \
git \
libb64-dev \
libre2-dev \
libssl-dev \
libtool \
libboost-dev \
libcurl4-openssl-dev \
libopenblas-dev \
rapidjson-dev \
patchelf \
zlib1g-dev启动服务
1
bin/tritonserver --model-repository=/sdk/python/triton/models
创建模型仓库库
模型仓库布局如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18<model-repository-path>/
<model-name>/
[config.pbtxt] # 非必须,若模型为 TensorRT, TensorFlow saved-model, ONNX 则可以不配置,其他模型必须配置
[<output-labels-file> ...]
<version>/ # 1, 2,3...
<model-definition-file> # model.onnx, model.pb,...
<version>/
<model-definition-file>
...
<model-name>/
[config.pbtxt]
[<output-labels-file> ...]
<version>/
<model-definition-file>
<version>/
<model-definition-file>
...
...以文本分类为例,模型仓库布局如下:
1
2
3
4
5
6
7
8
9
10
11
12
13[root@bg1 sunshine]# tree infer/
infer/
├── densenet_onnx # 模型名
│ ├── 1 # 版本
│ │ └── model.onnx # 模型文件,默认为model.xxx
│ ├── config.pbtxt # 配置文件
│ └── densenet_labels.txt # 分类类别,根据配置文件选择是否添加
└── text_class
├── 1
│ └── model.onnx
└── config.pbtxt
4 directories, 5 files添加配置文件
配置文件默认为模型目录下的config.pbtxt,其格式不同于json,要严格按照官网给的格式来配置,否则加载报错
同样以文本分类为例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30platform: "onnxruntime_onnx"
max_batch_size : 0
input [
{
name: "token_type_ids"
data_type: TYPE_INT64
format: FORMAT_NONE
dims: [-1,-1]
},
{
name: "attention_mask"
data_type: TYPE_INT64
format: FORMAT_NONE
dims: [-1,-1]
},
{
name: "input_ids"
data_type: TYPE_INT64
format: FORMAT_NONE
dims: [-1,-1]
}
]
output: [
{
name: "logits"
data_type: TYPE_FP32
dims: [-1,2]
label_filename: ""
}
]若模型为TensorRT, TensorFlow saved-model, ONNX,则可以不配置config.pbtxt配置文件,在启动服务的时候指定参数$–strict-model-config=false$ 即可。
使用如下调用获取配置文件:
1
curl 192.168.0.15:8000/v2/models/text_class/config
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84{
"name": "text_class",
"platform": "onnxruntime_onnx",
"backend": "onnxruntime",
"version_policy": {
"latest": {
"num_versions": 1
}
},
"max_batch_size": 1,
"input": [
{
"name": "token_type_ids",
"data_type": "TYPE_INT64",
"format": "FORMAT_NONE",
"dims": [
-1
],
"is_shape_tensor": false,
"allow_ragged_batch": false
},
{
"name": "attention_mask",
"data_type": "TYPE_INT64",
"format": "FORMAT_NONE",
"dims": [
-1
],
"is_shape_tensor": false,
"allow_ragged_batch": false
},
{
"name": "input_ids",
"data_type": "TYPE_INT64",
"format": "FORMAT_NONE",
"dims": [
-1
],
"is_shape_tensor": false,
"allow_ragged_batch": false
}
],
"output": [
{
"name": "logits",
"data_type": "TYPE_FP32",
"dims": [
2
],
"label_filename": "",
"is_shape_tensor": false
}
],
"batch_input": [],
"batch_output": [],
"optimization": {
"priority": "PRIORITY_DEFAULT",
"input_pinned_memory": {
"enable": true
},
"output_pinned_memory": {
"enable": true
},
"gather_kernel_buffer_threshold": 0,
"eager_batching": false
},
"instance_group": [
{
"name": "text_class",
"kind": "KIND_CPU",
"count": 1,
"gpus": [],
"secondary_devices": [],
"profile": [],
"passive": false,
"host_policy": ""
}
],
"default_model_filename": "model.onnx",
"cc_model_filenames": {},
"metric_tags": {},
"parameters": {},
"model_warmup": []
}服务启动
命令启动
1
tritonserver --model-repository=/models --strict-model-config=false
docker容器启动
1
docker run -it --rm -p8000:8000 -p8001:8001 -p8002:8002 -v /home/sunshine/infer/:/models nvcr.io/nvidia/tritonserver:21.07-py3 tritonserver --model-repository=/models --strict-model-config=false
服务启动后,返回如下信息:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63=============================
== Triton Inference Server ==
=============================
NVIDIA Release 21.07 (build 24810355)
Copyright (c) 2018-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Various files include modifications (c) NVIDIA CORPORATION. All rights reserved.
...
I0819 02:54:33.978409 1 onnxruntime.cc:2072] TRITONBACKEND_ModelInstanceInitialize: text_class (CPU device 0)
I0819 02:54:36.500230 1 onnxruntime.cc:2072] TRITONBACKEND_ModelInstanceInitialize: densenet_onnx (CPU device 0)
I0819 02:54:36.501075 1 model_repository_manager.cc:1212] successfully loaded 'text_class' version 1
I0819 02:54:37.042157 1 model_repository_manager.cc:1212] successfully loaded 'densenet_onnx' version 1
I0819 02:54:37.042560 1 server.cc:504]
+------------------+------+
| Repository Agent | Path |
+------------------+------+
+------------------+------+
I0819 02:54:37.042936 1 server.cc:543]
+-------------+-----------------------------------------------------------------+--------+
| Backend | Path | Config |
+-------------+-----------------------------------------------------------------+--------+
| tensorrt | <built-in> | {} |
| pytorch | /opt/tritonserver/backends/pytorch/libtriton_pytorch.so | {} |
| tensorflow | /opt/tritonserver/backends/tensorflow1/libtriton_tensorflow1.so | {} |
| onnxruntime | /opt/tritonserver/backends/onnxruntime/libtriton_onnxruntime.so | {} |
| openvino | /opt/tritonserver/backends/openvino/libtriton_openvino.so | {} |
+-------------+-----------------------------------------------------------------+--------+
I0819 02:54:37.043287 1 server.cc:586]
+---------------+---------+--------+
| Model | Version | Status |
+---------------+---------+--------+
| densenet_onnx | 1 | READY |
| text_class | 1 | READY |
+---------------+---------+--------+
I0819 02:54:37.043677 1 tritonserver.cc:1718]
+----------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Option | Value |
+----------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| server_id | triton |
| server_version | 2.12.0 |
| server_extensions | classification sequence model_repository model_repository(unload_dependents) schedule_policy model_configuration system_shared_memory cuda_shared_memory binary_tensor_data |
| | statistics |
| model_repository_path[0] | /models |
| model_control_mode | MODE_NONE |
| strict_model_config | 0 |
| pinned_memory_pool_byte_size | 268435456 |
| min_supported_compute_capability | 6.0 |
| strict_readiness | 1 |
| exit_timeout | 30 |
+----------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
I0819 02:54:37.046408 1 grpc_server.cc:4072] Started GRPCInferenceService at 0.0.0.0:8001
I0819 02:54:37.046930 1 http_server.cc:2795] Started HTTPService at 0.0.0.0:8000
I0819 02:54:37.090841 1 sagemaker_server.cc:134] Started Sagemaker HTTPService at 0.0.0.0:8080
I0819 02:54:37.134294 1 http_server.cc:162] Started Metrics Service at 0.0.0.0:8002验证
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16[root@bg1 ~]# curl -v localhost:8000/v2/health/ready
* About to connect() to localhost port 8000 (#0)
* Trying ::1...
* Connected to localhost (::1) port 8000 (#0)
GET /v2/health/ready HTTP/1.1
User-Agent: curl/7.29.0
Host: localhost:8000
Accept: */*
< HTTP/1.1 200 OK
< Content-Length: 0
< Content-Type: text/plain
<
* Connection #0 to host localhost left intact
yolov4
2. triton client
官网: https://github.com/triton-inference-server/client
本地安装
1
2pip install nvidia-pyindex
pip install tritonclient[http] # [all, http, grpc, utils]http请求方式所需要的依赖:
1
2
3geventhttpclient>=1.4.4
numpy>=1.19.1
python-rapidjson>=0.9.1其他请求方式需要依赖:https://github.com/triton-inference-server/client/tree/main/src/python/library/requirements
容器
拉取容器
1
docker pull nvcr.io/nvidia/tritonserver:21.07-py3-sdk
启动容器
1
docker run -it --rm --net=host nvcr.io/nvidia/tritonserver::21.07-py3-sdk
启动服务
1
2
3
4
5
6/workspace/install/bin/image_client -m densenet_onnx -c 3 -s INCEPTION /workspace/images/mug.jpg
Request 0, batch size 1
Image '/workspace/images/mug.jpg':
15.346230 (504) = COFFEE MUG
13.224326 (968) = CUP
10.422965 (505) = COFFEEPOT
python客户端代码调用服务
http请求文本分类示例如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98# author: sunshine
# datetime:2021/8/19 上午11:01
import numpy as np
import tritonclient.http as httpclient
from attrdict import AttrDict
from transformers import BertTokenizer
from tritonclient.utils import triton_to_np_dtype
def convert_http_metadata_config(_metadata, _config):
_model_metadata = AttrDict(_metadata)
_model_config = AttrDict(_config)
return _model_metadata, _model_config
def parse_model(model_metadata, model_config):
"""
input_dtype,
:param model_metadata:
:param model_config:
"""
input_metadata = model_metadata.inputs
output_metadata = model_metadata.outputs
max_batch_size = model_config.max_batch_size
input_params = {i.name: i.datatype for i in input_metadata}
output_names = [o.name for o in output_metadata]
return max_batch_size, input_params, output_names
def preprocess(text, input_param, output_names, max_len=128):
client = httpclient
inputs = tokenizer(text, padding='longest', max_length=max_len, truncation='longest_first')
input_data = []
names = ['input_ids', 'attention_mask', 'token_type_ids']
for name in names:
ndtype = triton_to_np_dtype(input_param[name])
data = np.array(inputs[name]).astype(ndtype)
data_t = client.InferInput(name, list(data.shape), input_param[name])
data_t.set_data_from_numpy(data)
input_data.append(data_t)
outputs = [
client.InferRequestedOutput(out_name) for out_name in output_names
]
return input_data, outputs
def postprocess(results, output_names):
"""
response结果处理
"""
logit_name = output_names[0]
output = results.as_numpy(logit_name)
pred = np.argmax(output, axis=-1)
return pred
if __name__ == '__main__':
model_name = 'text_class'
model_version = '' # 若为空,则选择最新版本
url = "192.168.0.15:8000"
bert_path = '/home/sunshine/pre_models/pytorch/bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(bert_path)
triton_client = httpclient.InferenceServerClient(url=url, verbose=False)
model_metadata = triton_client.get_model_metadata(
model_name=model_name, model_version=model_version)
model_config = triton_client.get_model_config(
model_name=model_name, model_version=model_version)
model_metadata, model_config = convert_http_metadata_config(
model_metadata, model_config)
max_batch_size, input_params, output_names = parse_model(model_metadata, model_config)
texts = ['今天天气真好', '我讨厌你']
inputs, outputs = preprocess(texts, input_params, output_names)
results = triton_client.infer(model_name,
inputs,
request_id=str(1),
model_version=model_version,
outputs=outputs)
pred = postprocess(results, output_names)
print(pred)output:
1
[1 0]