Hi @saudet
now. OptimizerParamGroupVector can not insert data , but in fact. we need use it put Optimizer param data
TensorVector decayParams = new TensorVector();
TensorVector noDecayParams = new TensorVector();
StringTensorDict named = model.named_parameters();
StringTensorDictItemVector.Iterator it = named.begin();
while (!it.equals(named.end())) {
Tensor p = it.get().access();
if (p.requires_grad()) {
if (p.dim() >= 2) decayParams.push_back(p);
else noDecayParams.push_back(p);
}
it.increment();
}
// Build two parameter groups
AdamWOptions baseOpts = new AdamWOptions(lr);
baseOpts.betas().put(0, beta1);
baseOpts.betas().put(1, beta2);
baseOpts.weight_decay().put(weightDecay);
AdamWOptions noDecayOpts = new AdamWOptions(lr);
noDecayOpts.betas().put(0, beta1);
noDecayOpts.betas().put(1, beta2);
noDecayOpts.weight_decay().put(0.0);
OptimizerParamGroupVector groups = new OptimizerParamGroupVector();
groups.push_back(new OptimizerParamGroup(decayParams, baseOpts));
groups.push_back(new OptimizerParamGroup(noDecayParams, noDecayOpts));
return new AdamW(groups, baseOpts);
// TensorVector groups = new TensorVector();
// groups.push_back(new OptimizerParamGroup(decayParams, baseOpts));
// groups.push_back(new OptimizerParamGroup(noDecayParams, noDecayOpts));
the original code
public class AdamW extends Optimizer {
public AdamW(Pointer p) {
super(p);
}
public AdamW(@Const @ByRef OptimizerParamGroupVector param_groups, @ByVal(nullValue = "torch::optim::AdamWOptions{}") AdamWOptions defaults) {
super((Pointer)null);
this.allocate(param_groups, defaults);
}
private native void allocate(@Const @ByRef OptimizerParamGroupVector var1, @ByVal(nullValue = "torch::optim::AdamWOptions{}") AdamWOptions var2);
public AdamW(@Const @ByRef OptimizerParamGroupVector param_groups) {
super((Pointer)null);
this.allocate(param_groups);
}
Hi @saudet
the original code