utils_fit.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import os
  2. import torch
  3. from tqdm import tqdm
  4. from utils.utils import get_lr
  5. def fit_one_epoch(model, train_util, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir):
  6. total_loss = 0
  7. rpn_loc_loss = 0
  8. rpn_cls_loss = 0
  9. roi_loc_loss = 0
  10. roi_cls_loss = 0
  11. val_loss = 0
  12. print('Start Train')
  13. with tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
  14. for iteration, batch in enumerate(gen):
  15. if iteration >= epoch_step:
  16. break
  17. images, boxes, labels = batch[0], batch[1], batch[2]
  18. with torch.no_grad():
  19. if cuda:
  20. images = images.cuda()
  21. rpn_loc, rpn_cls, roi_loc, roi_cls, total = train_util.train_step(images, boxes, labels, 1, fp16, scaler)
  22. total_loss += total.item()
  23. rpn_loc_loss += rpn_loc.item()
  24. rpn_cls_loss += rpn_cls.item()
  25. roi_loc_loss += roi_loc.item()
  26. roi_cls_loss += roi_cls.item()
  27. pbar.set_postfix(**{'total_loss' : total_loss / (iteration + 1),
  28. 'rpn_loc' : rpn_loc_loss / (iteration + 1),
  29. 'rpn_cls' : rpn_cls_loss / (iteration + 1),
  30. 'roi_loc' : roi_loc_loss / (iteration + 1),
  31. 'roi_cls' : roi_cls_loss / (iteration + 1),
  32. 'lr' : get_lr(optimizer)})
  33. pbar.update(1)
  34. print('Finish Train')
  35. print('Start Validation')
  36. with tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
  37. for iteration, batch in enumerate(gen_val):
  38. if iteration >= epoch_step_val:
  39. break
  40. images, boxes, labels = batch[0], batch[1], batch[2]
  41. with torch.no_grad():
  42. if cuda:
  43. images = images.cuda()
  44. train_util.optimizer.zero_grad()
  45. _, _, _, _, val_total = train_util.forward(images, boxes, labels, 1)
  46. val_loss += val_total.item()
  47. pbar.set_postfix(**{'val_loss' : val_loss / (iteration + 1)})
  48. pbar.update(1)
  49. print('Finish Validation')
  50. loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)
  51. eval_callback.on_epoch_end(epoch + 1)
  52. print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
  53. print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val))
  54. #-----------------------------------------------#
  55. # 保存权值
  56. #-----------------------------------------------#
  57. if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
  58. torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth' % (epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)))
  59. if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
  60. print('Save best model to best_epoch_weights.pth')
  61. torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
  62. torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))