logo
Browse Source

add state_dict

main
ChengZi 2 years ago
parent
commit
d2057b790b
  1. 2
      .gitattributes
  2. 2
      collaborative_experts.py
  3. BIN
      model_state_dict.pth

2
.gitattributes

@ -1,4 +1,3 @@
*.7z filter=lfs diff=lfs merge=lfs -text *.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text *.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text *.bin filter=lfs diff=lfs merge=lfs -text
@ -26,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text
*.zstandard filter=lfs diff=lfs merge=lfs -text *.zstandard filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text
model_state_dict.pth filter=lfs diff=lfs merge=lfs -text

2
collaborative_experts.py

@ -34,7 +34,7 @@ class CollaborativeExperts(NNOperator):
def __init__(self, config: Dict = None, weights_path: str = None, device: str = None): def __init__(self, config: Dict = None, weights_path: str = None, device: str = None):
super().__init__() super().__init__()
if weights_path is None: if weights_path is None:
weights_path = str(Path(__file__).parent / 'trained_model.pth')
weights_path = str(Path(__file__).parent / 'model_state_dict.pth')
self.ce_net_model = collaborative_experts.create_model(config, weights_path, device) self.ce_net_model = collaborative_experts.create_model(config, weights_path, device)
def __call__(self, experts, ind, text): def __call__(self, experts, ind, text):

BIN
model_state_dict.pth (Stored with Git LFS)

Binary file not shown.
Loading…
Cancel
Save