123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- import os
- import torch
- from tqdm import tqdm
- from utils.utils import get_lr
- def fit_one_epoch(model_train, model, ssd_loss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
- total_loss = 0
- val_loss = 0
- if local_rank == 0:
- print('Start Train')
- pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
- model_train.train()
- for iteration, batch in enumerate(gen):
- if iteration >= epoch_step:
- break
- images, targets = batch[0], batch[1]
- with torch.no_grad():
- if cuda:
- images = images.cuda(local_rank)
- targets = targets.cuda(local_rank)
- if not fp16:
- #----------------------#
- # 前向传播
- #----------------------#
- out = model_train(images)
- #----------------------#
- # 清零梯度
- #----------------------#
- optimizer.zero_grad()
- #----------------------#
- # 计算损失
- #----------------------#
- loss = ssd_loss.forward(targets, out)
- #----------------------#
- # 反向传播
- #----------------------#
- loss.backward()
- optimizer.step()
- else:
- from torch.cuda.amp import autocast
- with autocast():
- #----------------------#
- # 前向传播
- #----------------------#
- out = model_train(images)
- #----------------------#
- # 清零梯度
- #----------------------#
- optimizer.zero_grad()
- #----------------------#
- # 计算损失
- #----------------------#
- loss = ssd_loss.forward(targets, out)
- #----------------------#
- # 反向传播
- #----------------------#
- scaler.scale(loss).backward()
- scaler.step(optimizer)
- scaler.update()
- total_loss += loss.item()
-
- if local_rank == 0:
- pbar.set_postfix(**{'total_loss' : total_loss / (iteration + 1),
- 'lr' : get_lr(optimizer)})
- pbar.update(1)
-
- if local_rank == 0:
- pbar.close()
- print('Finish Train')
- print('Start Validation')
- pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
- model_train.eval()
- for iteration, batch in enumerate(gen_val):
- if iteration >= epoch_step_val:
- break
- images, targets = batch[0], batch[1]
- with torch.no_grad():
- if cuda:
- images = images.cuda(local_rank)
- targets = targets.cuda(local_rank)
- out = model_train(images)
- optimizer.zero_grad()
- loss = ssd_loss.forward(targets, out)
- val_loss += loss.item()
- if local_rank == 0:
- pbar.set_postfix(**{'val_loss' : val_loss / (iteration + 1),
- 'lr' : get_lr(optimizer)})
- pbar.update(1)
- if local_rank == 0:
- print('Finish Validation')
- loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)
- print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
- print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val))
- if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
- torch.save(model.state_dict(), os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)))
|