高维数组的理解
[[[1, 2],[3, 4]]]
按括号拆分,有三个括号,这个是个三维数组。第三层括号里面包含一个二层括号,所以包含一个二维数组
二维数组包含两个一层括号,所以有两个一维数组。
一维数组:1x2
二维数组:2x2
array[0] 会返回整个二维数组 [[1, 2], [3, 4]],因为它是三维数组的第一个元素。
array[0, 0] 会返回第一行 [[1, 2], [3, 4]] 中的第一行 [1, 2]。
array[0, 1, 1] 会返回二维数组中第二行第二列的元素,也就是 4。
concatenate 操作
1. 轴的定义
对于一个 n
维数组(n
是数组的维度),axis
的取值范围是 0
到 n-1
,也可以是负数,表示从最后一维开始倒数计数。
axis=0
:在第一维度上拼接,即沿着“行”方向拼接。axis=1
:在第二维度上拼接,即沿着“列”方向拼接。axis=-1
:在最后一个维度上拼接。axis=-2
:在倒数第二个维度上拼接。
2. 不同维度下 axis
的使用
一维数组
对于一维数组,axis
只有 0
或 -1
两个取值:
import jax.numpy as jnp
array1 = jnp.array([1, 2, 3])
array2 = jnp.array([4, 5, 6])
# 在 axis=0 上拼接
result = jnp.concatenate([array1, array2], axis=0)
# result: [1 2 3 4 5 6]
在一维数组中,axis=0
和 axis=-1
是等价的,因为只有一个维度。
二维数组
对于二维数组,axis
可以是 0
、1
、-1
或 -2
:
import jax.numpy as jnp
array1 = jnp.array([[1, 2],
[3, 4]])
array2 = jnp.array([[5, 6],
[7, 8]])
# 在 axis=0 上拼接(沿着行方向)
result = jnp.concatenate([array1, array2], axis=0)
# result:
# [[1 2]
# [3 4]
# [5 6]
# [7 8]]
# 在 axis=1 上拼接(沿着列方向)
result = jnp.concatenate([array1, array2], axis=1)
# result:
# [[1 2 5 6]
# [3 4 7 8]]
axis=0
:将array2
的行添加到array1
的行之后,结果的行数变多。axis=1
:将array2
的列添加到array1
的列之后,结果的列数变多。
三维数组
对于三维数组,axis
可以是 0
、1
、2
、-1
、-2
、-3
,分别表示在不同的维度上进行拼接:
import jax.numpy as jnp
array1 = jnp.array([[[1, 2],
[3, 4]]])
array2 = jnp.array([[[5, 6],
[7, 8]]])
# 在 axis=0 上拼接(沿着第一维度)
result = jnp.concatenate([array1, array2], axis=0)
# result:
# [[[1 2]
# [3 4]]
# [[5 6]
# [7 8]]]
# 在 axis=1 上拼接(沿着第二维度)
result = jnp.concatenate([array1, array2], axis=1)
# result:
# [[[1 2]
# [3 4]
# [5 6]
# [7 8]]]
# 在 axis=2 上拼接(沿着第三维度)
result = jnp.concatenate([array1, array2], axis=2)
# result:
# [[[1 2 5 6]
# [3 4 7 8]]]
axis=0
:增加第一维的数量,即增加“深度”。axis=1
:增加第二维的数量,即增加“高度”。axis=2
:增加第三维的数量,即增加“宽度”。