From 90d37da15ffdefb74959124fd656da308f8ed0f5 Mon Sep 17 00:00:00 2001 From: wxywb Date: Thu, 17 Nov 2022 20:19:00 +0800 Subject: [PATCH] update the camel. Signed-off-by: wxywb --- camel.py | 29 ++++---- utils/__init__.py | 19 ++++++ utils/logger.py | 164 ++++++++++++++++++++++++++++++++++++++++++++++ utils/typing.py | 6 ++ utils/utils.py | 58 ++++++++++++++++ 5 files changed, 265 insertions(+), 11 deletions(-) create mode 100644 utils/__init__.py create mode 100644 utils/logger.py create mode 100644 utils/typing.py create mode 100644 utils/utils.py diff --git a/camel.py b/camel.py index f3683f9..63dcd69 100644 --- a/camel.py +++ b/camel.py @@ -25,7 +25,8 @@ from towhee.types.arg import arg, to_image_color from towhee.types.image_utils import to_pil from towhee.operator.base import NNOperator, OperatorFlag from towhee import register -from towhee.models import clip + +from towhee.command.s3 import S3Bucket class Camel(NNOperator): """ @@ -33,8 +34,8 @@ class Camel(NNOperator): """ def _gen_args(self): args = edict() - args.image_dim = args.N_enc = 3 + args.N_dec = 3 args.d_model = 512 args.d_ff = 2048 args.head = 8 @@ -48,21 +49,27 @@ class Camel(NNOperator): super().__init__() sys.path.append(str(Path(__file__).parent)) self.device = "cuda" if torch.cuda.is_available() else "cpu" - from models import Captioner + from models import Captioner, clip from data import ImageField, TextField + from models import clip # Pipeline for text self.text_field = TextField() args = self._gen_args() - - self.clip_model = clip.create_model(model_name='clip_resnet_r50x4', pretrained=True, jit=True) - self.clip_tfms = clip.get_transforms(model_name='clip_resnet_r50x4') + path = str(Path(__file__).parent) + self.clip_model, self.clip_tfms = clip.load('RN50x16', jit=False) + #import ipdb + #ipdb.set_trace() self.image_model = self.clip_model.visual self.image_model.forward = self.image_model.intermediate_features image_field = ImageField(transform=self.clip_tfms) - args.image_dim = self.mage_model.embed_dim + args.image_dim = self.image_model.embed_dim + config = self._configs()[model_name] + s3_bucket = S3Bucket() + s3_bucket.download_file(config['weights'], path + '/weights/') + model_path = path + '/weights/' + os.path.basename(config['weights']) # Create the model self.model = Captioner(args, self.text_field).to(self.device) self.model.forward = self.model.beam_search @@ -105,10 +112,10 @@ class Camel(NNOperator): def _configs(self): config = {} - config['clipcap_coco'] = {} - config['clipcap_coco']['weights'] = 'coco_weights.pt' - config['clipcap_conceptual'] = {} - config['clipcap_conceptual']['weights'] = 'conceptual_weights.pt' + config['camel_nomesh'] = {} + config['camel_nomesh']['weights'] = 's3://pretrainedweights.towhee.io/image-captioning/camel/camel_nomesh.pth' + config['camel_mesh'] = {} + config['camel_mesh']['weights'] = 's3://pretrainedweights.towhee.io/image-captioning/camel/camel_mesh.pth' return config if __name__ == '__main__': diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..5b59243 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,19 @@ +from .logger import * +from .typing import * +from .utils import * + + +def get_batch_size(x: TensorOrSequence) -> int: + if isinstance(x, torch.Tensor): + b_s = x.size(0) + else: + b_s = x[0].size(0) + return b_s + + +def get_device(x: TensorOrSequence) -> int: + if isinstance(x, torch.Tensor): + b_s = x.device + else: + b_s = x[0].device + return b_s diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..0406c89 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,164 @@ +from collections import defaultdict, deque +import datetime +import time +import torch +import torch.distributed as dist + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {}'.format(header, total_time_str)) + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True diff --git a/utils/typing.py b/utils/typing.py new file mode 100644 index 0000000..e71e2a4 --- /dev/null +++ b/utils/typing.py @@ -0,0 +1,6 @@ +from typing import Union, Sequence + +import torch + +TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor] +TensorOrNone = Union[torch.Tensor, None] diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..ea7e8df --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,58 @@ +import sys +import os +import shutil +import warnings +import requests +import pidfile +from contextlib import contextmanager +from time import sleep + +@contextmanager +def exclusive(pidname): + done = False + while not done: + try: + with pidfile.PIDFile(pidname): + yield + done = True + except pidfile.AlreadyRunningError: + sleep(5) + + +def download_from_url(url, path): + """Download file, with logic (from tensor2tensor) for Google Drive""" + if 'drive.google.com' not in url: + print('Downloading %s; may take a few minutes' % url) + r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) + with open(path, "wb") as file: + file.write(r.content) + return + print('Downloading from Google Drive; may take a few minutes') + confirm_token = None + session = requests.Session() + response = session.get(url, stream=True) + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + confirm_token = v + + if confirm_token: + url = url + "&confirm=" + confirm_token + response = session.get(url, stream=True) + + chunk_size = 16 * 1024 + with open(path, "wb") as f: + for chunk in response.iter_content(chunk_size): + if chunk: + f.write(chunk) + + +class DummyFile(object): + def write(self, x): pass + + +@contextmanager +def nostdout(): + save_stdout = sys.stdout + sys.stdout = DummyFile() + yield + sys.stdout = save_stdout