Pytorch基础:Tensor的transpose()方法

张开发
2026/4/10 16:06:17 15 分钟阅读

分享文章

Pytorch基础:Tensor的transpose()方法
相关阅读Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm1001.2014.3001.5482在Pytorch中transpose()是Tensor类的一个重要方法同时它也是torch模块中的一个函数它们的语法如下所示。Tensor.transpose(dim0, dim1) → Tensor torch.transpose(input, dim0, dim1) → Tensor input (Tensor) – the input tensor. dim0 (int) – the first dimension to be transposed dim1 (int) – the second dimension to be transposed官方的解释如下返回一个张量它是输入张量的转置版本其中将给定的维度dim0和dim1交换。返回的新对象可能会变得不连续使用is_contiguous()方法可以鉴定是否连续。关于非连续张量的更多细节可以看下面的文章。Pytorch基础Tensor的连续性https://blog.csdn.net/weixin_45791458/article/details/140736700?ops_request_misc%257B%2522request%255Fid%2522%253A%2522eb4c722817c335758581a52404bb2dce%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fblog.%2522%257Drequest_ideb4c722817c335758581a52404bb2dcebiz_id0utm_mediumdistribute.pc_search_result.none-task-blog-2~blog~first_rank_ecpm_v1~rank_v31_ecpm-2-140736700-null-null.nonecaseutm_term%E9%9D%9E%E8%BF%9E%E7%BB%ADspm1018.2226.3001.4450如果输入是一个具有步幅的张量常规稠密张量那么输出张量将与输入张量共享底层存储因此改变一个张量的内容将改变另一个张量的内容。如果输入是一个稀疏张量那么输出张量不与输入张量共享底层存储。如果输入是压缩布局的稀疏张量(SparseCSR, SparseBSR, SparseCSC或SparseBSC)参数dim0和dim1必须同时是批处理维度或者必须同时是稀疏维度稀疏张量的批处理维度是稀疏维之前的维度。在详细说明之前我们需要明确tensor的形状相关的概念对于一个tensor来说它的维度数和维度大小是两个概念维度数是从0开始依次递增的分别称为第0维、第1维...每一个维度又有自己的大小。官方解释中的dim0和dim1指的是维度数。tensor([[1, 2, 3], [4, 5, 6]]) # 这个张量有2个维度分别是0维和1维第0维大小是2第1维大小是3import torch # 创建一个张量 x torch.tensor([[1, 2, 3], [4, 5, 6]]) # 使用transpose操作 y x.transpose(0, 1) # 等价于y x.transpose(1, 0) print(x, y) tensor([[1, 2, 3], [4, 5, 6]]) tensor([[1, 4], [2, 5], [3, 6]]) print(id(x),id(y)) 4347791776 5006922112 # 说明两个张量对象不同 print(x.storage().data_ptr(), y.storage().data_ptr()) 5207345856 5207345856 # 说明两个张量对象里面保存的数据存储是共享的 y[0, 0] 7 print(x, y) tensor([[7, 2, 3], [4, 5, 6]]) tensor([[7, 4], [2, 5], [3, 6]]) # 说明对新tensor的更改影响了原tensor print(x.is_contiguous(), y.is_contiguous()) True False # 说明x是连续的y不是连续的类似于之前在列表的浅拷贝文中说到的那样对新列表内部嵌套的列表中的元素的更改会影响原列表如下所示。import copy my_list [1, 2, [1, 2]] your_list list(my_list) #工厂函数 his_list my_list[:] #切片操作 her_list copy.copy(my_list) #copy模块的copy函数 your_list[2][0] 3 print(my_list) print(your_list) print(his_list) print(her_list) his_list[2][1] 4 print(my_list) print(your_list) print(his_list) print(her_list) her_list[2].append(5) print(my_list) print(your_list) print(his_list) print(her_list) 输出 [1, 2, [3, 2]] [1, 2, [3, 2]] [1, 2, [3, 2]] [1, 2, [3, 2]] [1, 2, [3, 4]] [1, 2, [3, 4]] [1, 2, [3, 4]] [1, 2, [3, 4]] [1, 2, [3, 4, 5]] [1, 2, [3, 4, 5]] [1, 2, [3, 4, 5]] [1, 2, [3, 4, 5]]但与列表不一样的是tensor中非嵌套的内容的修改也会导致另一个tensor受到影响如下所示。import torch # 创建一个张量 x torch.tensor([[1, 2, 3], [4, 5, 6]]) # 使用transpose操作 y x.transpose(0, 1) # 等价于y x.transpose(1, 0) print(x, y) tensor([[1, 2, 3], [4, 5, 6]]) tensor([[1, 4], [2, 5], [3, 6]]) x[0] torch.tensor[4, 4, 4] # 改变其中一个tensor的第0个元素 print(x, y) tensor([[4, 4, 4], [4, 5, 6]]) tensor([[4, 4], [4, 5], [4, 6]])在pytorch中和transpose()方法类似的还有permute()方法它的使用与transpose()方法在功能上有一定重叠但也有区别。transpose()方法只能交换两个维度而permute()方法可以一次性对任意维度进行重排因此在复杂场景中更加灵活。更多细节可以看下面的文章。Pytorch基础Tensor的permute()方法https://blog.csdn.net/weixin_45791458/article/details/133612401?spm1001.2014.3001.5502

更多文章