logo
Browse Source

Fix for triton device

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
125daa72f5
  1. 2
      benchmark/qps_test.py
  2. BIN
      benchmark/towhee.jpeg
  3. 6
      isc.py

2
benchmark/qps_test.py

@ -34,7 +34,7 @@ p = (
.output('vec') .output('vec')
) )
data = 'towhee.jpeg'
data = '../towhee.jpeg'
out1 = p(data).get()[0] out1 = p(data).get()[0]
print('Pipe: OK') print('Pipe: OK')

BIN
benchmark/towhee.jpeg

Binary file not shown.

Before

Width:  |  Height:  |  Size: 49 KiB

6
isc.py

@ -106,7 +106,7 @@ class Isc(NNOperator):
img = img if self.skip_tfms else self.tfms(img) img = img if self.skip_tfms else self.tfms(img)
img_list.append(img) img_list.append(img)
inputs = torch.stack(img_list) inputs = torch.stack(img_list)
inputs = inputs.to(self.device)
inputs = inputs
features = self.model(inputs) features = self.model(inputs)
features = features.to('cpu') features = features.to('cpu')
@ -138,7 +138,7 @@ class Isc(NNOperator):
path = path + '.onnx' path = path + '.onnx'
else: else:
raise ValueError(f'Invalid format {format}.') raise ValueError(f'Invalid format {format}.')
dummy_input = torch.rand(1, 3, 224, 224).to(self.device)
dummy_input = torch.rand(1, 3, 224, 224)
if format == 'pytorch': if format == 'pytorch':
torch.save(self._model, path) torch.save(self._model, path)
elif format == 'torchscript': elif format == 'torchscript':
@ -153,7 +153,7 @@ class Isc(NNOperator):
raise RuntimeError(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx': elif format == 'onnx':
try: try:
torch.onnx.export(self._model,
torch.onnx.export(self._model.to('cpu'),
dummy_input, dummy_input,
path, path,
input_names=['input_0'], input_names=['input_0'],

Loading…
Cancel
Save