Browse Source
refactor the readme.
Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb
2 years ago
2 changed files with
13 additions and
5 deletions
-
README.md
-
lightningdot.py
|
|
@ -35,8 +35,8 @@ towhee.dc(["A teddybear on a skateboard in Times Square."]) \ |
|
|
|
.image_text_embedding.lightningdot(modality='text') \ |
|
|
|
.show() |
|
|
|
``` |
|
|
|
<img src="https://towhee.io/towhee/lightningdot/raw/branch/main/vec1.png" alt="result1" style="height:20px;"/> |
|
|
|
<img src="https://towhee.io/towhee/lightningdot/raw/branch/main/vec2.png" alt="result2" style="height:20px;"/> |
|
|
|
<img src="https://towhee.io/image-text-embedding/lightningdot/raw/branch/main/vec1.png" alt="result1" style="height:20px;"/> |
|
|
|
<img src="https://towhee.io/image-text-embedding/lightningdot/raw/branch/main/vec2.png" alt="result2" style="height:20px;"/> |
|
|
|
|
|
|
|
*Write a same pipeline with explicit inputs/outputs name specifications:* |
|
|
|
|
|
|
@ -66,10 +66,18 @@ towhee.dc['text'](["A teddybear on a skateboard in Times Square."]) \ |
|
|
|
|
|
|
|
Create the operator via the following factory method |
|
|
|
|
|
|
|
***lightningdot(modality)*** |
|
|
|
***lightningdot(model_name, modality)*** |
|
|
|
|
|
|
|
**Parameters:** |
|
|
|
|
|
|
|
|
|
|
|
***model_name:*** *str* |
|
|
|
|
|
|
|
The model name of LightningDOT. Supported model names: |
|
|
|
- lightningdot_base |
|
|
|
- lightningdot_coco_ft |
|
|
|
- lightningdot_flickr_ft |
|
|
|
|
|
|
|
***modality:*** *str* |
|
|
|
|
|
|
|
Which modality(*image* or *text*) is used to generate the embedding. |
|
|
|
|
|
@ -79,7 +79,7 @@ class LightningDOT(NNOperator): |
|
|
|
state_dict[k[5:]] = state_dict.pop(k) |
|
|
|
else: |
|
|
|
state_dict.pop(k) |
|
|
|
bi_encoder.load_state_dict(state_dict, strict=True) |
|
|
|
self.bi_encoder.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
|
|
img_model, txt_model = self.bi_encoder.img_model, self.bi_encoder.txt_model |
|
|
|
img_model.eval() |
|
|
@ -138,7 +138,7 @@ class LightningDOT(NNOperator): |
|
|
|
vec = self._inference_from_text(data) |
|
|
|
else: |
|
|
|
raise ValueError("modality[{}] not implemented.".format(self._modality)) |
|
|
|
return vec.detach().cpu().numpy() |
|
|
|
return vec.detach().cpu().numpy().flatten() |
|
|
|
|
|
|
|
def _inference_from_text(self, data): |
|
|
|
ids = self.tokenizer.encode(data) |
|
|
|