added basic constant folding
This commit is contained in:
74
main.py
74
main.py
@@ -1684,6 +1684,41 @@ class FunctionEmitter:
|
||||
])
|
||||
|
||||
|
||||
def _int_trunc_div(lhs: int, rhs: int) -> int:
|
||||
if rhs == 0:
|
||||
raise ZeroDivisionError("division by zero")
|
||||
quotient = abs(lhs) // abs(rhs)
|
||||
if (lhs < 0) ^ (rhs < 0):
|
||||
quotient = -quotient
|
||||
return quotient
|
||||
|
||||
|
||||
def _int_trunc_mod(lhs: int, rhs: int) -> int:
|
||||
if rhs == 0:
|
||||
raise ZeroDivisionError("division by zero")
|
||||
return lhs - _int_trunc_div(lhs, rhs) * rhs
|
||||
|
||||
|
||||
def _bool_to_int(value: bool) -> int:
|
||||
return 1 if value else 0
|
||||
|
||||
|
||||
_FOLDABLE_WORDS: Dict[str, Tuple[int, Callable[..., int]]] = {
|
||||
"+": (2, lambda a, b: a + b),
|
||||
"-": (2, lambda a, b: a - b),
|
||||
"*": (2, lambda a, b: a * b),
|
||||
"/": (2, _int_trunc_div),
|
||||
"%": (2, _int_trunc_mod),
|
||||
"==": (2, lambda a, b: _bool_to_int(a == b)),
|
||||
"!=": (2, lambda a, b: _bool_to_int(a != b)),
|
||||
"<": (2, lambda a, b: _bool_to_int(a < b)),
|
||||
"<=": (2, lambda a, b: _bool_to_int(a <= b)),
|
||||
">": (2, lambda a, b: _bool_to_int(a > b)),
|
||||
">=": (2, lambda a, b: _bool_to_int(a >= b)),
|
||||
"not": (1, lambda a: _bool_to_int(a == 0)),
|
||||
}
|
||||
|
||||
|
||||
def sanitize_label(name: str) -> str:
|
||||
parts: List[str] = []
|
||||
for ch in name:
|
||||
@@ -1831,7 +1866,7 @@ class _CTHandleTable:
|
||||
|
||||
|
||||
class Assembler:
|
||||
def __init__(self, dictionary: Dictionary) -> None:
|
||||
def __init__(self, dictionary: Dictionary, *, enable_constant_folding: bool = True) -> None:
|
||||
self.dictionary = dictionary
|
||||
self._string_literals: Dict[str, Tuple[str, int]] = {}
|
||||
self._float_literals: Dict[float, str] = {}
|
||||
@@ -1840,6 +1875,36 @@ class Assembler:
|
||||
self._inline_counter: int = 0
|
||||
self._emit_stack: List[str] = []
|
||||
self._export_all_defs: bool = False
|
||||
self.enable_constant_folding = enable_constant_folding
|
||||
|
||||
def _fold_constants_in_definition(self, definition: Definition) -> None:
|
||||
optimized: List[Op] = []
|
||||
for node in definition.body:
|
||||
optimized.append(node)
|
||||
self._attempt_constant_fold_tail(optimized)
|
||||
definition.body = optimized
|
||||
|
||||
def _attempt_constant_fold_tail(self, nodes: List[Op]) -> None:
|
||||
while nodes:
|
||||
last = nodes[-1]
|
||||
if last.op != "word":
|
||||
return
|
||||
fold_entry = _FOLDABLE_WORDS.get(str(last.data))
|
||||
if fold_entry is None:
|
||||
return
|
||||
arity, func = fold_entry
|
||||
if len(nodes) < arity + 1:
|
||||
return
|
||||
operands = nodes[-(arity + 1):-1]
|
||||
if any(op.op != "literal" or not isinstance(op.data, int) for op in operands):
|
||||
return
|
||||
values = [int(op.data) for op in operands]
|
||||
try:
|
||||
result = func(*values)
|
||||
except Exception:
|
||||
return
|
||||
new_loc = operands[0].loc or last.loc
|
||||
nodes[-(arity + 1):] = [Op(op="literal", data=result, loc=new_loc)]
|
||||
|
||||
def _reachable_runtime_defs(self, runtime_defs: Sequence[Union[Definition, AsmDefinition]]) -> Set[str]:
|
||||
edges: Dict[str, Set[str]] = {}
|
||||
@@ -1885,6 +1950,10 @@ class Assembler:
|
||||
valid_defs = (Definition, AsmDefinition)
|
||||
raw_defs = [form for form in module.forms if isinstance(form, valid_defs)]
|
||||
definitions = self._dedup_definitions(raw_defs)
|
||||
if self.enable_constant_folding:
|
||||
for defn in definitions:
|
||||
if isinstance(defn, Definition):
|
||||
self._fold_constants_in_definition(defn)
|
||||
stray_forms = [form for form in module.forms if not isinstance(form, valid_defs)]
|
||||
if stray_forms:
|
||||
raise CompileError("top-level literals or word references are not supported yet")
|
||||
@@ -3917,6 +3986,7 @@ def cli(argv: Sequence[str]) -> int:
|
||||
parser.add_argument("--clean", action="store_true", help="remove the temp build directory and exit")
|
||||
parser.add_argument("--repl", action="store_true", help="interactive REPL; source file is optional")
|
||||
parser.add_argument("-l", dest="libs", action="append", default=[], help="pass library to linker (e.g. -l m or -l libc.so.6)")
|
||||
parser.add_argument("--no-folding", action="store_true", help="disable constant folding optimization")
|
||||
|
||||
# Parse known and unknown args to allow -l flags anywhere
|
||||
args, unknown = parser.parse_known_args(argv)
|
||||
@@ -3933,6 +4003,7 @@ def cli(argv: Sequence[str]) -> int:
|
||||
i += 1
|
||||
|
||||
artifact_kind = args.artifact
|
||||
folding_enabled = not args.no_folding
|
||||
|
||||
if artifact_kind != "exe" and (args.run or args.dbg):
|
||||
parser.error("--run/--dbg are only available when --artifact exe is selected")
|
||||
@@ -3966,6 +4037,7 @@ def cli(argv: Sequence[str]) -> int:
|
||||
print("[warn] --libs ignored for static/object outputs")
|
||||
|
||||
compiler = Compiler(include_paths=[Path("."), Path("./stdlib"), *args.include_paths])
|
||||
compiler.assembler.enable_constant_folding = folding_enabled
|
||||
try:
|
||||
if args.repl:
|
||||
return run_repl(compiler, args.temp_dir, args.libs, debug=args.debug, initial_source=args.source)
|
||||
|
||||
Reference in New Issue
Block a user