diff --git a/arch/src/x86_64/mod.rs b/arch/src/x86_64/mod.rs index ab3835993..17bae5f5a 100644 --- a/arch/src/x86_64/mod.rs +++ b/arch/src/x86_64/mod.rs @@ -50,6 +50,9 @@ const AMX_BF16: u8 = 22; // AMX tile computation on bfloat16 numbers const AMX_TILE: u8 = 24; // AMX tile load/store instructions const AMX_INT8: u8 = 25; // AMX tile computation on 8-bit integers +const AMX_FP16: u8 = 21; // AMX tile computation on fp16 numbers +const AMX_COMPLEX: u8 = 8; // AMX tile computation on complex numbers + // KVM feature bits #[cfg(feature = "tdx")] const KVM_FEATURE_CLOCKSOURCE_BIT: u8 = 0; @@ -638,8 +641,14 @@ pub fn generate_common_cpuid( match entry.function { // Clear AMX related bits if the AMX feature is not enabled 0x7 => { - if !config.amx && entry.index == 0 { - entry.edx &= !((1 << AMX_BF16) | (1 << AMX_TILE) | (1 << AMX_INT8)); + if !config.amx { + if entry.index == 0 { + entry.edx &= !((1 << AMX_BF16) | (1 << AMX_TILE) | (1 << AMX_INT8)); + } + if entry.index == 1 { + entry.eax &= !(1 << AMX_FP16); + entry.edx &= !(1 << AMX_COMPLEX); + } } } 0xd => @@ -661,6 +670,25 @@ pub fn generate_common_cpuid( } } } + 0x1d => { + // Tile Information (purely AMX related). + if !config.amx { + entry.eax = 0; + entry.ebx = 0; + entry.ecx = 0; + entry.edx = 0; + } + } + 0x1e => { + // TMUL information (purely AMX related) + if !config.amx { + entry.eax = 0; + entry.ebx = 0; + entry.ecx = 0; + entry.edx = 0; + } + } + // Copy host L1 cache details if not populated by KVM 0x8000_0005 => { if entry.eax == 0 && entry.ebx == 0 && entry.ecx == 0 && entry.edx == 0 {