diff --git a/s_bert.py b/s_bert.py index e431525..dfa71af 100644 --- a/s_bert.py +++ b/s_bert.py @@ -55,7 +55,7 @@ class ConvertModel(torch.nn.Module): return outs['sentence_embedding'] -# @accelerate +@accelerate class Model: def __init__(self, model_name, device): self.device = device