【python】DL数组相关操作

高维数组的理解

[[[1, 2],[3, 4]]]

按括号拆分,有三个括号,这个是个三维数组。第三层括号里面包含一个二层括号,所以包含一个二维数组
二维数组包含两个一层括号,所以有两个一维数组。
一维数组:1x2
二维数组:2x2

array[0] 会返回整个二维数组 [[1, 2], [3, 4]],因为它是三维数组的第一个元素。
array[0, 0] 会返回第一行 [[1, 2], [3, 4]] 中的第一行 [1, 2]。
array[0, 1, 1] 会返回二维数组中第二行第二列的元素,也就是 4。

concatenate 操作

1. 轴的定义

对于一个 n 维数组(n 是数组的维度),axis 的取值范围是 0n-1,也可以是负数,表示从最后一维开始倒数计数。

  • axis=0:在第一维度上拼接,即沿着“行”方向拼接。
  • axis=1:在第二维度上拼接,即沿着“列”方向拼接。
  • axis=-1:在最后一个维度上拼接。
  • axis=-2:在倒数第二个维度上拼接。

2. 不同维度下 axis 的使用

一维数组

对于一维数组,axis 只有 0-1 两个取值:

import jax.numpy as jnp

array1 = jnp.array([1, 2, 3])
array2 = jnp.array([4, 5, 6])

# 在 axis=0 上拼接
result = jnp.concatenate([array1, array2], axis=0)
# result: [1 2 3 4 5 6]

在一维数组中,axis=0axis=-1 是等价的,因为只有一个维度。

二维数组

对于二维数组,axis 可以是 01-1-2

import jax.numpy as jnp

array1 = jnp.array([[1, 2],
                    [3, 4]])

array2 = jnp.array([[5, 6],
                    [7, 8]])

# 在 axis=0 上拼接(沿着行方向)
result = jnp.concatenate([array1, array2], axis=0)
# result:
# [[1 2]
#  [3 4]
#  [5 6]
#  [7 8]]

# 在 axis=1 上拼接(沿着列方向)
result = jnp.concatenate([array1, array2], axis=1)
# result:
# [[1 2 5 6]
#  [3 4 7 8]]
  • axis=0:将 array2 的行添加到 array1 的行之后,结果的行数变多。
  • axis=1:将 array2 的列添加到 array1 的列之后,结果的列数变多。

三维数组

对于三维数组,axis 可以是 012-1-2-3,分别表示在不同的维度上进行拼接:

import jax.numpy as jnp

array1 = jnp.array([[[1, 2],
                     [3, 4]]])

array2 = jnp.array([[[5, 6],
                     [7, 8]]])

# 在 axis=0 上拼接(沿着第一维度)
result = jnp.concatenate([array1, array2], axis=0)
# result:
# [[[1 2]
#   [3 4]]
#  [[5 6]
#   [7 8]]]

# 在 axis=1 上拼接(沿着第二维度)
result = jnp.concatenate([array1, array2], axis=1)
# result:
# [[[1 2]
#   [3 4]
#   [5 6]
#   [7 8]]]

# 在 axis=2 上拼接(沿着第三维度)
result = jnp.concatenate([array1, array2], axis=2)
# result:
# [[[1 2 5 6]
#   [3 4 7 8]]]
  • axis=0:增加第一维的数量,即增加“深度”。
  • axis=1:增加第二维的数量,即增加“高度”。
  • axis=2:增加第三维的数量,即增加“宽度”。
赞赏