diff --git a/README.md b/README.md index 987c5c8..329a366 100644 --- a/README.md +++ b/README.md @@ -17,28 +17,19 @@ This operator generates the caption with [CLIPReward](https://arxiv.org/abs/2205 Load an image from path './animals.jpg' to generate the caption. - *Write the pipeline in simplified style*: +*Write a pipeline with explicit inputs/outputs name specifications:* ```python -import towhee +from towhee.dc2 import pipe, ops, DataCollection -towhee.glob('./animals.jpg') \ - .image_decode() \ - .image_captioning.clip_caption_reward(model_name='clipRN50_clips_grammar') \ - .show() -``` -result1 - -*Write a same pipeline with explicit inputs/outputs name specifications:* - -```python -import towhee +p = ( + pipe.input('url') + .map('url', 'img', ops.image_decode.cv2_rgb()) + .map('img', 'text', ops.image_captioning.clip_caption_reward(model_name='clipRN50_clips_grammar')) + .output('img', 'text') +) -towhee.glob['path']('./animals.jpg') \ - .image_decode['path', 'img']() \ - .image_captioning.clip_caption_reward['img', 'text'](model_name='clipRN50_clips_grammar') \ - .select['img', 'text']() \ - .show() +DataCollection(p('./animals.jpg')).show() ``` result2 diff --git a/clip_caption_reward.py b/clip_caption_reward.py index 2deb197..89a2472 100644 --- a/clip_caption_reward.py +++ b/clip_caption_reward.py @@ -80,6 +80,7 @@ class ClipCaptionReward(NNOperator): self.model = TransformerModel(opt) self.model.load_state_dict(raw_state_dict) + self.model.to(self.device) self.image_mean = ( torch.Tensor([0.48145466, 0.4578275, 0.40821073])