logo
Browse Source

refactor the readme.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
4147838cf8
  1. 14
      README.md
  2. 4
      lightningdot.py

14
README.md

@ -35,8 +35,8 @@ towhee.dc(["A teddybear on a skateboard in Times Square."]) \
.image_text_embedding.lightningdot(modality='text') \ .image_text_embedding.lightningdot(modality='text') \
.show() .show()
``` ```
<img src="https://towhee.io/towhee/lightningdot/raw/branch/main/vec1.png" alt="result1" style="height:20px;"/>
<img src="https://towhee.io/towhee/lightningdot/raw/branch/main/vec2.png" alt="result2" style="height:20px;"/>
<img src="https://towhee.io/image-text-embedding/lightningdot/raw/branch/main/vec1.png" alt="result1" style="height:20px;"/>
<img src="https://towhee.io/image-text-embedding/lightningdot/raw/branch/main/vec2.png" alt="result2" style="height:20px;"/>
*Write a same pipeline with explicit inputs/outputs name specifications:* *Write a same pipeline with explicit inputs/outputs name specifications:*
@ -66,10 +66,18 @@ towhee.dc['text'](["A teddybear on a skateboard in Times Square."]) \
Create the operator via the following factory method Create the operator via the following factory method
***lightningdot(modality)***
***lightningdot(model_name, modality)***
**Parameters:** **Parameters:**
***model_name:*** *str*
​ The model name of LightningDOT. Supported model names:
- lightningdot_base
- lightningdot_coco_ft
- lightningdot_flickr_ft
***modality:*** *str* ***modality:*** *str*
​ Which modality(*image* or *text*) is used to generate the embedding. ​ Which modality(*image* or *text*) is used to generate the embedding.

4
lightningdot.py

@ -79,7 +79,7 @@ class LightningDOT(NNOperator):
state_dict[k[5:]] = state_dict.pop(k) state_dict[k[5:]] = state_dict.pop(k)
else: else:
state_dict.pop(k) state_dict.pop(k)
bi_encoder.load_state_dict(state_dict, strict=True)
self.bi_encoder.load_state_dict(state_dict, strict=True)
img_model, txt_model = self.bi_encoder.img_model, self.bi_encoder.txt_model img_model, txt_model = self.bi_encoder.img_model, self.bi_encoder.txt_model
img_model.eval() img_model.eval()
@ -138,7 +138,7 @@ class LightningDOT(NNOperator):
vec = self._inference_from_text(data) vec = self._inference_from_text(data)
else: else:
raise ValueError("modality[{}] not implemented.".format(self._modality)) raise ValueError("modality[{}] not implemented.".format(self._modality))
return vec.detach().cpu().numpy()
return vec.detach().cpu().numpy().flatten()
def _inference_from_text(self, data): def _inference_from_text(self, data):
ids = self.tokenizer.encode(data) ids = self.tokenizer.encode(data)

Loading…
Cancel
Save