|
@ -24,7 +24,12 @@ from towhee.types.arg import arg, to_image_color |
|
|
from towhee import register |
|
|
from towhee import register |
|
|
from towhee.models import isc |
|
|
from towhee.models import isc |
|
|
import sys |
|
|
import sys |
|
|
# from towhee.dc2 import accelerate |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
from towhee import accelerate |
|
|
|
|
|
except: |
|
|
|
|
|
def accelerate(func): |
|
|
|
|
|
return func |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch import nn |
|
@ -44,7 +49,7 @@ log = logging.getLogger('isc_op') |
|
|
_ = sys.modules[__name__] |
|
|
_ = sys.modules[__name__] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @accelerate |
|
|
|
|
|
|
|
|
@accelerate |
|
|
class Model: |
|
|
class Model: |
|
|
def __init__(self, timm_backbone, checkpoint_path, device): |
|
|
def __init__(self, timm_backbone, checkpoint_path, device): |
|
|
self.device = device |
|
|
self.device = device |
|
|