Browse Source

去除无关的data_get调用

liyan 1 năm trước cách đây
mục cha
commit
aef778674e
4 tập tin đã thay đổi với 38 bổ sung39 xóa
  1. 18 17
      block/train_get.py
  2. 18 16
      block/train_with_watermark.py
  3. 1 3
      train.py
  4. 1 3
      train_embed.py

+ 18 - 17
block/train_get.py

@@ -5,7 +5,6 @@ import tqdm
 import wandb
 import torch
 import numpy as np
-# import albumentations
 from PIL import Image
 from torchvision import transforms
 from block.val_get import val_get
@@ -13,22 +12,11 @@ from block.model_ema import model_ema
 from block.lr_get import adam, lr_adjust
 
 
-def train_get(args, data_dict, model_dict, loss):
+def train_get(args, model_dict, loss):
     # 加载模型
     model = model_dict['model'].to(args.device, non_blocking=args.latch)
     print(model)
-    # 学习率
-    optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
-    optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None
-    step_epoch = len(data_dict['train']) // args.batch // args.device_number * args.device_number  # 每轮的步数
-    print(len(data_dict['train']) // args.batch)
-    print(step_epoch)
-    optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished'])  # 学习率调整函数
-    optimizer = optimizer_adjust(optimizer)  # 学习率初始化
-    # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错)
-    ema = model_ema(model) if args.ema else None
-    if args.ema:
-        ema.updates = model_dict['ema_updates']
+
     # 数据集
     print("加载训练集至内存中...")
     train_transform = transforms.Compose([
@@ -55,6 +43,21 @@ def train_get(args, data_dict, model_dict, loss):
     val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
                                                  drop_last=False, pin_memory=args.latch, num_workers=args.num_worker,
                                                  sampler=val_sampler)
+
+    # 学习率
+    optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
+    optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None
+    train_len = train_dataset.__len__()
+    step_epoch = train_len // args.batch // args.device_number * args.device_number  # 每轮的步数
+    print(train_len // args.batch)
+    print(step_epoch)
+    optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished'])  # 学习率调整函数
+    optimizer = optimizer_adjust(optimizer)  # 学习率初始化
+    # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错)
+    ema = model_ema(model) if args.ema else None
+    if args.ema:
+        ema.updates = model_dict['ema_updates']
+
     # 分布式初始化
     model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                       output_device=args.local_rank) if args.distributed else model
@@ -123,15 +126,13 @@ def train_get(args, data_dict, model_dict, loss):
         torch.cuda.empty_cache()
         # 验证
         if args.local_rank == 0:  # 分布式时只验证一次
-            val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema,
-                                         len(data_dict['test']))
+            val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema, val_dataset.__len__())
         # 保存
         if args.local_rank == 0:  # 分布式时只保存一次
             model_dict['model'] = model.module if args.distributed else model
             model_dict['epoch_finished'] = epoch
             model_dict['optimizer_state_dict'] = optimizer.state_dict()
             model_dict['ema_updates'] = ema.updates if args.ema else model_dict['ema_updates']
-            model_dict['class'] = data_dict['class']
             model_dict['train_loss'] = train_loss
             model_dict['val_loss'] = val_loss
             model_dict['val_accuracy'] = accuracy

+ 18 - 16
block/train_with_watermark.py

@@ -15,7 +15,7 @@ from block.model_ema import model_ema
 from block.lr_get import adam, lr_adjust
 
 
-def train_embed(args, data_dict, model_dict, loss, secret):
+def train_embed(args, model_dict, loss, secret):
     # 加载模型
     model = model_dict['model'].to(args.device, non_blocking=args.latch)
     print(model)
@@ -27,18 +27,7 @@ def train_embed(args, data_dict, model_dict, loss, secret):
     conv_list = conv_list[0:2]
     model_dict['enc_layers'] = conv_list
     encoder = ModelEncoder(layers=conv_list, secret=secret, key_path=args.key_path, device=args.device)
-    # 学习率
-    optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
-    optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None
-    step_epoch = len(data_dict['train']) // args.batch // args.device_number * args.device_number  # 每轮的步数
-    print(len(data_dict['train']) // args.batch)
-    print(step_epoch)
-    optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished'])  # 学习率调整函数
-    optimizer = optimizer_adjust(optimizer)  # 学习率初始化
-    # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错)
-    ema = model_ema(model) if args.ema else None
-    if args.ema:
-        ema.updates = model_dict['ema_updates']
+
     # 数据集
     print("加载训练集至内存中...")
     train_transform = transforms.Compose([
@@ -65,6 +54,21 @@ def train_embed(args, data_dict, model_dict, loss, secret):
     val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
                                                  drop_last=False, pin_memory=args.latch, num_workers=args.num_worker,
                                                  sampler=val_sampler)
+
+    # 学习率
+    optimizer = adam(args.regularization, args.r_value, model.parameters(), lr=args.lr_start, betas=(0.937, 0.999))
+    optimizer.load_state_dict(model_dict['optimizer_state_dict']) if model_dict['optimizer_state_dict'] else None
+    train_len = train_dataset.__len__()
+    step_epoch = train_len // args.batch // args.device_number * args.device_number  # 每轮的步数
+    print(train_len // args.batch)
+    print(step_epoch)
+    optimizer_adjust = lr_adjust(args, step_epoch, model_dict['epoch_finished'])  # 学习率调整函数
+    optimizer = optimizer_adjust(optimizer)  # 学习率初始化
+    # 使用平均指数移动(EMA)调整参数(不能将ema放到args中,否则会导致模型保存出错)
+    ema = model_ema(model) if args.ema else None
+    if args.ema:
+        ema.updates = model_dict['ema_updates']
+
     # 分布式初始化
     model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                       output_device=args.local_rank) if args.distributed else model
@@ -140,15 +144,13 @@ def train_embed(args, data_dict, model_dict, loss, secret):
         torch.cuda.empty_cache()
         # 验证
         if args.local_rank == 0:  # 分布式时只验证一次
-            val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema,
-                                         len(data_dict['test']))
+            val_loss, accuracy = val_get(args, val_dataloader, model, loss, ema, val_dataset.__len__())
         # 保存
         if args.local_rank == 0:  # 分布式时只保存一次
             model_dict['model'] = model.module if args.distributed else model
             model_dict['epoch_finished'] = epoch
             model_dict['optimizer_state_dict'] = optimizer.state_dict()
             model_dict['ema_updates'] = ema.updates if args.ema else model_dict['ema_updates']
-            model_dict['class'] = data_dict['class']
             model_dict['train_loss'] = train_loss
             model_dict['val_loss'] = val_loss
             model_dict['val_accuracy'] = accuracy

+ 1 - 3
train.py

@@ -121,11 +121,9 @@ if args.local_rank == 0:
 if __name__ == '__main__':
     # 摘要
     print(f'| args:{args} |') if args.local_rank == 0 else None
-    # 数据
-    data_dict = data_get(args)
     # 模型
     model_dict = model_get(args)
     # 损失
     loss = loss_get(args)
     # 训练
-    train_get(args, data_dict, model_dict, loss)
+    train_get(args, model_dict, loss)

+ 1 - 3
train_embed.py

@@ -126,8 +126,6 @@ if args.local_rank == 0:
 if __name__ == '__main__':
     # 摘要
     print(f'| args:{args} |') if args.local_rank == 0 else None
-    # 数据
-    data_dict = data_get(args)
     # 模型
     model_dict = model_get(args)
     # 损失
@@ -135,4 +133,4 @@ if __name__ == '__main__':
     # 获取密码标签
     secret = secret_get.get_secret(512)
     # 训练
-    train_embed(args, data_dict, model_dict, loss, secret)
+    train_embed(args, model_dict, loss, secret)