# This script is hacked and modified from https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/sts/training_stsbenchmark_continue_training.py
# For more specified training tasks, please refer https://github.com/UKPLab/sentence-transformers/tree/master/examples/training
from torch . utils . data import DataLoader
import math
from sentence_transformers import SentenceTransformer , LoggingHandler , losses , InputExample
from sentence_transformers . evaluation import EmbeddingSimilarityEvaluator
import logging
from datetime import datetime
import os
import gzip
import csv
#### Just some code to print debug information to stdout
logging . basicConfig ( format = ' %(asctime)s - %(message)s ' ,
datefmt = ' % Y- % m- %d % H: % M: % S ' ,
level = logging . INFO ,
handlers = [ LoggingHandler ( ) ] )
#### /print debug information to stdout
def train_sts ( model , training_config ) :
sts_dataset_path = training_config [ ' sts_dataset_path ' ]
train_batch_size = training_config [ ' train_batch_size ' ]
num_epochs = training_config [ ' num_epochs ' ]
model_save_path = training_config [ ' model_save_path ' ]
if not os . path . exists ( model_save_path ) :
os . mkdir ( model_save_path )
model_save_path = os . path . join ( model_save_path , ' training_stsbenchmark_continue_training- ' + datetime . now ( ) . strftime (
" % Y- % m- %d _ % H- % M- % S " ) )
# Convert the dataset to a DataLoader ready for training
logging . info ( " Read STSbenchmark train dataset " )
train_samples = [ ]
dev_samples = [ ]
test_samples = [ ]
with gzip . open ( sts_dataset_path , ' rt ' , encoding = ' utf8 ' ) as fIn :
reader = csv . DictReader ( fIn , delimiter = ' \t ' , quoting = csv . QUOTE_NONE )
for row in reader :
score = float ( row [ ' score ' ] ) / 5.0 # Normalize score to range 0 ... 1
inp_example = InputExample ( texts = [ row [ ' sentence1 ' ] , row [ ' sentence2 ' ] ] , label = score )
if row [ ' split ' ] == ' dev ' :
dev_samples . append ( inp_example )
elif row [ ' split ' ] == ' test ' :
test_samples . append ( inp_example )
else :
train_samples . append ( inp_example )
train_dataloader = DataLoader ( train_samples , shuffle = True , batch_size = train_batch_size )
train_loss = losses . CosineSimilarityLoss ( model = model )
# Development set: Measure correlation between cosine score and gold labels
logging . info ( " Read STSbenchmark dev dataset " )
evaluator = EmbeddingSimilarityEvaluator . from_input_examples ( dev_samples , name = ' sts-dev ' )
# Configure the training. We skip evaluation in this example
warmup_steps = math . ceil ( len ( train_dataloader ) * num_epochs * 0.1 ) #10% of train data for warm-up
logging . info ( " Warmup-steps: {} " . format ( warmup_steps ) )
# Train the model
model . fit ( train_objectives = [ ( train_dataloader , train_loss ) ] ,
evaluator = evaluator ,
epochs = num_epochs ,
evaluation_steps = 1000 ,
warmup_steps = warmup_steps ,
output_path = model_save_path )
##############################################################################
#
# Load the stored model and evaluate its performance on STS benchmark dataset
#
##############################################################################
model = SentenceTransformer ( model_save_path )
test_evaluator = EmbeddingSimilarityEvaluator . from_input_examples ( test_samples , name = ' sts-test ' )
test_evaluator ( model , output_path = model_save_path )