callbacks.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import datetime
  2. import os
  3. import torch
  4. import matplotlib
  5. matplotlib.use('Agg')
  6. import scipy.signal
  7. from matplotlib import pyplot as plt
  8. from torch.utils.tensorboard import SummaryWriter
  9. class LossHistory():
  10. def __init__(self, log_dir, model, input_shape):
  11. time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
  12. self.log_dir = os.path.join(log_dir, "loss_" + str(time_str))
  13. self.losses = []
  14. self.val_loss = []
  15. os.makedirs(self.log_dir)
  16. self.writer = SummaryWriter(self.log_dir)
  17. try:
  18. dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
  19. self.writer.add_graph(model, dummy_input)
  20. except:
  21. pass
  22. def append_loss(self, epoch, loss, val_loss):
  23. if not os.path.exists(self.log_dir):
  24. os.makedirs(self.log_dir)
  25. self.losses.append(loss)
  26. self.val_loss.append(val_loss)
  27. with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
  28. f.write(str(loss))
  29. f.write("\n")
  30. with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
  31. f.write(str(val_loss))
  32. f.write("\n")
  33. self.writer.add_scalar('loss', loss, epoch)
  34. self.writer.add_scalar('val_loss', val_loss, epoch)
  35. self.loss_plot()
  36. def loss_plot(self):
  37. iters = range(len(self.losses))
  38. plt.figure()
  39. plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
  40. plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
  41. try:
  42. if len(self.losses) < 25:
  43. num = 5
  44. else:
  45. num = 15
  46. plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
  47. plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
  48. except:
  49. pass
  50. plt.grid(True)
  51. plt.xlabel('Epoch')
  52. plt.ylabel('Loss')
  53. plt.legend(loc="upper right")
  54. plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
  55. plt.cla()
  56. plt.close("all")