optimizations

This commit is contained in:
igor
2026-02-18 14:18:41 +01:00
parent b0de836beb
commit 5030138f30
2 changed files with 164 additions and 1 deletions

163
main.py
View File

@@ -2510,6 +2510,18 @@ class CompileTimeVM:
ip += 1
continue
if kind == "list_literal":
values = list(node.data or [])
count = len(values)
buf_size = (count + 1) * 8
addr = self.memory.allocate(buf_size)
CTMemory.write_qword(addr, count)
for idx_item, val in enumerate(values):
CTMemory.write_qword(addr + 8 + idx_item * 8, int(val))
_push(addr)
ip += 1
continue
if kind == "list_end":
if not self._list_capture_stack:
raise ParseError("']' without matching '['")
@@ -2931,6 +2943,82 @@ class Assembler:
optimized.append(nodes[idx])
idx += 1
nodes = optimized
# Literal-aware algebraic identities and redundant unary chains.
changed = True
while changed:
changed = False
optimized = []
idx = 0
while idx < len(nodes):
# Redundant unary pairs.
if idx + 1 < len(nodes):
a = nodes[idx]
b = nodes[idx + 1]
if a.op == "word" and b.op == "word":
wa = str(a.data)
wb = str(b.data)
if (wa, wb) in {
("not", "not"),
("neg", "neg"),
}:
idx += 2
changed = True
continue
# Binary op identities where right operand is a literal.
if idx + 1 < len(nodes):
lit = nodes[idx]
op = nodes[idx + 1]
if lit.op == "literal" and isinstance(lit.data, int) and op.op == "word":
k = int(lit.data)
w = str(op.data)
base_loc = lit.loc or op.loc
if (w == "+" and k == 0) or (w == "-" and k == 0) or (w == "*" and k == 1) or (w == "/" and k == 1):
idx += 2
changed = True
continue
if w == "*" and k == -1:
optimized.append(Op(op="word", data="neg", loc=base_loc))
idx += 2
changed = True
continue
if w == "%" and k == 1:
optimized.append(Op(op="word", data="drop", loc=base_loc))
optimized.append(Op(op="literal", data=0, loc=base_loc))
idx += 2
changed = True
continue
if w == "==" and k == 0:
optimized.append(Op(op="word", data="not", loc=base_loc))
idx += 2
changed = True
continue
if (w == "bor" and k == 0) or (w == "bxor" and k == 0):
idx += 2
changed = True
continue
if w == "band" and k == -1:
idx += 2
changed = True
continue
if w in {"shl", "shr", "sar"} and k == 0:
idx += 2
changed = True
continue
optimized.append(nodes[idx])
idx += 1
nodes = optimized
definition.body = nodes
def _fold_constants_in_definition(self, definition: Definition) -> None:
@@ -3084,6 +3172,58 @@ class Assembler:
definition.body = rebuilt
def _fold_static_list_literals_definition(self, definition: Definition) -> None:
nodes = definition.body
rebuilt: List[Op] = []
idx = 0
while idx < len(nodes):
node = nodes[idx]
if node.op != "list_begin":
rebuilt.append(node)
idx += 1
continue
depth = 1
j = idx + 1
static_values: List[int] = []
is_static = True
while j < len(nodes):
cur = nodes[j]
if cur.op == "list_begin":
depth += 1
is_static = False
j += 1
continue
if cur.op == "list_end":
depth -= 1
if depth == 0:
break
j += 1
continue
if depth == 1:
if cur.op == "literal" and isinstance(cur.data, int):
static_values.append(int(cur.data))
else:
is_static = False
j += 1
if depth != 0:
rebuilt.append(node)
idx += 1
continue
if is_static:
rebuilt.append(Op(op="list_literal", data=static_values, loc=node.loc))
idx = j + 1
continue
rebuilt.append(node)
idx += 1
definition.body = rebuilt
def _reachable_runtime_defs(self, runtime_defs: Sequence[Union[Definition, AsmDefinition]]) -> Set[str]:
edges: Dict[str, Set[str]] = {}
for definition in runtime_defs:
@@ -3139,6 +3279,9 @@ class Assembler:
for defn in definitions:
if isinstance(defn, Definition):
self._fold_constants_in_definition(defn)
for defn in definitions:
if isinstance(defn, Definition):
self._fold_static_list_literals_definition(defn)
stray_forms = [form for form in module.forms if not isinstance(form, valid_defs)]
if stray_forms:
raise CompileError("top-level literals or word references are not supported yet")
@@ -3386,6 +3529,26 @@ class Assembler:
builder.emit(" mov [rel list_capture_sp], rax")
return
if kind == "list_literal":
values = list(data or [])
count = len(values)
bytes_needed = (count + 1) * 8
builder.comment("list literal")
builder.emit(" xor rdi, rdi")
builder.emit(f" mov rsi, {bytes_needed}")
builder.emit(" mov rdx, 3")
builder.emit(" mov r10, 34")
builder.emit(" mov r8, -1")
builder.emit(" xor r9, r9")
builder.emit(" mov rax, 9")
builder.emit(" syscall")
builder.emit(f" mov qword [rax], {count}")
for idx_item, value in enumerate(values):
builder.emit(f" mov qword [rax + {8 + idx_item * 8}], {int(value)}")
builder.emit(" sub r12, 8")
builder.emit(" mov [r12], rax")
return
if kind == "list_end":
base = str(data)
loop_label = f"{base}_copy_loop"