2 Star 0 Fork 0

Yizhou Guo/Computation of congruence closures

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
equation.py 6.38 KB
一键复制 编辑 原始数据 按行查看 历史
Yizhou Guo 提交于 2023-05-25 19:22 . remove more unused code
from typing import Dict, Tuple
from logger import logger
from enum import Enum
class Expr:
def coerce_to_const(self):
if type(self) == Func:
return Const(self.func_name+"_const")
return self
def __eq__(self, __value: object) -> bool:
return type(self) == Const and type(__value) == Const and self.name == __value.name
class Replacement:
def __init__(self):
self.symbols : Dict[str, Expr] = dict()
self.symbols_rev : Dict[Expr, str] = dict()
self.counter = 0
def add_symbol(self, name, expr):
self.symbols[name] = expr
self.symbols_rev[expr] = name
def new_symbol(self):
res = "_x" + str(self.counter)
self.counter += 1
return res
def has_expr(self, expr: Expr) -> bool:
return expr in self.symbols_rev
def to_flat_eqs(self) -> "list[Equation]":
return [Equation(func, Const(symb_name)) for symb_name, func in self.symbols.items()]
def __len__(self):
return len(self.symbols)
def __str__(self):
res = ""
for symbol, expr in self.symbols.items():
res += f"{symbol} -> {expr}\n"
return res
## TODO: consider fusing this into `Func` with `args = []`
class Const(Expr):
def __init__(self, name):
self.name = name
def __str__(self) -> str:
return self.name
def get_symbol_names(self):
return [self.name]
# def symb_replace(self, root_mapping: dict):
# return Const(root_mapping[self.name] if self.name in root_mapping else self.name)
def __hash__(self):
return hash(("Const", self.name))
def names_in_const_list(const_list: list[Const]) -> set[str]:
"""Return set of names in the given list of constants."""
return set([const.name for const in const_list])
def name_list_to_const_list(name_list: list[str]) -> list[Const]:
"""Return list of constants for the given list of names."""
return [Const(name) for name in name_list]
def remove_duplicates(const_list: list[Const]) -> list[Const]:
return name_list_to_const_list(names_in_const_list(const_list))
class Func(Expr):
def __init__(self, func_name: str, args: list[Expr]):
self.func_name = func_name
self.args = args
# for arg in self.args:
# if type(arg) == str:
# logger.info(func_name)
# import pdb
# pdb.set_trace()
# logger.debug(args)
def __eq__(self, __value: object) -> bool:
if type(__value) != Func:
return False
func_names_equal = type(__value) == Func and self.func_name == __value.func_name
arity_equal = len(self.args) == len(__value.args)
args_equal = all([this_arg == other_arg for (this_arg, other_arg) in zip(self.args, __value.args)])
return func_names_equal and arity_equal and args_equal
def __str__(self) -> str:
comma_sep_args = ','.join(map(lambda x : str(x), self.args))
return f'{self.func_name}({comma_sep_args})'
def __hash__(self):
return hash(("Func", self.func_name, tuple(hash(arg) for arg in self.args)))
def get_symbol_names(self):
for arg in self.args:
if type(arg) == str:
import pdb
pdb.set_trace()
list_of_lists = [arg.get_symbol_names() for arg in self.args]
flat_list = [item for sublist in list_of_lists for item in sublist]
return list(set(flat_list))
# def symb_replace(self, root_mapping: dict):
# # logger.debug(self.args)
# args_replaced = [arg.symb_replace(root_mapping) for arg in self.args]
# return Func(self.func_name, args_replaced)
def args_all_constant(self):
return all(map(lambda x: type(x) == Const, self.args))
def flatten(self, repl: Replacement) -> Expr:
"""Flatten the given function application, return the Const expression."""
new_args = list()
for arg in self.args:
if isinstance(arg, Func):
new_args.append(arg.flatten(repl))
else:
new_args.append(arg)
new_expr = Func(self.func_name, new_args)
# import pdb
# pdb.set_trace()
if repl.has_expr(new_expr):
return Const(repl.symbols_rev[new_expr])
else:
new_name = repl.new_symbol()
repl.add_symbol(new_name, new_expr)
return Const(new_name)
class AtomStatus(Enum):
TRUE = 1
UNKNOWN = 2
class Equation:
def __init__(self, lhs: type[Expr], rhs: type[Expr]):
self.lhs = lhs
self.rhs = rhs
self.status = AtomStatus.TRUE if str(lhs) == str(rhs) else AtomStatus.UNKNOWN
self._symb_name_counter = 0
self.deleted = False
def is_trivial(self):
return self.lhs == self.rhs
def is_const(self):
return type(self.lhs) == Const and type(self.rhs) == Const
def is_flat_eq(self):
if type(self.lhs) == Func and type(self.rhs) == Const:
return self.lhs.args_all_constant()
return False
def __str__(self) -> str:
return f'{str(self.lhs)} = {str(self.rhs)}'
### Algorithm from page 8 of paper
### TODO: test removal of trivial equalities of the form `a=a`
def flatten(self, repl: Replacement) -> "Equation":
lhs_type, rhs_type = type(self.lhs), type(self.rhs)
both_types = lhs_type, rhs_type
# print(both_types)
if both_types == (Const, Const):
return self
if both_types == (Func, Const):
new_lhs = self.lhs.flatten(repl)
return Equation(new_lhs, self.rhs)
if both_types == (Const, Func):
return Equation(self.rhs, self.lhs).flatten(repl)
if both_types == (Func, Func):
new_lhs = self.lhs.flatten(repl)
new_rhs = self.rhs.flatten(repl)
return Equation(new_lhs, new_rhs)
def get_symbol_set(equations: list[Equation]) -> list[Const]:
symbol_names = set()
for eqn in equations:
symbol_names = symbol_names.union(eqn.lhs.get_symbol_names())
symbol_names = symbol_names.union(eqn.rhs.get_symbol_names())
return list(map(lambda x: Const(x), symbol_names))
def flatten_equations(equations: list[Equation], repl: Replacement=None) -> Tuple[list[Equation], Replacement]:
if repl is None:
repl = Replacement()
return [eq.flatten(repl) for eq in equations], repl
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/guoyiz/congruence_closure_computation.git
git@gitee.com:guoyiz/congruence_closure_computation.git
guoyiz
congruence_closure_computation
Computation of congruence closures
master

搜索帮助