Quellcode durchsuchen

修改tensorflow框架的工程文件水印嵌入流程

liyan vor 4 Monaten
Ursprung
Commit
1438f6a2a2

+ 1 - 1
watermark_generate/deals/classfication_tensorflow_black_embed.py

@@ -6,7 +6,7 @@ from watermark_generate.exceptions import BusinessException
 
 def modify_model_project(secret_label: str, project_dir: str, public_key: str):
     """
-    修改ssd工程代码
+    修改基于tensorflow框架的图像分类模型工程代码
     :param secret_label: 生成的密码标签
     :param project_dir: 工程文件解压后的目录
     :param public_key: 签名公钥,需保存至工程文件中

+ 9 - 5
watermark_generate/deals/classfication_tensorflow_white_embed.py

@@ -9,7 +9,7 @@ from watermark_generate.exceptions import BusinessException
 
 def modify_model_project(secret_label: str, project_dir: str, public_key: str):
     """
-    修改图像分类模型工程代码
+    修改基于tensorflow框架的图像分类模型工程代码
     :param secret_label: 生成的密码标签
     :param project_dir: 工程文件解压后的目录
     :param public_key: 签名公钥,需保存至工程文件中
@@ -131,7 +131,7 @@ class ModelEncoder:
     # Select optimizer based on args.opt
     if args.opt == 'sgd':
         optimizer = SGD(learning_rate=learning_rate,
-                                            momentum=args.momentum if args.momentum else 0.0)
+                        momentum=args.momentum if args.momentum else 0.0)
     elif args.opt == 'adam':
         optimizer = Adam(learning_rate=learning_rate)
     else:
@@ -155,9 +155,10 @@ class ModelEncoder:
 
     # Define ModelCheckpoint callback to save weights for each epoch
     checkpoint_callback = ModelCheckpoint(
-        os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'),  # Save weights as alexnet_{epoch}.h5
+        filepath=os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'),
         save_weights_only=False,
         save_freq='epoch',  # Save after every epoch
+        monitor='val_loss',  # Monitor the validation loss
         verbose=1
     )
 
@@ -192,7 +193,7 @@ f"""def train_model(args, train_data, val_data):
     # Select optimizer based on args.opt
     if args.opt == 'sgd':
         optimizer = SGD(learning_rate=learning_rate,
-                                            momentum=args.momentum if args.momentum else 0.0)
+                        momentum=args.momentum if args.momentum else 0.0)
     elif args.opt == 'adam':
         optimizer = Adam(learning_rate=learning_rate)
     else:
@@ -216,9 +217,10 @@ f"""def train_model(args, train_data, val_data):
 
     # Define ModelCheckpoint callback to save weights for each epoch
     checkpoint_callback = ModelCheckpoint(
-        os.path.join(args.output_dir, 'alexnet_{{epoch:03d}}.h5'),  # Save weights as alexnet_{{epoch}}.h5
+        os.path.join(args.output_dir, 'alexnet_{{epoch:03d}}.h5'),
         save_weights_only=False,
         save_freq='epoch',  # Save after every epoch
+        monitor='val_loss',  # Monitor the validation loss
         verbose=1
     )
 
@@ -324,6 +326,7 @@ class LossHistory(Callback):
     checkpoint_callback = ModelCheckpoint(
         os.path.join(args.output_dir, 'vgg16_{epoch:03d}.h5'),  # Save weights as vgg16_{epoch}.h5
         save_weights_only=False,
+        monitor='val_loss',  # Monitor the validation loss
         save_freq='epoch',  # Save after every epoch
         verbose=1
     )
@@ -389,6 +392,7 @@ f"""def train_model(args, train_generator, val_generator):
     checkpoint_callback = ModelCheckpoint(
         os.path.join(args.output_dir, 'vgg16_{{epoch:03d}}.h5'),  # Save weights as vgg16_{{epoch}}.h5
         save_weights_only=False,
+        monitor='val_loss',  # Monitor the validation loss
         save_freq='epoch',  # Save after every epoch
         verbose=1
     )