logo
Browse Source

refine train readme and script

main
ChengZi 1 year ago
parent
commit
f2eb6159a1
  1. 17
      README.md
  2. 2
      train_sts_task.py

17
README.md

@ -96,7 +96,7 @@ If None, it will return a full list of supported model names.
```python
from towhee import ops
op = ops.sentence_embedding.sentence_transformers().get_op()
op = ops.sentence_embedding.sbert().get_op()
full_list = op.supported_model_names()
```
@ -113,22 +113,31 @@ import towhee
import os
from sentence_transformers import util
op = towhee.ops.sentence_embedding.sentence_transformers(model_name='nli-distilroberta-base-v2').get_op()
op = towhee.ops.sentence_embedding.sbert(model_name='nli-distilroberta-base-v2').get_op()
sts_dataset_path = 'datasets/stsbenchmark.tsv.gz'
if not os.path.exists(sts_dataset_path):
util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)
model_save_path = './output'
training_config = {
'sts_dataset_path': sts_dataset_path,
'train_batch_size': 16,
'num_epochs': 4,
'model_save_path': './output'
'model_save_path': model_save_path
}
op.train(training_config)
### Load trained weights
### You just need to init a new operator with the trained folder under `model_save_path`.
model_path = os.path.join(model_save_path, os.listdir(model_save_path)[-1])
new_op = towhee.ops.sentence_embedding.sbert(model_name=model_path).get_op()
```
### Dive deep and customize your training
You can change the [training script](https://towhee.io/sentence-embedding/sentence_transformers/src/branch/main/train_sts_task.py) in your customer way.
You can change the [training script](https://towhee.io/sentence-embedding/sbert/src/branch/main/train_sts_task.py) in your customer way.
Or your can refer to the original [sbert training guide](https://www.sbert.net/docs/training/overview.html) and [code example](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training) for more information.

2
train_sts_task.py

@ -27,7 +27,7 @@ def train_sts(model, training_config):
model_save_path = training_config['model_save_path']
if not os.path.exists(model_save_path):
os.mkdir(model_save_path)
model_save_path = os.path.join('training_stsbenchmark_continue_training-' + datetime.now().strftime(
model_save_path = os.path.join(model_save_path, 'training_stsbenchmark_continue_training-' + datetime.now().strftime(
"%Y-%m-%d_%H-%M-%S"))

Loading…
Cancel
Save