diff --git a/.DS_Store b/.DS_Store
index 0f450d5..4412f06 100644
Binary files a/.DS_Store and b/.DS_Store differ
diff --git a/README.md b/README.md
index 6a4a1b3..8844b89 100644
--- a/README.md
+++ b/README.md
@@ -25,7 +25,7 @@ import towhee
towhee.glob('./image.jpg') \
.image_decode() \
- .image_captioning.magic(model_name='expansionnet_rf') \
+ .image_captioning.magic(model_name='magic_mscoco') \
.show()
```
@@ -37,11 +37,11 @@ import towhee
towhee.glob['path']('./image.jpg') \
.image_decode['path', 'img']() \
- .image_captioning.magic['img', 'text'](model_name='expansionnet_rf') \
+ .image_captioning.magic['img', 'text'](model_name='magic_mscoco') \
.select['img', 'text']() \
.show()
```
-
+
@@ -51,7 +51,7 @@ towhee.glob['path']('./image.jpg') \
Create the operator via the following factory method
-***expansionnet_v2(model_name)***
+***magic(model_name)***
**Parameters:**
@@ -64,16 +64,14 @@ Create the operator via the following factory method
## Interface
-An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption.
+An image captioning operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption.
**Parameters:**
***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)*
- The image to generate embedding.
-
-
+ The image to generate caption.
**Returns:** *str*
diff --git a/cap.png b/cap.png
new file mode 100644
index 0000000..0fbc5c7
Binary files /dev/null and b/cap.png differ
diff --git a/language_model/dataclass.py b/language_model/dataclass.py
index 3786dc5..740140c 100644
--- a/language_model/dataclass.py
+++ b/language_model/dataclass.py
@@ -2,7 +2,6 @@ import json
import random
import torch
import numpy as np
-import progressbar
from torch.nn.utils import rnn
class Data:
@@ -67,13 +66,9 @@ class Data:
res_token_list, res_token_id_list = [], []
n = len(lines)
- p = progressbar.ProgressBar(n)
- p.start()
for i in range(n):
- p.update(i)
text = lines[i].strip('\n')
self.process_one_text(text, res_token_list, res_token_id_list)
- p.finish()
print ('{} processed!'.format(path))
return res_token_list, res_token_id_list
diff --git a/language_model/simctg.py b/language_model/simctg.py
index 25599e9..234330c 100644
--- a/language_model/simctg.py
+++ b/language_model/simctg.py
@@ -1,7 +1,6 @@
import os
import sys
import operator
-from tqdm import tqdm
from operator import itemgetter
import torch
from torch import nn
@@ -13,9 +12,6 @@ from torch.nn import CrossEntropyLoss
from loss_func import contrastive_loss
from utlis import PlugAndPlayContrastiveDecodingOneStepFast
-import seaborn as sns
-import matplotlib.pyplot as plt
-import pandas as pd
import datetime
train_fct = CrossEntropyLoss()
diff --git a/language_model/train.py b/language_model/train.py
index f401e44..21823d5 100644
--- a/language_model/train.py
+++ b/language_model/train.py
@@ -8,7 +8,6 @@ import random
import numpy as np
import time
import logging
-import progressbar
import logging
logging.getLogger('transformers.generation_utils').disabled = True
diff --git a/language_model/trainer.py b/language_model/trainer.py
index e51a850..8ff5919 100644
--- a/language_model/trainer.py
+++ b/language_model/trainer.py
@@ -8,9 +8,7 @@ import random
import numpy as np
import time
import logging
-import progressbar
-import logging
logging.getLogger('transformers.generation_utils').disabled = True
def eval_model(args, model, data, cuda_available, device):
@@ -19,10 +17,7 @@ def eval_model(args, model, data, cuda_available, device):
val_loss, token_sum = 0., 0.
model.eval()
with torch.no_grad():
- p = progressbar.ProgressBar(eval_step)
- p.start()
for idx in range(eval_step):
- p.update(idx)
batch_input_tensor, batch_labels, _ = \
data.get_next_validation_batch(batch_size=dataset_batch_size, mode='test')
if cuda_available:
@@ -33,7 +28,6 @@ def eval_model(args, model, data, cuda_available, device):
one_val_token_sum = torch.sum(one_val_token_sum)
val_loss += one_val_loss.item()
token_sum += one_val_token_sum.item()
- p.finish()
model.train()
val_loss = val_loss / token_sum
return val_loss
diff --git a/requirements.txt b/requirements.txt
index e69de29..e31a54b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+torch
+torchvision
+numpy
+transformers
diff --git a/tab.png b/tab.png
new file mode 100644
index 0000000..6bb1f5d
Binary files /dev/null and b/tab.png differ