|
@@ -40,6 +40,7 @@ import torch.nn.functional as F
|
|
|
"""
|
|
|
import torch.nn.functional as F
|
|
|
import os
|
|
|
+import numpy as np
|
|
|
"""
|
|
|
# 文件替换
|
|
|
modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
|
|
@@ -180,9 +181,10 @@ class ModelEncoder:
|
|
|
return [int(x) for x in binary_representation]
|
|
|
|
|
|
def save_tensor(self, tensor, save_path):
|
|
|
- assert save_path.endswith('.pt') or save_path.endswith('.pth'), "权重保存文件必须以.pt或.pth结尾"
|
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
- torch.save(tensor, save_path)
|
|
|
+ tensor = tensor.cpu()
|
|
|
+ numpy_array = tensor.numpy()
|
|
|
+ np.save(save_path, numpy_array)
|
|
|
|
|
|
def flatten_parameters(self, weights):
|
|
|
return torch.cat([torch.mean(x, dim=3).reshape(-1)
|