# -*- coding: utf-8 -*-
"This file contains functions to optimise the code generated for quadrature representation."
# Copyright (C) 2009-2010 Kristian B. Oelgaard
#
# This file is part of FFC.
#
# FFC is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# FFC is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with FFC. If not, see <http://www.gnu.org/licenses/>.
from ufl.utils.sorting import sorted_by_key
# FFC modules
from ffc.log import error
from ffc.quadrature.cpp import format
# TODO: Use proper errors, not just RuntimeError.
# TODO: Change all if value == 0.0 to something more safe.
# Some basic variables.
BASIS = 0
IP = 1
GEO = 2
CONST = 3
type_to_string = {BASIS: "BASIS", IP: "IP", GEO: "GEO", CONST: "CONST"}
# Functions and dictionaries for cache implementation.
# Increases speed and should also reduce memory consumption.
_float_cache = {}
[docs]def create_float(val):
if val in _float_cache:
return _float_cache[val]
float_val = FloatValue(val)
_float_cache[val] = float_val
return float_val
_symbol_cache = {}
[docs]def create_symbol(variable, symbol_type, base_expr=None, base_op=0):
key = (variable, symbol_type, base_expr, base_op)
if key in _symbol_cache:
return _symbol_cache[key]
symbol = Symbol(variable, symbol_type, base_expr, base_op)
_symbol_cache[key] = symbol
return symbol
_product_cache = {}
[docs]def create_product(variables):
# NOTE: If I switch on the sorted line, it might be possible to find more
# variables in the cache, but it adds some overhead so I don't think it
# pays off. The member variables are also sorted in the classes
# (Product and Sum) so the list 'variables' is probably already sorted.
key = tuple(variables)
if key in _product_cache:
return _product_cache[key]
product = Product(key)
_product_cache[key] = product
return product
_sum_cache = {}
[docs]def create_sum(variables):
# NOTE: If I switch on the sorted line, it might be possible to
# find more variables in the cache, but it adds some overhead so I
# don't think it pays off. The member variables are also sorted in
# the classes (Product and Sum) so the list 'variables' is
# probably already sorted.
key = tuple(variables)
if key in _sum_cache:
return _sum_cache[key]
s = Sum(key)
_sum_cache[key] = s
return s
_fraction_cache = {}
[docs]def create_fraction(num, denom):
key = (num, denom)
if key in _fraction_cache:
return _fraction_cache[key]
fraction = Fraction(num, denom)
_fraction_cache[key] = fraction
return fraction
[docs]def generate_aux_constants(constant_decl, name, var_type, print_ops=False):
"A helper tool to generate code for constant declarations."
format_comment = format["comment"]
code = []
append = code.append
ops = 0
for num, expr in sorted((v, k) for k, v in sorted_by_key(constant_decl)):
# Expand and reduce expression (If we don't already get
# reduced expressions.)
expr = expr.expand().reduce_ops()
if print_ops:
op = expr.ops()
ops += op
append(format_comment("Number of operations: %d" % op))
append(var_type(name(num), str(expr)))
append("")
else:
ops += expr.ops()
append(var_type(name(num), str(expr)))
return (ops, code)
[docs]def optimise_code(expr, ip_consts, geo_consts, trans_set):
"""Optimise a given expression with respect to, basis functions,
integration points variables and geometric constants. The
function will update the dictionaries ip_const and geo_consts with
new declarations and update the trans_set (used
transformations).
"""
format_G = format["geometry constant"]
format_I = format["ip constant"]
trans_set_update = trans_set.update
# Return constant symbol if expanded value is zero.
exp_expr = expr.expand()
if exp_expr.val == 0.0:
return create_float(0)
# Reduce expression with respect to basis function variable.
basis_expressions = exp_expr.reduce_vartype(BASIS)
# If we had a product instance we'll get a tuple back so embed in
# list.
if not isinstance(basis_expressions, list):
basis_expressions = [basis_expressions]
basis_vals = []
# Process each instance of basis functions.
for basis, ip_expr in basis_expressions:
# Get the basis and the ip expression.
# If we have no basis (like functionals) create a const.
if not basis:
basis = create_float(1)
# If the ip expression doesn't contain any operations skip
# remainder
if not ip_expr or ip_expr.val == 0.0:
basis_vals.append(basis)
continue
if not ip_expr.ops() > 0:
basis_vals.append(create_product([basis, ip_expr]))
continue
# Reduce the ip expressions with respect to IP variables.
ip_expressions = ip_expr.expand().reduce_vartype(IP)
# If we had a product instance we'll get a tuple back so embed in list.
if not isinstance(ip_expressions, list):
ip_expressions = [ip_expressions]
ip_vals = []
# Loop ip expressions.
for ip in sorted(ip_expressions):
ip_dec, geo = ip
# Update transformation set with those values that might
# be embedded in IP terms.
if ip_dec and ip_dec.val != 0.0:
trans_set_update([str(x) for x in ip_dec.get_unique_vars(GEO)])
# Append and continue if we did not have any geo values.
if not geo or geo.val == 0.0:
if ip_dec and ip_dec.val != 0.0:
ip_vals.append(ip_dec)
continue
# Update the transformation set with the variables in the
# geo term.
trans_set_update([str(x) for x in geo.get_unique_vars(GEO)])
# Only declare auxiliary geo terms if we can save
# operations.
if geo.ops() > 0:
# If the geo term is not in the dictionary append it.
if geo not in geo_consts:
geo_consts[geo] = len(geo_consts)
# Substitute geometry expression.
geo = create_symbol(format_G(geo_consts[geo]), GEO)
# If we did not have any ip_declarations use geo, else
# create a product and append to the list of ip_values.
if not ip_dec or ip_dec.val == 0.0:
ip_dec = geo
else:
ip_dec = create_product([ip_dec, geo])
ip_vals.append(ip_dec)
# Create sum of ip expressions to multiply by basis.
if len(ip_vals) > 1:
ip_expr = create_sum(ip_vals)
elif ip_vals:
ip_expr = ip_vals.pop()
# If we can save operations by declaring it as a constant do
# so, if it is not in IP dictionary, add it and use new name.
if ip_expr.ops() > 0 and ip_expr.val != 0.0:
if ip_expr not in ip_consts:
ip_consts[ip_expr] = len(ip_consts)
# Substitute ip expression.
ip_expr = create_symbol(format_I(ip_consts[ip_expr]), IP)
# Multiply by basis and append to basis vals.
basis_vals.append(create_product([basis, ip_expr]))
# Return (possible) sum of basis values.
if len(basis_vals) > 1:
return create_sum(basis_vals)
elif basis_vals:
return basis_vals[0]
# Where did the values go?
error("Values disappeared.")
from .floatvalue import FloatValue
from .symbol import Symbol
from .product import Product
from .sumobj import Sum
from .fraction import Fraction