logo
Browse Source

update the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
b4b9e9f46c
  1. BIN
      .DS_Store
  2. 14
      README.md
  3. BIN
      cap.png
  4. 5
      language_model/dataclass.py
  5. 4
      language_model/simctg.py
  6. 1
      language_model/train.py
  7. 6
      language_model/trainer.py
  8. 4
      requirements.txt
  9. BIN
      tab.png

BIN
.DS_Store

Binary file not shown.

14
README.md

@ -25,7 +25,7 @@ import towhee
towhee.glob('./image.jpg') \ towhee.glob('./image.jpg') \
.image_decode() \ .image_decode() \
.image_captioning.magic(model_name='expansionnet_rf') \
.image_captioning.magic(model_name='magic_mscoco') \
.show() .show()
``` ```
<img src="./cap.png" alt="result1" style="height:20px;"/> <img src="./cap.png" alt="result1" style="height:20px;"/>
@ -37,11 +37,11 @@ import towhee
towhee.glob['path']('./image.jpg') \ towhee.glob['path']('./image.jpg') \
.image_decode['path', 'img']() \ .image_decode['path', 'img']() \
.image_captioning.magic['img', 'text'](model_name='expansionnet_rf') \
.image_captioning.magic['img', 'text'](model_name='magic_mscoco') \
.select['img', 'text']() \ .select['img', 'text']() \
.show() .show()
``` ```
<img src="./tabular.png" alt="result2" style="height:60px;"/>
<img src="./tab.png" alt="result2" style="height:60px;"/>
<br /> <br />
@ -51,7 +51,7 @@ towhee.glob['path']('./image.jpg') \
Create the operator via the following factory method Create the operator via the following factory method
***expansionnet_v2(model_name)***
***magic(model_name)***
**Parameters:** **Parameters:**
@ -64,16 +64,14 @@ Create the operator via the following factory method
## Interface ## Interface
An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption.
An image captioning operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption.
**Parameters:** **Parameters:**
***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)*
​ The image to generate embedding.
​ The image to generate caption.
**Returns:** *str* **Returns:** *str*

BIN
cap.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.3 KiB

5
language_model/dataclass.py

@ -2,7 +2,6 @@ import json
import random import random
import torch import torch
import numpy as np import numpy as np
import progressbar
from torch.nn.utils import rnn from torch.nn.utils import rnn
class Data: class Data:
@ -67,13 +66,9 @@ class Data:
res_token_list, res_token_id_list = [], [] res_token_list, res_token_id_list = [], []
n = len(lines) n = len(lines)
p = progressbar.ProgressBar(n)
p.start()
for i in range(n): for i in range(n):
p.update(i)
text = lines[i].strip('\n') text = lines[i].strip('\n')
self.process_one_text(text, res_token_list, res_token_id_list) self.process_one_text(text, res_token_list, res_token_id_list)
p.finish()
print ('{} processed!'.format(path)) print ('{} processed!'.format(path))
return res_token_list, res_token_id_list return res_token_list, res_token_id_list

4
language_model/simctg.py

@ -1,7 +1,6 @@
import os import os
import sys import sys
import operator import operator
from tqdm import tqdm
from operator import itemgetter from operator import itemgetter
import torch import torch
from torch import nn from torch import nn
@ -13,9 +12,6 @@ from torch.nn import CrossEntropyLoss
from loss_func import contrastive_loss from loss_func import contrastive_loss
from utlis import PlugAndPlayContrastiveDecodingOneStepFast from utlis import PlugAndPlayContrastiveDecodingOneStepFast
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import datetime import datetime
train_fct = CrossEntropyLoss() train_fct = CrossEntropyLoss()

1
language_model/train.py

@ -8,7 +8,6 @@ import random
import numpy as np import numpy as np
import time import time
import logging import logging
import progressbar
import logging import logging
logging.getLogger('transformers.generation_utils').disabled = True logging.getLogger('transformers.generation_utils').disabled = True

6
language_model/trainer.py

@ -8,9 +8,7 @@ import random
import numpy as np import numpy as np
import time import time
import logging import logging
import progressbar
import logging
logging.getLogger('transformers.generation_utils').disabled = True logging.getLogger('transformers.generation_utils').disabled = True
def eval_model(args, model, data, cuda_available, device): def eval_model(args, model, data, cuda_available, device):
@ -19,10 +17,7 @@ def eval_model(args, model, data, cuda_available, device):
val_loss, token_sum = 0., 0. val_loss, token_sum = 0., 0.
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
p = progressbar.ProgressBar(eval_step)
p.start()
for idx in range(eval_step): for idx in range(eval_step):
p.update(idx)
batch_input_tensor, batch_labels, _ = \ batch_input_tensor, batch_labels, _ = \
data.get_next_validation_batch(batch_size=dataset_batch_size, mode='test') data.get_next_validation_batch(batch_size=dataset_batch_size, mode='test')
if cuda_available: if cuda_available:
@ -33,7 +28,6 @@ def eval_model(args, model, data, cuda_available, device):
one_val_token_sum = torch.sum(one_val_token_sum) one_val_token_sum = torch.sum(one_val_token_sum)
val_loss += one_val_loss.item() val_loss += one_val_loss.item()
token_sum += one_val_token_sum.item() token_sum += one_val_token_sum.item()
p.finish()
model.train() model.train()
val_loss = val_loss / token_sum val_loss = val_loss / token_sum
return val_loss return val_loss

4
requirements.txt

@ -0,0 +1,4 @@
torch
torchvision
numpy
transformers

BIN
tab.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

Loading…
Cancel
Save