|
@ -136,7 +136,7 @@ class BLIP_Decoder(nn.Module): |
|
|
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) |
|
|
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) |
|
|
|
|
|
|
|
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
|
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
|
|
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} |
|
|
|
|
|
|
|
|
kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} |
|
|
|
|
|
|
|
|
prompt = [self.prompt] * image.size(0) |
|
|
prompt = [self.prompt] * image.size(0) |
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) |
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) |
|
@ -154,7 +154,7 @@ class BLIP_Decoder(nn.Module): |
|
|
eos_token_id=self.tokenizer.sep_token_id, |
|
|
eos_token_id=self.tokenizer.sep_token_id, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
repetition_penalty=1.1, |
|
|
repetition_penalty=1.1, |
|
|
**model_kwargs) |
|
|
|
|
|
|
|
|
**kwargs) |
|
|
else: |
|
|
else: |
|
|
#beam search |
|
|
#beam search |
|
|
outputs = self.text_decoder.generate(input_ids=input_ids, |
|
|
outputs = self.text_decoder.generate(input_ids=input_ids, |
|
@ -164,7 +164,7 @@ class BLIP_Decoder(nn.Module): |
|
|
eos_token_id=self.tokenizer.sep_token_id, |
|
|
eos_token_id=self.tokenizer.sep_token_id, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
repetition_penalty=repetition_penalty, |
|
|
repetition_penalty=repetition_penalty, |
|
|
**model_kwargs) |
|
|
|
|
|
|
|
|
**kwargs) |
|
|
|
|
|
|
|
|
captions = [] |
|
|
captions = [] |
|
|
for output in outputs: |
|
|
for output in outputs: |
|
|