1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
|
import tensorrt as trt import pycuda.driver as cuda import pycuda.autoinit
TRT_LOGGER = trt.Logger()
class HostDeviceMem(object): def __init__(self, host_mem, device_mem): self.host = host_mem self.device = device_mem
def __str__(self): return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
def __repr__(self): return self.__str__()
class TRTModelPredict: def __init__(self, engine_path, shape=(608, 608)): shape = (1, 3, shape[0], shape[1]) self.engine = self.get_engine(engine_path) self.context = self.engine.create_execution_context()
self.buffers = self.allocate_buffers(self.engine, 1) self.context.set_binding_shape(0, shape)
def allocate_buffers(self, engine, batch_size): inputs = [] outputs = [] bindings = [] stream = cuda.Stream() for binding in engine:
size = trt.volume(engine.get_binding_shape(binding)) * batch_size dims = engine.get_binding_shape(binding)
if dims[0] < 0: size *= -1
dtype = trt.nptype(engine.get_binding_dtype(binding)) host_mem = cuda.pagelocked_empty(size, dtype) device_mem = cuda.mem_alloc(host_mem.nbytes) bindings.append(int(device_mem)) if engine.binding_is_input(binding): inputs.append(HostDeviceMem(host_mem, device_mem)) else: outputs.append(HostDeviceMem(host_mem, device_mem)) return inputs, outputs, bindings, stream
def get_engine(self, engine_path): print("Reading engine from file {}".format(engine_path)) with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: return runtime.deserialize_cuda_engine(f.read())
def do_inference(self, img_in):
inputs, outputs, bindings, stream = self.buffers inputs[0].host = img_in for i in range(2): [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] self.context.execute_async(bindings=bindings, stream_handle=stream.handle) [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] stream.synchronize() return [out.host for out in outputs]
|