logo
Browse Source

update

main
junjie.jiang 2 years ago
parent
commit
953dd25115
  1. 16
      clip_vision.py

16
clip_vision.py

@ -16,7 +16,7 @@ import torch
import numpy as np
from torchvision import transforms as T
from towhee.operator import NNOperator
from towhee.operator import NNOperator, SharedType
from towhee.models import clip
@ -32,6 +32,7 @@ class Model:
class ClipVision(NNOperator):
def __init__(self, model_name='clip_vit_b32'):
super().__init__()
self._device = None
self.tfms = torch.nn.Sequential(
T.Resize(224, interpolation=T.InterpolationMode.BICUBIC),
@ -44,16 +45,15 @@ class ClipVision(NNOperator):
@property
def device(self):
if self._device_id < 0:
return 'cpu'
else:
return self._device_id
if self._device is None:
self._device = torch.device(self._device_id)
return self._device
def __call__(self, image: 'Image'):
img = np.transpose(image, [2, 0, 1])
data = torch.from_numpy(img)
data = data.to(self.device)
image_tensor = self.tfms(img)
image_tensor = self.tfms(data).unsqueeze(0)
features = self.model(image_tensor)
return features.detach().cpu().numpy().flatten()
@ -69,3 +69,7 @@ class ClipVision(NNOperator):
@property
def supported_formats(self):
return ['onnx']
@property
def shared_type(self):
return SharedType.Shareable

Loading…
Cancel
Save