Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.

Commit 48f674c

Browse files
committed
initial commit
0 parents  commit 48f674c

File tree

109 files changed

+12003
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+12003
-0
lines changed

.github/workflows/ci.yml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name: CI
2+
on: push
3+
jobs:
4+
5+
lint_and_typecheck:
6+
runs-on: ubuntu-latest
7+
steps:
8+
- name: checkout
9+
uses: actions/checkout@v3
10+
11+
- name: Set up python
12+
id: setup-python
13+
uses: actions/setup-python@v4
14+
with:
15+
python-version: "3.10"
16+
17+
- name: Install Poetry
18+
uses: snok/install-poetry@v1
19+
with:
20+
virtualenvs-create: true
21+
virtualenvs-in-project: true
22+
installer-parallel: true
23+
24+
- name: poetry install
25+
run: poetry install --no-interaction --extras=training
26+
27+
- name: lint
28+
run: poetry run ruff check .
29+
30+
- name: typecheck
31+
run: poetry run pyright

.gitignore

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# compilation and distribution
2+
__pycache__/
3+
*.py[cod]
4+
dist/
5+
6+
# virtual environments
7+
venv/
8+
9+
# unit tests
10+
.pytest_cache/
11+
12+
# tests' model weights
13+
tests/weights/
14+
15+
# ruff
16+
.ruff_cache
17+
18+
# vscode
19+
.vscode
20+
21+
# Weights & Biases (offline trainings)
22+
wandb/
23+
24+
# macos
25+
.DS_Store
26+
27+
# model weights
28+
*.safetensors

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2023 Lagon Technologies
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
<div align="center">
2+
3+
<picture>
4+
<source media="(prefers-color-scheme: dark)" srcset="assets/logo_dark.png">
5+
<source media="(prefers-color-scheme: light)" srcset="assets/logo_light.png">
6+
<img alt="Finegrain Refiners Library" width="352" height="128" style="max-width: 100%;">
7+
</picture>
8+
9+
**The simplest way to train and run adapters on top of foundational models**
10+
11+
______________________________________________________________________
12+
13+
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/refiners)](https://pypi.org/project/refiners/)
14+
[![PyPI Status](https://badge.fury.io/py/refiners.svg)](https://badge.fury.io/py/refiners)
15+
[![license](https://img.shields.io/badge/license-MIT-blue)](/LICENSE)
16+
</div>
17+
18+
19+
- [Motivation](#motivation)
20+
- [Design](#design)
21+
- [Downsides](#downsides)
22+
- [Overview](#overview)
23+
- [Key Concepts](#key-concepts)
24+
- [The Chain class](#the-chain-class)
25+
- [The Context API](#the-context-api)
26+
- [The Adapter API](#the-adapter-api)
27+
- [Getting Started](#getting-started)
28+
- [Install](#install)
29+
- [Hello World](#hello-world)
30+
- [Training](#training)
31+
- [Credits](#credits)
32+
- [Citation](#citation)
33+
34+
35+
## Motivation
36+
37+
At [Finegrain](https://finegrain.ai), we're on a mission to automate product photography. Given our "no human in the loop approach", nailing the quality of the outputs we generate is paramount to our success.
38+
39+
That's why we're building Refiners.
40+
41+
It's a framework to easily bridge the last mile quality gap of foundational models like Stable Diffusion or Segment Anything Model (SAM), by adapting them to specific tasks with lightweight trainable and composable patches.
42+
43+
We decided to build Refiners in the open.
44+
45+
It's because model adaptation is a new paradigm that goes beyond our specific use cases. Our hope is to help people looking at creating their own adapters save time, whatever the foundation model they're using.
46+
47+
## Design
48+
49+
We are huge fans of PyTorch (we actually were core committers to [Torch](http://torch.ch/) in another life), but we felt it's too low level for the specific model adaptation task: PyTorch models are generally hard to understand, and their adaptation requires intricate ad hoc code.
50+
51+
Instead, we needed:
52+
53+
- A model structure that's human readable so that you know what models do and how they work right here, right now
54+
- A mechanism to easily inject parameters in some target layers, or between them
55+
- A way to easily pass data (like a conditioning input) between layers even when deeply nested
56+
- Native support for iconic adapter types like LoRAs and their community trained incarnations (hosted on [Civitai](http://civitai.com/) and the likes)
57+
58+
Refiners is designed to tackle all these challenges while remaining just one abstraction away from our beloved PyTorch.
59+
60+
## Downsides
61+
62+
As they say, there is no free lunch. Given Refiners comes with a new model structure, it can only work with models implemented that way. For now, we support Stable Diffusion 1.5, but more is in the making (SDXL, SAM, ...) - stay tuned.
63+
64+
## Overview
65+
66+
The Refiners library is made of:
67+
68+
1. An abstraction layer (called Fluxion) on top of [PyTorch](https://pytorch.org/) to easily build models
69+
2. A zoo of compatible foundational models
70+
3. Adapter APIs to easily patch supported foundational models
71+
4. Training utils to train concrete adapters
72+
5. Conversion scripts to easily use existing community adapters
73+
74+
## Key Concepts
75+
76+
### The Chain class
77+
78+
The `Chain` class is at the core of Refiners. It basically lets you express models as a composition of basic layers (linear, convolution, attention, etc) in a **declarative way**.
79+
80+
E.g.: this is how a Vision Transformer (ViT) looks like with Refiners:
81+
82+
```python
83+
import torch
84+
import refiners.fluxion.layers as fl
85+
86+
class ViT(fl.Chain):
87+
# The Vision Transformer model structure is entirely defined in the constructor. It is
88+
# ready-to-use right after i.e. no need to implement any forward function or add extra logic
89+
def __init__(
90+
self,
91+
embedding_dim: int = 512,
92+
patch_size: int = 16,
93+
image_size: int = 384,
94+
num_layers: int = 12,
95+
num_heads: int = 8,
96+
):
97+
num_patches = (image_size // patch_size)
98+
super().__init__(
99+
fl.Conv2d(in_channels=3, out_channels=embedding_dim, kernel_size=patch_size, stride=patch_size),
100+
fl.Reshape(num_patches**2, embedding_dim),
101+
# The Residual layer implements the so-called skip-connection, i.e. x + F(x).
102+
# Here the patch embeddings (x) are summed with the position embeddings (F(x)) whose
103+
# weights are stored in the Parameter layer (note: there is no extra classification
104+
# token in this toy example)
105+
fl.Residual(fl.Parameter(num_patches**2, embedding_dim)),
106+
# These are the transformer encoders:
107+
*(
108+
fl.Chain(
109+
fl.LayerNorm(embedding_dim),
110+
fl.Residual(
111+
# The Parallel layer is used to pass multiple inputs to a downstream
112+
# layer, here multiheaded self-attention
113+
fl.Parallel(
114+
fl.Identity(),
115+
fl.Identity(),
116+
fl.Identity()
117+
),
118+
fl.Attention(
119+
embedding_dim=embedding_dim,
120+
num_heads=num_heads,
121+
key_embedding_dim=embedding_dim,
122+
value_embedding_dim=embedding_dim,
123+
),
124+
),
125+
fl.LayerNorm(embedding_dim),
126+
fl.Residual(
127+
fl.Linear(embedding_dim, embedding_dim * 4),
128+
fl.GeLU(),
129+
fl.Linear(embedding_dim * 4, embedding_dim),
130+
),
131+
fl.Chain(
132+
fl.Linear(embedding_dim, embedding_dim * 4),
133+
fl.GeLU(),
134+
fl.Linear(embedding_dim * 4, embedding_dim),
135+
),
136+
)
137+
for _ in range(num_layers)
138+
),
139+
fl.Reshape(embedding_dim, num_patches, num_patches),
140+
)
141+
142+
vit = ViT(embedding_dim=768, image_size=224, num_heads=12) # ~ViT-B/16 like
143+
x = torch.randn(2, 3, 224, 224)
144+
y = vit(x)
145+
```
146+
147+
### The Context API
148+
149+
The `Chain` class has a context provider that allows you to **pass data to layers even when deeply nested**.
150+
151+
E.g. to implement cross-attention you would just need to modify the ViT model like in the toy example below:
152+
153+
154+
```diff
155+
@@ -21,8 +21,8 @@
156+
fl.Residual(
157+
fl.Parallel(
158+
fl.Identity(),
159+
- fl.Identity(),
160+
- fl.Identity()
161+
+ fl.UseContext(context="cross_attention", key="my_embed"),
162+
+ fl.UseContext(context="cross_attention", key="my_embed"),
163+
), # used to pass multiple inputs to a layer
164+
fl.Attention(
165+
embedding_dim=embedding_dim,
166+
@@ -49,5 +49,6 @@
167+
)
168+
169+
vit = ViT(embedding_dim=768, image_size=224, num_heads=12) # ~ViT-B/16 like
170+
+vit.set_context("cross_attention", {"my_embed": torch.randn(2, 196, 768)})
171+
x = torch.randn(2, 3, 224, 224)
172+
y = vit(x)
173+
```
174+
175+
### The Adapter API
176+
177+
The `Adapter` API lets you **easily patch models** by injecting parameters in targeted layers. It comes with built-in support for canonical adapter types like LoRA, but you can also implement your custom adapters with it.
178+
179+
E.g. to inject LoRA layers in all attention's linear layers:
180+
181+
```python
182+
from refiners.adapters.lora import LoraAdapter
183+
184+
for layer in vit.layers(fl.Attention):
185+
for linear, parent in layer.walk(fl.Linear):
186+
adapter = LoraAdapter(target=linear, rank=64, device=vit.device, dtype=vit.dtype)
187+
adapter.inject(parent)
188+
189+
# ... and load existing weights if the LoRAs are pretrained ...
190+
```
191+
192+
## Getting Started
193+
194+
### Install
195+
196+
```bash
197+
# inference only
198+
pip install refiners
199+
```
200+
201+
Or:
202+
203+
```bash
204+
# inference + training
205+
pip install 'refiners[training]'
206+
```
207+
208+
### Hello World
209+
210+
Here is how to perform a text-to-image inference using the Stable Diffusion 1.5 foundational model patched with a Pokemon LoRA:
211+
212+
Step 1: prepare the model weights in refiners' format:
213+
214+
```bash
215+
python scripts/convert-clip-weights.py --output-file CLIPTextEncoderL.safetensors
216+
python scripts/convert-sd-lda-weights.py --output-file lda.safetensors
217+
python scripts/convert-sd-unet-weights.py --output-file unet.safetensors
218+
```
219+
220+
> Note: this will download the original weights from https://huggingface.co/runwayml/stable-diffusion-v1-5 which takes some time. If you already have this repo cloned locally, use the `--from /path/to/stable-diffusion-v1-5` option instead.
221+
222+
Step 2: download and convert a community Pokemon LoRA, e.g. [this one](https://huggingface.co/pcuenq/pokemon-lora)
223+
224+
```bash
225+
curl -LO https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin
226+
python scripts/convert-lora-weights.py \
227+
--from pytorch_lora_weights.bin \
228+
--output-file pokemon_lora.safetensors
229+
```
230+
231+
Step 3: run inference using the GPU:
232+
233+
```python
234+
from refiners.foundationals.latent_diffusion import StableDiffusion_1
235+
from refiners.foundationals.latent_diffusion.lora import LoraWeights
236+
from refiners.fluxion.utils import load_from_safetensors, manual_seed
237+
import torch
238+
239+
240+
sd15 = StableDiffusion_1(device="cuda")
241+
sd15.clip_text_encoder.load_state_dict(load_from_safetensors("CLIPTextEncoderL.safetensors"))
242+
sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors"))
243+
sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors"))
244+
245+
# This uses the LoraAdapter internally and takes care to inject it where it should
246+
lora_weights = LoraWeights("pokemon_lora.safetensors", device=sd15.device)
247+
lora_weights.patch(sd15, scale=1.0)
248+
249+
prompt = "a cute cat"
250+
251+
with torch.no_grad():
252+
clip_text_embedding = sd15.compute_text_embedding(prompt)
253+
254+
sd15.set_num_inference_steps(30)
255+
256+
manual_seed(2)
257+
x = torch.randn(1, 4, 64, 64, device=sd15.device)
258+
259+
with torch.no_grad():
260+
for step in sd15.steps:
261+
x = sd15(
262+
x,
263+
step=step,
264+
clip_text_embedding=clip_text_embedding,
265+
condition_scale=7.5,
266+
)
267+
predicted_image = sd15.lda.decode_latents(x)
268+
predicted_image.save("pokemon_cat.png")
269+
```
270+
271+
You should get:
272+
273+
![pokemon cat output](assets/pokemon_cat.png)
274+
275+
## Training
276+
277+
Refiners has a built-in training utils library and provides scripts that can be used as a starting point.
278+
279+
E.g. to train a LoRA on top of Stable Diffusion, copy and edit `configs/finetune-lora.toml` to suit your needs and launch the training as follows:
280+
281+
```bash
282+
python scripts/training/finetune-ldm-lora.py configs/finetune-lora.toml
283+
```
284+
285+
## Credits
286+
287+
We took inspiration from these great projects:
288+
289+
- [tinygrad](https://github.com/tinygrad/tinygrad) - For something between PyTorch and [karpathy/micrograd](https://github.com/karpathy/micrograd)
290+
- [Composer](https://github.com/mosaicml/composer) - A PyTorch Library for Efficient Neural Network Training
291+
- [Keras](https://github.com/keras-team/keras) - Deep Learning for humans
292+
293+
## Citation
294+
295+
```bibtex
296+
@misc{the-finegrain-team-2023-refiners,
297+
author = {Benjamin Trom and Pierre Chapuis and Cédric Deltheil},
298+
title = {Refiners: The simplest way to train and run adapters on top of foundational models},
299+
year = {2023},
300+
publisher = {GitHub},
301+
journal = {GitHub repository},
302+
howpublished = {\url{https://github.com/finegrain-ai/refiners}}
303+
}
304+
```

assets/dropy.png

4.5 KB
Loading

assets/logo_dark.png

26.8 KB
Loading

assets/logo_light.png

25.4 KB
Loading

assets/pokemon_cat.png

336 KB
Loading

0 commit comments

Comments
 (0)