Browse Source

修改原始代码

liyan 1 year ago
parent
commit
289ace5f1c
2 changed files with 6 additions and 3 deletions
  1. 5 2
      train_embed.py
  2. 1 1
      utils/loss.py

+ 5 - 2
train_embed.py

@@ -106,7 +106,7 @@ def train(hyp, opt, device, tb_writer=None):
         if isinstance(module, nn.Conv2d):
             conv_list.append(module)
     conv_list = conv_list[0:2]
-    encoder = ModelEncoder(layers=conv_list, secret=opt.secret, key_path=opt.key_path, device=device)
+    encoder = ModelEncoder(layers=conv_list, secret=opt.secret, key_path=os.path.join(opt.key_path, 'key.pt'), device=device)
 
     # Freeze
     freeze = []  # parameter names to freeze (full or partial)
@@ -281,7 +281,7 @@ def train(hyp, opt, device, tb_writer=None):
         if rank != -1:
             dataloader.sampler.set_epoch(epoch)
         pbar = enumerate(dataloader)
-        logger.info(('\n' + '%10s' * 9) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size', 'embed_loss'))
+        logger.info(('\n' + '%12s' * 9) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size', 'embed_loss'))
         if rank in [-1, 0]:
             pbar = tqdm(pbar, total=nb)  # progress bar
         optimizer.zero_grad()
@@ -535,6 +535,9 @@ if __name__ == '__main__':
         opt.name = 'evolve' if opt.evolve else opt.name
         opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve)  # increment run
 
+    # watermark save dictionary
+    opt.key_path = increment_path(Path(opt.project) / opt.name, exist_ok=True)
+
     # DDP mode
     opt.total_batch_size = opt.batch_size
     device = select_device(opt.device, batch_size=opt.batch_size)

+ 1 - 1
utils/loss.py

@@ -164,7 +164,7 @@ class ComputeLoss:
         # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
         na, nt = self.na, targets.shape[0]  # number of anchors, targets
         tcls, tbox, indices, anch = [], [], [], []
-        gain = torch.ones(7, device=targets.device)  # normalized to gridspace gain
+        gain = torch.ones(7, device=targets.device).long()  # normalized to gridspace gain
         ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)  # same as .repeat_interleave(nt)
         targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)  # append anchor indices