diff --git a/README.md b/README.md index 3e6481c..5c88350 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,78 @@ -# expansionnet-v2 +# Image Captioning with ExpansionNet v2 +*author: David Wang* + + +
+ + +## Description + +This operator generates the caption with [ExpansionNet v2](https://arxiv.org/abs/2208.06551) which describes the content of the given image. ExpansionNet v2 introduces the Block Static Expansion which distributes and processes the input over a heterogeneous and arbitrarily big collection of sequences characterized by a different length compared to the input one. This is an adaptation from [jchenghu/ExpansionNet_v2](https://github.com/jchenghu/expansionnet_v2). + + +
+ + +## Code Example + +Load an image from path './image.jpg' to generate the caption. + + *Write the pipeline in simplified style*: + +```python +import towhee + +towhee.glob('./image.jpg') \ + .image_decode() \ + .image_captioning.expansionnet_v2(model_name='expansionnet_rf') \ + .show() +``` +result1 + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +import towhee + +towhee.glob['path']('./image.jpg') \ + .image_decode['path', 'img']() \ + .image_captioning.expansionnet_v2['img', 'text'](model_name='expansionnet_rf') \ + .select['img', 'text']() \ + .show() +``` +result2 + + +
+ + +## Factory Constructor + +Create the operator via the following factory method + +***expansionnet_v2(model_name)*** + +**Parameters:** + +​ ***model_name:*** *str* + +​ The model name of ExpansionNet v2. Supported model names: +- expansionnet_rf + +
+ +## Interface + +An image captioning operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption. + + +**Parameters:** + +​ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* + +​ The image to generate caption. + +**Returns:** *str* + +​ The caption generated by model. diff --git a/cap.png b/cap.png new file mode 100644 index 0000000..ea2e572 Binary files /dev/null and b/cap.png differ diff --git a/expansionnet_v2.py b/expansionnet_v2.py index 90a7e12..4902073 100644 --- a/expansionnet_v2.py +++ b/expansionnet_v2.py @@ -82,7 +82,8 @@ class ExpansionNetV2(NNOperator): max_seq_len=max_seq_len, drop_args=model_args.drop_args, rank='cpu') - checkpoint = torch.load('{}/weights/rf_model.pth'.format(path), map_location=torch.device('cpu')) + cfg = self.model._configs()[model_name] + checkpoint = torch.load('{}/weights/{}'.format(cfg,cfg['weights']), map_location=torch.device('cpu')) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.to(self.device) @@ -132,3 +133,9 @@ class ExpansionNetV2(NNOperator): pred = ' '.join(pred).capitalize() return pred + + def _configs(self): + config = {} + config['expansionnet_rf'] = {} + config['expansionnet_rf']['weights'] = 'rf_model.pth' + return config diff --git a/tabular.png b/tabular.png new file mode 100644 index 0000000..5b31ef8 Binary files /dev/null and b/tabular.png differ