diff --git a/README.md b/README.md index 343956f..825089c 100644 --- a/README.md +++ b/README.md @@ -111,3 +111,60 @@ from towhee import ops op = ops.image_embedding.isc().get_op() op.save_model('onnx', 'test.onnx') ``` + +## Fine-tune + +If you want to fine tune this operator, make sure your timm version is 0.4.12, and the higher version would cause model collapse during training. + +### Get start to fine-tune +Just call method op.train() and pass in your args. +```python +import towhee + +op = towhee.ops.image_embedding.isc().get_op() +op.train(training_args={ + 'train_data_dir': './your_training_images', + 'distributed': False, + 'output_dir': './output', + 'gpu': 0, + 'epochs': 2, + 'batch_size': 128, + 'init_lr': 0.1 + }) +``` +These are each arg infos in training_args: +- output_dir + - default: './output' + - metadata_dict: {'help': 'output checkpoint saving dir.'} + +- distributed + - default: False + - metadata_dict: {'help': 'If true, use all gpu in your machine, else use only one gpu.'} + +- gpu + - default: 0 + - metadata_dict: {'help': 'When distributed is False, use this gpu No. in your machine.'} + +- start_epoch + - default: 0 + - metadata_dict: {'help': 'Start epoch number.'} + +- epochs + - default: 6 + - metadata_dict: {'help': 'End epoch number.'} + +- batch_size + - default: 128 + - metadata_dict: {'help': 'Total batch size in all gpu.'} + +- init_lr + - default: 0.1 + - metadata_dict: {'help': 'init learning rate in SGD.'} + +- train_data_dir + - default: None + - metadata_dict: {'help': 'The dir containing all training data image files.'} + +### Your custom training +Your can change [training script](https://towhee.io/image-embedding/isc/src/branch/main/train_isc.py) in your way. +Or your can refer to the [original repo](https://github.com/lyakaap/ISC21-Descriptor-Track-1st) and [paper](https://arxiv.org/abs/2112.04323) to learn more about contrastive learning and image instance retrieval. \ No newline at end of file