翼度科技»论坛 编程开发 python 查看内容

Pytorch实现List Tensor转Tensor,reshape拼接等操作

6

主题

6

帖子

18

积分

新手上路

Rank: 1

积分
18
持续更新一些常用的Tensor操作,比如List,Numpy,Tensor之间的转换,Tensor的拼接,维度的变换等操作。
其它Tensor操作如 einsum等见:待更新。
用到两个函数:

    1. torch.cat
    复制代码
    1. torch.stack
    复制代码

一、List Tensor转Tensor (torch.cat)

  1. // An highlighted block
  2. >>> t1 = torch.FloatTensor([[1,2],[5,6]])
  3. >>> t2 = torch.FloatTensor([[3,4],[7,8]])
  4. >>> l = []
  5. >>> l.append(t1)
  6. >>> l.append(t2)
  7. >>> ta = torch.cat(l,dim=0)
  8. >>> ta = torch.cat(l,dim=0).reshape(2,2,2)
  9. >>> tb = torch.cat(l,dim=1).reshape(2,2,2)
  10. >>> ta
  11. tensor([[[1., 2.],
  12.          [5., 6.]],

  13.         [[3., 4.],
  14.          [7., 8.]]])
  15. >>> tb
  16. tensor([[[1., 2.],
  17.          [3., 4.]],

  18.         [[5., 6.],
  19.          [7., 8.]]])
复制代码
高维tensor

** 如果理解了2D to 3DTensor,以此类推,不难理解3D to 4D,看下面代码即可明白:**
  1. >>> t1 = torch.range(1,8).reshape(2,2,2)
  2. >>> t2 = torch.range(11,18).reshape(2,2,2)
  3. >>> l = []
  4. >>> l.append(t1)
  5. >>> l.append(t2)
  6. >>> torch.cat(l,dim=2).reshape(2,2,2,2)
  7. tensor([[[[ 1.,  2.],
  8.           [11., 12.]],

  9.          [[ 3.,  4.],
  10.           [13., 14.]]],


  11.         [[[ 5.,  6.],
  12.           [15., 16.]],

  13.          [[ 7.,  8.],
  14.           [17., 18.]]]])
  15. >>> torch.cat(l,dim=1).reshape(2,2,2,2)
  16. tensor([[[[ 1.,  2.],
  17.           [ 3.,  4.]],

  18.          [[11., 12.],
  19.           [13., 14.]]],


  20.         [[[ 5.,  6.],
  21.           [ 7.,  8.]],

  22.          [[15., 16.],
  23.           [17., 18.]]]])
  24. >>> torch.cat(l,dim=0).reshape(2,2,2,2)
  25. tensor([[[[ 1.,  2.],
  26.           [ 3.,  4.]],

  27.          [[ 5.,  6.],
  28.           [ 7.,  8.]]],


  29.         [[[11., 12.],
  30.           [13., 14.]],

  31.          [[15., 16.],
  32.           [17., 18.]]]])
复制代码
二、List Tensor转Tensor (torch.stack)


代码:
  1. import torch

  2. t1 = torch.FloatTensor([[1,2],[5,6]])
  3. t2 = torch.FloatTensor([[3,4],[7,8]])
  4. l = [t1, t2]

  5. t3 = torch.stack(l, dim=2)
  6. print(t3.shape)
  7. print(t3)

  8. ## output:
  9. ## torch.Size([2, 2, 2])
  10. ## tensor([[[1., 3.],
  11. ##          [2., 4.]],
  12. ##        [[5., 7.],
  13. ##         [6., 8.]]])
复制代码
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

来源:https://www.jb51.net/article/266622.htm
免责声明:由于采集信息均来自互联网,如果侵犯了您的权益,请联系我们【E-Mail:cb@itdo.tech】 我们会及时删除侵权内容,谢谢合作!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x

举报 回复 使用道具