Skip to content

sync: sync with latest ggml #670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class UpSampleBlock : public GGMLBlock {
// x: [N, channels, h, w]
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);

x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2]
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); // [N, channels, h*2, w*2]
x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2]
return x;
}
Expand Down
4 changes: 2 additions & 2 deletions esrgan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ class RRDBNet : public GGMLBlock {
body_feat = conv_body->forward(ctx, body_feat);
feat = ggml_add(ctx, feat, body_feat);
// upsample
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2)));
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2)));
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat)));
return out;
}
Expand Down
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated from ff9052 to 17733d
2 changes: 1 addition & 1 deletion ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct g
a->ne[0] * b->ne[0],
a->ne[1] * b->ne[1],
a->ne[2] * b->ne[2],
a->ne[3] * b->ne[3]),
a->ne[3] * b->ne[3], GGML_SCALE_MODE_NEAREST),
b);
}

Expand Down
78 changes: 39 additions & 39 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,69 +350,69 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {

// convert attn to out
if (ends_with(key, "to_out")) {
key += format("%c0", seq);
key += sd_format("%c0", seq);
}

// unet
if (match(m, std::regex(format("unet%cconv_in(.*)", seq)), key)) {
return format("model%cdiffusion_model%cinput_blocks%c0%c0", seq, seq, seq, seq) + m[0];
if (match(m, std::regex(sd_format("unet%cconv_in(.*)", seq)), key)) {
return sd_format("model%cdiffusion_model%cinput_blocks%c0%c0", seq, seq, seq, seq) + m[0];
}

if (match(m, std::regex(format("unet%cconv%cout(.*)", seq, seq)), key)) {
return format("model%cdiffusion_model%cout%c2", seq, seq, seq) + m[0];
if (match(m, std::regex(sd_format("unet%cconv%cout(.*)", seq, seq)), key)) {
return sd_format("model%cdiffusion_model%cout%c2", seq, seq, seq) + m[0];
}

if (match(m, std::regex(format("unet%cconv_norm_out(.*)", seq)), key)) {
return format("model%cdiffusion_model%cout%c0", seq, seq, seq) + m[0];
if (match(m, std::regex(sd_format("unet%cconv_norm_out(.*)", seq)), key)) {
return sd_format("model%cdiffusion_model%cout%c0", seq, seq, seq) + m[0];
}

if (match(m, std::regex(format("unet%ctime_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) {
return format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
if (match(m, std::regex(sd_format("unet%ctime_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) {
return sd_format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
}

if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
if (match(m, std::regex(sd_format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
std::string suffix = get_converted_suffix(m[1], m[3]);
// LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str());
return format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(1 + std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq +
return sd_format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(1 + std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq +
(m[1] == "attentions" ? "1" : "0") + seq + suffix;
}

if (match(m, std::regex(format("unet%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq)), key)) {
if (match(m, std::regex(sd_format("unet%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq)), key)) {
std::string suffix = get_converted_suffix(m[0], m[2]);
return format("model%cdiffusion_model%cmiddle_block%c", seq, seq, seq) + (m[0] == "attentions" ? "1" : std::to_string(std::stoi(m[1]) * 2)) +
return sd_format("model%cdiffusion_model%cmiddle_block%c", seq, seq, seq) + (m[0] == "attentions" ? "1" : std::to_string(std::stoi(m[1]) * 2)) +
seq + suffix;
}

if (match(m, std::regex(format("unet%cup_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
if (match(m, std::regex(sd_format("unet%cup_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
std::string suffix = get_converted_suffix(m[1], m[3]);
return format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq +
return sd_format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq +
(m[1] == "attentions" ? "1" : "0") + seq + suffix;
}

if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) {
return format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(3 + std::stoi(m[0]) * 3) + seq + "0" + seq + "op";
if (match(m, std::regex(sd_format("unet%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) {
return sd_format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(3 + std::stoi(m[0]) * 3) + seq + "0" + seq + "op";
}

if (match(m, std::regex(format("unet%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) {
return format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(2 + std::stoi(m[0]) * 3) + seq +
if (match(m, std::regex(sd_format("unet%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) {
return sd_format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(2 + std::stoi(m[0]) * 3) + seq +
(std::stoi(m[0]) > 0 ? "2" : "1") + seq + "conv";
}

// clip
if (match(m, std::regex(format("te%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
return format("cond_stage_model%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq) + m[0] + seq + m[1];
if (match(m, std::regex(sd_format("te%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
return sd_format("cond_stage_model%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq) + m[0] + seq + m[1];
}

if (match(m, std::regex(format("te%ctext_model(.*)", seq)), key)) {
return format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0];
if (match(m, std::regex(sd_format("te%ctext_model(.*)", seq)), key)) {
return sd_format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0];
}

// vae
if (match(m, std::regex(format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) {
return format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str());
if (match(m, std::regex(sd_format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) {
return sd_format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str());
}

if (match(m, std::regex(format("vae%c(.*)%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
if (match(m, std::regex(sd_format("vae%c(.*)%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
std::string suffix;
std::string block_name;
if (m[1] == "attentions") {
Expand All @@ -422,40 +422,40 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
block_name = "block";
suffix = m[3];
}
return format("first_stage_model%c%s%cmid%c%s_%d%c%s",
return sd_format("first_stage_model%c%s%cmid%c%s_%d%c%s",
seq, m[0].c_str(), seq, seq, block_name.c_str(), std::stoi(m[2]) + 1, seq, suffix.c_str());
}

if (match(m, std::regex(format("vae%c(.*)%cup_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
if (match(m, std::regex(sd_format("vae%c(.*)%cup_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
std::string suffix = m[3];
if (suffix == "conv_shortcut") {
suffix = "nin_shortcut";
}
return format("first_stage_model%c%s%cup%c%d%cblock%c%s%c%s",
return sd_format("first_stage_model%c%s%cup%c%d%cblock%c%s%c%s",
seq, m[0].c_str(), seq, seq, 3 - std::stoi(m[1]), seq, seq, m[2].c_str(), seq, suffix.c_str());
}

if (match(m, std::regex(format("vae%c(.*)%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) {
return format("first_stage_model%c%s%cdown%c%d%cdownsample%cconv",
if (match(m, std::regex(sd_format("vae%c(.*)%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) {
return sd_format("first_stage_model%c%s%cdown%c%d%cdownsample%cconv",
seq, m[0].c_str(), seq, seq, std::stoi(m[1]), seq, seq);
}

if (match(m, std::regex(format("vae%c(.*)%cdown_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
if (match(m, std::regex(sd_format("vae%c(.*)%cdown_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
std::string suffix = m[3];
if (suffix == "conv_shortcut") {
suffix = "nin_shortcut";
}
return format("first_stage_model%c%s%cdown%c%d%cblock%c%s%c%s",
return sd_format("first_stage_model%c%s%cdown%c%d%cblock%c%s%c%s",
seq, m[0].c_str(), seq, seq, std::stoi(m[1]), seq, seq, m[2].c_str(), seq, suffix.c_str());
}

if (match(m, std::regex(format("vae%c(.*)%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) {
return format("first_stage_model%c%s%cup%c%d%cupsample%cconv",
if (match(m, std::regex(sd_format("vae%c(.*)%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) {
return sd_format("first_stage_model%c%s%cup%c%d%cupsample%cconv",
seq, m[0].c_str(), seq, seq, 3 - std::stoi(m[1]), seq, seq);
}

if (match(m, std::regex(format("vae%c(.*)", seq)), key)) {
return format("first_stage_model%c", seq) + m[0];
if (match(m, std::regex(sd_format("vae%c(.*)", seq)), key)) {
return sd_format("first_stage_model%c", seq) + m[0];
}

return key;
Expand Down Expand Up @@ -756,7 +756,7 @@ void convert_tensor(void* src,
} else {
auto qtype = ggml_get_type_traits(src_type);
if (qtype->to_float == NULL) {
throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available",
throw std::runtime_error(sd_format("type %s unsupported for integer quantization: no dequantization available",
ggml_type_name(src_type)));
}
qtype->to_float(src, (float*)dst, n);
Expand All @@ -766,7 +766,7 @@ void convert_tensor(void* src,
// src_type is quantized => dst_type == GGML_TYPE_F16 or dst_type is quantized
auto qtype = ggml_get_type_traits(src_type);
if (qtype->to_float == NULL) {
throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available",
throw std::runtime_error(sd_format("type %s unsupported for integer quantization: no dequantization available",
ggml_type_name(src_type)));
}
std::vector<char> buf;
Expand Down
2 changes: 1 addition & 1 deletion tae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class TinyDecoder : public UnaryBlock {
if (i == 1) {
h = ggml_relu_inplace(ctx, h);
} else {
h = ggml_upscale(ctx, h, 2);
h = ggml_upscale(ctx, h, 2, GGML_SCALE_MODE_NEAREST);
}
continue;
}
Expand Down
2 changes: 1 addition & 1 deletion util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void replace_all_chars(std::string& str, char target, char replacement) {
}
}

std::string format(const char* fmt, ...) {
std::string sd_format(const char* fmt, ...) {
va_list ap;
va_list ap2;
va_start(ap, fmt);
Expand Down
2 changes: 1 addition & 1 deletion util.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ bool ends_with(const std::string& str, const std::string& ending);
bool starts_with(const std::string& str, const std::string& start);
bool contains(const std::string& str, const std::string& substr);

std::string format(const char* fmt, ...);
std::string sd_format(const char* fmt, ...);

void replace_all_chars(std::string& str, char target, char replacement);

Expand Down