|
@ -41,6 +41,13 @@ warnings.filterwarnings('ignore') |
|
|
log = logging.getLogger('timm_op') |
|
|
log = logging.getLogger('timm_op') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def torch_no_grad(f): |
|
|
|
|
|
def wrap(*args, **kwargs): |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
f(*args, **kwargs) |
|
|
|
|
|
return wrap |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @accelerate |
|
|
# @accelerate |
|
|
class Model: |
|
|
class Model: |
|
|
def __init__(self, model_name, device, num_classes): |
|
|
def __init__(self, model_name, device, num_classes): |
|
@ -88,6 +95,7 @@ class TimmImage(NNOperator): |
|
|
log.warning('The operator is initialized without specified model.') |
|
|
log.warning('The operator is initialized without specified model.') |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
@torch_no_grad |
|
|
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): |
|
|
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): |
|
|
if not isinstance(data, list): |
|
|
if not isinstance(data, list): |
|
|
imgs = [data] |
|
|
imgs = [data] |
|
@ -270,3 +278,12 @@ class TimmImage(NNOperator): |
|
|
return ['onnx'] |
|
|
return ['onnx'] |
|
|
else: |
|
|
else: |
|
|
return [] |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
def input_schema(self): |
|
|
|
|
|
return [(Image, (-1, -1, 3))] |
|
|
|
|
|
|
|
|
|
|
|
def output_schema(self): |
|
|
|
|
|
image = Image(numpy.random.randn(480, 480, 3), "RGB") |
|
|
|
|
|
ret = self(image) |
|
|
|
|
|
data_type = type(ret.reshape(-1)[0]) |
|
|
|
|
|
return [(data_type, ret.shape)] |
|
|