added basic constant folding
This commit is contained in:
74
main.py
74
main.py
@@ -1684,6 +1684,41 @@ class FunctionEmitter:
|
|||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def _int_trunc_div(lhs: int, rhs: int) -> int:
|
||||||
|
if rhs == 0:
|
||||||
|
raise ZeroDivisionError("division by zero")
|
||||||
|
quotient = abs(lhs) // abs(rhs)
|
||||||
|
if (lhs < 0) ^ (rhs < 0):
|
||||||
|
quotient = -quotient
|
||||||
|
return quotient
|
||||||
|
|
||||||
|
|
||||||
|
def _int_trunc_mod(lhs: int, rhs: int) -> int:
|
||||||
|
if rhs == 0:
|
||||||
|
raise ZeroDivisionError("division by zero")
|
||||||
|
return lhs - _int_trunc_div(lhs, rhs) * rhs
|
||||||
|
|
||||||
|
|
||||||
|
def _bool_to_int(value: bool) -> int:
|
||||||
|
return 1 if value else 0
|
||||||
|
|
||||||
|
|
||||||
|
_FOLDABLE_WORDS: Dict[str, Tuple[int, Callable[..., int]]] = {
|
||||||
|
"+": (2, lambda a, b: a + b),
|
||||||
|
"-": (2, lambda a, b: a - b),
|
||||||
|
"*": (2, lambda a, b: a * b),
|
||||||
|
"/": (2, _int_trunc_div),
|
||||||
|
"%": (2, _int_trunc_mod),
|
||||||
|
"==": (2, lambda a, b: _bool_to_int(a == b)),
|
||||||
|
"!=": (2, lambda a, b: _bool_to_int(a != b)),
|
||||||
|
"<": (2, lambda a, b: _bool_to_int(a < b)),
|
||||||
|
"<=": (2, lambda a, b: _bool_to_int(a <= b)),
|
||||||
|
">": (2, lambda a, b: _bool_to_int(a > b)),
|
||||||
|
">=": (2, lambda a, b: _bool_to_int(a >= b)),
|
||||||
|
"not": (1, lambda a: _bool_to_int(a == 0)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def sanitize_label(name: str) -> str:
|
def sanitize_label(name: str) -> str:
|
||||||
parts: List[str] = []
|
parts: List[str] = []
|
||||||
for ch in name:
|
for ch in name:
|
||||||
@@ -1831,7 +1866,7 @@ class _CTHandleTable:
|
|||||||
|
|
||||||
|
|
||||||
class Assembler:
|
class Assembler:
|
||||||
def __init__(self, dictionary: Dictionary) -> None:
|
def __init__(self, dictionary: Dictionary, *, enable_constant_folding: bool = True) -> None:
|
||||||
self.dictionary = dictionary
|
self.dictionary = dictionary
|
||||||
self._string_literals: Dict[str, Tuple[str, int]] = {}
|
self._string_literals: Dict[str, Tuple[str, int]] = {}
|
||||||
self._float_literals: Dict[float, str] = {}
|
self._float_literals: Dict[float, str] = {}
|
||||||
@@ -1840,6 +1875,36 @@ class Assembler:
|
|||||||
self._inline_counter: int = 0
|
self._inline_counter: int = 0
|
||||||
self._emit_stack: List[str] = []
|
self._emit_stack: List[str] = []
|
||||||
self._export_all_defs: bool = False
|
self._export_all_defs: bool = False
|
||||||
|
self.enable_constant_folding = enable_constant_folding
|
||||||
|
|
||||||
|
def _fold_constants_in_definition(self, definition: Definition) -> None:
|
||||||
|
optimized: List[Op] = []
|
||||||
|
for node in definition.body:
|
||||||
|
optimized.append(node)
|
||||||
|
self._attempt_constant_fold_tail(optimized)
|
||||||
|
definition.body = optimized
|
||||||
|
|
||||||
|
def _attempt_constant_fold_tail(self, nodes: List[Op]) -> None:
|
||||||
|
while nodes:
|
||||||
|
last = nodes[-1]
|
||||||
|
if last.op != "word":
|
||||||
|
return
|
||||||
|
fold_entry = _FOLDABLE_WORDS.get(str(last.data))
|
||||||
|
if fold_entry is None:
|
||||||
|
return
|
||||||
|
arity, func = fold_entry
|
||||||
|
if len(nodes) < arity + 1:
|
||||||
|
return
|
||||||
|
operands = nodes[-(arity + 1):-1]
|
||||||
|
if any(op.op != "literal" or not isinstance(op.data, int) for op in operands):
|
||||||
|
return
|
||||||
|
values = [int(op.data) for op in operands]
|
||||||
|
try:
|
||||||
|
result = func(*values)
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
new_loc = operands[0].loc or last.loc
|
||||||
|
nodes[-(arity + 1):] = [Op(op="literal", data=result, loc=new_loc)]
|
||||||
|
|
||||||
def _reachable_runtime_defs(self, runtime_defs: Sequence[Union[Definition, AsmDefinition]]) -> Set[str]:
|
def _reachable_runtime_defs(self, runtime_defs: Sequence[Union[Definition, AsmDefinition]]) -> Set[str]:
|
||||||
edges: Dict[str, Set[str]] = {}
|
edges: Dict[str, Set[str]] = {}
|
||||||
@@ -1885,6 +1950,10 @@ class Assembler:
|
|||||||
valid_defs = (Definition, AsmDefinition)
|
valid_defs = (Definition, AsmDefinition)
|
||||||
raw_defs = [form for form in module.forms if isinstance(form, valid_defs)]
|
raw_defs = [form for form in module.forms if isinstance(form, valid_defs)]
|
||||||
definitions = self._dedup_definitions(raw_defs)
|
definitions = self._dedup_definitions(raw_defs)
|
||||||
|
if self.enable_constant_folding:
|
||||||
|
for defn in definitions:
|
||||||
|
if isinstance(defn, Definition):
|
||||||
|
self._fold_constants_in_definition(defn)
|
||||||
stray_forms = [form for form in module.forms if not isinstance(form, valid_defs)]
|
stray_forms = [form for form in module.forms if not isinstance(form, valid_defs)]
|
||||||
if stray_forms:
|
if stray_forms:
|
||||||
raise CompileError("top-level literals or word references are not supported yet")
|
raise CompileError("top-level literals or word references are not supported yet")
|
||||||
@@ -3917,6 +3986,7 @@ def cli(argv: Sequence[str]) -> int:
|
|||||||
parser.add_argument("--clean", action="store_true", help="remove the temp build directory and exit")
|
parser.add_argument("--clean", action="store_true", help="remove the temp build directory and exit")
|
||||||
parser.add_argument("--repl", action="store_true", help="interactive REPL; source file is optional")
|
parser.add_argument("--repl", action="store_true", help="interactive REPL; source file is optional")
|
||||||
parser.add_argument("-l", dest="libs", action="append", default=[], help="pass library to linker (e.g. -l m or -l libc.so.6)")
|
parser.add_argument("-l", dest="libs", action="append", default=[], help="pass library to linker (e.g. -l m or -l libc.so.6)")
|
||||||
|
parser.add_argument("--no-folding", action="store_true", help="disable constant folding optimization")
|
||||||
|
|
||||||
# Parse known and unknown args to allow -l flags anywhere
|
# Parse known and unknown args to allow -l flags anywhere
|
||||||
args, unknown = parser.parse_known_args(argv)
|
args, unknown = parser.parse_known_args(argv)
|
||||||
@@ -3933,6 +4003,7 @@ def cli(argv: Sequence[str]) -> int:
|
|||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
artifact_kind = args.artifact
|
artifact_kind = args.artifact
|
||||||
|
folding_enabled = not args.no_folding
|
||||||
|
|
||||||
if artifact_kind != "exe" and (args.run or args.dbg):
|
if artifact_kind != "exe" and (args.run or args.dbg):
|
||||||
parser.error("--run/--dbg are only available when --artifact exe is selected")
|
parser.error("--run/--dbg are only available when --artifact exe is selected")
|
||||||
@@ -3966,6 +4037,7 @@ def cli(argv: Sequence[str]) -> int:
|
|||||||
print("[warn] --libs ignored for static/object outputs")
|
print("[warn] --libs ignored for static/object outputs")
|
||||||
|
|
||||||
compiler = Compiler(include_paths=[Path("."), Path("./stdlib"), *args.include_paths])
|
compiler = Compiler(include_paths=[Path("."), Path("./stdlib"), *args.include_paths])
|
||||||
|
compiler.assembler.enable_constant_folding = folding_enabled
|
||||||
try:
|
try:
|
||||||
if args.repl:
|
if args.repl:
|
||||||
return run_repl(compiler, args.temp_dir, args.libs, debug=args.debug, initial_source=args.source)
|
return run_repl(compiler, args.temp_dir, args.libs, debug=args.debug, initial_source=args.source)
|
||||||
|
|||||||
Reference in New Issue
Block a user