|
867 | 867 | ), |
868 | 868 | ) |
869 | 869 |
|
| 870 | +# Config for v6e-64 |
| 871 | +llama3_1_8b_8192_bs5 = _add_to_model_dictionary( |
| 872 | + trillium_model_dict, |
| 873 | + MaxTextModel( |
| 874 | + model_name="llama3_1-8b-8192-bs5", |
| 875 | + model_type="llama3.1-8b", |
| 876 | + tuning_params={ |
| 877 | + "per_device_batch_size": 5, |
| 878 | + "ici_fsdp_parallelism": -1, |
| 879 | + "remat_policy": "custom", |
| 880 | + "decoder_layer_input": "offload", |
| 881 | + "out_proj": "offload", |
| 882 | + "query_proj": "offload", |
| 883 | + "key_proj": "offload", |
| 884 | + "value_proj": "offload", |
| 885 | + "max_target_length": 8192, |
| 886 | + "attention": "flash", |
| 887 | + "use_iota_embed": True, |
| 888 | + "dataset_path": "gs://max-datasets-rogue", |
| 889 | + "dataset_type": "synthetic", |
| 890 | + "enable_checkpointing": False, |
| 891 | + "sa_block_q": 2048, |
| 892 | + "sa_block_kv": 2048, |
| 893 | + "sa_block_kv_compute": 2048, |
| 894 | + "sa_block_q_dkv": 2048, |
| 895 | + "sa_block_kv_dkv": 2048, |
| 896 | + "sa_block_kv_dkv_compute": 2048, |
| 897 | + "sa_block_q_dq": 2048, |
| 898 | + "sa_block_kv_dq": 2048, |
| 899 | + "sa_use_fused_bwd_kernel": True, |
| 900 | + "profiler": "xplane", |
| 901 | + "skip_first_n_steps_for_profiler": 10, |
| 902 | + "profiler_steps": 5, |
| 903 | + }, |
| 904 | + xla_flags=( |
| 905 | + xla_flags_library.DENSE_VMEM_LIMIT_FLAG |
| 906 | + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER |
| 907 | + + xla_flags_library.DATA_PARALLEL_OVERLAP |
| 908 | + + xla_flags_library.CF_FOR_ALL_GATHER |
| 909 | + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE |
| 910 | + + xla_flags_library.HOST_OFFLOAD_FLAGS |
| 911 | + ), |
| 912 | + ), |
| 913 | +) |
| 914 | + |
870 | 915 |
|
871 | 916 | llama3_1_8b_8192_no_collective_matmul = _add_to_model_dictionary( |
872 | 917 | trillium_model_dict, |
|
956 | 1001 | ), |
957 | 1002 | ) |
958 | 1003 |
|
| 1004 | +# Config for v6e-64 |
| 1005 | +llama3_1_70b_8192_bs2 = _add_to_model_dictionary( |
| 1006 | + trillium_model_dict, |
| 1007 | + MaxTextModel( |
| 1008 | + model_name="llama3_1-70b-8192-bs2", |
| 1009 | + model_type="llama3.1-70b", |
| 1010 | + tuning_params={ |
| 1011 | + "per_device_batch_size": 2, |
| 1012 | + "ici_fsdp_parallelism": -1, |
| 1013 | + "remat_policy": "custom", |
| 1014 | + "decoder_layer_input": "offload", |
| 1015 | + "query_proj": "offload", |
| 1016 | + "key_proj": "offload", |
| 1017 | + "value_proj": "offload", |
| 1018 | + "max_target_length": 8192, |
| 1019 | + "attention": "flash", |
| 1020 | + "use_iota_embed": True, |
| 1021 | + "dataset_path": "gs://max-datasets-rogue", |
| 1022 | + "dataset_type": "synthetic", |
| 1023 | + "enable_checkpointing": False, |
| 1024 | + "sa_block_q": 2048, |
| 1025 | + "sa_block_kv": 2048, |
| 1026 | + "sa_block_kv_compute": 2048, |
| 1027 | + "sa_block_q_dkv": 2048, |
| 1028 | + "sa_block_kv_dkv": 2048, |
| 1029 | + "sa_block_kv_dkv_compute": 2048, |
| 1030 | + "sa_block_q_dq": 2048, |
| 1031 | + "sa_block_kv_dq": 2048, |
| 1032 | + "sa_use_fused_bwd_kernel": True, |
| 1033 | + "profiler": "xplane", |
| 1034 | + "skip_first_n_steps_for_profiler": 10, |
| 1035 | + "profiler_steps": 5, |
| 1036 | + }, |
| 1037 | + xla_flags=( |
| 1038 | + xla_flags_library.DENSE_VMEM_LIMIT_FLAG |
| 1039 | + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER |
| 1040 | + + xla_flags_library.DATA_PARALLEL_OVERLAP |
| 1041 | + + xla_flags_library.CF_FOR_ALL_GATHER |
| 1042 | + + xla_flags_library.HOST_OFFLOAD_FLAGS |
| 1043 | + ), |
| 1044 | + ), |
| 1045 | +) |
| 1046 | + |
| 1047 | +# Config for v6e-32 |
| 1048 | +llama3_1_70b_8192_bs2_bfloat16_no_collective_matmul = _add_to_model_dictionary( |
| 1049 | + trillium_model_dict, |
| 1050 | + MaxTextModel( |
| 1051 | + model_name="llama3_1-70b-8192-bs2-bfloat16-no-collective-matmul", |
| 1052 | + model_type="llama3.1-70b", |
| 1053 | + tuning_params={ |
| 1054 | + "per_device_batch_size": 2, |
| 1055 | + "ici_fsdp_parallelism": -1, |
| 1056 | + "remat_policy": "custom", |
| 1057 | + "decoder_layer_input": "offload", |
| 1058 | + "query_proj": "offload", |
| 1059 | + "key_proj": "offload", |
| 1060 | + "value_proj": "offload", |
| 1061 | + "max_target_length": 8192, |
| 1062 | + "attention": "flash", |
| 1063 | + "use_iota_embed": True, |
| 1064 | + "dataset_path": "gs://max-datasets-rogue", |
| 1065 | + "dataset_type": "synthetic", |
| 1066 | + "enable_checkpointing": False, |
| 1067 | + "sa_block_q": 2048, |
| 1068 | + "sa_block_kv": 2048, |
| 1069 | + "sa_block_kv_compute": 2048, |
| 1070 | + "sa_block_q_dkv": 2048, |
| 1071 | + "sa_block_kv_dkv": 2048, |
| 1072 | + "sa_block_kv_dkv_compute": 2048, |
| 1073 | + "sa_block_q_dq": 2048, |
| 1074 | + "sa_block_kv_dq": 2048, |
| 1075 | + "sa_use_fused_bwd_kernel": True, |
| 1076 | + "profiler": "xplane", |
| 1077 | + "skip_first_n_steps_for_profiler": 10, |
| 1078 | + "profiler_steps": 5, |
| 1079 | + "weight_dtype": "bfloat16", |
| 1080 | + }, |
| 1081 | + xla_flags=( |
| 1082 | + xla_flags_library.DENSE_VMEM_LIMIT_FLAG |
| 1083 | + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER |
| 1084 | + + xla_flags_library.DATA_PARALLEL_OVERLAP |
| 1085 | + + xla_flags_library.CF_FOR_ALL_GATHER |
| 1086 | + + xla_flags_library.HOST_OFFLOAD_FLAGS |
| 1087 | + + xla_flags_library.DISABLE_COLLECTIVE_MATMUL |
| 1088 | + ), |
| 1089 | + ), |
| 1090 | +) |
| 1091 | + |
| 1092 | +# Config for v6e-128 |
| 1093 | +llama3_1_70b_8192_bs4 = _add_to_model_dictionary( |
| 1094 | + trillium_model_dict, |
| 1095 | + MaxTextModel( |
| 1096 | + model_name="llama3_1-70b-8192-bs4", |
| 1097 | + model_type="llama3.1-70b", |
| 1098 | + tuning_params={ |
| 1099 | + "per_device_batch_size": 4, |
| 1100 | + "ici_fsdp_parallelism": -1, |
| 1101 | + "remat_policy": "custom", |
| 1102 | + "decoder_layer_input": "offload", |
| 1103 | + "query_proj": "offload", |
| 1104 | + "key_proj": "offload", |
| 1105 | + "value_proj": "offload", |
| 1106 | + "max_target_length": 8192, |
| 1107 | + "attention": "flash", |
| 1108 | + "use_iota_embed": True, |
| 1109 | + "dataset_path": "gs://max-datasets-rogue", |
| 1110 | + "dataset_type": "synthetic", |
| 1111 | + "enable_checkpointing": False, |
| 1112 | + "sa_block_q": 2048, |
| 1113 | + "sa_block_kv": 2048, |
| 1114 | + "sa_block_kv_compute": 2048, |
| 1115 | + "sa_block_q_dkv": 2048, |
| 1116 | + "sa_block_kv_dkv": 2048, |
| 1117 | + "sa_block_kv_dkv_compute": 2048, |
| 1118 | + "sa_block_q_dq": 2048, |
| 1119 | + "sa_block_kv_dq": 2048, |
| 1120 | + "sa_use_fused_bwd_kernel": True, |
| 1121 | + "profiler": "xplane", |
| 1122 | + "skip_first_n_steps_for_profiler": 10, |
| 1123 | + "profiler_steps": 5, |
| 1124 | + }, |
| 1125 | + xla_flags=( |
| 1126 | + xla_flags_library.DENSE_VMEM_LIMIT_FLAG |
| 1127 | + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER |
| 1128 | + + xla_flags_library.DATA_PARALLEL_OVERLAP |
| 1129 | + + xla_flags_library.CF_FOR_ALL_GATHER |
| 1130 | + + xla_flags_library.HOST_OFFLOAD_FLAGS |
| 1131 | + ), |
| 1132 | + ), |
| 1133 | +) |
| 1134 | + |
959 | 1135 | llama3_1_70b_8192_iter_synthetic = _add_to_model_dictionary( |
960 | 1136 | trillium_model_dict, |
961 | 1137 | MaxTextModel( |
|
0 commit comments