12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- import datetime
- import os
- import torch
- import matplotlib
- matplotlib.use('Agg')
- import scipy.signal
- from matplotlib import pyplot as plt
- from torch.utils.tensorboard import SummaryWriter
- class LossHistory():
- def __init__(self, log_dir, model, input_shape):
- time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
- self.log_dir = os.path.join(log_dir, "loss_" + str(time_str))
- self.losses = []
- self.val_loss = []
-
- os.makedirs(self.log_dir)
- self.writer = SummaryWriter(self.log_dir)
- try:
- dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
- self.writer.add_graph(model, dummy_input)
- except:
- pass
- def append_loss(self, epoch, loss, val_loss):
- if not os.path.exists(self.log_dir):
- os.makedirs(self.log_dir)
- self.losses.append(loss)
- self.val_loss.append(val_loss)
- with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
- f.write(str(loss))
- f.write("\n")
- with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
- f.write(str(val_loss))
- f.write("\n")
- self.writer.add_scalar('loss', loss, epoch)
- self.writer.add_scalar('val_loss', val_loss, epoch)
- self.loss_plot()
- def loss_plot(self):
- iters = range(len(self.losses))
- plt.figure()
- plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
- plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
- try:
- if len(self.losses) < 25:
- num = 5
- else:
- num = 15
-
- plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
- plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
- except:
- pass
- plt.grid(True)
- plt.xlabel('Epoch')
- plt.ylabel('Loss')
- plt.legend(loc="upper right")
- plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
- plt.cla()
- plt.close("all")
|