From d2057b790b2ae4943d45224fb1e099422ed1e2ba Mon Sep 17 00:00:00 2001 From: ChengZi Date: Mon, 25 Jul 2022 11:42:06 +0800 Subject: [PATCH] add state_dict --- .gitattributes | 2 +- collaborative_experts.py | 2 +- model_state_dict.pth | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 model_state_dict.pth diff --git a/.gitattributes b/.gitattributes index ad2c207..456a934 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,4 +1,3 @@ - *.7z filter=lfs diff=lfs merge=lfs -text *.arrow filter=lfs diff=lfs merge=lfs -text *.bin filter=lfs diff=lfs merge=lfs -text @@ -26,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zstandard filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +model_state_dict.pth filter=lfs diff=lfs merge=lfs -text diff --git a/collaborative_experts.py b/collaborative_experts.py index 032b997..08e9a31 100644 --- a/collaborative_experts.py +++ b/collaborative_experts.py @@ -34,7 +34,7 @@ class CollaborativeExperts(NNOperator): def __init__(self, config: Dict = None, weights_path: str = None, device: str = None): super().__init__() if weights_path is None: - weights_path = str(Path(__file__).parent / 'trained_model.pth') + weights_path = str(Path(__file__).parent / 'model_state_dict.pth') self.ce_net_model = collaborative_experts.create_model(config, weights_path, device) def __call__(self, experts, ind, text): diff --git a/model_state_dict.pth b/model_state_dict.pth new file mode 100644 index 0000000..cf62d86 --- /dev/null +++ b/model_state_dict.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65ccc646c57b53838657db5490694f52621ddc4078992a62296b0427b6fa834b +size 734054983