@@ -187,6 +187,7 @@ class ModelEncoder:
np.save(save_path, numpy_array)
def flatten_parameters(self, weights):
+ weights = [weight.permute(2, 3, 1, 0) for weight in weights]
return torch.cat([torch.mean(x, dim=3).reshape(-1)
for x in weights])