Skip to content

Commit 546fa4d

Browse files
committed
random seed: be random for re-generation.
1 parent f0282b5 commit 546fa4d

File tree

4 files changed

+16
-7
lines changed

4 files changed

+16
-7
lines changed

models/bailing.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ namespace chatllm::bailing::llada
572572
GenerationConfig _gen_config(gen_config);
573573
_gen_config.do_sample = true;
574574
_gen_config.sampling = "top_p";
575-
std::unique_ptr<Sampler> sampler = std::unique_ptr<Sampler>(SamplerFactory::Create(_gen_config, _seed));
575+
std::unique_ptr<Sampler> sampler = std::unique_ptr<Sampler>(SamplerFactory::Create(_gen_config, get_seed()));
576576

577577
aborted = false;
578578

src/chat.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,7 @@ namespace chatllm
977977
virtual void set_ctx(int n_ctx)= 0;
978978

979979
virtual void seed(int x) = 0;
980+
virtual int get_seed(void) const = 0;
980981

981982
virtual int get_max_length(void) = 0;
982983

@@ -1097,6 +1098,7 @@ namespace chatllm
10971098
void set_ctx(int n_ctx) override { model->set_ctx(n_ctx); }
10981099

10991100
void seed(int x) override { model->seed(x); }
1101+
int get_seed(void) const override { return model->get_seed(); }
11001102

11011103
int get_max_length(void) override { return model->get_max_length(); }
11021104

@@ -1131,7 +1133,8 @@ namespace chatllm
11311133
BaseModel(uint32_t type, ModelPurpose purpose) :
11321134
type_(type), n_past(0),
11331135
n_past_offset(0), tokenizer(nullptr),
1134-
purpose(purpose), aborted(false)
1136+
purpose(purpose), aborted(false),
1137+
_seed(-1)
11351138
{}
11361139

11371140
virtual ~BaseModel()
@@ -1234,6 +1237,13 @@ namespace chatllm
12341237
void set_ctx(int n_ctx) override {}
12351238

12361239
void seed(int x) override { _seed = x; }
1240+
int get_seed(void) const override
1241+
{
1242+
if (_seed > 0) return _seed;
1243+
1244+
std::random_device rd;
1245+
return rd();
1246+
}
12371247

12381248
int get_n_past(void) override { return n_past; }
12391249

@@ -1273,12 +1283,13 @@ namespace chatllm
12731283
uint32_t type_;
12741284
std::string name_;
12751285
std::string native_name_;
1276-
int _seed;
12771286
int n_past;
12781287
int n_past_offset;
12791288
BaseTokenizer *tokenizer;
12801289
ModelPurpose purpose;
12811290
bool aborted;
1291+
private:
1292+
int _seed;
12821293
};
12831294

12841295
class ModelObject

src/main.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ struct Args
7272
float frequency_penalty = 0.0f;
7373
int num_threads = 0;
7474
bool multi_line = false;
75-
int seed;
75+
int seed = -1;
7676
chatllm::ChatFormat format = chatllm::ChatFormat::CHAT;
7777
bool tokenize = false;
7878
DistanceStrategy vc = DistanceStrategy::MaxInnerProduct;
@@ -287,8 +287,6 @@ static std::string load_txt(const std::string &fn)
287287

288288
static size_t parse_args(Args &args, const std::vector<std::string> &argv)
289289
{
290-
std::random_device rd;
291-
args.seed = rd();
292290
const size_t argc = argv.size();
293291

294292
#define handle_para0(fmt1, field, f) \

src/models.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ namespace chatllm
914914
// printf("%d, ", input_ids[i]);
915915
//printf("\nn_past = %d, %d\n\n", n_past, continuous);
916916

917-
std::unique_ptr<Sampler> sampler = std::unique_ptr<Sampler>(SamplerFactory::Create(gen_config, _seed));
917+
std::unique_ptr<Sampler> sampler = std::unique_ptr<Sampler>(SamplerFactory::Create(gen_config, get_seed()));
918918

919919
aborted = false;
920920

0 commit comments

Comments
 (0)