From 953dd25115e7526024c38720a4b5f1139487bfac Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Wed, 7 Dec 2022 13:44:41 +0800 Subject: [PATCH] update --- clip_vision.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/clip_vision.py b/clip_vision.py index 54961fe..a7c3e03 100644 --- a/clip_vision.py +++ b/clip_vision.py @@ -16,7 +16,7 @@ import torch import numpy as np from torchvision import transforms as T -from towhee.operator import NNOperator +from towhee.operator import NNOperator, SharedType from towhee.models import clip @@ -32,6 +32,7 @@ class Model: class ClipVision(NNOperator): def __init__(self, model_name='clip_vit_b32'): super().__init__() + self._device = None self.tfms = torch.nn.Sequential( T.Resize(224, interpolation=T.InterpolationMode.BICUBIC), @@ -44,16 +45,15 @@ class ClipVision(NNOperator): @property def device(self): - if self._device_id < 0: - return 'cpu' - else: - return self._device_id + if self._device is None: + self._device = torch.device(self._device_id) + return self._device def __call__(self, image: 'Image'): img = np.transpose(image, [2, 0, 1]) data = torch.from_numpy(img) data = data.to(self.device) - image_tensor = self.tfms(img) + image_tensor = self.tfms(data).unsqueeze(0) features = self.model(image_tensor) return features.detach().cpu().numpy().flatten() @@ -69,3 +69,7 @@ class ClipVision(NNOperator): @property def supported_formats(self): return ['onnx'] + + @property + def shared_type(self): + return SharedType.Shareable