开发者

Einsum matrix multiplication with missing dimensions

I want to modify this einsum to be more flexible. Right now it's doing a matrix multiplication of the last two dimensions of A against the last 3 of B:

tf.einsum("...xp,...pyz->...xyz", A, B)

However I want it to be able to handle any number of dimensions after p in B. If it helps, I do know the pivot_axis of p (i.e. the position in tf.shape(B)).

I suppose one way to do it would be to get the number of dimensions after p and do some string trick like:

suffix='abcdefghijklmno'[:len(tf.shape(B))-pivot_axis]
tf.einsum(f"...xp,...p{suffix}->...x{suffix}", A, B)

and append it to the 开发者_如何学运维einsum strings? Is there a cleaner way?

0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜