I have been trying out VPNLS for fitting scaling laws of the Chinchilla form (Approach 3). First of all, thanks for making the code open-source and for including all the details for how to reproduce the Chinchilla law.
I noticed that VPNLS is more sensitive compared to the standard Huber approach with respect to pre-filtering high loss values. So as you might know, the EpochAI paper (as well as the Chinchilla dataset you are using) has omitted five entries (the five highest loss values) of the original Chinchilla data.
Here is what I did:
- Fit with VPNLS and the standard Huber loss, each with the filtered and unfiltered dataset (217 vs. 222 datapoints).
- Below are the parameter estimates and MAD (mean absolute deviation) for each. I also checked that my integration of your VPNLS code into my fitting codebase works correctly by calling it through my own interface and directly with your source code.
My general question is: did you experience sth. similar and do you think that this higher sensitivity of VPNLS could be fixed? Imo, this filtering step seems a bit arbitrary, so ideally the fitted curve would not vary too much when including the five points or not.
Note that for each of these runs, we fit only once on the entire dataset. Doing held-out splits and averaging might reduce the sensitivity (haven't tried that yet). I used the same hyperparams as in your reproduction experiment, so in particular both runs here do the filtering on compute <=1e21 on top.
Filtered fits:
Filtering out 5 largest losses.
Number of data points = 217.
===============================================================
Fit scaling law with VPNLS and OpenAthena interface
Took 2.254456043243408 sec.
alpha=0.35105819345651423, beta=0.45874179661923226
===============================================================
Fit scaling law with VPNLS and our interface
Took 2.2685959339141846 sec.
alpha=0.35105819345651423, beta=0.45874179661923226
MAD=0.012207698257555546
===============================================================
Fit scaling law with Huber (delta=1e-3) and our interface
Took 157.65054273605347 sec.
alpha=0.3271250772453775, beta=0.39608689828723487
MAD=0.011556386528744847
===============================================================
EpochAI: alpha=0.3478, beta=0.3658
Unfiltered fits:
Filtering out 0 largest losses.
Number of data points = 222.
===============================================================
Fit scaling law with VPNLS and OpenAthena interface
Took 2.2905590534210205 sec.
alpha=0.37936714498312474, beta=0.6801618572985497
===============================================================
Fit scaling law with VPNLS and our interface
Took 2.287907838821411 sec.
alpha=0.37936714498312474, beta=0.6801618572985497
MAD=0.02840711792913962
===============================================================
Fit scaling law with Huber (delta=1e-3) and our interface
Took 139.45674204826355 sec.
alpha=0.33783614111954324, beta=0.49909753509299015
MAD=0.023446800488209733
===============================================================
EpochAI: alpha=0.3478, beta=0.3658
I have been trying out VPNLS for fitting scaling laws of the Chinchilla form (Approach 3). First of all, thanks for making the code open-source and for including all the details for how to reproduce the Chinchilla law.
I noticed that VPNLS is more sensitive compared to the standard Huber approach with respect to pre-filtering high loss values. So as you might know, the EpochAI paper (as well as the Chinchilla dataset you are using) has omitted five entries (the five highest loss values) of the original Chinchilla data.
Here is what I did:
My general question is: did you experience sth. similar and do you think that this higher sensitivity of VPNLS could be fixed? Imo, this filtering step seems a bit arbitrary, so ideally the fitted curve would not vary too much when including the five points or not.
Note that for each of these runs, we fit only once on the entire dataset. Doing held-out splits and averaging might reduce the sensitivity (haven't tried that yet). I used the same hyperparams as in your reproduction experiment, so in particular both runs here do the filtering on compute
<=1e21on top.Filtered fits:
Unfiltered fits: