Skip to content

Commit c04832b

Browse files
authored
sd: add eta support (#2164)
1 parent 18a3bed commit c04832b

3 files changed

Lines changed: 13 additions & 0 deletions

File tree

expose.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ struct sd_generation_inputs
225225
const int seed = 0;
226226
const char * sample_method = nullptr;
227227
const char * scheduler = nullptr;
228+
const float eta = -1.0f;
228229
const int clip_skip = -1;
229230
const int vid_req_frames = 1;
230231
const int video_output_type = 0; //0=gif, 1=avi, 2=both

koboldcpp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ class sd_generation_inputs(ctypes.Structure):
389389
("seed", ctypes.c_int),
390390
("sample_method", ctypes.c_char_p),
391391
("scheduler", ctypes.c_char_p),
392+
("eta", ctypes.c_float),
392393
("clip_skip", ctypes.c_int),
393394
("vid_req_frames", ctypes.c_int),
394395
("video_output_type", ctypes.c_int),
@@ -2645,6 +2646,7 @@ def sd_generate(genparams):
26452646
sample_method = (genparams.get("sampler_name") or "default")
26462647
scheduler = (genparams.get("scheduler") or "default").lower()
26472648
clip_skip = tryparseint(genparams.get("clip_skip", -1),-1)
2649+
eta = tryparsefloat(genparams.get("eta", None), None)
26482650
vid_req_frames = tryparseint(genparams.get("frames", 1),1)
26492651
vid_req_frames = 1 if (not vid_req_frames or vid_req_frames < 1) else vid_req_frames
26502652
video_output_type = genparams.get("video_output_type", 0)
@@ -2697,6 +2699,7 @@ def sd_generate(genparams):
26972699
inputs.seed = ((seed + 2**31) % 2**32) - 2**31
26982700
inputs.sample_method = sd_sampler_canonical_name(sample_method).encode("UTF-8")
26992701
inputs.scheduler = scheduler.encode("UTF-8")
2702+
inputs.eta = -1.0 if eta is None else eta
27002703
inputs.clip_skip = clip_skip
27012704
inputs.vid_req_frames = vid_req_frames
27022705
inputs.video_output_type = video_output_type

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ struct SDParams {
133133
float distilled_guidance = -1.0f;
134134
float shifted_timestep = 0;
135135
float flow_shift = -1.0f;
136+
float eta = -1.0f;
136137
float strength = 0.75f;
137138
int64_t seed = 42;
138139
bool clip_on_cpu = false;
@@ -600,6 +601,8 @@ static std::string get_image_params(const sd_img_gen_params_t & params, const st
600601
<< " | Size: " << params.width << "x" << params.height
601602
<< " | Sampler: " << sd_sample_method_name(params.sample_params.sample_method)
602603
<< get_scheduler_name(params.sample_params.scheduler, true);
604+
if (params.sample_params.eta != -1.0f)
605+
ss << "| Eta: " << params.sample_params.eta;
603606
if (params.sample_params.shifted_timestep != 0)
604607
ss << "| Timestep Shift: " << params.sample_params.shifted_timestep;
605608
if (params.sample_params.flow_shift > 0.f && params.sample_params.flow_shift != INFINITY)
@@ -978,6 +981,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
978981
sd_params->sample_steps = inputs.sample_steps;
979982
sd_params->shifted_timestep = inputs.shifted_timestep;
980983
sd_params->flow_shift = inputs.flow_shift;
984+
sd_params->eta = inputs.eta;
981985
sd_params->seed = inputs.seed;
982986
sd_params->width = inputs.width;
983987
sd_params->height = inputs.height;
@@ -1212,6 +1216,9 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
12121216
params.sample_params.scheduler = sd_params->scheduler;
12131217
params.sample_params.sample_steps = sd_params->sample_steps;
12141218
params.sample_params.shifted_timestep = sd_params->shifted_timestep;
1219+
if (sd_params->eta >= 0.f && sd_params->eta <= 1.f) {
1220+
params.sample_params.eta = sd_params->eta;
1221+
}
12151222
if (sd_params->flow_shift > 0.f && sd_params->flow_shift != INFINITY) {
12161223
params.sample_params.flow_shift = sd_params->flow_shift;
12171224
}
@@ -1418,6 +1425,8 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
14181425
jsoninfo["extra_generation_params"] = nlohmann::json::object();
14191426
if (params.sample_params.scheduler != scheduler_t::SCHEDULER_COUNT)
14201427
jsoninfo["extra_generation_params"]["Schedule type"] = get_scheduler_name(params.sample_params.scheduler);
1428+
if (params.sample_params.eta >= 0 && params.sample_params.eta <= 1)
1429+
jsoninfo["eta"] = params.sample_params.eta;
14211430
if (is_img2img)
14221431
jsoninfo["denoising_strength"] = params.strength;
14231432
if (sd_params->model_path.empty())

0 commit comments

Comments
 (0)