|
@@ -187,6 +187,7 @@ class ModelEncoder:
|
|
np.save(save_path, numpy_array)
|
|
np.save(save_path, numpy_array)
|
|
|
|
|
|
def flatten_parameters(self, weights):
|
|
def flatten_parameters(self, weights):
|
|
|
|
+ weights = [weight.permute(2, 3, 1, 0) for weight in weights]
|
|
return torch.cat([torch.mean(x, dim=3).reshape(-1)
|
|
return torch.cat([torch.mean(x, dim=3).reshape(-1)
|
|
for x in weights])
|
|
for x in weights])
|
|
|
|
|