# -*- coding: utf-8 -*-
# Copyright (C) 2011-2017 Martin Sandve Alnæs
#
# This file is part of UFLACS.
#
# UFLACS is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# UFLACS is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with UFLACS. If not, see <http://www.gnu.org/licenses/>.
"""Tools for C/C++ expression formatting."""
from ffc.log import error
from ufl.corealg.multifunction import MultiFunction
#from ufl.corealg.map_dag import map_expr_dag
[docs]class UFL2CNodesMixin(object):
"""Rules collection mixin for a UFL to CNodes translator class."""
def __init__(self, language):
self.L = language
self.force_floats = False
self.enable_strength_reduction = False
# === Error handlers for missing formatting rules ===
[docs] def expr(self, o):
"Generic fallback with error message for missing rules."
error("Missing C++ formatting rule for expr type {0}.".format(o._ufl_class_))
# === Formatting rules for scalar literals ===
[docs] def zero(self, o):
return self.L.LiteralFloat(0.0)
#def complex_value(self, o):
# return self.L.ComplexValue(complex(o))
[docs] def float_value(self, o):
return self.L.LiteralFloat(float(o))
[docs] def int_value(self, o):
if self.force_floats:
return self.float_value(o)
return self.L.LiteralInt(int(o))
# === Formatting rules for arithmetic operators ===
[docs] def sum(self, o, a, b):
return self.L.Add(a, b)
#def sub(self, o, a, b): # Not in UFL
# return self.L.Sub(a, b)
[docs] def product(self, o, a, b):
return self.L.Mul(a, b)
[docs] def division(self, o, a, b):
if self.enable_strength_reduction:
return self.L.Mul(a, self.L.Div(1.0, b))
else:
return self.L.Div(a, b)
# === Formatting rules for conditional expressions ===
[docs] def conditional(self, o, c, t, f):
return self.L.Conditional(c, t, f)
[docs] def eq(self, o, a, b):
return self.L.EQ(a, b)
[docs] def ne(self, o, a, b):
return self.L.NE(a, b)
[docs] def le(self, o, a, b):
return self.L.LE(a, b)
[docs] def ge(self, o, a, b):
return self.L.GE(a, b)
[docs] def lt(self, o, a, b):
return self.L.LT(a, b)
[docs] def gt(self, o, a, b):
return self.L.GT(a, b)
[docs] def and_condition(self, o, a, b):
return self.L.And(a, b)
[docs] def or_condition(self, o, a, b):
return self.L.Or(a, b)
[docs] def not_condition(self, o, a):
return self.L.Not(a)
# === Formatting rules for cmath functions ===
[docs] def math_function(self, o, op):
# Fallback for unhandled MathFunction subclass: attempting to just call it.
# TODO: Introduce a UserFunction to UFL to keep it separate from MathFunction?
return self.L.Call(o._name, op)
[docs] def sqrt(self, o, op):
return self._cmath("sqrt", op)
#def cbrt(self, o, op): # Not in UFL
# return self._cmath("cbrt", op)
# cmath also has log10 etc
[docs] def ln(self, o, op):
return self._cmath("log", op)
# cmath also has exp2 etc
[docs] def exp(self, o, op):
return self._cmath("exp", op)
[docs] def cos(self, o, op):
return self._cmath("cos", op)
[docs] def sin(self, o, op):
return self._cmath("sin", op)
[docs] def tan(self, o, op):
return self._cmath("tan", op)
[docs] def cosh(self, o, op):
return self._cmath("cosh", op)
[docs] def sinh(self, o, op):
return self._cmath("sinh", op)
[docs] def tanh(self, o, op):
return self._cmath("tanh", op)
[docs] def atan_2(self, o, y, x):
return self._cmath("atan2", (y, x))
[docs] def acos(self, o, op):
return self._cmath("acos", op)
[docs] def asin(self, o, op):
return self._cmath("asin", op)
[docs] def atan(self, o, op):
return self._cmath("atan", op)
#def acosh(self, o, op): # Not in UFL
# return self._cmath("acosh", op)
#def asinh(self, o, op): # Not in UFL
# return self._cmath("asinh", op)
#def atanh(self, o, op): # Not in UFL
# return self._cmath("atanh", op)
[docs] def erf(self, o, op):
return self._cmath("erf", op)
#def erfc(self, o, op): # Not in UFL
# # C++11 stl has this function
# return self._cmath("erfc", op)
[docs]class RulesForC(object):
def _cmath(self, name, op):
return self.L.Call(name, op)
[docs] def power(self, o, a, b):
return self.L.Call("pow", (a, b))
[docs] def abs(self, o, op):
return self.L.Call("fabs", op)
[docs] def min_value(self, o, a, b):
return self.L.Call("fmin", (a, b))
[docs] def max_value(self, o, a, b):
return self.L.Call("fmax", (a, b))
# ignoring bessel functions
[docs]class RulesForCpp(object):
def _cmath(self, name, op):
return self.L.Call("std::" + name, op)
[docs] def power(self, o, a, b):
return self.L.Call("std::pow", (a, b))
[docs] def abs(self, o, op):
return self.L.Call("std::abs", op)
[docs] def min_value(self, o, a, b):
return self.L.Call("std::min", (a, b))
[docs] def max_value(self, o, a, b):
return self.L.Call("std::max", (a, b))
# === Formatting rules for bessel functions ===
def _bessel(self, o, n, v, name):
return self.L.Call("boost::math::" + name, (n, v))
[docs] def bessel_i(self, o, n, v):
return self._bessel(o, n, v, "cyl_bessel_i")
[docs] def bessel_j(self, o, n, v):
return self._bessel(o, n, v, "cyl_bessel_j")
[docs] def bessel_k(self, o, n, v):
return self._bessel(o, n, v, "cyl_bessel_k")
[docs] def bessel_y(self, o, n, v):
return self._bessel(o, n, v, "cyl_neumann")
[docs]class UFL2CNodesTranslatorC(MultiFunction, UFL2CNodesMixin, RulesForC):
"""UFL to CNodes translator class."""
def __init__(self, language):
MultiFunction.__init__(self)
UFL2CNodesMixin.__init__(self, language)
[docs]class UFL2CNodesTranslatorCpp(MultiFunction, UFL2CNodesMixin, RulesForCpp):
"""UFL to CNodes translator class."""
def __init__(self, language):
MultiFunction.__init__(self)
UFL2CNodesMixin.__init__(self, language)