三行写完高斯消元,这就是Python!!!

洛谷评测链接

先给大家看一眼核心代码

核心代码

1
2
3
for i in range(len(a)):
row = [j for j in range(len(a)) if a[j][i] != 0 and sum(a[j][:i]) < 1e-8][0]
a[:] = [r + a[row]*(-r[i]/a[row][i]) if j != row else r/a[row][i] for (j,r) in enumerate(a)]

没错,就只有三行,完成了高斯消元最核心的操作,把矩阵消元成主对角线为1,其余除了常数项全是0的形式。我只用了Python中的切片操作,列表解析式,还有numpy中array的性质。 为了方便大家理解,我先来介绍一些这些python中的语法

python语法介绍

切片

语法格式是 [开始:结束] 可以取出列表中的一段区间,如果不填写就是默认开始位置是0,结束位置是列表最后一个元素位置。注意这里的区间是左闭右开区间。并且支持倒着数,也就是使用负号

比如:

1
2
3
4
5
l = [0,2,3,4,5]
print(l[2:]) # 取出[3,4,5]
print(l[2:3]) # 取出[3]
print(l[:3]) # 取出[0,2,3]
print(l[:-1]) # 取出[0,2,3,4]

除了1维列表可以用切片,二维矩阵也可以但是需要用numpy里的array才可以。 语法格式是 [行开始:行结束,列开始:列结束],不写就默认是全部部分 或者 [第x行,第x列]

比如

1
2
3
4
5
6
7
8
9
l = np.array([  [0,2,3,4,5], 
[1,5,3,7,8],
[3,2,5,6,0]])

print(l[2:]) # 取出第0-2行
print(l[:,3:]) # 取出第3列到最后1列
print(l[:,-1]) # 取出最后一列
print(l[[0,2],:]) # 取出第0行和第2行
print(l[[0,2],[2]]) # 取出第0行和第2行中第2列的内容

我们还可以使用切片交换矩阵中的2行

1
2
3
l[[0,2],:] = l[[2,0],:]  # 交换第0行和第2行
# 注: 这里用 l[0],l[2] = l[2],l[0] 由于交换引用会出错

列表解析式(List Conprehension)

简单来说,我们可以直接用for来构造一个具有某个性质的列表 语法格式: [ 表达式 for 元素 in 迭代器/列表 if 筛选条件 ]。

比如

1
2
3
4
5
l = [1, 2, 4, 5, 6, 7, 8, 9]

# 把l列表中大于5的取出来,并且值再加上1
l2 = [x + 1 for x in l if x > 5]
print(l2) # 结果: [7, 8, 9, 10]

numpy的矩阵

numpy中的矩阵可以很方便的对集体元素进行操作。

比如

1
2
3
4
from numpy import array # 首先要导入array

l = array([1,2,4,5,6,7,8,9])
print(l*2) # l集体*2 ,结果:[ 2 4 8 10 12 14 16 18]

使用shape可以得到矩阵的行列格式

1
2
3
4
5
l = np.array([  [0,2,3,4,5], 
[1,5,3,7,8],
[3,2,5,6,0]])

print(l.shape) # 结果 (3, 5)

三目运算符

对于只用if语句来判断,接着赋值的操作,我们可以用三目运算符简化。比如

1
2
3
4
5
6
l = [1,2,4,5,6,7]
for i in range(l):
if i > 4:
l[i] -= 1
else:
l[i] += 1

可以简化为

1
2
for i in range(l):
l[i] = l[i]-1 if i > 4 else l[i]+1

如何完成高斯消元

首先这里默认你已经明白高斯消元的原理,我这里只解释如何用python去完成它。如果不明白可以看其他人的博客再会来。这里就不再解释高斯消元,否则会篇幅会很长。

首先我们需要对每一个主元消元一次,这里用for循环来实现

1
2
for i in range(len(a)):  # i 是第i个主元, len(a) 是行数(也是主元个数)
pass

找到第1个第i个系数非0且未使用过的行

比如在这一个矩阵里 \[ A=\left(\begin{array}{cccc} 1 & 3 & 4 & 5 \\ 1 & 4 & 7 & 3 \\ 9 & 3 & 2 & 2 \\ \end{array}\right) \] 假如我们要消第i个主元,那么我们就要找第i列不为0的行,然后用它把其他行的这一列全消元成0,如果用列表解析式,那么可以这样表示满足该条件的行的集合

1
2
# a[j][i] 就是第j行的第i列
[j for j in range(len(a)) if a[j][i] != 0]

那么还有一个条件,就是我们消元的时候,已经消元过的行不能再使用,所以这里要排除已经用过的主元行。那么哪些是用过的呢?比如你正在消第2个主元(从0开始数) \[ A=\left(\begin{array}{cccc} 1 & 0 & 4 & 5 \\ 0 & 0 & 7 & 3 \\ \end{array}\right) \] 这里两行都满足第2列不为0的条件,但是吗显然不能用前面有1的第0行去消,要用前面都是0的第1行的。也就是如果一个行前面没被使用过,那么前面的项和都是0。

如何判断"前面的项都是0"?我们在小学二年级就学过:n个非负数加和是0,那么它们就都是0。由于它前面都是我们已经消元过的,只可能出现0和1,所以只要加和是0就可以了,那么我们可以加上这个:

1
2
# a[j][i] 就是第j行的第i列
[j for j in range(len(a)) if a[j][i] != 0 and sum(a[j][:i]) < 1e-8]

这里由于是处理浮点数,需要加一个精度处理(1e-8)。Python的sum可以轻松把一个列表的加和计算出来,而我们可以用切片操作取出第j行第i列前面的元素组成一个列表。

那么还不够,我们这里取出了所有满足条件的数,我们只需要1个即可,所以我们把第一个行取出来,记行号为row

1
row = [j for j in range(len(a)) if a[j][i] != 0 and sum(a[j][:i]) < 1e-8][0]

用该行消去其他行的系数

在这一步,我们需要把其他行的这一列全消元成0,把自己行的这一列消元成1。我们先取出所有行

1
[r for (j,r) in enumerate(a)]  # 这里用enumerate不仅可以取出a矩阵的行r,还可以生成对应的下标j

然后根据numpy的性质对行进行集体操作来消元

1
2
# r + a[row]*(-r[i]/a[row][i]) 这个是消元
[r + a[row]*(-r[i]/a[row][i]) for (j,r) in enumerate(a)]

如果同时需要把自己行单独判断消元成1,这里可以用三目运算符简化代码

1
2
# r/a[row][i] 这个是把自己消元成1的情况,直接除自己的系数即可
[r + a[row]*(-r[i]/a[row][i]) if j != row else r/a[row][i] for (j,r) in enumerate(a)]

然后我们把处理完的矩阵直接赋值回去就好了

1
2
# 这里要写a[:],不能写a 否则会出错
a[:] = [r + a[row]*(-r[i]/a[row][i]) if j != row else r/a[row][i] for (j,r) in enumerate(a)]

返回解集

最后我们把计算好的解集用切片操作返回即可,也就是返回最后一列

1
return a[:, a.shape[1]-1]

特殊情况:无唯一解

这种情况下,上述代码肯定会报错,所以我们在运行的时候加个try来处理异常即可

1
2
3
4
try:
print("".join(["%.2lf\n" % x for x in gauss(arr)]))
except:
print("No Solution")

那么写道这里,就已经写完了。至于运行效率......毕竟是Python嘛,肯定是比不上C++的,够用就行。

完整AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from numpy import array
from sys import stdin

# 高斯消元
def gauss(a):
for i in range(len(a)):
# 找到第1个第i个系数非0且未使用过的行
row = [j for j in range(len(a)) if a[j][i] != 0 and sum(a[j][:i]) < 1e-8][0]
# 用该行消去其他行的系数
a[:] = [r + a[row]*(-r[i]/a[row][i]) if j != row else r/a[row][i] for (j,r) in enumerate(a)]
# 返回解集
return a[:, a.shape[1]-1]

if __name__ == "__main__":
# 输入矩阵
n = int(stdin.readline())
arr = array([list(map((float), stdin.readline().split(" ")[:-1] )) for _ in range(n)])
# 若无异常则存在唯一解
try:
print("".join(["%.2lf\n" % x for x in gauss(arr)]))
except:
print("No Solution")