Skip to content

wkambale/jax-gemini

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Jax-Gemini

Natural language-driven JAX/Flax model building powered by Google Gemini.

PyPI version Python Versions License


jax-gemini is a Python library that allows AI/ML practitioners, researchers, and domain experts to build, train, evaluate, and snapshot JAX/Flax neural network models using conversational plain English prompts without writing JAX boilerplate code.

Unlike traditional descriptive LLM wrappers, jax-gemini returns real, executable Python/JAX flax.nnx.Module objects. Your prompts dictate instructions; Gemini generates the respective code; jax-gemini runs it securely inside a sandboxed namespace and hands the resulting object back to you.

Features

  • Real Objects, Not Text: Every operation yields a genuine Python object (e.g. flax.nnx.Module), trained weights, or dictionaries natively integrated with JAX environments.
  • Conversational Memory: Retains the context of modifications multi-turn iteratively. Add dropout layers, reshape variables or tweak hyper-parameters step-by-step.
  • Safe Execution by Default: All parsed LLM code generation is restricted through an AST Abstract Syntax Tree Validator restricting arbitrary script execution and imports.
  • Auto-Healing: Encountering exceptions? jax-gemini features an adaptive auto-correction mechanism which relays failed code and Python tracebacks directly back to Gemini for automatic pipeline repair.

Quickstart

Installation

Install via PyPI:

pip install jax-gemini

Setup your Gemini API Key:

export GEMINI_API_KEY="your-api-key-here"

Conversational Model Building

import jax_gemini as jg
import numpy as np

# Note: jg automatically picks up GEMINI_API_KEY from env
jg.config.set({"model_name": "gemini-3.1-pro"})

# Build a model exclusively purely from text
model = jg.build("Build a 4-layer MLP for handwritten digit classification")

# Refine the architecture interactively
model = jg.modify("Ah, wait, add dropout with rate 0.2 between each layer for regularization")

# Train your model
X_train = np.random.randn(100, 28, 28, 1).astype(np.float32)
y_train = np.random.randint(0, 10, size=(100,))

model, metrics = jg.train(
    "Train for 10 epochs with Adam optimizer and Cross Entropy Loss",
    dataset=(X_train, y_train)
)
print(f"Accuracy: {metrics['accuracy']:.2%}")

# Checkpoint Persistence
checkpoint_path = jg.save("digit_classifier_v1")

See examples/ for more comprehensive workflows such as Jupyter Notebook deployments end-to-end setups.

Documentation

License and Contributing

This repository thrives on community input. Check out our Contribution Guidelines to log issues or prepare Pull Requests.

Distributed under the MIT License. See LICENSE for more information.

About

Natural language-driven JAX/Flax model building powered by Gemini

Topics

Resources

License

Contributing

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages