Browse Source
update the opearator for newer transformers.
Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb
2 years ago
3 changed files with
8 additions and
8 deletions
-
README.md
-
models/blip.py
-
models/med.py
|
|
@ -68,14 +68,14 @@ Create the operator via the following factory method |
|
|
|
|
|
|
|
## Interface |
|
|
|
|
|
|
|
An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption. |
|
|
|
An image captioning operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption. |
|
|
|
|
|
|
|
|
|
|
|
**Parameters:** |
|
|
|
|
|
|
|
***img:*** *towhee.types.Image (a sub-class of numpy.ndarray)* |
|
|
|
|
|
|
|
The image to generate embedding. |
|
|
|
The image to generate caption. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -136,7 +136,7 @@ class BLIP_Decoder(nn.Module): |
|
|
|
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) |
|
|
|
|
|
|
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
|
|
|
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} |
|
|
|
kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} |
|
|
|
|
|
|
|
prompt = [self.prompt] * image.size(0) |
|
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) |
|
|
@ -154,7 +154,7 @@ class BLIP_Decoder(nn.Module): |
|
|
|
eos_token_id=self.tokenizer.sep_token_id, |
|
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
|
repetition_penalty=1.1, |
|
|
|
**model_kwargs) |
|
|
|
**kwargs) |
|
|
|
else: |
|
|
|
#beam search |
|
|
|
outputs = self.text_decoder.generate(input_ids=input_ids, |
|
|
@ -164,7 +164,7 @@ class BLIP_Decoder(nn.Module): |
|
|
|
eos_token_id=self.tokenizer.sep_token_id, |
|
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
|
repetition_penalty=repetition_penalty, |
|
|
|
**model_kwargs) |
|
|
|
**kwargs) |
|
|
|
|
|
|
|
captions = [] |
|
|
|
for output in outputs: |
|
|
|
|
|
@ -929,7 +929,7 @@ class BertLMHeadModel(BertPreTrainedModel): |
|
|
|
cross_attentions=outputs.cross_attentions, |
|
|
|
) |
|
|
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): |
|
|
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs): |
|
|
|
input_shape = input_ids.shape |
|
|
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly |
|
|
|
if attention_mask is None: |
|
|
@ -943,8 +943,8 @@ class BertLMHeadModel(BertPreTrainedModel): |
|
|
|
"input_ids": input_ids, |
|
|
|
"attention_mask": attention_mask, |
|
|
|
"past_key_values": past, |
|
|
|
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), |
|
|
|
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), |
|
|
|
"encoder_hidden_states": kwargs.get("encoder_hidden_states", None), |
|
|
|
"encoder_attention_mask": kwargs.get("encoder_attention_mask", None), |
|
|
|
"is_decoder": True, |
|
|
|
} |
|
|
|
|
|
|
|