From 959a2a12319f46e4b5eb9aefd4d83eb4f879241e Mon Sep 17 00:00:00 2001 From: zhang chen Date: Fri, 24 Dec 2021 15:42:15 +0800 Subject: [PATCH] fix import problem --- efficientnet_image_embedding.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/efficientnet_image_embedding.py b/efficientnet_image_embedding.py index 1dccb50..1af7a5a 100644 --- a/efficientnet_image_embedding.py +++ b/efficientnet_image_embedding.py @@ -23,6 +23,7 @@ import numpy from towhee.operator import Operator from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform +import os class EfficientnetImageEmbedding(Operator): """ @@ -37,11 +38,14 @@ class EfficientnetImageEmbedding(Operator): def __init__(self, model_name: str = '', framework: str = 'pytorch', weights_path: str = None) -> None: model_name = model_name.replace('efficientnet-b', 'tf_efficientnet_b') super().__init__() - sys.path.append(str(Path(__file__).parent)) if framework == 'pytorch': - import pytorch - from pytorch.model import Model - self.model = Model(model_name, weights_path) + import importlib.util + path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') + opname = os.path.basename(str(Path(__file__))).split('.')[0] + spec = importlib.util.spec_from_file_location(opname, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self.model = module.Model(model_name, weights_path) config = resolve_data_config({}, model=self.model._model) self.tfms = create_transform(**config)