From f2eb6159a16c976e98359cef527201a3bd4cf278 Mon Sep 17 00:00:00 2001 From: ChengZi Date: Wed, 8 Feb 2023 17:50:57 +0800 Subject: [PATCH] refine train readme and script --- README.md | 17 +++++++++++++---- train_sts_task.py | 2 +- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 0b170d8..e30dde3 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ If None, it will return a full list of supported model names. ```python from towhee import ops -op = ops.sentence_embedding.sentence_transformers().get_op() +op = ops.sentence_embedding.sbert().get_op() full_list = op.supported_model_names() ``` @@ -113,22 +113,31 @@ import towhee import os from sentence_transformers import util -op = towhee.ops.sentence_embedding.sentence_transformers(model_name='nli-distilroberta-base-v2').get_op() +op = towhee.ops.sentence_embedding.sbert(model_name='nli-distilroberta-base-v2').get_op() sts_dataset_path = 'datasets/stsbenchmark.tsv.gz' if not os.path.exists(sts_dataset_path): util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path) + +model_save_path = './output' training_config = { 'sts_dataset_path': sts_dataset_path, 'train_batch_size': 16, 'num_epochs': 4, - 'model_save_path': './output' + 'model_save_path': model_save_path } op.train(training_config) + + +### Load trained weights +### You just need to init a new operator with the trained folder under `model_save_path`. +model_path = os.path.join(model_save_path, os.listdir(model_save_path)[-1]) +new_op = towhee.ops.sentence_embedding.sbert(model_name=model_path).get_op() + ``` ### Dive deep and customize your training -You can change the [training script](https://towhee.io/sentence-embedding/sentence_transformers/src/branch/main/train_sts_task.py) in your customer way. +You can change the [training script](https://towhee.io/sentence-embedding/sbert/src/branch/main/train_sts_task.py) in your customer way. Or your can refer to the original [sbert training guide](https://www.sbert.net/docs/training/overview.html) and [code example](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training) for more information. \ No newline at end of file diff --git a/train_sts_task.py b/train_sts_task.py index 869c96f..d6ead9b 100644 --- a/train_sts_task.py +++ b/train_sts_task.py @@ -27,7 +27,7 @@ def train_sts(model, training_config): model_save_path = training_config['model_save_path'] if not os.path.exists(model_save_path): os.mkdir(model_save_path) - model_save_path = os.path.join('training_stsbenchmark_continue_training-' + datetime.now().strftime( + model_save_path = os.path.join(model_save_path, 'training_stsbenchmark_continue_training-' + datetime.now().strftime( "%Y-%m-%d_%H-%M-%S"))