logo
Browse Source

Fix for triton device

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
125daa72f5
  1. 4
      benchmark/qps_test.py
  2. BIN
      benchmark/towhee.jpeg
  3. 6
      isc.py

4
benchmark/qps_test.py

@ -34,7 +34,7 @@ p = (
.output('vec')
)
data = 'towhee.jpeg'
data = '../towhee.jpeg'
out1 = p(data).get()[0]
print('Pipe: OK')
@ -93,7 +93,7 @@ if args.onnx:
else:
max_diff = numpy.abs(out1 - out3).max()
min_diff = numpy.abs(out1 - out3).min()
mean_diff = numpy.abs(out1 - out3).mean()
mean_diff = numpy.abs(out1 - out3).mean()
print(f'Check accuracy: atol is larger than {args.atol}.')
print(f'Maximum absolute difference is {max_diff}.')
print(f'Minimum absolute difference is {min_diff}.')

BIN
benchmark/towhee.jpeg

Binary file not shown.

Before

Width:  |  Height:  |  Size: 49 KiB

6
isc.py

@ -106,7 +106,7 @@ class Isc(NNOperator):
img = img if self.skip_tfms else self.tfms(img)
img_list.append(img)
inputs = torch.stack(img_list)
inputs = inputs.to(self.device)
inputs = inputs
features = self.model(inputs)
features = features.to('cpu')
@ -138,7 +138,7 @@ class Isc(NNOperator):
path = path + '.onnx'
else:
raise ValueError(f'Invalid format {format}.')
dummy_input = torch.rand(1, 3, 224, 224).to(self.device)
dummy_input = torch.rand(1, 3, 224, 224)
if format == 'pytorch':
torch.save(self._model, path)
elif format == 'torchscript':
@ -153,7 +153,7 @@ class Isc(NNOperator):
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
try:
torch.onnx.export(self._model,
torch.onnx.export(self._model.to('cpu'),
dummy_input,
path,
input_names=['input_0'],

Loading…
Cancel
Save