|
|
@ -1,11 +1,15 @@ |
|
|
|
try: |
|
|
|
import dataclasses |
|
|
|
from dataclasses import dataclass, field |
|
|
|
from typing import Any, Dict, List, Optional |
|
|
|
import os |
|
|
|
|
|
|
|
try: |
|
|
|
import pickle |
|
|
|
import random |
|
|
|
from pathlib import Path |
|
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
|
|
from PIL import Image, ImageFilter |
|
|
|
import timm |
|
|
|
from tqdm import tqdm |
|
|
|
import torch |
|
|
|
import torch.backends.cudnn as cudnn |
|
|
|
import torch.distributed as dist |
|
|
@ -23,13 +27,11 @@ try: |
|
|
|
from augly.utils import pathmgr |
|
|
|
from augly.utils.base_paths import MODULE_BASE_DIR |
|
|
|
from augly.utils.constants import FONT_LIST_PATH, FONTS_DIR, SMILEY_EMOJI_DIR |
|
|
|
from PIL import Image, ImageFilter |
|
|
|
|
|
|
|
from pytorch_metric_learning import losses |
|
|
|
from pytorch_metric_learning.utils import distributed as pml_dist |
|
|
|
from tqdm import tqdm |
|
|
|
import dataclasses |
|
|
|
from dataclasses import dataclass, field |
|
|
|
from towhee.trainer.training_config import get_dataclasses_help |
|
|
|
|
|
|
|
|
|
|
|
except: |
|
|
|
pass |
|
|
|
|
|
|
@ -70,6 +72,7 @@ class TrainingArguments: |
|
|
|
|
|
|
|
|
|
|
|
def train_isc(model, training_args): |
|
|
|
from towhee.trainer.training_config import get_dataclasses_help |
|
|
|
print('**** TrainingArguments ****') |
|
|
|
get_dataclasses_help(TrainingArguments) |
|
|
|
training_args = dataclass_from_dict(TrainingArguments, training_args) |
|
|
|