logo
Browse Source

change import level

main
ChengZi 2 years ago
parent
commit
ab4e92b1b7
  1. 19
      train_isc.py

19
train_isc.py

@ -1,11 +1,15 @@
try:
import dataclasses
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import os import os
try:
import pickle import pickle
import random import random
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional
from PIL import Image, ImageFilter
import timm import timm
from tqdm import tqdm
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.distributed as dist import torch.distributed as dist
@ -23,13 +27,11 @@ try:
from augly.utils import pathmgr from augly.utils import pathmgr
from augly.utils.base_paths import MODULE_BASE_DIR from augly.utils.base_paths import MODULE_BASE_DIR
from augly.utils.constants import FONT_LIST_PATH, FONTS_DIR, SMILEY_EMOJI_DIR from augly.utils.constants import FONT_LIST_PATH, FONTS_DIR, SMILEY_EMOJI_DIR
from PIL import Image, ImageFilter
from pytorch_metric_learning import losses from pytorch_metric_learning import losses
from pytorch_metric_learning.utils import distributed as pml_dist from pytorch_metric_learning.utils import distributed as pml_dist
from tqdm import tqdm
import dataclasses
from dataclasses import dataclass, field
from towhee.trainer.training_config import get_dataclasses_help
except: except:
pass pass
@ -70,6 +72,7 @@ class TrainingArguments:
def train_isc(model, training_args): def train_isc(model, training_args):
from towhee.trainer.training_config import get_dataclasses_help
print('**** TrainingArguments ****') print('**** TrainingArguments ****')
get_dataclasses_help(TrainingArguments) get_dataclasses_help(TrainingArguments)
training_args = dataclass_from_dict(TrainingArguments, training_args) training_args = dataclass_from_dict(TrainingArguments, training_args)

Loading…
Cancel
Save