chemical_equation_balancer/equation_solver.py

112 lines
3.6 KiB
Python
Raw Permalink Normal View History

2022-12-10 21:38:17 +08:00
from fractions import Fraction
import numpy as np
import scipy as sp
2022-12-10 16:30:23 +08:00
def solve_equation(eq):
"""
配平化学方程式返回配平后的化学方程式
:param eq: 化学方程式
格式为
{
'left': [ {
'atoms': [ {'元素名称': 元素个数}, {'元素名称': 元素个数}, ... ],
'coefficient': 系数,
'pretty_name': 化学式的字符串表示
}, ... ],
'right': [ ... ]
}
:return: 配平后的化学方程式与输入格式相同
若无法配平则返回 None
"""
2022-12-10 21:38:17 +08:00
# 统计所有元素的种类
elements = set()
for each in eq['left']:
for atom in each['atoms']:
elements.add(list(atom.keys())[0])
for each in eq['right']:
for atom in each['atoms']:
elements.add(list(atom.keys())[0])
elements = list(elements)
# 构造系数矩阵
matrix = []
constant = []
for atom in elements:
# 遍历左边,左侧系数为正
row = []
for each in eq['left']:
for atom_ in each['atoms']:
if atom in atom_:
row.append(atom_[atom])
break
else:
row.append(0)
# 遍历右边,右侧系数为负
for each in eq['right']:
for atom_ in each['atoms']:
if atom in atom_:
row.append(-atom_[atom])
break
else:
row.append(0)
# 取出row中的最后一个元素反转后加入常数项
constant.append(-row.pop())
matrix.append(row)
2022-12-10 21:38:17 +08:00
matrix = np.mat(matrix, int)
# 求解线性方程组
try:
2022-12-11 16:00:02 +08:00
result = sp.optimize.lsq_linear(matrix, constant, bounds=(0, None)).x.tolist()
except ValueError:
raise ValueError('无法配平')
2022-12-10 16:30:23 +08:00
2022-12-10 21:38:17 +08:00
# 将结果写入化学方程式,最后一个生成物的系数需要计算得到
last_substance = eq['right'][-1]
last_atom = list(last_substance['atoms'][0].keys())[0]
# 计算除最后一种生成物外的所有生成物包含last_atom的系数之和
sum_ = 0
index = 0
for each in eq['left']:
for atom in each['atoms']:
if last_atom in atom:
# 该生成物包含last_atom则获得last_atom的总个数等于系数*原子个数
sum_ += result[index] * atom[last_atom]
break
index += 1
for each in eq['right'][:-1]:
for atom in each['atoms']:
if last_atom in atom:
sum_ -= result[index] * atom[last_atom]
break
index += 1
result.append(sum_ / last_substance['atoms'][0][last_atom])
result = expand_to_int(*result)
for i, each in enumerate(eq['left']):
each['coefficient'] = result[i]
for i, each in enumerate(eq['right']):
each['coefficient'] = result[i + len(eq['left'])]
return eq
2022-12-10 16:30:23 +08:00
2022-12-10 21:38:17 +08:00
def expand_to_int(*args):
"""
将一系列小数同时扩大全部转换为最接近的整数
2022-12-10 16:30:23 +08:00
2022-12-10 21:38:17 +08:00
:param args: 一系列小数
:return: 一系列整数
"""
# 将所有小数转换为分数
fractions = []
for each in args:
fractions.append(Fraction(each).limit_denominator())
# 计算所有分数的最小公倍数
lcm = 1
for each in fractions:
lcm = lcm * each.denominator // np.gcd(lcm, each.denominator)
# 将所有分数扩大为最小公倍数
result = []
for each in fractions:
result.append(each.numerator * lcm // each.denominator)
2022-12-11 16:00:02 +08:00
return result