@@ -81,6 +81,21 @@ struct DeviceOps {
8181
8282const uint8_t MAX_REVISION = 0xff ;
8383
84+ static bool is_xe2_or_xe3_family (gpu_arch arch) {
85+ switch (arch) {
86+ case gpu_arch::xe2:
87+ case gpu_arch::xe3:
88+ case gpu_arch::xe3p_35_10:
89+ case gpu_arch::xe3p_35_11:
90+ case gpu_arch::xe3p_unknown:
91+ return true ;
92+ default :
93+ return false ;
94+ }
95+ }
96+
97+ const DeviceOps xe2_xe3_fallback_ops = { {}, { 1 , 16 , 32 , 128 , 256 , 0 }, {} };
98+
8499const std::vector<DeviceOps> device_ops_table = {
85100// | gfx_ver | MAD | DPAS (immad) | DP4A | device_id |
86101// | | fp64 | fp32 | fp16 | fp16 | int8 | int8 | |
@@ -117,11 +132,19 @@ float device::get_gops(data_types dt) const {
117132 [&info](auto & entry) {
118133 return entry.match (info);
119134 });
135+
136+ const DeviceOps* selected_ops = nullptr ;
120137 if (it != device_ops_table.end ()) {
121- opsPerEU = it->get_ops (info, dt);
138+ selected_ops = &(*it);
139+ } else if (is_xe2_or_xe3_family (info.arch )) {
140+ selected_ops = &xe2_xe3_fallback_ops;
141+ }
142+
143+ if (selected_ops != nullptr ) {
144+ opsPerEU = selected_ops->get_ops (info, dt);
122145 if (DeviceOps::is_zero (opsPerEU) && (dt == data_types::i8 || dt == data_types::u8 )) {
123146 // WA: ops of i8/u8 is twice of f16.
124- opsPerEU = it ->get_ops_for_mad (data_types::f16 ) * 2 ;
147+ opsPerEU = selected_ops ->get_ops_for_mad (data_types::f16 ) * 2 ;
125148 }
126149 }
127150 } catch (std::exception& e) {
0 commit comments