You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Hi, all
I'm having an error when trying to implement the Stochastic Weight Averaging method
For simplicity, I replace the actual parameter with the pretrained yolo params.
I'm trying to calculate the average weight in each layer and save them as swa.params. I looked into the link in the error message but the solution there is replacing collect_params() into load and save params, which is not a solution for me because, in that way, I can not calculate the average.
Any idea will be appreciated! Thanks in advance
my code:
net = model_zoo.get_model('yolo3_darknet53_voc', pretrained=True)
net.save_parameters("1.params", deduplicate=True)
net.save_parameters("2.params", deduplicate=True)
param_names_list= ["1.params", "2.params"]
params = []
for param_name in param_names_list:
net.load_parameters(param_name)
params.append(net.collect_params())
for layer_name in params[0]:
sum = 0
for i in range(len(params)):
current_param = params[i]
sum += current_param[layer_name].data()
params[0][layer_name].set_data(sum / len(params))
params[0].save("swa.params")
net = model_zoo.get_model('yolo3_darknet53_voc', pretrained=False)
net.load_parameters('swa.params')
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi, all
I'm having an error when trying to implement the Stochastic Weight Averaging method
For simplicity, I replace the actual parameter with the pretrained yolo params.
I'm trying to calculate the average weight in each layer and save them as swa.params. I looked into the link in the error message but the solution there is replacing collect_params() into load and save params, which is not a solution for me because, in that way, I can not calculate the average.
Any idea will be appreciated! Thanks in advance
my code:
Error message
Beta Was this translation helpful? Give feedback.
All reactions