|
|
@ -16,7 +16,7 @@ import torch |
|
|
|
import numpy as np |
|
|
|
from torchvision import transforms as T |
|
|
|
|
|
|
|
from towhee.operator import NNOperator |
|
|
|
from towhee.operator import NNOperator, SharedType |
|
|
|
from towhee.models import clip |
|
|
|
|
|
|
|
|
|
|
@ -32,6 +32,7 @@ class Model: |
|
|
|
class ClipVision(NNOperator): |
|
|
|
def __init__(self, model_name='clip_vit_b32'): |
|
|
|
super().__init__() |
|
|
|
self._device = None |
|
|
|
|
|
|
|
self.tfms = torch.nn.Sequential( |
|
|
|
T.Resize(224, interpolation=T.InterpolationMode.BICUBIC), |
|
|
@ -44,16 +45,15 @@ class ClipVision(NNOperator): |
|
|
|
|
|
|
|
@property |
|
|
|
def device(self): |
|
|
|
if self._device_id < 0: |
|
|
|
return 'cpu' |
|
|
|
else: |
|
|
|
return self._device_id |
|
|
|
if self._device is None: |
|
|
|
self._device = torch.device(self._device_id) |
|
|
|
return self._device |
|
|
|
|
|
|
|
def __call__(self, image: 'Image'): |
|
|
|
img = np.transpose(image, [2, 0, 1]) |
|
|
|
data = torch.from_numpy(img) |
|
|
|
data = data.to(self.device) |
|
|
|
image_tensor = self.tfms(img) |
|
|
|
image_tensor = self.tfms(data).unsqueeze(0) |
|
|
|
features = self.model(image_tensor) |
|
|
|
return features.detach().cpu().numpy().flatten() |
|
|
|
|
|
|
@ -69,3 +69,7 @@ class ClipVision(NNOperator): |
|
|
|
@property |
|
|
|
def supported_formats(self): |
|
|
|
return ['onnx'] |
|
|
|
|
|
|
|
@property |
|
|
|
def shared_type(self): |
|
|
|
return SharedType.Shareable |
|
|
|