From 06472ed6047a350f45f069cb88215cf781d2591e Mon Sep 17 00:00:00 2001 From: wxywb Date: Tue, 27 Sep 2022 14:37:36 +0800 Subject: [PATCH] update the opearator for newer transformers. Signed-off-by: wxywb --- README.md | 4 ++-- models/blip.py | 6 +++--- models/med.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 194851b..931ac27 100644 --- a/README.md +++ b/README.md @@ -68,14 +68,14 @@ Create the operator via the following factory method ## Interface -An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption. +An image captioning operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption. **Parameters:** ​ ***img:*** *towhee.types.Image (a sub-class of numpy.ndarray)* -​ The image to generate embedding. +​ The image to generate caption. diff --git a/models/blip.py b/models/blip.py index cb8e1b3..efbf85c 100644 --- a/models/blip.py +++ b/models/blip.py @@ -136,7 +136,7 @@ class BLIP_Decoder(nn.Module): image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) - model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} + kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} prompt = [self.prompt] * image.size(0) input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) @@ -154,7 +154,7 @@ class BLIP_Decoder(nn.Module): eos_token_id=self.tokenizer.sep_token_id, pad_token_id=self.tokenizer.pad_token_id, repetition_penalty=1.1, - **model_kwargs) + **kwargs) else: #beam search outputs = self.text_decoder.generate(input_ids=input_ids, @@ -164,7 +164,7 @@ class BLIP_Decoder(nn.Module): eos_token_id=self.tokenizer.sep_token_id, pad_token_id=self.tokenizer.pad_token_id, repetition_penalty=repetition_penalty, - **model_kwargs) + **kwargs) captions = [] for output in outputs: diff --git a/models/med.py b/models/med.py index 7b00a35..d960323 100644 --- a/models/med.py +++ b/models/med.py @@ -929,7 +929,7 @@ class BertLMHeadModel(BertPreTrainedModel): cross_attentions=outputs.cross_attentions, ) - def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs): input_shape = input_ids.shape # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: @@ -943,8 +943,8 @@ class BertLMHeadModel(BertPreTrainedModel): "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past, - "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), - "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "encoder_hidden_states": kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": kwargs.get("encoder_attention_mask", None), "is_decoder": True, }