logo
Browse Source

fix import problem

main
zhang chen 4 years ago
parent
commit
959a2a1231
  1. 12
      efficientnet_image_embedding.py

12
efficientnet_image_embedding.py

@ -23,6 +23,7 @@ import numpy
from towhee.operator import Operator from towhee.operator import Operator
from timm.data import resolve_data_config from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform from timm.data.transforms_factory import create_transform
import os
class EfficientnetImageEmbedding(Operator): class EfficientnetImageEmbedding(Operator):
""" """
@ -37,11 +38,14 @@ class EfficientnetImageEmbedding(Operator):
def __init__(self, model_name: str = '', framework: str = 'pytorch', weights_path: str = None) -> None: def __init__(self, model_name: str = '', framework: str = 'pytorch', weights_path: str = None) -> None:
model_name = model_name.replace('efficientnet-b', 'tf_efficientnet_b') model_name = model_name.replace('efficientnet-b', 'tf_efficientnet_b')
super().__init__() super().__init__()
sys.path.append(str(Path(__file__).parent))
if framework == 'pytorch': if framework == 'pytorch':
import pytorch
from pytorch.model import Model
self.model = Model(model_name, weights_path)
import importlib.util
path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py')
opname = os.path.basename(str(Path(__file__))).split('.')[0]
spec = importlib.util.spec_from_file_location(opname, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
self.model = module.Model(model_name, weights_path)
config = resolve_data_config({}, model=self.model._model) config = resolve_data_config({}, model=self.model._model)
self.tfms = create_transform(**config) self.tfms = create_transform(**config)

Loading…
Cancel
Save