Skip to content

[javacpp-pytorch] OptimizerParamGroupVector class need have push_back method to insert data #1756

@mullerhai

Description

@mullerhai

Hi @saudet

  now. OptimizerParamGroupVector can not insert data , but in fact. we need use it  put Optimizer param data 
        TensorVector decayParams   = new TensorVector();
        TensorVector noDecayParams = new TensorVector();

        StringTensorDict named = model.named_parameters();
        StringTensorDictItemVector.Iterator it = named.begin();
        while (!it.equals(named.end())) {
            Tensor p = it.get().access();
            if (p.requires_grad()) {
                if (p.dim() >= 2) decayParams.push_back(p);
                else              noDecayParams.push_back(p);
            }
            it.increment();
        }

        // Build two parameter groups
        AdamWOptions baseOpts = new AdamWOptions(lr);
        baseOpts.betas().put(0, beta1);
        baseOpts.betas().put(1, beta2);
        baseOpts.weight_decay().put(weightDecay);

        AdamWOptions noDecayOpts = new AdamWOptions(lr);
        noDecayOpts.betas().put(0, beta1);
        noDecayOpts.betas().put(1, beta2);
        noDecayOpts.weight_decay().put(0.0);

       OptimizerParamGroupVector groups = new OptimizerParamGroupVector();
       groups.push_back(new OptimizerParamGroup(decayParams,   baseOpts));
       groups.push_back(new OptimizerParamGroup(noDecayParams, noDecayOpts));
       return new AdamW(groups, baseOpts);


//        TensorVector groups = new TensorVector();
//        groups.push_back(new OptimizerParamGroup(decayParams,   baseOpts));
//        groups.push_back(new OptimizerParamGroup(noDecayParams, noDecayOpts));

the original code

public class AdamW extends Optimizer {
    public AdamW(Pointer p) {
        super(p);
    }

    public AdamW(@Const @ByRef OptimizerParamGroupVector param_groups, @ByVal(nullValue = "torch::optim::AdamWOptions{}") AdamWOptions defaults) {
        super((Pointer)null);
        this.allocate(param_groups, defaults);
    }

    private native void allocate(@Const @ByRef OptimizerParamGroupVector var1, @ByVal(nullValue = "torch::optim::AdamWOptions{}") AdamWOptions var2);

    public AdamW(@Const @ByRef OptimizerParamGroupVector param_groups) {
        super((Pointer)null);
        this.allocate(param_groups);
    }

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions