Open
Description
I asked ChatGPT o1-preview to create specialized versions of stdlib_dgemm for the cases where the matrix multiplication is done with the original matrix a
or its transpose. It created stdlib_dgemm_a_orig
and stdlib_dgemm_a_trans
below. I have not checked them. In general, what is knowable at compile time (whether to use the matrix or its transpose) should not be done at run time.
pure subroutine stdlib_dgemm_a_orig(transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
! Specialized DGEMM subroutine where A is not transposed (transa = 'N')
! Performs the operation: C = alpha * A * op(B) + beta * C
! Scalar Arguments
character, intent(in) :: transb
integer, intent(in) :: m, n, k
real(dp), intent(in) :: alpha, beta
integer, intent(in) :: lda, ldb, ldc
! Array Arguments
real(dp), intent(in) :: a(lda, *)
real(dp), intent(in) :: b(ldb, *)
real(dp), intent(inout) :: c(ldc, *)
! Local Scalars
integer :: i, j, l, info, nrowb
real(dp) :: temp
logical :: notb
real(dp), parameter :: zero = 0.0_dp, one = 1.0_dp
! Intrinsic Functions
intrinsic :: max
! Set notb as true if B is not transposed
notb = (transb == 'N' .or. transb == 'n')
if (notb) then
nrowb = k
else
nrowb = n
end if
! Test the input parameters.
info = 0
if ((.not. notb) .and. (.not. (transb == 'T' .or. transb == 't'))) then
info = 1
else if (m < 0) then
info = 2
else if (n < 0) then
info = 3
else if (k < 0) then
info = 4
else if (lda < max(1, m)) then
info = 7
else if (ldb < max(1, nrowb)) then
info = 9
else if (ldc < max(1, m)) then
info = 12
end if
if (info /= 0) then
call stdlib_xerbla('DGEMM ', info)
return
end if
! Quick return if possible.
if ((m == 0) .or. (n == 0) .or. (((alpha == zero) .or. (k == 0)) .and. (beta == one))) return
! If alpha is zero.
if (alpha == zero) then
if (beta == zero) then
do j = 1, n
do i = 1, m
c(i, j) = zero
end do
end do
else
do j = 1, n
do i = 1, m
c(i, j) = beta * c(i, j)
end do
end do
end if
return
end if
! Start the operations.
if (notb) then
! Form C := alpha*A*B + beta*C.
do j = 1, n
if (beta == zero) then
do i = 1, m
c(i, j) = zero
end do
else if (beta /= one) then
do i = 1, m
c(i, j) = beta * c(i, j)
end do
end if
do l = 1, k
temp = alpha * b(l, j)
do i = 1, m
c(i, j) = c(i, j) + temp * a(i, l)
end do
end do
end do
else
! Form C := alpha*A*B**T + beta*C
do j = 1, n
if (beta == zero) then
do i = 1, m
c(i, j) = zero
end do
else if (beta /= one) then
do i = 1, m
c(i, j) = beta * c(i, j)
end do
end if
do l = 1, k
temp = alpha * b(j, l)
do i = 1, m
c(i, j) = c(i, j) + temp * a(i, l)
end do
end do
end do
end if
return
end subroutine stdlib_dgemm_a_orig
pure subroutine stdlib_dgemm_a_trans(transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
! Specialized DGEMM subroutine where A is transposed (transa = 'T')
! Performs the operation: C = alpha * A**T * op(B) + beta * C
! Scalar Arguments
character, intent(in) :: transb
integer, intent(in) :: m, n, k
real(dp), intent(in) :: alpha, beta
integer, intent(in) :: lda, ldb, ldc
! Array Arguments
real(dp), intent(in) :: a(lda, *)
real(dp), intent(in) :: b(ldb, *)
real(dp), intent(inout) :: c(ldc, *)
! Local Scalars
integer :: i, j, l, info, nrowb
real(dp) :: temp
logical :: notb
real(dp), parameter :: zero = 0.0_dp, one = 1.0_dp
! Intrinsic Functions
intrinsic :: max
! Set notb as true if B is not transposed
notb = (transb == 'N' .or. transb == 'n')
if (notb) then
nrowb = k
else
nrowb = n
end if
! Test the input parameters.
info = 0
if ((.not. notb) .and. (.not. (transb == 'T' .or. transb == 't'))) then
info = 1
else if (m < 0) then
info = 2
else if (n < 0) then
info = 3
else if (k < 0) then
info = 4
else if (lda < max(1, k)) then
info = 7
else if (ldb < max(1, nrowb)) then
info = 9
else if (ldc < max(1, m)) then
info = 12
end if
if (info /= 0) then
call stdlib_xerbla('DGEMM ', info)
return
end if
! Quick return if possible.
if ((m == 0) .or. (n == 0) .or. (((alpha == zero) .or. (k == 0)) .and. (beta == one))) return
! If alpha is zero.
if (alpha == zero) then
if (beta == zero) then
do j = 1, n
do i = 1, m
c(i, j) = zero
end do
end do
else
do j = 1, n
do i = 1, m
c(i, j) = beta * c(i, j)
end do
end do
end if
return
end if
! Start the operations.
if (notb) then
! Form C := alpha*A**T*B + beta*C
do j = 1, n
do i = 1, m
temp = zero
do l = 1, k
temp = temp + a(l, i) * b(l, j)
end do
if (beta == zero) then
c(i, j) = alpha * temp
else
c(i, j) = alpha * temp + beta * c(i, j)
end if
end do
end do
else
! Form C := alpha*A**T*B**T + beta*C
do j = 1, n
do i = 1, m
temp = zero
do l = 1, k
temp = temp + a(l, i) * b(j, l)
end do
if (beta == zero) then
c(i, j) = alpha * temp
else
c(i, j) = alpha * temp + beta * c(i, j)
end if
end do
end do
end if
return
end subroutine stdlib_dgemm_a_trans
Metadata
Metadata
Assignees
Labels
No labels