From b53c6f01de54d94ec4d99d1faebea27f0e59fef6 Mon Sep 17 00:00:00 2001 From: sid Date: Mon, 29 Mar 2021 21:22:39 +0200 Subject: [PATCH 1/4] clean up configs + add pretrained ones --- configs/GPT3_2-7B_pretrained.json | 41 +++++++++++++++++++ configs/GPT3_XL_pretrained.json | 40 ++++++++++++++++++ configs/gpt2_small.json | 5 ++- configs/gpt3_13B_256.json | 2 +- configs/gpt3_13B_256_Pile.json | 38 ----------------- configs/gpt3_2-7B_256.json | 5 ++- configs/gpt3_6-7B_256.json | 6 ++- configs/gpt3_PAR_small_256.json | 5 ++- ...gpt3_XL_256_Pile.json => gpt3_XL_256.json} | 2 +- configs/gpt3_large_256.json | 5 ++- configs/gpt3_medium_256.json | 6 ++- configs/gpt3_small_256.json | 5 ++- 12 files changed, 106 insertions(+), 54 deletions(-) create mode 100644 configs/GPT3_2-7B_pretrained.json create mode 100644 configs/GPT3_XL_pretrained.json delete mode 100644 configs/gpt3_13B_256_Pile.json rename configs/{gpt3_XL_256_Pile.json => gpt3_XL_256.json} (94%) diff --git a/configs/GPT3_2-7B_pretrained.json b/configs/GPT3_2-7B_pretrained.json new file mode 100644 index 00000000..fb9a6105 --- /dev/null +++ b/configs/GPT3_2-7B_pretrained.json @@ -0,0 +1,41 @@ +{ +"n_head" : 20, +"n_vocab" : 50257, +"embed_dropout" : 0, +"lr" : 0.00016, +"lr_decay" : "cosine", +"warmup_steps" : 3000, +"beta1" : 0.9, +"beta2" : 0.95, +"epsilon" : 1e-08, +"ada_epsilon1" : "1e-30", +"ada_epsilon2" : 0.001, +"opt_name" : "adam", +"weight_decay" : 0, +"train_batch_size" : 512, +"attn_dropout" : 0, +"train_steps" : 400000, +"lr_decay_end" : 300000, +"eval_steps" : 10, +"predict_steps" : 0, +"res_dropout" : 0, +"eval_batch_size" : 128, +"predict_batch_size" : 1, +"iterations" : 500, +"n_embd" : 2560, +"datasets" : [["pile", null, null, null]], +"model_path" : "gs://neo-d/models/GPT3_2-7B", +"n_ctx" : 2048, +"n_layer" : 32, +"scale_by_depth" : true, +"scale_by_in" : false, +"attention_types" : [[["global", "local"], 16]], +"mesh_shape" : "x:64,y:4", +"layout" : "batch:x,embd:y", +"activation_function" : "gelu", +"recompute_grad" : true, +"gradient_clipping" : 1.0, +"tokens_per_mb_per_replica" : 4096, +"padding_id" : 50257, +"eos_id" : 50256 +} diff --git a/configs/GPT3_XL_pretrained.json b/configs/GPT3_XL_pretrained.json new file mode 100644 index 00000000..7b9def31 --- /dev/null +++ b/configs/GPT3_XL_pretrained.json @@ -0,0 +1,40 @@ +{ +"n_head" : 16, +"n_vocab" : 50257, +"embed_dropout" : 0, +"lr" : 0.0002, +"lr_decay" : "cosine", +"warmup_steps" : 3000, +"beta1" : 0.9, +"beta2" : 0.95, +"epsilon" : 1e-08, +"opt_name" : "adam", +"weight_decay" : 0, +"train_batch_size" : 512, +"attn_dropout" : 0, +"train_steps" : 400000, +"lr_decay_end" : 300000, +"eval_steps" : 10, +"predict_steps" : 0, +"res_dropout" : 0, +"eval_batch_size" : 128, +"predict_batch_size" : 128, +"iterations" : 500, +"n_embd" : 2048, +"datasets" : [["pile", null, null, null]], +"model_path" : "gs://neo-d/models/GPT3_XL_Pile", +"n_ctx" : 2048, +"n_layer" : 24, +"scale_by_depth" : true, +"scale_by_in" : false, +"attention_types" : [[["global", "local"], 12]], +"mesh_shape" : "x:128,y:2", +"layout" : "batch:x,memory_length:y,embd:y", +"activation_function" : "gelu", +"recompute_grad" : true, +"gradient_clipping" : 1.0, +"tokens_per_mb_per_replica" : 4096, +"precision" : "bfloat16", +"padding_id" : 50257, +"eos_id" : 50256 +} diff --git a/configs/gpt2_small.json b/configs/gpt2_small.json index 2fc767c2..fe6fb350 100644 --- a/configs/gpt2_small.json +++ b/configs/gpt2_small.json @@ -21,7 +21,7 @@ "predict_batch_size": 8, "iterations": 2500, "n_embd": 768, - "datasets": ["openwebtext2_new_inputs"], + "datasets": [["pile", null, null, null]], "model_path": "gs://neo-models/GPT2_SMALL", "n_ctx": 1024, "n_layer": 12, @@ -32,5 +32,6 @@ "mesh_shape": "all:64", "layout": "batch:all", "recompute_grad": false, - "gradient_clipping": 1.0 + "gradient_clipping": 1.0, + "precision": "bfloat16" } \ No newline at end of file diff --git a/configs/gpt3_13B_256.json b/configs/gpt3_13B_256.json index 3f3195f5..92e5ff06 100644 --- a/configs/gpt3_13B_256.json +++ b/configs/gpt3_13B_256.json @@ -22,7 +22,7 @@ "predict_batch_size": 1, "iterations": 500, "n_embd": 5120, - "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], + "datasets": [["pile", null, null, null]], "model_path": "gs://neo-models/GPT3_13B", "n_ctx": 2048, "n_layer": 40, diff --git a/configs/gpt3_13B_256_Pile.json b/configs/gpt3_13B_256_Pile.json deleted file mode 100644 index a31fe05e..00000000 --- a/configs/gpt3_13B_256_Pile.json +++ /dev/null @@ -1,38 +0,0 @@ - -{ - "n_head": 40, - "n_vocab": 50257, - "embed_dropout": 0, - "lr": 0.0001, - "lr_decay": "cosine", - "warmup_steps": 3000, - "beta1": 0.9, - "beta2": 0.95, - "epsilon": 1e-8, - "opt_name": "adam", - "weight_decay": 0.1, - "train_batch_size": 1024, - "attn_dropout": 0, - "train_steps": 286150, - "eval_steps": 10, - "predict_steps": 1, - "res_dropout": 0, - "eval_batch_size": 512, - "predict_batch_size": 1, - "iterations": 500, - "n_embd": 5120, - "datasets": [["pile", 25, "documents_random", 1.0]], - "model_path": "gs://neo-models/GPT3_13B_Pile", - "n_ctx": 2048, - "n_layer": 40, - "scale_by_depth": true, - "scale_by_in": false, - "attention_types" : [[["global"],40]], - "mesh_shape": "x:16,y:16", - "layout": "batch:x,memory_length:y,embd:y", - "activation_function": "gelu", - "recompute_grad": true, - "gradient_clipping": 1.0, - "tokens_per_mb_per_replica": 2048, - "precision": "bfloat16" -} diff --git a/configs/gpt3_2-7B_256.json b/configs/gpt3_2-7B_256.json index 4af98692..a1edc668 100644 --- a/configs/gpt3_2-7B_256.json +++ b/configs/gpt3_2-7B_256.json @@ -22,7 +22,7 @@ "predict_batch_size": 1, "iterations": 500, "n_embd": 2560, - "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], + "datasets": [["pile", null, null, null]], "model_path": "gs://neo-models/GPT3_2-7B", "n_ctx": 2048, "n_layer": 32, @@ -33,6 +33,7 @@ "layout": "embd:y,batch:x", "activation_function": "gelu", "recompute_grad": true, - "gradient_clipping": 1.0 + "gradient_clipping": 1.0, + "precision": "bfloat16" } diff --git a/configs/gpt3_6-7B_256.json b/configs/gpt3_6-7B_256.json index 0d2c8a6f..716dda53 100644 --- a/configs/gpt3_6-7B_256.json +++ b/configs/gpt3_6-7B_256.json @@ -20,7 +20,7 @@ "predict_batch_size": 1, "iterations": 500, "n_embd": 4096, - "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], + "datasets": [["pile", null, null, null]], "model_path": "gs://neo-models/GPT3_6-7B", "n_ctx": 2048, "n_layer": 32, @@ -31,6 +31,8 @@ "layout": "embd:y,batch:x", "activation_function": "gelu", "recompute_grad": true, - "gradient_clipping": 1.0 + "gradient_clipping": 1.0, + "precision": "bfloat16" + } diff --git a/configs/gpt3_PAR_small_256.json b/configs/gpt3_PAR_small_256.json index 3dba88ea..1c333ee1 100644 --- a/configs/gpt3_PAR_small_256.json +++ b/configs/gpt3_PAR_small_256.json @@ -20,7 +20,7 @@ "predict_batch_size": 1, "iterations": 1000, "n_embd": 768, - "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], + "datasets": [["pile", null, null, null]], "model_path": "gs://neo-models/GPT3_PAR_SMALL", "n_ctx": 2048, "n_layer": 19, @@ -31,6 +31,7 @@ "layout": "batch:x,heads:y,vocab:y,intermediate_expanded:y", "activation_function": "gelu", "recompute_grad": false, - "gradient_clipping": 1.0 + "gradient_clipping": 1.0, + "precision": "bfloat16" } diff --git a/configs/gpt3_XL_256_Pile.json b/configs/gpt3_XL_256.json similarity index 94% rename from configs/gpt3_XL_256_Pile.json rename to configs/gpt3_XL_256.json index c39ad353..1be43946 100644 --- a/configs/gpt3_XL_256_Pile.json +++ b/configs/gpt3_XL_256.json @@ -20,7 +20,7 @@ "predict_batch_size": 1, "iterations": 500, "n_embd": 2048, - "datasets": [["pile", 25, "documents_random", 1.0]], + "datasets": [["pile", null, null, null]], "model_path": "gs://neo-models/GPT3_XL_Pile", "n_ctx": 2048, "n_layer": 24, diff --git a/configs/gpt3_large_256.json b/configs/gpt3_large_256.json index c8464c74..445673e1 100644 --- a/configs/gpt3_large_256.json +++ b/configs/gpt3_large_256.json @@ -22,7 +22,7 @@ "predict_batch_size": 1, "iterations": 2500, "n_embd": 1536, - "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], + "datasets": [["pile", null, null, null]], "model_path": "gs://neo-models/GPT3_LARGE", "n_ctx": 2048, "n_layer": 24, @@ -34,6 +34,7 @@ "activation_function": "gelu", "recompute_grad": true, "gradient_clipping": 1.0, - "tokens_per_mb_per_replica": 2048 + "tokens_per_mb_per_replica": 2048, + "precision": "bfloat16" } diff --git a/configs/gpt3_medium_256.json b/configs/gpt3_medium_256.json index 7726f483..b9012977 100644 --- a/configs/gpt3_medium_256.json +++ b/configs/gpt3_medium_256.json @@ -20,7 +20,7 @@ "predict_batch_size": 1, "iterations": 2500, "n_embd": 1024, - "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], + "datasets": [["pile", null, null, null]], "model_path": "gs://neo-models/GPT3_MEDIUM", "n_ctx": 2048, "n_layer": 24, @@ -31,6 +31,8 @@ "layout": "batch:x,heads:y,vocab:y", "activation_function": "gelu", "recompute_grad": false, - "gradient_clipping": 1.0 + "gradient_clipping": 1.0, + "precision": "bfloat16" + } diff --git a/configs/gpt3_small_256.json b/configs/gpt3_small_256.json index a4afe268..5adcb79f 100644 --- a/configs/gpt3_small_256.json +++ b/configs/gpt3_small_256.json @@ -20,7 +20,7 @@ "predict_batch_size": 1, "iterations": 2500, "n_embd": 768, - "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], + "datasets": [["openwebtext-documents", null, "documents_random", 1.0]], "model_path": "gs://neo-models/GPT3_SMALL", "n_ctx": 2048, "n_layer": 12, @@ -31,6 +31,7 @@ "layout": "batch:x,heads:y,vocab:y,intermediate_expanded:y", "activation_function": "gelu", "recompute_grad": false, - "gradient_clipping": 1.0 + "gradient_clipping": 1.0, + "precision": "bfloat16" } From 5dd8c0f64121a11ec36faadf6d7625ec5300ba27 Mon Sep 17 00:00:00 2001 From: sid Date: Mon, 29 Mar 2021 21:22:58 +0200 Subject: [PATCH 2/4] fix activations.py --- models/activations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/activations.py b/models/activations.py index 68e543c5..1a34c73d 100644 --- a/models/activations.py +++ b/models/activations.py @@ -48,7 +48,7 @@ def _elish(x): 'square_relax': lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(5 * x) / 5 - mtf.cos(7 * x) / 7, 'spike': lambda x: 1 / (1 + x ** 2), 'spike2': lambda x: mtf.exp(-x ** 2), - 'tanhshrink': lambda x: x - tanh(x), + 'tanhshrink': lambda x: x - mtf.tanh(x), 'softsign': lambda x: x / (mtf.abs(x) + 1), 'softmax': lambda x: mtf.softmax(x, x.shape[-1]), 'logsoftmax': lambda x: mtf.log_softmax(x, x.shape[-1]), From 6536ecfe97f661578ee1f2890cd606af120f9abf Mon Sep 17 00:00:00 2001 From: sid Date: Mon, 29 Mar 2021 21:24:37 +0200 Subject: [PATCH 3/4] add sampling args and do some minor cleanup --- main.py | 100 +++++++++++++---------------- model_fns.py | 14 +++-- optimizers.py | 170 +++++++++++++++++++++++++------------------------- sample.py | 15 ++--- tasks.py | 51 ++++++++++++++- utils.py | 37 ++++++----- 6 files changed, 218 insertions(+), 169 deletions(-) diff --git a/main.py b/main.py index 2627d022..09bf42fb 100644 --- a/main.py +++ b/main.py @@ -12,15 +12,15 @@ from model_fns import model_fn from data.encoders import fetch_encoder from configs import fetch_model_params -from tasks import task_descriptors +from tasks import task_descriptors, run_eval_tasks, run_eval import argparse -import json -import numpy def parse_args(): # Parse command line arguments parser = argparse.ArgumentParser() + + # training args parser.add_argument("--tpu", type=str, help="Name of TPU to train on, if any.") parser.add_argument("--gpu_ids", nargs="+", type=str, default=["device:GPU:0"], help="If training on GPU, can specify your GPU names in a list - i.e 'device:GPU:0 device:GPU:1'") @@ -33,15 +33,32 @@ def parse_args(): " MTF auto layout.") parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and " "starts a new training run") + + # sampling args parser.add_argument("--predict", action="store_true", help="If set, uses the model to predict rather than train.") - parser.add_argument("--eval", action="store_true", help="If set, run model in evaluation mode.") parser.add_argument("--prompt", type=str, help="path to .txt file containing a prompt for prediction. If empty, " "defaults to unicorns.", default="") + parser.add_argument("--temperature", type=float, help="temperature for temperature sampling. Float between 0 and 1", + default=0.9) + parser.add_argument("--top-k", type=int, help="sampling_keep_top_k: an integer - if not -1, only sample from the " + "top k logits", default=-1) + parser.add_argument("--entmax_sampling", action="store_true", help="(experimental) use entmax sampling") + parser.add_argument("--max_steps", type=int, help="an optional integer, the max number of steps to decode when " + "sampling.", default=None) + parser.add_argument("--sampling-stop-token", type=int, help="An optional integer. Stop sampling when this token is " + "produced. Defaults to EOS token if none is provided.", + default=None) + parser.add_argument("--remove-prompt", action="store_true", help="whether to remove the prompt from the sampling " + "output. Defaults to False.") + parser.add_argument("--sample-save-path", type=str, help="path to save the samples to. If None is provided, " + "defaults to predictions_{current_step}.txt") + + # misc args + parser.add_argument("--eval", action="store_true", help="If set, run model in evaluation mode.") parser.add_argument("--check_dataset", action="store_true", help="If set, outputs sample from the dataset and quits.") parser.add_argument("--sacred_id", type=str, default="nosacred", help="Sacred run id.") - parser.add_argument("--entmax_sampling", action="store_true", help="(experimental) use entmax sampling") parser.add_argument("--export", action="store_true", help="If set, will export the model.") args = parser.parse_args() assert args.model is not None, "Model must be set" @@ -62,7 +79,6 @@ def main(args): elif input_fn == "generic_text": input_fn = generic_text pred_input_fn = pred_input - handle_pred_output_fn = handle_pred_output # get current step current_step = int(estimator_lib._load_global_step_from_checkpoint_dir(params["model_path"])) @@ -74,7 +90,6 @@ def main(args): if args.check_dataset: check_dataset(input_fn, params) - # Fetch encoder per params encoder = fetch_encoder(params) @@ -105,13 +120,19 @@ def main(args): # Expand attention types param params["attention_types"] = expand_attention_types_params(params["attention_types"]) assert len(params["attention_types"]) == params["n_layer"] # Assert that the length of expanded list = num layers - params["predict_batch_size"] = params.get("predict_batch_size", 1) # Default to 1 - params["predict"] = args.predict - params['model'] = params.get("model", "GPT") # Default model selection to GPT since it's the only option for now + params['model'] = params.get("model", "GPT") # Default model selection to GPT since it's the only option for now params["export"] = args.export + # Set sampling parameters + params["predict"] = args.predict + params["predict_batch_size"] = params.get("predict_batch_size", 1) # Default to 1 + params["sampling_temperature"] = args.temperature + params["sampling_max_steps"] = args.max_steps + params["sampling_top_k"] = args.top_k params["sampling_use_entmax"] = args.entmax_sampling - + params["sampling_stop_token"] = args.sampling_stop_token if args.sampling_stop_token is not None else params[ + "eos_id"] + params["sampling_remove-prompt"] = args.remove_prompt # Sample quality of MoE models suffers when using the faster sampling method, so default to slow_sampling if # moe layers are present params["slow_sampling"] = True if params["moe_layers"] is not None else False @@ -131,7 +152,8 @@ def main(args): if args.tpu == "colab": tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver() if params["use_tpu"] else None else: - tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(args.tpu) if params["use_tpu"] else None + tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(args.tpu) if params[ + "use_tpu"] else None config = tpu_config.RunConfig( cluster=tpu_cluster_resolver, @@ -181,47 +203,15 @@ def _make_task_estimator(task): predictions = estimator.predict(input_fn=pred_input_fn) logger.info("Predictions generated") enc = fetch_encoder(params) - handle_pred_output_fn(predictions, logger, enc, params, out_name=f"predictions_{args.sacred_id}_{current_step}") + out_name = f"predictions_{current_step}" if args.sample_save_path is None else args.sample_save_path + handle_pred_output(predictions, logger, enc, params, out_name=out_name) return - def save_eval_results(task, eval_results): - def as_python(x): - if isinstance(x, numpy.generic): - return x.item() - return x - eval_results = {k: as_python(v) for k, v in eval_results.items()} - with open(f'eval_{args.sacred_id}.jsonl', 'a') as fh: - json.dump({'task': task, 'current_step': current_step, **eval_results}, fh) - fh.write('\n') - - def run_eval(): - logger.info("Running evaluation...") - eval_results = estimator.evaluate( - input_fn=partial(input_fn, eval=True), - steps=params["eval_steps"]) - logger.info(f"Eval results: {eval_results}") - save_eval_results('validation', eval_results) - - def run_eval_tasks(): - for task in eval_tasks: - logger.info(f"Starting evaluation task '{task}'") - task_info = task_descriptors[task]["get_task_info_fn"](params) - task_estimator = eval_task_estimators[task] - task_input_fn = task_descriptors[task]["input_fn"] - eval_results = task_estimator.evaluate( - input_fn=task_input_fn, - steps=task_info["n_steps"], - name=task) - logger.info(f"Eval task '{task}' results: {eval_results}") - save_eval_results(task, eval_results) - if args.eval: - run_eval_tasks() + run_eval_tasks(params, eval_task_estimators, eval_tasks, logger, current_step) if params["eval_steps"] > 0: - run_eval() + run_eval(params, estimator, logger, input_fn) return - - elif has_predict_or_eval_steps_or_eval_tasks: # Eval and train - stop and predict and/or eval every checkpoint while current_step < params["train_steps"]: @@ -235,20 +225,20 @@ def run_eval_tasks(): logger.info("Running prediction...") predictions = estimator.predict(input_fn=pred_input_fn) enc = fetch_encoder(params) - handle_pred_output_fn(predictions, logger, enc, params, out_name=f"predictions_{args.sacred_id}_{current_step}") + handle_pred_output(predictions, logger, enc, params, + out_name=f"predictions_{current_step}") if params["eval_steps"] > 0: - run_eval() + run_eval(params, estimator, logger, input_fn) if eval_tasks: - run_eval_tasks() - + run_eval_tasks(params, eval_task_estimators, eval_tasks, logger, current_step) + return else: # Else, just train - while current_step < params["train_steps"]: - # Else, don't stop and restart - estimator.train(input_fn=partial(input_fn, global_step=current_step, eval=False), max_steps=params["train_steps"]) + estimator.train(input_fn=partial(input_fn, global_step=current_step, eval=False), + max_steps=params["train_steps"]) if __name__ == "__main__": diff --git a/model_fns.py b/model_fns.py index 50d3a898..358fdbf2 100644 --- a/model_fns.py +++ b/model_fns.py @@ -96,10 +96,16 @@ def model_fn(features, labels, mode, params): export = params.get("export", False) if not export: - mtf_samples = sample_autoregressive( - inputs, other_features=other_features, params=params, variable_dtype=variable_dtype, - remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], - sampling_use_entmax=params['sampling_use_entmax'], max_steps=params["predict_max_steps"]) + mtf_samples = sample_autoregressive(inputs, + temperature=params["sampling_temperature"], + sampling_keep_top_k=params["sampling_top_k"], + max_steps = params["sampling_max_steps"], + sampling_use_entmax=params['sampling_use_entmax'], + remove_partial_sequences=params["sampling_remove_prompt"], + stop_at_token=params["sampling_stop_token"], + other_features=other_features, + params=params, + variable_dtype=variable_dtype) else: with mtf.utils.outside_all_rewrites(): diff --git a/optimizers.py b/optimizers.py index 9470e56b..17b3ac63 100644 --- a/optimizers.py +++ b/optimizers.py @@ -6,6 +6,7 @@ import mesh_tensorflow as mtf import tensorflow.compat.v1 as tf + def clip_by_global_norm(grads, clip_norm): """Clip the grads by global norm.""" global_norm = mtf.sqrt(mtf.add_n([mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None])) @@ -13,6 +14,7 @@ def clip_by_global_norm(grads, clip_norm): clipped_grads = [None if t is None else t * multiplier for t in grads] return clipped_grads, global_norm + def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): """Creates and returns an optimizer training op.""" global_step = tf.train.get_or_create_global_step() @@ -29,14 +31,14 @@ def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): var_grads_fp = [mtf.cast(v, variable_dtype.slice_dtype) for v in var_grads] # decrease LR to final lr (lr*0.1) by this step - defaults to train_steps - end_step = params.get("lr_decay_end", params["train_steps"]) + end_step = params.get("lr_decay_end", params["train_steps"]) if params["lr_decay"] == "linear": learning_rate = tf.train.polynomial_decay( learning_rate, global_step, end_step, - end_learning_rate=params["lr"]*0.1, # Decrease to 10% of initial LR according to GPT-3 paper + end_learning_rate=params["lr"] * 0.1, # Decrease to 10% of initial LR according to GPT-3 paper power=1.0, cycle=False) elif params["lr_decay"] == "cosine": @@ -61,7 +63,7 @@ def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype) learning_rate = ((1.0 - is_warmup) * learning_rate + - is_warmup * warmup_learning_rate) + is_warmup * warmup_learning_rate) learning_rate = mtf.import_fully_replicated(mesh, learning_rate, mtf.Shape([]), name="learning_rate") mtf.scalar_summary("lr", learning_rate) @@ -93,84 +95,84 @@ def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): class AdamWeightDecayOptimizer(mtf.optimize.Optimizer): - """A basic Adam optimizer that includes "correct" L2 weight decay.""" - - def __init__(self, - learning_rate, - weight_decay_rate=0.0, - beta_1=0.9, - beta_2=0.999, - epsilon=1e-6, - exclude_from_weight_decay=None, - variable_dtype=None): - """Constructs a AdamWeightDecayOptimizer.""" - - self.learning_rate = learning_rate - self.weight_decay_rate = weight_decay_rate - self.beta_1 = beta_1 - self.beta_2 = beta_2 - self.epsilon = epsilon - self.exclude_from_weight_decay = exclude_from_weight_decay - self.variable_dtype = variable_dtype - - def apply_grad(self, grad, var): - """See base class.""" - if grad is None: - tf.logging.warning("Gradient is None for variable %s" % var.name) - return [] - - grad = mtf.to_float(grad) - - assignments = [] - - m = mtf.get_variable( - var.mesh, var.name + "/adam_m", var.shape, - initializer=tf.zeros_initializer(), - # master_dtype=self.variable_dtype.master_dtype, - # slice_dtype=self.variable_dtype.slice_dtype, - # activation_dtype=self.variable_dtype.activation_dtype, - trainable=False) - - v = mtf.get_variable( - var.mesh, var.name + "/adam_v", var.shape, - initializer=tf.zeros_initializer(), - # master_dtype=self.variable_dtype.master_dtype, - # slice_dtype=self.variable_dtype.slice_dtype, - # activation_dtype=self.variable_dtype.activation_dtype, - trainable=False) - - # Standard Adam update. - next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad - next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad) - - update = next_m / (mtf.sqrt(next_v) + self.epsilon) - - # Just adding the square of the weights to the loss function is *not* - # the correct way of using L2 regularization/weight decay with Adam, - # since that will interact with the m and v parameters in strange ways. - # - # Instead we want to decay the weights in a manner that doesn't interact - # with the m/v parameters. This is equivalent to adding the square - # of the weights to the loss with plain (non-momentum) SGD. - if self._do_use_weight_decay(var.name): - update += mtf.to_float(var.value) * self.weight_decay_rate - - update_with_lr = self.learning_rate * update - - var_update = mtf.assign_sub(var, update_with_lr) - - assignments.extend( - [var_update, - mtf.assign(m, next_m), - mtf.assign(v, next_v)]) - return assignments - - def _do_use_weight_decay(self, param_name): - """Whether to use L2 weight decay for `param_name`.""" - if not self.weight_decay_rate: - return False - if self.exclude_from_weight_decay: - for r in self.exclude_from_weight_decay: - if re.search(r, param_name) is not None: - return False - return True \ No newline at end of file + """A basic Adam optimizer that includes "correct" L2 weight decay.""" + + def __init__(self, + learning_rate, + weight_decay_rate=0.0, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-6, + exclude_from_weight_decay=None, + variable_dtype=None): + """Constructs a AdamWeightDecayOptimizer.""" + + self.learning_rate = learning_rate + self.weight_decay_rate = weight_decay_rate + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + self.exclude_from_weight_decay = exclude_from_weight_decay + self.variable_dtype = variable_dtype + + def apply_grad(self, grad, var): + """See base class.""" + if grad is None: + tf.logging.warning("Gradient is None for variable %s" % var.name) + return [] + + grad = mtf.to_float(grad) + + assignments = [] + + m = mtf.get_variable( + var.mesh, var.name + "/adam_m", var.shape, + initializer=tf.zeros_initializer(), + # master_dtype=self.variable_dtype.master_dtype, + # slice_dtype=self.variable_dtype.slice_dtype, + # activation_dtype=self.variable_dtype.activation_dtype, + trainable=False) + + v = mtf.get_variable( + var.mesh, var.name + "/adam_v", var.shape, + initializer=tf.zeros_initializer(), + # master_dtype=self.variable_dtype.master_dtype, + # slice_dtype=self.variable_dtype.slice_dtype, + # activation_dtype=self.variable_dtype.activation_dtype, + trainable=False) + + # Standard Adam update. + next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad + next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad) + + update = next_m / (mtf.sqrt(next_v) + self.epsilon) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + if self._do_use_weight_decay(var.name): + update += mtf.to_float(var.value) * self.weight_decay_rate + + update_with_lr = self.learning_rate * update + + var_update = mtf.assign_sub(var, update_with_lr) + + assignments.extend( + [var_update, + mtf.assign(m, next_m), + mtf.assign(v, next_v)]) + return assignments + + def _do_use_weight_decay(self, param_name): + """Whether to use L2 weight decay for `param_name`.""" + if not self.weight_decay_rate: + return False + if self.exclude_from_weight_decay: + for r in self.exclude_from_weight_decay: + if re.search(r, param_name) is not None: + return False + return True diff --git a/sample.py b/sample.py index 72d2f1ef..e3c55576 100644 --- a/sample.py +++ b/sample.py @@ -5,6 +5,7 @@ from models.utils import entmax, sample_categorical from models.gpt2 import gpt2 + def sample_autoregressive(partial_sequences, other_features, params, @@ -18,11 +19,9 @@ def sample_autoregressive(partial_sequences, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None, - never_end=False, remove_partial_sequences=False, sampling_keep_top_k=-1, - sampling_use_entmax = False, - bos_id=50256, + sampling_use_entmax=False, ): """Sample randomly one token at a time. @@ -66,9 +65,9 @@ def sample_autoregressive(partial_sequences, padding_id = params.get("padding_id", 0) slow_sampling = params.get("slow_sampling", False) - initial_position = mtf.reduce_sum( - mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim) # Gets position where zero padding starts + mtf.to_int32(mtf.not_equal(inputs, padding_id)), + reduced_dim=length_dim) # Gets position where zero padding starts length_range = mtf.range(inputs.mesh, length_dim, tf.int32) input_full_attention = True # for now hardcode this to true bc lazy @@ -107,7 +106,8 @@ def sample_autoregressive(partial_sequences, encoder_inputs=encoder_inputs) with tf.variable_scope("gpt2"): - logits, _, _ = gpt2.model({"inputs": inputs}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context_first_part) + logits, _, _ = gpt2.model({"inputs": inputs}, other_features, params, inputs.mesh, + variable_dtype=variable_dtype, context=context_first_part) if not has_partial_sequences: initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] @@ -168,7 +168,8 @@ def body_fn(position, ids, *states): encoder_inputs=encoder_inputs) if not slow_sampling else None with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE): - logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context = context) + logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, + variable_dtype=variable_dtype, context=context) if not sampling_use_entmax: # By default, do top_k sampling of 0.9 diff --git a/tasks.py b/tasks.py index f4a03047..b07aac4a 100644 --- a/tasks.py +++ b/tasks.py @@ -1,12 +1,11 @@ import os.path -import json import requests -import numpy as np import ftfy from data.encoders import fetch_encoder, encode import tensorflow as tf -import re +import numpy as np from functools import partial +import json lambada_src_uri = 'http://eaidata.bmk.sh/data/lambada_test.jsonl' normalization = 'NFKC' @@ -107,6 +106,52 @@ def _get_output(bin): return dataset +def save_eval_results(task, eval_results, current_step): + + def as_python(x): + if isinstance(x, np.generic): + return x.item() + return x + + eval_results = {k: as_python(v) for k, v in eval_results.items()} + import os + save_path = f'eval_{current_step}.jsonl' + if os.path.isfile(save_path): + while True: + num = 0 + new_save_path = f'{save_path}_{num}' + if not os.path.isfile(new_save_path): + save_path = new_save_path + break + + with open(save_path, 'a') as fh: + json.dump({'task': task, 'current_step': current_step, **eval_results}, fh) + fh.write('\n') + + +def run_eval(params, estimator, logger, input_fn): + logger.info("Running evaluation...") + eval_results = estimator.evaluate( + input_fn=partial(input_fn, eval=True), + steps=params["eval_steps"]) + logger.info(f"Eval results: {eval_results}") + save_eval_results('validation', eval_results) + + +def run_eval_tasks(params, eval_task_estimators, eval_tasks, logger, current_step): + for task in eval_tasks: + logger.info(f"Starting evaluation task '{task}'") + task_info = task_descriptors[task]["get_task_info_fn"](params) + task_estimator = eval_task_estimators[task] + task_input_fn = task_descriptors[task]["input_fn"] + eval_results = task_estimator.evaluate( + input_fn=task_input_fn, + steps=task_info["n_steps"], + name=task) + logger.info(f"Eval task '{task}' results: {eval_results}") + save_eval_results(task, eval_results, current_step) + + task_descriptors = { 'lambada': { 'init_fn': lambada_init, diff --git a/utils.py b/utils.py index 3666b889..2cca3212 100644 --- a/utils.py +++ b/utils.py @@ -11,6 +11,7 @@ from data.encoders import fetch_encoder import re + def setup_logging(args): Path("logs").mkdir(exist_ok=True) tf.logging.set_verbosity(logging.INFO) @@ -83,7 +84,7 @@ def remove_batch_from_layout(layout): def yes_or_no(question): while True: - reply = str(input(question+' (y/n): ')).lower().strip() + reply = str(input(question + ' (y/n): ')).lower().strip() if reply[:1] == 'y': return True if reply[:1] == 'n': @@ -115,7 +116,7 @@ def save_config(params_dict, logdir): if count == total_params - 1: text += f'"{str(key)}"' + ' : ' + config_value + '\n\n' else: - text += f'"{str(key)}"' + ' : ' + config_value + ',\n\n' + text += f'"{str(key)}"' + ' : ' + config_value + ',\n\n' text += '\n\n}' sess = tf.InteractiveSession() summary_op = tf.summary.text("run_config", tf.convert_to_tensor(text)) @@ -145,11 +146,11 @@ def get_n_trainable_vars(graph): """ total_parameters = 0 for variable in graph.trainable_variables: - shape = variable.shape.dims - variable_parameters = 1 - for dim in shape: - variable_parameters *= dim.size - total_parameters += variable_parameters + shape = variable.shape.dims + variable_parameters = 1 + for dim in shape: + variable_parameters *= dim.size + total_parameters += variable_parameters print(f"\n\nN TRAINABLE VARS:\n{total_parameters:,}\n\n") @@ -165,7 +166,7 @@ def print_dim_names(graph): all_dim_names.append(names) # Print all dim names in graph & write to file - all_dim_names = [item for sublist in all_dim_names for item in sublist] # Flatten all dims + all_dim_names = [item for sublist in all_dim_names for item in sublist] # Flatten all dims unique_dims = list(set(all_dim_names)) print("ALL DIM NAMES:") for dim_name in unique_dims: @@ -202,6 +203,7 @@ class constructor. ret = float(targets.shape.size) * num_microbatches return float(ret) + def check_dataset(input_fn, params, global_step=None): tf.enable_eager_execution() if global_step is not None: @@ -220,17 +222,20 @@ def check_dataset(input_fn, params, global_step=None): print('-' * 50) exit() + def auto_layout(graph, mesh_shape, logits, loss): layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss]) print(f"Auto-selected layout:\n{layout_rules}\nRe-initialize graph with selected layout") - quit() + quit() + def auto_layout_and_mesh_shape(graph, num_cores, logits, loss): layout_rules, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(graph, num_cores, - [logits, loss], max_mesh_shape_dimensions=4) + [logits, loss], max_mesh_shape_dimensions=4) print(f"Num cores:\n{num_cores}\nAuto-selected layout:\n{layout_rules}\nAuto-selected mesh shape:\n{mesh_shape}" \ - f"\nRe-initialize graph with selected layout & mesh shape") - quit() + f"\nRe-initialize graph with selected layout & mesh shape") + quit() + def create_host_call(model_dir): """Construct a host_call writing scalar summaries. @@ -285,7 +290,7 @@ def host_call_fn(global_step, *args): return host_call_fn, [global_step_t] + reshaped_tensors -def natural_sort(l): - convert = lambda text: int(text) if text.isdigit() else text.lower() - alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] - return sorted(l, key = alphanum_key) +def natural_sort(l): + convert = lambda text: int(text) if text.isdigit() else text.lower() + alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] + return sorted(l, key=alphanum_key) From 86b1a8d8b99c44248caab751ac78219f365a7703 Mon Sep 17 00:00:00 2001 From: sid Date: Mon, 29 Mar 2021 21:30:53 +0200 Subject: [PATCH 4/4] update readme to reflect changes --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index ae9d04a4..ad2d16ef 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,14 @@ Once you have a trained model, or you've downloaded one of our pre-trained model python3 main.py --predict --prompt --tpu --model ``` +(Optional) extra arguments for sampling: + +- `--temperature` : Temperature for temperature sampling. Float between 0 and 1. +- `--top-k` : An optional integer - if not -1, only sample from the top k logits. +- `--max_steps` : An optional integer, the max number of steps to decode when sampling. +- `--sampling-stop-token` : An optional integer. Stop sampling when this token is produced. Defaults to EOS token if none is provided. +- `--remove-prompt` : Boolean. whether to remove the prompt from the sampling output. Defaults to False. +- `--sample-save-path` : An optional String. Path to save the samples to. If None is provided, defaults to predictions_{current_step}.txt or, if using GPUs: ```bash