Skip to content

Create specialized versions of BLAS subroutines with fewer arguments and run-time decisions #65

Open
@Beliavsky

Description

@Beliavsky

I asked ChatGPT o1-preview to create specialized versions of stdlib_dgemm for the cases where the matrix multiplication is done with the original matrix a or its transpose. It created stdlib_dgemm_a_orig and stdlib_dgemm_a_trans below. I have not checked them. In general, what is knowable at compile time (whether to use the matrix or its transpose) should not be done at run time.

pure subroutine stdlib_dgemm_a_orig(transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
    ! Specialized DGEMM subroutine where A is not transposed (transa = 'N')
    ! Performs the operation: C = alpha * A * op(B) + beta * C

    ! Scalar Arguments
    character, intent(in) :: transb
    integer, intent(in) :: m, n, k
    real(dp), intent(in) :: alpha, beta
    integer, intent(in) :: lda, ldb, ldc

    ! Array Arguments
    real(dp), intent(in) :: a(lda, *)
    real(dp), intent(in) :: b(ldb, *)
    real(dp), intent(inout) :: c(ldc, *)

    ! Local Scalars
    integer :: i, j, l, info, nrowb
    real(dp) :: temp
    logical :: notb
    real(dp), parameter :: zero = 0.0_dp, one = 1.0_dp

    ! Intrinsic Functions
    intrinsic :: max

    ! Set notb as true if B is not transposed
    notb = (transb == 'N' .or. transb == 'n')
    if (notb) then
        nrowb = k
    else
        nrowb = n
    end if

    ! Test the input parameters.
    info = 0
    if ((.not. notb) .and. (.not. (transb == 'T' .or. transb == 't'))) then
        info = 1
    else if (m < 0) then
        info = 2
    else if (n < 0) then
        info = 3
    else if (k < 0) then
        info = 4
    else if (lda < max(1, m)) then
        info = 7
    else if (ldb < max(1, nrowb)) then
        info = 9
    else if (ldc < max(1, m)) then
        info = 12
    end if
    if (info /= 0) then
        call stdlib_xerbla('DGEMM ', info)
        return
    end if

    ! Quick return if possible.
    if ((m == 0) .or. (n == 0) .or. (((alpha == zero) .or. (k == 0)) .and. (beta == one))) return

    ! If alpha is zero.
    if (alpha == zero) then
        if (beta == zero) then
            do j = 1, n
                do i = 1, m
                    c(i, j) = zero
                end do
            end do
        else
            do j = 1, n
                do i = 1, m
                    c(i, j) = beta * c(i, j)
                end do
            end do
        end if
        return
    end if

    ! Start the operations.
    if (notb) then
        ! Form C := alpha*A*B + beta*C.
        do j = 1, n
            if (beta == zero) then
                do i = 1, m
                    c(i, j) = zero
                end do
            else if (beta /= one) then
                do i = 1, m
                    c(i, j) = beta * c(i, j)
                end do
            end if
            do l = 1, k
                temp = alpha * b(l, j)
                do i = 1, m
                    c(i, j) = c(i, j) + temp * a(i, l)
                end do
            end do
        end do
    else
        ! Form C := alpha*A*B**T + beta*C
        do j = 1, n
            if (beta == zero) then
                do i = 1, m
                    c(i, j) = zero
                end do
            else if (beta /= one) then
                do i = 1, m
                    c(i, j) = beta * c(i, j)
                end do
            end if
            do l = 1, k
                temp = alpha * b(j, l)
                do i = 1, m
                    c(i, j) = c(i, j) + temp * a(i, l)
                end do
            end do
        end do
    end if

    return
end subroutine stdlib_dgemm_a_orig

pure subroutine stdlib_dgemm_a_trans(transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
    ! Specialized DGEMM subroutine where A is transposed (transa = 'T')
    ! Performs the operation: C = alpha * A**T * op(B) + beta * C

    ! Scalar Arguments
    character, intent(in) :: transb
    integer, intent(in) :: m, n, k
    real(dp), intent(in) :: alpha, beta
    integer, intent(in) :: lda, ldb, ldc

    ! Array Arguments
    real(dp), intent(in) :: a(lda, *)
    real(dp), intent(in) :: b(ldb, *)
    real(dp), intent(inout) :: c(ldc, *)

    ! Local Scalars
    integer :: i, j, l, info, nrowb
    real(dp) :: temp
    logical :: notb
    real(dp), parameter :: zero = 0.0_dp, one = 1.0_dp

    ! Intrinsic Functions
    intrinsic :: max

    ! Set notb as true if B is not transposed
    notb = (transb == 'N' .or. transb == 'n')
    if (notb) then
        nrowb = k
    else
        nrowb = n
    end if

    ! Test the input parameters.
    info = 0
    if ((.not. notb) .and. (.not. (transb == 'T' .or. transb == 't'))) then
        info = 1
    else if (m < 0) then
        info = 2
    else if (n < 0) then
        info = 3
    else if (k < 0) then
        info = 4
    else if (lda < max(1, k)) then
        info = 7
    else if (ldb < max(1, nrowb)) then
        info = 9
    else if (ldc < max(1, m)) then
        info = 12
    end if
    if (info /= 0) then
        call stdlib_xerbla('DGEMM ', info)
        return
    end if

    ! Quick return if possible.
    if ((m == 0) .or. (n == 0) .or. (((alpha == zero) .or. (k == 0)) .and. (beta == one))) return

    ! If alpha is zero.
    if (alpha == zero) then
        if (beta == zero) then
            do j = 1, n
                do i = 1, m
                    c(i, j) = zero
                end do
            end do
        else
            do j = 1, n
                do i = 1, m
                    c(i, j) = beta * c(i, j)
                end do
            end do
        end if
        return
    end if

    ! Start the operations.
    if (notb) then
        ! Form C := alpha*A**T*B + beta*C
        do j = 1, n
            do i = 1, m
                temp = zero
                do l = 1, k
                    temp = temp + a(l, i) * b(l, j)
                end do
                if (beta == zero) then
                    c(i, j) = alpha * temp
                else
                    c(i, j) = alpha * temp + beta * c(i, j)
                end if
            end do
        end do
    else
        ! Form C := alpha*A**T*B**T + beta*C
        do j = 1, n
            do i = 1, m
                temp = zero
                do l = 1, k
                    temp = temp + a(l, i) * b(j, l)
                end do
                if (beta == zero) then
                    c(i, j) = alpha * temp
                else
                    c(i, j) = alpha * temp + beta * c(i, j)
                end if
            end do
        end do
    end if

    return
end subroutine stdlib_dgemm_a_trans

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions