开发者

How does BLAS sgemm/dgemm work?

I am trying to make use of the function sgemm in BLAS using ctypes in python. Trying to solve C = A x B the following code works just fine:

no_trans = c_char("n")
m = c_int(number_of_rows_of_A)
n = c_int(number_of_columns_of_B)
k = c_int(number_of_columns_of_A)
one = c_float(1.0)
zero = c_float(0.0)

blaslib.sgemm_(byref(no_trans), byref(no_trans), byref(m), byref(n), byref(k),
               byref开发者_StackOverflow社区(one), A, byref(m), B, byref(k), byref(zero), C, byref(m))

Now I would like to solve this equation: C = A' x A where A' is the transpose of A and the following code runs without an exception but the result returned is wrong:

trans = c_char("t")
no_trans = c_char("n")
m = c_int(number_of_rows_of_A)
n = c_int(number_of_columns_of_A)
one = c_float(1.0)
zero = c_float(0.0)

blaslib.sgemm_(byref(trans), byref(no_trans), byref(n), byref(n), byref(m),
               byref(one), A, byref(m), A, byref(m), byref(zero), C, byref(n))

For a test I inserted a matrix A = [1 2; 3 4]. The correct result is C = [10 14; 14 20] but the sgemm routine spits out C = [5 11; 11 25].

As far as I understand it, the matrix A does not have to be transposed by me since the algorithm takes care of it. What is wrong with my parameter passing in the second case?

Any help, link, article, advice is appreciated!


Blas typically uses column-major matrices (like Fortran), hence A = [1 2; 3 4] means

    |1 3|   
A = |   |
    |2 4|

and the result is correct (assuming that your Python library does the same). See this read-me


The result you have indicates sgemm had computed A*A' rather than A'*A as you wanted. The simple solution is to switch the two inputs to the function.

0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜