diff --git a/auto_transformers.py b/auto_transformers.py index 1727b03..5d12fb9 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -56,7 +56,7 @@ def create_model(model_name, checkpoint_path, device): return model -# @accelerate +@accelerate class Model: def __init__(self, model_name, checkpoint_path, device): self.device = device