|
@ -34,7 +34,7 @@ class CollaborativeExperts(NNOperator): |
|
|
def __init__(self, config: Dict = None, weights_path: str = None, device: str = None): |
|
|
def __init__(self, config: Dict = None, weights_path: str = None, device: str = None): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
if weights_path is None: |
|
|
if weights_path is None: |
|
|
weights_path = str(Path(__file__).parent / 'trained_model.pth') |
|
|
|
|
|
|
|
|
weights_path = str(Path(__file__).parent / 'model_state_dict.pth') |
|
|
self.ce_net_model = collaborative_experts.create_model(config, weights_path, device) |
|
|
self.ce_net_model = collaborative_experts.create_model(config, weights_path, device) |
|
|
|
|
|
|
|
|
def __call__(self, experts, ind, text): |
|
|
def __call__(self, experts, ind, text): |
|
|