symbols = [Symbol(name, real=True) for name in diff_eq_names]
// Coefficients
wildcards = [Wild("c_" + name, exclude=symbols) for name in diff_eq_names]
//Additive constant
constant_wildcard = Wild("c", exclude=symbols)
pattern = reduce(operator.add, [c * s for c, s in zip(wildcards, symbols)])
pattern += constant_wildcard
coefficients = sp.zeros(len(diff_eq_names))
constants = sp.zeros((len(diff_eq_names), 1))
for row_idx, (name, expr) in enumerate(diff_eqs):
s_expr = expr.sympy_expr.expand()
pattern_matches = s_expr.match(pattern)
if pattern_matches is None:
raise ValueError(("The expression "%s", defining the variable %s, "
"could not be separated into linear components") %
(expr, name))
for col_idx in xrange(len(diff_eq_names)):
coefficients[row_idx, col_idx] = pattern_matches[wildcards[col_idx]]
constants[row_idx] = pattern_matches[constant_wildcard]
return (diff_eq_names, coefficients, constants)
After Change
for row_idx, (name, expr) in enumerate(diff_eqs):
s_expr = expr.sympy_expr.expand()
current_s_expr = s_expr
for col_idx, (name, symbol) in enumerate(zip(eqs.diff_eq_names, symbols)):
current_s_expr = current_s_expr.collect(symbol)
constant_wildcard = Wild("c", exclude=[symbol])
factor_wildcard = Wild("c_"+name, exclude=symbols)
one_pattern = factor_wildcard*symbol + constant_wildcard
matches = current_s_expr.match(one_pattern)