Einsum matrix multiplication with missing dimensions
I want to modify this einsum to be more flexible. Right now it's doing a matrix multiplication of the last two dimensions of A against the last 3 of B:
tf.einsum("...xp,...pyz->...xyz", A, B)
However I want it to be able to handle any number of dimensions after p
in B. If it helps, I do know the pivot_axis
of p
(i.e. the position in tf.shape(B)
).
I suppose one way to do it would be to get the number of dimensions after p and do some string trick like:
suffix='abcdefghijklmno'[:len(tf.shape(B))-pivot_axis]
tf.einsum(f"...xp,...p{suffix}->...x{suffix}", A, B)
and append it to the 开发者_如何学运维einsum strings? Is there a cleaner way?
精彩评论