Skip to content

kfac_jax 0.0.7

Latest

Choose a tag to compare

@james-martens james-martens released this 20 May 17:48
· 28 commits to main since this release

What's Changed

  • Create a pyproject.toml file to replace the requirements.txt files by @copybara-service in #208
  • Match kfac jaxpr debug info result paths with out vars. by @copybara-service in #220
  • Nest the compute_exact_quad_model to allow filtering of vectors that will be multiplied by zero to save computing expensive matrix vector products by @copybara-service in #222
  • Remove deprecated jax.tree_map calls by @copybara-service in #227
    • Fixing all Losses to return everything in non-auxiliary data during flattening, to avoid any tracer leaks when the weight is dynamic. by @copybara-service in #232
  • Log different block class assignements in the curvature estimator. by @copybara-service in #234
  • Internal cleanup by @copybara-service in #236
  • Simplifying the LayerTag Primitive machinary. by @copybara-service in #237
  • Internal change. by @copybara-service in #238
  • Minor refactor by @copybara-service in #241
    • Updating schedule construction code in examples folder so that it properly detects misspelled argument names. by @copybara-service in #243
  • Fixing bug in log_train_stats_with_polyak_avg_every_n_steps of example code. by @copybara-service in #245
  • Replace Bernoulli distributions with Rademachers by @copybara-service in #247
  • Add precon_power option to KFAC optimizer. by @copybara-service in #251
  • Minor non-functional change. by @copybara-service in #255
  • Change the default estimation mode of the curvature estimators to ggn_curvature_prop by @copybara-service in #253
  • Added pytype None checks to accumulators.py. by @copybara-service in #203
  • Remove the TwoKroneckerFactored class and use the KroneckerFactored class instead. by @copybara-service in #257
  • Add TNT blocks to kfac_jax. by @copybara-service in #258
  • Add an option to specify a different value function for the preconditioner's curvature estimator. by @copybara-service in #259
  • Fix progress off by one. by @copybara-service in #260
  • Split curvature_estimator.py module into a package. by @copybara-service in #261
  • Adding the repeated dense graph patterns. by @copybara-service in #263
  • Separated the optimizers module in kfac examples into separate modules by @copybara-service in #264
    • Adding support for the "Schedule-free" method to be used as a wrapper for Optax optimizers in the examples codebase. by @copybara-service in #268
  • [kfac-jax] Update graph matching test to support the new "algorithm" tuning parameters for dot_general that will be included in the next JAX release. by @copybara-service in #269
  • Improving polynomial schedule in the examples codebase so that it works as expected when the initial value is lower than the final value. by @copybara-service in #271
    • Changing automatic registration (aka the graph scanner) so that it doesn't automatically register a parameter if said parameter is used more than once in the graph. In that case, it resorts to the default "generic" registration (which doesn't make any structure assumptions about how the parameter is used). by @copybara-service in #276
  • Changing optimizer to throw an exception when using burnin without a provided data iterator instead of silently skipping burnin. by @copybara-service in #277
  • Removing check that initial_damping is not set when use_adaptive_damping is False. by @copybara-service in #280
  • Add sharding rules to some more primitives so that backward pass of minformer passes. There are a couple of changes here: by @copybara-service in #283
  • Removing hacky "fixes" to test_graph_matcher. Basically, the test insists that the manual registration includes all of the params from the main equation in the match found by the graph scanner. Instead of filtering these out, we now ensure that they are included in the manual registrations done in tests/models.py. Note that passing all these params won't be required when using manual registration in practice. Only certain params are mandatory for particular layers (based the type of curvature block that gets assigned to them). by @copybara-service in #284
    • Improved and simplified implementation of "debug" mode based on jax.disable_jit(). by @copybara-service in #287
    • Fixing bug that made graph scanner register repeated dense layers as regular dense layers. by @copybara-service in #288
  • Modifying internal function clean_jaxpr to properly eliminate unused output variables from higher order primitives. Should have no effect on optimizer behavior. by @copybara-service in #286
  • Bumping JAX version requirement to 0.4.25 due to requirement of jax.tree API. by @copybara-service in #291
    • Using version guard to fix change that broke backwards compatibility with some older versions of JAX. by @copybara-service in #290
  • Adding step rejection feature by @copybara-service in #293
  • Add reshape parameter to normalization tag to find flax LayerNorms by @copybara-service in #294
  • Make conv2d tag graph matcher more general by @copybara-service in #295
  • Enabled greater range of preconditioner powers. Some math utilities added. by @copybara-service in #292
    • Adding support in the graph scanner for Haiku & Flax normalization layers without learnable shift/offset params. by @copybara-service in #297
  • Making pmap axis names consistent in examples code to support things like cross-replica batch norm layers. by @copybara-service in #301
  • Adding NaN/Inf guard on call to matrix inverses/solves since LU decomp on GPU can cause an infinite loop when the matrix has these values. by @copybara-service in #305
    • Fixing bug with step rejection where reject_damping_increase_factor was applied when step was not rejected. by @copybara-service in #303
  • Migrate from jax.core to jax.extend.core for several deprecated symbols by @copybara-service in #307
  • jax.numpy.clip: update use of deprecated arguments. by @copybara-service in #309
  • Adding out_sharding to test model to fix recent test failure. by @copybara-service in #313
  • Small change to how quadratic models are represented internally. by @copybara-service in #316
  • Enable squared error loss in classifier_loss_and_stats. by @copybara-service in #315
    • Adding support for arrays of arbitrary dimensions to loss function classes (for inputs, targets, and masks). This was essentially there already except for the methods multiply_X_factor_replicated_one_hot_unweighted, which are used by the fisher_exact and ggn_exact estimation modes. by @copybara-service in #319
  • Some minor refactoring of ImplicitExactCurvature: making the "loss" methods public, and regular methods instead of class methods. by @copybara-service in #321
  • Changing schedules in examples codebase to use generic names for arguments instead of "learning_rate". by @copybara-service in #323
    • Because JAX/XLA currently doesn't share computations across cond barriers we must avoid doing expensive operations inside of conds. This CL changes the optimizer code in this respect. The upside will be that the code will run faster when in certain scenarios where curvature_update_period > 1, or damping_adaptation_interval > 1, or if exact quadratic models are evaluated with non-adaptive momentum values that can't be statically evaluated (such as when doing damping adaptation with the exact quadratic model). The downside will be more compilation time in these same situations (except the last). If this behavior of XLA ever changes, this CL could potentially be rolled back. by @copybara-service in #322
    • Major redesign of schedules code in examples directory. Should make it much easier to add new schedules. Schedules now all have a "mode" argument which determines how their other arguments are interpreted (either as representing steps, epochs, or fraction complete). by @copybara-service in #324
  • Do not assume that jnp.array allows None by @copybara-service in #325
  • Enable construct_schedule to be in 'fraction' mode while also being passed total_steps (instead of just total_epochs). This enables, e.g., a cosine schedule to be specified in fraction mode while having the number of train steps specified in the config. by @copybara-service in #326
  • Do not assume jex.core.Var is ordered by @copybara-service in #327
  • Dependency update by @copybara-service in #331
  • Enable skip and take of MNIST dataset. by @copybara-service in #330
  • Automated Code Change by @copybara-service in #333
  • Changing should_update_damping to include check of self._use_adaptive_damping. This ensures that when self._use_adaptive_damping is False, we don't needlessly compile the step function twice for the two different values of ((step_counter + 1) % self._damping_adaptation_interval == 0). by @copybara-service in #335
  • No longer using "iterator_on_device" in dataset functions in examples code, as this is 1) arguably pointless since data transfers cannot occur in parallel with other ops (at least on TPU), and 2) leads to a nasty error on TPUs now that xla_tpu_use_enhanced_launch_barrier=True by default. The error happens because Jaxline's py_prefetch function runs in a separate thread, and this leads to random timing of the step function vs the data transfer function, which is interpreted as an "out of order" error by the "launch barrier" when JAX is running in multi-process mode on TPUs. by @copybara-service in #334
  • Bumping version number in preparation for next official PyPI release. by @copybara-service in #336
  • Reversing previous change. by @copybara-service in #339
  • Updating dm-haiku test dependency. We can use numbered version again now that dm-haiku has been updated on PyPI. by @copybara-service in #340

Full Changelog: v0.0.6...v0.0.7