logo
Browse Source

update the camel.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 3 years ago
parent
commit
90d37da15f
  1. 29
      camel.py
  2. 19
      utils/__init__.py
  3. 164
      utils/logger.py
  4. 6
      utils/typing.py
  5. 58
      utils/utils.py

29
camel.py

@ -25,7 +25,8 @@ from towhee.types.arg import arg, to_image_color
from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag
from towhee import register
from towhee.models import clip
from towhee.command.s3 import S3Bucket
class Camel(NNOperator):
"""
@ -33,8 +34,8 @@ class Camel(NNOperator):
"""
def _gen_args(self):
args = edict()
args.image_dim =
args.N_enc = 3
args.N_dec = 3
args.d_model = 512
args.d_ff = 2048
args.head = 8
@ -48,21 +49,27 @@ class Camel(NNOperator):
super().__init__()
sys.path.append(str(Path(__file__).parent))
self.device = "cuda" if torch.cuda.is_available() else "cpu"
from models import Captioner
from models import Captioner, clip
from data import ImageField, TextField
from models import clip
# Pipeline for text
self.text_field = TextField()
args = self._gen_args()
self.clip_model = clip.create_model(model_name='clip_resnet_r50x4', pretrained=True, jit=True)
self.clip_tfms = clip.get_transforms(model_name='clip_resnet_r50x4')
path = str(Path(__file__).parent)
self.clip_model, self.clip_tfms = clip.load('RN50x16', jit=False)
#import ipdb
#ipdb.set_trace()
self.image_model = self.clip_model.visual
self.image_model.forward = self.image_model.intermediate_features
image_field = ImageField(transform=self.clip_tfms)
args.image_dim = self.mage_model.embed_dim
args.image_dim = self.image_model.embed_dim
config = self._configs()[model_name]
s3_bucket = S3Bucket()
s3_bucket.download_file(config['weights'], path + '/weights/')
model_path = path + '/weights/' + os.path.basename(config['weights'])
# Create the model
self.model = Captioner(args, self.text_field).to(self.device)
self.model.forward = self.model.beam_search
@ -105,10 +112,10 @@ class Camel(NNOperator):
def _configs(self):
config = {}
config['clipcap_coco'] = {}
config['clipcap_coco']['weights'] = 'coco_weights.pt'
config['clipcap_conceptual'] = {}
config['clipcap_conceptual']['weights'] = 'conceptual_weights.pt'
config['camel_nomesh'] = {}
config['camel_nomesh']['weights'] = 's3://pretrainedweights.towhee.io/image-captioning/camel/camel_nomesh.pth'
config['camel_mesh'] = {}
config['camel_mesh']['weights'] = 's3://pretrainedweights.towhee.io/image-captioning/camel/camel_mesh.pth'
return config
if __name__ == '__main__':

19
utils/__init__.py

@ -0,0 +1,19 @@
from .logger import *
from .typing import *
from .utils import *
def get_batch_size(x: TensorOrSequence) -> int:
if isinstance(x, torch.Tensor):
b_s = x.size(0)
else:
b_s = x[0].size(0)
return b_s
def get_device(x: TensorOrSequence) -> int:
if isinstance(x, torch.Tensor):
b_s = x.device
else:
b_s = x[0].device
return b_s

164
utils/logger.py

@ -0,0 +1,164 @@
from collections import defaultdict, deque
import datetime
import time
import torch
import torch.distributed as dist
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {}'.format(header, total_time_str))
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True

6
utils/typing.py

@ -0,0 +1,6 @@
from typing import Union, Sequence
import torch
TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor]
TensorOrNone = Union[torch.Tensor, None]

58
utils/utils.py

@ -0,0 +1,58 @@
import sys
import os
import shutil
import warnings
import requests
import pidfile
from contextlib import contextmanager
from time import sleep
@contextmanager
def exclusive(pidname):
done = False
while not done:
try:
with pidfile.PIDFile(pidname):
yield
done = True
except pidfile.AlreadyRunningError:
sleep(5)
def download_from_url(url, path):
"""Download file, with logic (from tensor2tensor) for Google Drive"""
if 'drive.google.com' not in url:
print('Downloading %s; may take a few minutes' % url)
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'})
with open(path, "wb") as file:
file.write(r.content)
return
print('Downloading from Google Drive; may take a few minutes')
confirm_token = None
session = requests.Session()
response = session.get(url, stream=True)
for k, v in response.cookies.items():
if k.startswith("download_warning"):
confirm_token = v
if confirm_token:
url = url + "&confirm=" + confirm_token
response = session.get(url, stream=True)
chunk_size = 16 * 1024
with open(path, "wb") as f:
for chunk in response.iter_content(chunk_size):
if chunk:
f.write(chunk)
class DummyFile(object):
def write(self, x): pass
@contextmanager
def nostdout():
save_stdout = sys.stdout
sys.stdout = DummyFile()
yield
sys.stdout = save_stdout
Loading…
Cancel
Save