Skip to content

[API Proposal]: Add AVX-VNNI-INT8 and AVX-VNNI-INT16 API #112586

Open
@khushal1996

Description

@khushal1996

Background and motivation

This API proposal introduces API surface for AVX-VNNI-INT8 and AVX-VNNI-INT16 in .NET.
Spec doc - Link

As a part of this proposal, we will have a V512 class to represent a relationship between AVX10.2 and AVX-VNNI-INT8/AVX-VNNI-INT16 ISAs as discussed here (link)

A dependency will be added for Avx10.2

implication        ,X86   ,AVX10v2              ,AVXVNNIINT8
implication        ,X86   ,AVX10v2              ,AVXVNNIINT16
implication        ,X86   ,AVX10v2_V512         ,AVXVNNIINT8_V512
implication        ,X86   ,AVX10v2_V512         ,AVXVNNIINT16_V512

API Proposal

AVX-VNNI-INT8

// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;

namespace System.Runtime.Intrinsics.X86
{
    /// <summary>Provides access to the x86 AVXVNNI hardware instructions via intrinsics.</summary>
    [Intrinsic]
    [CLSCompliant(false)]
    public abstract class AvxVnniInt8 : Avx2
    {
        internal AvxVnniInt8() { }

        /// <summary>Gets a value that indicates whether the APIs in this class are supported.</summary>
        /// <value><see langword="true" /> if the APIs are supported; otherwise, <see langword="false" />.</value>
        /// <remarks>A value of <see langword="false" /> indicates that the APIs will throw <see cref="PlatformNotSupportedException" />.</remarks>
        public static new bool IsSupported { get => IsSupported; }

        /// <summary>Provides access to the x86 AVX-VNNI-INT8 hardware instructions, that are only available to 64-bit processes, via intrinsics.</summary>
        [Intrinsic]
        public new abstract class X64 : Avx2.X64
        {
            internal X64() { }

            /// <summary>Gets a value that indicates whether the APIs in this class are supported.</summary>
            /// <value><see langword="true" /> if the APIs are supported; otherwise, <see langword="false" />.</value>
            /// <remarks>A value of <see langword="false" /> indicates that the APIs will throw <see cref="PlatformNotSupportedException" />.</remarks>
            public static new bool IsSupported { get => IsSupported; }
        }

        // VPDPBSSD xmm1, xmm2, xmm3/m128
        public static Vector128<int> MultiplyWideningAndAdd(Vector128<int> addend, Vector128<sbyte> left, Vector128<sbyte> right) => MultiplyWideningAndAdd(addend, left, right);

        // VPDPBSUD xmm1, xmm2, xmm3/m128
        public static Vector128<int> MultiplyWideningAndAdd(Vector128<int> addend, Vector128<sbyte> left, Vector128<byte> right) => MultiplyWideningAndAdd(addend, left, right);

        // VPDPBUUD xmm1, xmm2, xmm3/m128
        public static Vector128<uint> MultiplyWideningAndAdd(Vector128<uint> addend, Vector128<byte> left, Vector128<byte> right) => MultiplyWideningAndAdd(addend, left, right);

        // VPDPBSSD ymm1, ymm2, ymm3/m256
        public static Vector256<int> MultiplyWideningAndAdd(Vector256<int> addend, Vector256<sbyte> left, Vector256<sbyte> right) => MultiplyWideningAndAdd(addend, left, right);

        // VPDPBSUD ymm1, ymm2, ymm3/m256
        public static Vector256<int> MultiplyWideningAndAdd(Vector256<int> addend, Vector256<sbyte> left, Vector256<byte> right) => MultiplyWideningAndAdd(addend, left, right);

        // VPDPBUUD ymm1, ymm2, ymm3/m256
        public static Vector256<uint> MultiplyWideningAndAdd(Vector256<uint> addend, Vector256<byte> left, Vector256<byte> right) => MultiplyWideningAndAdd(addend, left, right);

        // VPDPBSSDS xmm1, xmm2, xmm3/m128
        public static Vector128<int> MultiplyWideningAndAddSaturate(Vector128<int> addend, Vector128<sbyte> left, Vector128<sbyte> right) => MultiplyWideningAndAddSaturate(addend, left, right);

        // VPDPBSUDS xmm1, xmm2, xmm3/m128
        public static Vector128<int> MultiplyWideningAndAddSaturate(Vector128<int> addend, Vector128<sbyte> left, Vector128<byte> right) => MultiplyWideningAndAddSaturate(addend, left, right);

        // VPDPBUUDS xmm1, xmm2, xmm3/m128
        public static Vector128<uint> MultiplyWideningAndAddSaturate(Vector128<uint> addend, Vector128<byte> left, Vector128<byte> right) => MultiplyWideningAndAddSaturate(addend, left, right);

        // VPDPBSSDS ymm1, ymm2, ymm3/m256
        public static Vector256<int> MultiplyWideningAndAddSaturate(Vector256<int> addend, Vector256<sbyte> left, Vector256<sbyte> right) => MultiplyWideningAndAddSaturate(addend, left, right);

        // VPDPBSUDS ymm1, ymm2, ymm3/m256
        public static Vector256<int> MultiplyWideningAndAddSaturate(Vector256<int> addend, Vector256<sbyte> left, Vector256<byte> right) => MultiplyWideningAndAddSaturate(addend, left, right);

        // VPDPBUUDS ymm1, ymm2, ymm3/m256
        public static Vector256<uint> MultiplyWideningAndAddSaturate(Vector256<uint> addend, Vector256<byte> left, Vector256<byte> right) => MultiplyWideningAndAddSaturate(addend, left, right);

        /// <summary>Provides access to the x86 AVX10.2/512 hardware instructions for AVX-VNNI-INT8 via intrinsics.</summary>
        [Intrinsic]
        public abstract class V512
        {
            internal V512() { }

            /// <summary>Gets a value that indicates whether the APIs in this class are supported.</summary>
            /// <value><see langword="true" /> if the APIs are supported; otherwise, <see langword="false" />.</value>
            /// <remarks>A value of <see langword="false" /> indicates that the APIs will throw <see cref="PlatformNotSupportedException" />.</remarks>
            public static bool IsSupported { get => IsSupported; }

            // VPDPBSSD zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<int> MultiplyWideningAndAdd(Vector512<int> addend, Vector512<sbyte> left, Vector512<sbyte> right) => MultiplyWideningAndAdd(addend, left, right);

            // VPDPBSUD zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<int> MultiplyWideningAndAdd(Vector512<int> addend, Vector512<sbyte> left, Vector512<byte> right) => MultiplyWideningAndAdd(addend, left, right);

            // VPDPBUUD zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<uint> MultiplyWideningAndAdd(Vector512<uint> addend, Vector512<byte> left, Vector512<byte> right) => MultiplyWideningAndAdd(addend, left, right);

            // VPDPBSSDS zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<int> MultiplyWideningAndAddSaturate(Vector512<int> addend, Vector512<sbyte> left, Vector512<sbyte> right) => MultiplyWideningAndAddSaturate(addend, left, right);

            // VPDPBSUDS zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<int> MultiplyWideningAndAddSaturate(Vector512<int> addend, Vector512<sbyte> left, Vector512<byte> right) => MultiplyWideningAndAddSaturate(addend, left, right);

            // VPDPBUUDS zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<uint> MultiplyWideningAndAddSaturate(Vector512<uint> addend, Vector512<byte> left, Vector512<byte> right) => MultiplyWideningAndAddSaturate(addend, left, right);
        }
    }
}

AVX-VNNI-INT16

// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;

namespace System.Runtime.Intrinsics.X86
{
    /// <summary>Provides access to the x86 AVXVNNI hardware instructions via intrinsics.</summary>
    [Intrinsic]
    [CLSCompliant(false)]
    public abstract class AvxVnniInt16 : Avx2
    {
        internal AvxVnniInt16() { }

        /// <summary>Gets a value that indicates whether the APIs in this class are supported.</summary>
        /// <value><see langword="true" /> if the APIs are supported; otherwise, <see langword="false" />.</value>
        /// <remarks>A value of <see langword="false" /> indicates that the APIs will throw <see cref="PlatformNotSupportedException" />.</remarks>
        public static new bool IsSupported { get => IsSupported; }

        /// <summary>Provides access to the x86 AVX-VNNI-INT8 hardware instructions, that are only available to 64-bit processes, via intrinsics.</summary>
        [Intrinsic]
        public new abstract class X64 : Avx2.X64
        {
            internal X64() { }

            /// <summary>Gets a value that indicates whether the APIs in this class are supported.</summary>
            /// <value><see langword="true" /> if the APIs are supported; otherwise, <see langword="false" />.</value>
            /// <remarks>A value of <see langword="false" /> indicates that the APIs will throw <see cref="PlatformNotSupportedException" />.</remarks>
            public static new bool IsSupported { get => IsSupported; }
        }

        // VPDPWSUD xmm1, xmm2, xmm3/m128
        public static Vector128<int> MultiplyWideningAndAdd(Vector128<int> addend, Vector128<short> left, Vector128<ushort> right) => MultiplyWideningAndAdd(addend, left, right);

        // VPDPWUSD xmm1, xmm2, xmm3/m128
        public static Vector128<int> MultiplyWideningAndAdd(Vector128<int> addend, Vector128<ushort> left, Vector128<short> right) => MultiplyWideningAndAdd(addend, left, right);

        // VPDPWUUD xmm1, xmm2, xmm3/m128
        public static Vector128<uint> MultiplyWideningAndAdd(Vector128<uint> addend, Vector128<ushort> left, Vector128<ushort> right) => MultiplyWideningAndAdd(addend, left, right);

        // VPDPWSUD ymm1, ymm2, ymm3/m256
        public static Vector256<int> MultiplyWideningAndAdd(Vector256<int> addend, Vector256<short> left, Vector256<ushort> right) => MultiplyWideningAndAdd(addend, left, right);

        // VPDPWUSD ymm1, ymm2, ymm3/m256
        public static Vector256<int> MultiplyWideningAndAdd(Vector256<int> addend, Vector256<ushort> left, Vector256<short> right) => MultiplyWideningAndAdd(addend, left, right);

        // VPDPWUUD ymm1, ymm2, ymm3/m256
        public static Vector256<uint> MultiplyWideningAndAdd(Vector256<uint> addend, Vector256<ushort> left, Vector256<ushort> right) => MultiplyWideningAndAdd(addend, left, right);

        // VPDPWSUDS xmm1, xmm2, xmm3/m128
        public static Vector128<int> MultiplyWideningAndAddSaturate(Vector128<int> addend, Vector128<short> left, Vector128<ushort> right) => MultiplyWideningAndAddSaturate(addend, left, right);

        // VPDPWUSDS xmm1, xmm2, xmm3/m128
        public static Vector128<int> MultiplyWideningAndAddSaturate(Vector128<int> addend, Vector128<ushort> left, Vector128<short> right) => MultiplyWideningAndAddSaturate(addend, left, right);

        // VPDPWUUDS xmm1, xmm2, xmm3/m128
        public static Vector128<uint> MultiplyWideningAndAddSaturate(Vector128<uint> addend, Vector128<ushort> left, Vector128<ushort> right) => MultiplyWideningAndAddSaturate(addend, left, right);

        // VPDPWSUDS ymm1, ymm2, ymm3/m256
        public static Vector256<int> MultiplyWideningAndAddSaturate(Vector256<int> addend, Vector256<short> left, Vector256<ushort> right) => MultiplyWideningAndAddSaturate(addend, left, right);

        // VPDPWUSDS ymm1, ymm2, ymm3/m256
        public static Vector256<int> MultiplyWideningAndAddSaturate(Vector256<int> addend, Vector256<ushort> left, Vector256<short> right) => MultiplyWideningAndAddSaturate(addend, left, right);

        // VPDPWUUDS ymm1, ymm2, ymm3/m256
        public static Vector256<uint> MultiplyWideningAndAddSaturate(Vector256<uint> addend, Vector256<ushort> left, Vector256<ushort> right) => MultiplyWideningAndAddSaturate(addend, left, right);

        /// <summary>Provides access to the x86 AVX10.2/512 hardware instructions for AVX-VNNI-INT16 via intrinsics.</summary>
        [Intrinsic]
        public abstract class V512
        {
            internal V512() { }

            /// <summary>Gets a value that indicates whether the APIs in this class are supported.</summary>
            /// <value><see langword="true" /> if the APIs are supported; otherwise, <see langword="false" />.</value>
            /// <remarks>A value of <see langword="false" /> indicates that the APIs will throw <see cref="PlatformNotSupportedException" />.</remarks>
            public static bool IsSupported { get => IsSupported; }

            // VPDPWSUD zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<int> MultiplyWideningAndAdd(Vector512<int> addend, Vector512<short> left, Vector512<ushort> right) => MultiplyWideningAndAdd(addend, left, right);

            // VPDPWUSD zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<int> MultiplyWideningAndAdd(Vector512<int> addend, Vector512<ushort> left, Vector512<short> right) => MultiplyWideningAndAdd(addend, left, right);

            // VPDPWUUD zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<uint> MultiplyWideningAndAdd(Vector512<uint> addend, Vector512<ushort> left, Vector512<ushort> right) => MultiplyWideningAndAdd(addend, left, right);

            // VPDPWSUDS zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<int> MultiplyWideningAndAddSaturate(Vector512<int> addend, Vector512<short> left, Vector512<ushort> right) => MultiplyWideningAndAddSaturate(addend, left, right);

            // VPDPWUSDS zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<int> MultiplyWideningAndAddSaturate(Vector512<int> addend, Vector512<ushort> left, Vector512<short> right) => MultiplyWideningAndAddSaturate(addend, left, right);

            // VPDPWUUDS zmm1{k1}{z}, zmm2, zmm3/m512/m32bcst
            public static Vector512<uint> MultiplyWideningAndAddSaturate(Vector512<uint> addend, Vector512<ushort> left, Vector512<ushort> right) => MultiplyWideningAndAddSaturate(addend, left, right);
        }
    }
}

API Usage

Vector128<sbyte> v1 = Vector128.Create((sbyte)someParam1);
Vector128<sbyte> v2 = Vector128.Create((sbyte)someParam2);
Vector128<int> v3 = Vector128.Create((int)someParam3);
if (AvxVnniInt8.IsSupported()) {
  Vector128<int> v4 = AvxVnniInt8.MultiplyWideningAndAdd(v3, v1, v2);
  // etc
}
Vector512<sbyte> v1 = Vector512.Create((sbyte)someParam1);
Vector512<sbyte> v2 = Vector512.Create((sbyte)someParam2);
Vector512<int> v3 = Vector512.Create((int)someParam3);
if (Avx10v2.V512.IsSupported()) {
  Vector512<int> v4 = AvxVnniInt8.V512.MultiplyWideningAndAdd(v3, v1, v2);
  // etc
}

Alternative Designs

No response

Risks

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions