logo
Browse Source

fix the crop_size for different version of huggingface.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
d0b811ca62
  1. 14
      README.md
  2. 8
      clip.py

14
README.md

@ -28,11 +28,11 @@ import towhee
towhee.glob('./teddy.jpg') \ towhee.glob('./teddy.jpg') \
.image_decode() \ .image_decode() \
.image_text_embedding.clip(model_name='clip_vit_b32', modality='image') \
.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image') \
.show() .show()
towhee.dc(["A teddybear on a skateboard in Times Square."]) \ towhee.dc(["A teddybear on a skateboard in Times Square."]) \
.image_text_embedding.clip(model_name='clip_vit_b32', modality='text') \
.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text') \
.show() .show()
``` ```
<img src="https://towhee.io/image-text-embedding/clip/raw/branch/main/vec1.png" alt="result1" style="height:20px;"/> <img src="https://towhee.io/image-text-embedding/clip/raw/branch/main/vec1.png" alt="result1" style="height:20px;"/>
@ -45,12 +45,12 @@ import towhee
towhee.glob['path']('./teddy.jpg') \ towhee.glob['path']('./teddy.jpg') \
.image_decode['path', 'img']() \ .image_decode['path', 'img']() \
.image_text_embedding.clip['img', 'vec'](model_name='clip_vit_b32', modality='image') \
.image_text_embedding.clip['img', 'vec'](model_name='clip_vit_base_patch16', modality='image') \
.select['img', 'vec']() \ .select['img', 'vec']() \
.show() .show()
towhee.dc['text'](["A teddybear on a skateboard in Times Square."]) \ towhee.dc['text'](["A teddybear on a skateboard in Times Square."]) \
.image_text_embedding.clip['text','vec'](model_name='clip_vit_b32', modality='text') \
.image_text_embedding.clip['text','vec'](model_name='clip_vit_base_patch16', modality='text') \
.select['text', 'vec']() \ .select['text', 'vec']() \
.show() .show()
``` ```
@ -112,7 +112,7 @@ Save model to local with specified format.
```python ```python
from towhee import ops from towhee import ops
op = ops.image_text_embedding.clip(model_name='clip_vit_base_16', modality='image').get_op()
op = ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image').get_op()
op.save_model('onnx', 'test.onnx') op.save_model('onnx', 'test.onnx')
``` ```
<br /> <br />
@ -145,7 +145,7 @@ Get a list of all supported model names or supported model names for specified m
from towhee import ops from towhee import ops
op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_16', modality='image').get_op()
op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image').get_op()
full_list = op.supported_model_names() full_list = op.supported_model_names()
onnx_list = op.supported_model_names(format='onnx') onnx_list = op.supported_model_names(format='onnx')
print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}') print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}')
@ -164,7 +164,7 @@ If you want to train this operator, besides dependency in requirements.txt, you
```python ```python
import towhee import towhee
clip_op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_16', modality='image').get_op()
clip_op = towhee.ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image').get_op()
data_args = { data_args = {
'dataset_name': 'ydshieh/coco_dataset_script', 'dataset_name': 'ydshieh/coco_dataset_script',

8
clip.py

@ -172,7 +172,13 @@ class Clip(NNOperator):
raise AttributeError('Unsupported model_type.') raise AttributeError('Unsupported model_type.')
if self.modality == 'image': if self.modality == 'image':
sz = self.processor.feature_extractor.crop_size sz = self.processor.feature_extractor.crop_size
dummy_input = Image.new('RGB', (sz, sz), color = 'red')
if isinstance(sz, int):
h = sz
w = sz
elif isinstance(sz, dict):
h = sz['height']
w = sz['width']
dummy_input = Image.new('RGB', (w, h), color = 'red')
inputs = self.processor(images=dummy_input, return_tensors='pt') # a dictionary inputs = self.processor(images=dummy_input, return_tensors='pt') # a dictionary
elif self.modality == 'text': elif self.modality == 'text':
dummy_input = 'dummy' dummy_input = 'dummy'

Loading…
Cancel
Save