You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Files and versions

66 lines
2.9 KiB

# -*- coding: utf-8 -*-
import torch
import more_itertools
from tqdm import tqdm
class Predictor:
def __init__(self, clip_model, clip_processor, device, templates=None, bs=8, quiet=False):
self.device = device
self.clip_model = clip_model.to(self.device)
self.clip_processor = clip_processor
self.bs = bs
self.quiet = quiet
self.templates = templates or [
'фото, на котором изображено {}',
'изображение с {}',
'картинка с {}',
'фото с {}',
'на фото видно {}',
def get_text_latents(self, class_labels):
text_latents = []
for template in self.templates:
_text_latents = []
for chunk in more_itertools.chunked(class_labels, self.bs):
texts = [template.format(class_label.lower().strip()) for class_label in chunk]
inputs = self.clip_processor(text=texts, return_tensors='pt', padding=True)
text_latents.append(torch.cat(_text_latents, dim=0))
text_latents = torch.stack(text_latents).mean(0)
text_latents = text_latents / text_latents.norm(dim=-1, keepdim=True)
return text_latents
def run(self, images, text_latents):
if not self.quiet:
pbar = tqdm()
labels = []
logit_scale = self.clip_model.logit_scale.exp()
for pil_images in more_itertools.chunked(images, self.bs):
inputs = self.clip_processor(text='', images=list(pil_images), return_tensors='pt', padding=True)
image_latents = self.clip_model.encode_image(inputs['pixel_values'].to(self.device))
image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True)
logits_per_text = torch.matmul(text_latents.to(self.device), image_latents.t()) * logit_scale
_labels = logits_per_text.argmax(0).cpu().numpy().tolist()
if not self.quiet:
return labels
def get_image_latents(self, images):
if not self.quiet:
pbar = tqdm()
image_latents = []
for pil_images in more_itertools.chunked(images, self.bs):
inputs = self.clip_processor(text='', images=list(pil_images), return_tensors='pt', padding=True)
if not self.quiet:
image_latents = torch.cat(image_latents)
image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True)
return image_latents