From 08892c97bbc9db96f25d529ced06649388322574 Mon Sep 17 00:00:00 2001 From: IgorCielniak Date: Sat, 7 Feb 2026 21:06:01 +0100 Subject: [PATCH] added basic constant folding --- main.py | 74 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 74d1037..9f29a6b 100644 --- a/main.py +++ b/main.py @@ -1684,6 +1684,41 @@ class FunctionEmitter: ]) +def _int_trunc_div(lhs: int, rhs: int) -> int: + if rhs == 0: + raise ZeroDivisionError("division by zero") + quotient = abs(lhs) // abs(rhs) + if (lhs < 0) ^ (rhs < 0): + quotient = -quotient + return quotient + + +def _int_trunc_mod(lhs: int, rhs: int) -> int: + if rhs == 0: + raise ZeroDivisionError("division by zero") + return lhs - _int_trunc_div(lhs, rhs) * rhs + + +def _bool_to_int(value: bool) -> int: + return 1 if value else 0 + + +_FOLDABLE_WORDS: Dict[str, Tuple[int, Callable[..., int]]] = { + "+": (2, lambda a, b: a + b), + "-": (2, lambda a, b: a - b), + "*": (2, lambda a, b: a * b), + "/": (2, _int_trunc_div), + "%": (2, _int_trunc_mod), + "==": (2, lambda a, b: _bool_to_int(a == b)), + "!=": (2, lambda a, b: _bool_to_int(a != b)), + "<": (2, lambda a, b: _bool_to_int(a < b)), + "<=": (2, lambda a, b: _bool_to_int(a <= b)), + ">": (2, lambda a, b: _bool_to_int(a > b)), + ">=": (2, lambda a, b: _bool_to_int(a >= b)), + "not": (1, lambda a: _bool_to_int(a == 0)), +} + + def sanitize_label(name: str) -> str: parts: List[str] = [] for ch in name: @@ -1831,7 +1866,7 @@ class _CTHandleTable: class Assembler: - def __init__(self, dictionary: Dictionary) -> None: + def __init__(self, dictionary: Dictionary, *, enable_constant_folding: bool = True) -> None: self.dictionary = dictionary self._string_literals: Dict[str, Tuple[str, int]] = {} self._float_literals: Dict[float, str] = {} @@ -1840,6 +1875,36 @@ class Assembler: self._inline_counter: int = 0 self._emit_stack: List[str] = [] self._export_all_defs: bool = False + self.enable_constant_folding = enable_constant_folding + + def _fold_constants_in_definition(self, definition: Definition) -> None: + optimized: List[Op] = [] + for node in definition.body: + optimized.append(node) + self._attempt_constant_fold_tail(optimized) + definition.body = optimized + + def _attempt_constant_fold_tail(self, nodes: List[Op]) -> None: + while nodes: + last = nodes[-1] + if last.op != "word": + return + fold_entry = _FOLDABLE_WORDS.get(str(last.data)) + if fold_entry is None: + return + arity, func = fold_entry + if len(nodes) < arity + 1: + return + operands = nodes[-(arity + 1):-1] + if any(op.op != "literal" or not isinstance(op.data, int) for op in operands): + return + values = [int(op.data) for op in operands] + try: + result = func(*values) + except Exception: + return + new_loc = operands[0].loc or last.loc + nodes[-(arity + 1):] = [Op(op="literal", data=result, loc=new_loc)] def _reachable_runtime_defs(self, runtime_defs: Sequence[Union[Definition, AsmDefinition]]) -> Set[str]: edges: Dict[str, Set[str]] = {} @@ -1885,6 +1950,10 @@ class Assembler: valid_defs = (Definition, AsmDefinition) raw_defs = [form for form in module.forms if isinstance(form, valid_defs)] definitions = self._dedup_definitions(raw_defs) + if self.enable_constant_folding: + for defn in definitions: + if isinstance(defn, Definition): + self._fold_constants_in_definition(defn) stray_forms = [form for form in module.forms if not isinstance(form, valid_defs)] if stray_forms: raise CompileError("top-level literals or word references are not supported yet") @@ -3917,6 +3986,7 @@ def cli(argv: Sequence[str]) -> int: parser.add_argument("--clean", action="store_true", help="remove the temp build directory and exit") parser.add_argument("--repl", action="store_true", help="interactive REPL; source file is optional") parser.add_argument("-l", dest="libs", action="append", default=[], help="pass library to linker (e.g. -l m or -l libc.so.6)") + parser.add_argument("--no-folding", action="store_true", help="disable constant folding optimization") # Parse known and unknown args to allow -l flags anywhere args, unknown = parser.parse_known_args(argv) @@ -3933,6 +4003,7 @@ def cli(argv: Sequence[str]) -> int: i += 1 artifact_kind = args.artifact + folding_enabled = not args.no_folding if artifact_kind != "exe" and (args.run or args.dbg): parser.error("--run/--dbg are only available when --artifact exe is selected") @@ -3966,6 +4037,7 @@ def cli(argv: Sequence[str]) -> int: print("[warn] --libs ignored for static/object outputs") compiler = Compiler(include_paths=[Path("."), Path("./stdlib"), *args.include_paths]) + compiler.assembler.enable_constant_folding = folding_enabled try: if args.repl: return run_repl(compiler, args.temp_dir, args.libs, debug=args.debug, initial_source=args.source)