logo
Browse Source

refine isc training in readme

main
ChengZi 2 years ago
parent
commit
70f235293a
  1. 22
      README.md

22
README.md

@ -101,13 +101,22 @@ op.save_model('onnx', 'test.onnx')
### Requirements ### Requirements
If you want to fine tune this operator, make sure your timm version is 0.4.12, and the higher version would cause model collapse during training.
If you want to fine tune this operator, make sure your timm version is 0.4.12, and the higher version may cause model collapse during training.
```python ```python
! python -m pip install tqdm augly timm==0.4.12 pytorch-metric-learning==0.9.99 ! python -m pip install tqdm augly timm==0.4.12 pytorch-metric-learning==0.9.99
``` ```
### Download dataset ### Download dataset
Your need to download dataset from [Facebook AI Image Similarity Challenge: Descriptor Track](https://www.drivendata.org/competitions/79/competition-image-similarity-1-dev/), and use the `./training_images` folder as your training images root. It requires about 165G space.
ISC is trained using [contrastive learning](https://lilianweng.github.io/posts/2021-05-31-contrastive/), which is a type of self-supervised training. The training images do not require any labels. We only need to prepare a folder `./training_images`, under which a large number of diverse training images can be stored.
In the original training of [ISC21-Descriptor-Track-1st](https://github.com/lyakaap/ISC21-Descriptor-Track-1st), the training dataset is a huge dataset which takes more than 165G space. And it uses [multi-steps training strategy](https://arxiv.org/abs/2104.00298).
In our fine-tune example, to simplification, we prepare a small dataset to run, and you can replace it with your own custom dataset.
```python
! curl -L https://github.com/towhee-io/examples/releases/download/data/isc_training_image_examples.zip -o ./training_images.zip
! unzip -q -o ./training_images.zip
```
### Get started to fine-tune ### Get started to fine-tune
Just call method op.train() and pass in your args. Just call method op.train() and pass in your args.
@ -121,7 +130,7 @@ op.train(training_args={
'output_dir': './output', 'output_dir': './output',
'gpu': 0, 'gpu': 0,
'epochs': 2, 'epochs': 2,
'batch_size': 128,
'batch_size': 8,
'init_lr': 0.1 'init_lr': 0.1
}) })
``` ```
@ -158,6 +167,11 @@ These are each arg infos in training_args:
- default: None - default: None
- metadata_dict: {'help': 'The dir containing all training data image files.'} - metadata_dict: {'help': 'The dir containing all training data image files.'}
### Load trained model
```python
new_op = towhee.ops.image_embedding.isc(checkpoint_path='./output/checkpoint_epoch0001.pth.tar').get_op()
```
### Your custom training ### Your custom training
Your can change [training script](https://towhee.io/image-embedding/isc/src/branch/main/train_isc.py) in your way.
Your can change [training script](https://towhee.io/image-embedding/isc/src/branch/main/train_isc.py) in your custom way.
Or your can refer to the [original repo](https://github.com/lyakaap/ISC21-Descriptor-Track-1st) and [paper](https://arxiv.org/abs/2112.04323) to learn more about contrastive learning and image instance retrieval. Or your can refer to the [original repo](https://github.com/lyakaap/ISC21-Descriptor-Track-1st) and [paper](https://arxiv.org/abs/2112.04323) to learn more about contrastive learning and image instance retrieval.
Loading…
Cancel
Save