added basic constant folding

This commit is contained in:
IgorCielniak
2026-02-07 21:06:01 +01:00
parent a8908819c4
commit 08892c97bb

74
main.py
View File

@@ -1684,6 +1684,41 @@ class FunctionEmitter:
]) ])
def _int_trunc_div(lhs: int, rhs: int) -> int:
if rhs == 0:
raise ZeroDivisionError("division by zero")
quotient = abs(lhs) // abs(rhs)
if (lhs < 0) ^ (rhs < 0):
quotient = -quotient
return quotient
def _int_trunc_mod(lhs: int, rhs: int) -> int:
if rhs == 0:
raise ZeroDivisionError("division by zero")
return lhs - _int_trunc_div(lhs, rhs) * rhs
def _bool_to_int(value: bool) -> int:
return 1 if value else 0
_FOLDABLE_WORDS: Dict[str, Tuple[int, Callable[..., int]]] = {
"+": (2, lambda a, b: a + b),
"-": (2, lambda a, b: a - b),
"*": (2, lambda a, b: a * b),
"/": (2, _int_trunc_div),
"%": (2, _int_trunc_mod),
"==": (2, lambda a, b: _bool_to_int(a == b)),
"!=": (2, lambda a, b: _bool_to_int(a != b)),
"<": (2, lambda a, b: _bool_to_int(a < b)),
"<=": (2, lambda a, b: _bool_to_int(a <= b)),
">": (2, lambda a, b: _bool_to_int(a > b)),
">=": (2, lambda a, b: _bool_to_int(a >= b)),
"not": (1, lambda a: _bool_to_int(a == 0)),
}
def sanitize_label(name: str) -> str: def sanitize_label(name: str) -> str:
parts: List[str] = [] parts: List[str] = []
for ch in name: for ch in name:
@@ -1831,7 +1866,7 @@ class _CTHandleTable:
class Assembler: class Assembler:
def __init__(self, dictionary: Dictionary) -> None: def __init__(self, dictionary: Dictionary, *, enable_constant_folding: bool = True) -> None:
self.dictionary = dictionary self.dictionary = dictionary
self._string_literals: Dict[str, Tuple[str, int]] = {} self._string_literals: Dict[str, Tuple[str, int]] = {}
self._float_literals: Dict[float, str] = {} self._float_literals: Dict[float, str] = {}
@@ -1840,6 +1875,36 @@ class Assembler:
self._inline_counter: int = 0 self._inline_counter: int = 0
self._emit_stack: List[str] = [] self._emit_stack: List[str] = []
self._export_all_defs: bool = False self._export_all_defs: bool = False
self.enable_constant_folding = enable_constant_folding
def _fold_constants_in_definition(self, definition: Definition) -> None:
optimized: List[Op] = []
for node in definition.body:
optimized.append(node)
self._attempt_constant_fold_tail(optimized)
definition.body = optimized
def _attempt_constant_fold_tail(self, nodes: List[Op]) -> None:
while nodes:
last = nodes[-1]
if last.op != "word":
return
fold_entry = _FOLDABLE_WORDS.get(str(last.data))
if fold_entry is None:
return
arity, func = fold_entry
if len(nodes) < arity + 1:
return
operands = nodes[-(arity + 1):-1]
if any(op.op != "literal" or not isinstance(op.data, int) for op in operands):
return
values = [int(op.data) for op in operands]
try:
result = func(*values)
except Exception:
return
new_loc = operands[0].loc or last.loc
nodes[-(arity + 1):] = [Op(op="literal", data=result, loc=new_loc)]
def _reachable_runtime_defs(self, runtime_defs: Sequence[Union[Definition, AsmDefinition]]) -> Set[str]: def _reachable_runtime_defs(self, runtime_defs: Sequence[Union[Definition, AsmDefinition]]) -> Set[str]:
edges: Dict[str, Set[str]] = {} edges: Dict[str, Set[str]] = {}
@@ -1885,6 +1950,10 @@ class Assembler:
valid_defs = (Definition, AsmDefinition) valid_defs = (Definition, AsmDefinition)
raw_defs = [form for form in module.forms if isinstance(form, valid_defs)] raw_defs = [form for form in module.forms if isinstance(form, valid_defs)]
definitions = self._dedup_definitions(raw_defs) definitions = self._dedup_definitions(raw_defs)
if self.enable_constant_folding:
for defn in definitions:
if isinstance(defn, Definition):
self._fold_constants_in_definition(defn)
stray_forms = [form for form in module.forms if not isinstance(form, valid_defs)] stray_forms = [form for form in module.forms if not isinstance(form, valid_defs)]
if stray_forms: if stray_forms:
raise CompileError("top-level literals or word references are not supported yet") raise CompileError("top-level literals or word references are not supported yet")
@@ -3917,6 +3986,7 @@ def cli(argv: Sequence[str]) -> int:
parser.add_argument("--clean", action="store_true", help="remove the temp build directory and exit") parser.add_argument("--clean", action="store_true", help="remove the temp build directory and exit")
parser.add_argument("--repl", action="store_true", help="interactive REPL; source file is optional") parser.add_argument("--repl", action="store_true", help="interactive REPL; source file is optional")
parser.add_argument("-l", dest="libs", action="append", default=[], help="pass library to linker (e.g. -l m or -l libc.so.6)") parser.add_argument("-l", dest="libs", action="append", default=[], help="pass library to linker (e.g. -l m or -l libc.so.6)")
parser.add_argument("--no-folding", action="store_true", help="disable constant folding optimization")
# Parse known and unknown args to allow -l flags anywhere # Parse known and unknown args to allow -l flags anywhere
args, unknown = parser.parse_known_args(argv) args, unknown = parser.parse_known_args(argv)
@@ -3933,6 +4003,7 @@ def cli(argv: Sequence[str]) -> int:
i += 1 i += 1
artifact_kind = args.artifact artifact_kind = args.artifact
folding_enabled = not args.no_folding
if artifact_kind != "exe" and (args.run or args.dbg): if artifact_kind != "exe" and (args.run or args.dbg):
parser.error("--run/--dbg are only available when --artifact exe is selected") parser.error("--run/--dbg are only available when --artifact exe is selected")
@@ -3966,6 +4037,7 @@ def cli(argv: Sequence[str]) -> int:
print("[warn] --libs ignored for static/object outputs") print("[warn] --libs ignored for static/object outputs")
compiler = Compiler(include_paths=[Path("."), Path("./stdlib"), *args.include_paths]) compiler = Compiler(include_paths=[Path("."), Path("./stdlib"), *args.include_paths])
compiler.assembler.enable_constant_folding = folding_enabled
try: try:
if args.repl: if args.repl:
return run_repl(compiler, args.temp_dir, args.libs, debug=args.debug, initial_source=args.source) return run_repl(compiler, args.temp_dir, args.libs, debug=args.debug, initial_source=args.source)