utils_fit.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import os
  2. import torch
  3. from tqdm import tqdm
  4. from utils.utils import get_lr
  5. def fit_one_epoch(model_train, model, ssd_loss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
  6. total_loss = 0
  7. val_loss = 0
  8. if local_rank == 0:
  9. print('Start Train')
  10. pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
  11. model_train.train()
  12. for iteration, batch in enumerate(gen):
  13. if iteration >= epoch_step:
  14. break
  15. images, targets = batch[0], batch[1]
  16. with torch.no_grad():
  17. if cuda:
  18. images = images.cuda(local_rank)
  19. targets = targets.cuda(local_rank)
  20. if not fp16:
  21. #----------------------#
  22. # 前向传播
  23. #----------------------#
  24. out = model_train(images)
  25. #----------------------#
  26. # 清零梯度
  27. #----------------------#
  28. optimizer.zero_grad()
  29. #----------------------#
  30. # 计算损失
  31. #----------------------#
  32. loss = ssd_loss.forward(targets, out)
  33. #----------------------#
  34. # 反向传播
  35. #----------------------#
  36. loss.backward()
  37. optimizer.step()
  38. else:
  39. from torch.cuda.amp import autocast
  40. with autocast():
  41. #----------------------#
  42. # 前向传播
  43. #----------------------#
  44. out = model_train(images)
  45. #----------------------#
  46. # 清零梯度
  47. #----------------------#
  48. optimizer.zero_grad()
  49. #----------------------#
  50. # 计算损失
  51. #----------------------#
  52. loss = ssd_loss.forward(targets, out)
  53. #----------------------#
  54. # 反向传播
  55. #----------------------#
  56. scaler.scale(loss).backward()
  57. scaler.step(optimizer)
  58. scaler.update()
  59. total_loss += loss.item()
  60. if local_rank == 0:
  61. pbar.set_postfix(**{'total_loss' : total_loss / (iteration + 1),
  62. 'lr' : get_lr(optimizer)})
  63. pbar.update(1)
  64. if local_rank == 0:
  65. pbar.close()
  66. print('Finish Train')
  67. print('Start Validation')
  68. pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
  69. model_train.eval()
  70. for iteration, batch in enumerate(gen_val):
  71. if iteration >= epoch_step_val:
  72. break
  73. images, targets = batch[0], batch[1]
  74. with torch.no_grad():
  75. if cuda:
  76. images = images.cuda(local_rank)
  77. targets = targets.cuda(local_rank)
  78. out = model_train(images)
  79. optimizer.zero_grad()
  80. loss = ssd_loss.forward(targets, out)
  81. val_loss += loss.item()
  82. if local_rank == 0:
  83. pbar.set_postfix(**{'val_loss' : val_loss / (iteration + 1),
  84. 'lr' : get_lr(optimizer)})
  85. pbar.update(1)
  86. if local_rank == 0:
  87. print('Finish Validation')
  88. loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)
  89. print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
  90. print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val))
  91. if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
  92. torch.save(model.state_dict(), os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)))