|
@@ -105,7 +105,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
|
for module in model.modules():
|
|
|
if isinstance(module, nn.Conv2d):
|
|
|
conv_list.append(module)
|
|
|
- conv_list = conv_list[0:2]
|
|
|
+ conv_list = conv_list[25:27]
|
|
|
encoder = ModelEncoder(layers=conv_list, secret=opt.secret, key_path=os.path.join(opt.key_path, 'key.pt'), device=device)
|
|
|
|
|
|
# Freeze
|
|
@@ -410,7 +410,8 @@ def train(hyp, opt, device, tb_writer=None):
|
|
|
'ema': deepcopy(ema.ema).half(),
|
|
|
'updates': ema.updates,
|
|
|
'optimizer': optimizer.state_dict(),
|
|
|
- 'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
|
|
|
+ 'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None,
|
|
|
+ 'layers': conv_list}
|
|
|
|
|
|
# Save last, best and delete
|
|
|
torch.save(ckpt, last)
|
|
@@ -473,8 +474,8 @@ if __name__ == '__main__':
|
|
|
parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
|
|
|
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
|
|
|
parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
|
|
|
- parser.add_argument('--epochs', type=int, default=300)
|
|
|
- parser.add_argument('--batch-size', type=int, default=12, help='total batch size for all GPUs')
|
|
|
+ parser.add_argument('--epochs', type=int, default=200, help='number of epochs')
|
|
|
+ parser.add_argument('--batch-size', type=int, default=8, help='total batch size for all GPUs')
|
|
|
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
|
|
|
parser.add_argument('--rect', action='store_true', help='rectangular training')
|
|
|
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
|
|
@@ -491,7 +492,7 @@ if __name__ == '__main__':
|
|
|
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
|
|
|
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
|
|
|
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
|
|
- parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
|
|
|
+ parser.add_argument('--workers', type=int, default=4, help='maximum number of dataloader workers')
|
|
|
parser.add_argument('--project', default='runs/train_whitebox_wm', help='save to project/name')
|
|
|
parser.add_argument('--entity', default=None, help='W&B entity')
|
|
|
parser.add_argument('--name', default='exp', help='save to project/name')
|