Source code for simba.utils

import simba.config as conf
from simba.errors import DimensionError
import sympy

[docs]def halve_matrix(mat: sympy.Matrix): """ Halve the dimensions of the given matrix, truncating where necessary. Throws `DimensionError` if matrix dimensions are not even. Example: >>>from sympy import Matrix >>>m = Matrix([[1, 2], [3, 4]]) >>>assert halve_matrix(m) == Matrix([[1]]) """ if mat.shape[0] % 2 != 0 or mat.shape[1] % 2 != 0: raise DimensionError(f"Matrix dimensions not even: {mat.shape}") return mat.extract(range(mat.shape[0] // 2), range(mat.shape[1] // 2))
[docs]def solve_matrix_eqn(eqn, x): """ Solve matrix eqn for x, where eqn is a matrix equation or list of matrix equations (assumed equal to zero on RHS) and x is a sympy ``MatrixSymbol`` object. Transforms all solutions to list of matrices (same shape as x). """ from sympy import linsolve if isinstance(eqn, list): eqns = [] for e in eqn: eqns.extend(list(e)) # make single list of all equations else: eqns = list(eqn) sols = linsolve(eqns, list(x)) return list(map(lambda sol: sympy.Matrix(sol).reshape(*x.shape), sols))
[docs]def construct_permutation_matrix(n: int) -> sympy.Matrix: """Construct permutation matrix that reorders the elements so that (1, 2, 3, 11, 22, 33) -> (1, 11, 2, 22, 3, 33)""" if n % 2 != 0: raise DimensionError("n should be even") u = sympy.Matrix.zeros(n, n) for x in range(n // 2): u[x * 2, x] = 1 u[x * 2 + 1, x + n // 2] = 1 return u
[docs]def simplify(expr, rhs=None): """Simplify given expression or equation, using wolframscript if config.params['wolframscript'] is True""" if conf.params['wolframscript']: import subprocess from sympy import Eq import sympy.printing.mathematica as m import sympy.parsing.mathematica as mp if rhs is not None: s = m.mathematica_code(Eq(expr, rhs, evaluate=False)) else: s = m.mathematica_code(expr) result =["wolframscript", "-code", f"Simplify[{s}]"], capture_output=True) return mp.mathematica(str(result.stdout)) else: from sympy import simplify if rhs is not None: return simplify(expr) == rhs else: return simplify(expr)
[docs]def adiabatically_eliminate(expr: sympy.Expr, gamma: sympy.Symbol) -> sympy.Expr: """ Eliminate terms from ``expr`` which are very small compared to ``gamma`` which is much larger than any other frequency. """ if expr == 0: return sympy.Number(0) from sympy import fraction, Symbol numer, denom = fraction(expr) epsilon = Symbol('epsilon') numer = (numer / gamma).subs(gamma, 1 / epsilon).expand().subs(epsilon, 0) denom = (denom / gamma).subs(gamma, 1 / epsilon).expand().subs(epsilon, 0) return numer / denom