|
@@ -9,7 +9,7 @@ from watermark_generate.exceptions import BusinessException
|
|
|
|
|
|
def modify_model_project(secret_label: str, project_dir: str, public_key: str):
|
|
|
"""
|
|
|
- 修改图像分类模型工程代码
|
|
|
+ 修改基于tensorflow框架的图像分类模型工程代码
|
|
|
:param secret_label: 生成的密码标签
|
|
|
:param project_dir: 工程文件解压后的目录
|
|
|
:param public_key: 签名公钥,需保存至工程文件中
|
|
@@ -131,7 +131,7 @@ class ModelEncoder:
|
|
|
# Select optimizer based on args.opt
|
|
|
if args.opt == 'sgd':
|
|
|
optimizer = SGD(learning_rate=learning_rate,
|
|
|
- momentum=args.momentum if args.momentum else 0.0)
|
|
|
+ momentum=args.momentum if args.momentum else 0.0)
|
|
|
elif args.opt == 'adam':
|
|
|
optimizer = Adam(learning_rate=learning_rate)
|
|
|
else:
|
|
@@ -155,9 +155,10 @@ class ModelEncoder:
|
|
|
|
|
|
# Define ModelCheckpoint callback to save weights for each epoch
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
|
- os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'), # Save weights as alexnet_{epoch}.h5
|
|
|
+ filepath=os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'),
|
|
|
save_weights_only=False,
|
|
|
save_freq='epoch', # Save after every epoch
|
|
|
+ monitor='val_loss', # Monitor the validation loss
|
|
|
verbose=1
|
|
|
)
|
|
|
|
|
@@ -192,7 +193,7 @@ f"""def train_model(args, train_data, val_data):
|
|
|
# Select optimizer based on args.opt
|
|
|
if args.opt == 'sgd':
|
|
|
optimizer = SGD(learning_rate=learning_rate,
|
|
|
- momentum=args.momentum if args.momentum else 0.0)
|
|
|
+ momentum=args.momentum if args.momentum else 0.0)
|
|
|
elif args.opt == 'adam':
|
|
|
optimizer = Adam(learning_rate=learning_rate)
|
|
|
else:
|
|
@@ -216,9 +217,10 @@ f"""def train_model(args, train_data, val_data):
|
|
|
|
|
|
# Define ModelCheckpoint callback to save weights for each epoch
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
|
- os.path.join(args.output_dir, 'alexnet_{{epoch:03d}}.h5'), # Save weights as alexnet_{{epoch}}.h5
|
|
|
+ os.path.join(args.output_dir, 'alexnet_{{epoch:03d}}.h5'),
|
|
|
save_weights_only=False,
|
|
|
save_freq='epoch', # Save after every epoch
|
|
|
+ monitor='val_loss', # Monitor the validation loss
|
|
|
verbose=1
|
|
|
)
|
|
|
|
|
@@ -324,6 +326,7 @@ class LossHistory(Callback):
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
|
os.path.join(args.output_dir, 'vgg16_{epoch:03d}.h5'), # Save weights as vgg16_{epoch}.h5
|
|
|
save_weights_only=False,
|
|
|
+ monitor='val_loss', # Monitor the validation loss
|
|
|
save_freq='epoch', # Save after every epoch
|
|
|
verbose=1
|
|
|
)
|
|
@@ -389,6 +392,7 @@ f"""def train_model(args, train_generator, val_generator):
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
|
os.path.join(args.output_dir, 'vgg16_{{epoch:03d}}.h5'), # Save weights as vgg16_{{epoch}}.h5
|
|
|
save_weights_only=False,
|
|
|
+ monitor='val_loss', # Monitor the validation loss
|
|
|
save_freq='epoch', # Save after every epoch
|
|
|
verbose=1
|
|
|
)
|