logo
Browse Source

update the opearator for newer transformers.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
06472ed604
  1. 4
      README.md
  2. 6
      models/blip.py
  3. 6
      models/med.py

4
README.md

@ -68,14 +68,14 @@ Create the operator via the following factory method
## Interface ## Interface
An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption.
An image captioning operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption.
**Parameters:** **Parameters:**
***img:*** *towhee.types.Image (a sub-class of numpy.ndarray)* ***img:*** *towhee.types.Image (a sub-class of numpy.ndarray)*
​ The image to generate embedding.
​ The image to generate caption.

6
models/blip.py

@ -136,7 +136,7 @@ class BLIP_Decoder(nn.Module):
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
prompt = [self.prompt] * image.size(0) prompt = [self.prompt] * image.size(0)
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
@ -154,7 +154,7 @@ class BLIP_Decoder(nn.Module):
eos_token_id=self.tokenizer.sep_token_id, eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=1.1, repetition_penalty=1.1,
**model_kwargs)
**kwargs)
else: else:
#beam search #beam search
outputs = self.text_decoder.generate(input_ids=input_ids, outputs = self.text_decoder.generate(input_ids=input_ids,
@ -164,7 +164,7 @@ class BLIP_Decoder(nn.Module):
eos_token_id=self.tokenizer.sep_token_id, eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
**model_kwargs)
**kwargs)
captions = [] captions = []
for output in outputs: for output in outputs:

6
models/med.py

@ -929,7 +929,7 @@ class BertLMHeadModel(BertPreTrainedModel):
cross_attentions=outputs.cross_attentions, cross_attentions=outputs.cross_attentions,
) )
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
input_shape = input_ids.shape input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None: if attention_mask is None:
@ -943,8 +943,8 @@ class BertLMHeadModel(BertPreTrainedModel):
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"past_key_values": past, "past_key_values": past,
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
"encoder_hidden_states": kwargs.get("encoder_hidden_states", None),
"encoder_attention_mask": kwargs.get("encoder_attention_mask", None),
"is_decoder": True, "is_decoder": True,
} }

Loading…
Cancel
Save