optimizations
This commit is contained in:
163
main.py
163
main.py
@@ -2510,6 +2510,18 @@ class CompileTimeVM:
|
||||
ip += 1
|
||||
continue
|
||||
|
||||
if kind == "list_literal":
|
||||
values = list(node.data or [])
|
||||
count = len(values)
|
||||
buf_size = (count + 1) * 8
|
||||
addr = self.memory.allocate(buf_size)
|
||||
CTMemory.write_qword(addr, count)
|
||||
for idx_item, val in enumerate(values):
|
||||
CTMemory.write_qword(addr + 8 + idx_item * 8, int(val))
|
||||
_push(addr)
|
||||
ip += 1
|
||||
continue
|
||||
|
||||
if kind == "list_end":
|
||||
if not self._list_capture_stack:
|
||||
raise ParseError("']' without matching '['")
|
||||
@@ -2931,6 +2943,82 @@ class Assembler:
|
||||
optimized.append(nodes[idx])
|
||||
idx += 1
|
||||
nodes = optimized
|
||||
|
||||
# Literal-aware algebraic identities and redundant unary chains.
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
optimized = []
|
||||
idx = 0
|
||||
|
||||
while idx < len(nodes):
|
||||
# Redundant unary pairs.
|
||||
if idx + 1 < len(nodes):
|
||||
a = nodes[idx]
|
||||
b = nodes[idx + 1]
|
||||
if a.op == "word" and b.op == "word":
|
||||
wa = str(a.data)
|
||||
wb = str(b.data)
|
||||
if (wa, wb) in {
|
||||
("not", "not"),
|
||||
("neg", "neg"),
|
||||
}:
|
||||
idx += 2
|
||||
changed = True
|
||||
continue
|
||||
|
||||
# Binary op identities where right operand is a literal.
|
||||
if idx + 1 < len(nodes):
|
||||
lit = nodes[idx]
|
||||
op = nodes[idx + 1]
|
||||
if lit.op == "literal" and isinstance(lit.data, int) and op.op == "word":
|
||||
k = int(lit.data)
|
||||
w = str(op.data)
|
||||
base_loc = lit.loc or op.loc
|
||||
|
||||
if (w == "+" and k == 0) or (w == "-" and k == 0) or (w == "*" and k == 1) or (w == "/" and k == 1):
|
||||
idx += 2
|
||||
changed = True
|
||||
continue
|
||||
|
||||
if w == "*" and k == -1:
|
||||
optimized.append(Op(op="word", data="neg", loc=base_loc))
|
||||
idx += 2
|
||||
changed = True
|
||||
continue
|
||||
|
||||
if w == "%" and k == 1:
|
||||
optimized.append(Op(op="word", data="drop", loc=base_loc))
|
||||
optimized.append(Op(op="literal", data=0, loc=base_loc))
|
||||
idx += 2
|
||||
changed = True
|
||||
continue
|
||||
|
||||
if w == "==" and k == 0:
|
||||
optimized.append(Op(op="word", data="not", loc=base_loc))
|
||||
idx += 2
|
||||
changed = True
|
||||
continue
|
||||
|
||||
if (w == "bor" and k == 0) or (w == "bxor" and k == 0):
|
||||
idx += 2
|
||||
changed = True
|
||||
continue
|
||||
|
||||
if w == "band" and k == -1:
|
||||
idx += 2
|
||||
changed = True
|
||||
continue
|
||||
|
||||
if w in {"shl", "shr", "sar"} and k == 0:
|
||||
idx += 2
|
||||
changed = True
|
||||
continue
|
||||
|
||||
optimized.append(nodes[idx])
|
||||
idx += 1
|
||||
|
||||
nodes = optimized
|
||||
definition.body = nodes
|
||||
|
||||
def _fold_constants_in_definition(self, definition: Definition) -> None:
|
||||
@@ -3084,6 +3172,58 @@ class Assembler:
|
||||
|
||||
definition.body = rebuilt
|
||||
|
||||
def _fold_static_list_literals_definition(self, definition: Definition) -> None:
|
||||
nodes = definition.body
|
||||
rebuilt: List[Op] = []
|
||||
idx = 0
|
||||
while idx < len(nodes):
|
||||
node = nodes[idx]
|
||||
if node.op != "list_begin":
|
||||
rebuilt.append(node)
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
depth = 1
|
||||
j = idx + 1
|
||||
static_values: List[int] = []
|
||||
is_static = True
|
||||
|
||||
while j < len(nodes):
|
||||
cur = nodes[j]
|
||||
if cur.op == "list_begin":
|
||||
depth += 1
|
||||
is_static = False
|
||||
j += 1
|
||||
continue
|
||||
if cur.op == "list_end":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
break
|
||||
j += 1
|
||||
continue
|
||||
|
||||
if depth == 1:
|
||||
if cur.op == "literal" and isinstance(cur.data, int):
|
||||
static_values.append(int(cur.data))
|
||||
else:
|
||||
is_static = False
|
||||
j += 1
|
||||
|
||||
if depth != 0:
|
||||
rebuilt.append(node)
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
if is_static:
|
||||
rebuilt.append(Op(op="list_literal", data=static_values, loc=node.loc))
|
||||
idx = j + 1
|
||||
continue
|
||||
|
||||
rebuilt.append(node)
|
||||
idx += 1
|
||||
|
||||
definition.body = rebuilt
|
||||
|
||||
def _reachable_runtime_defs(self, runtime_defs: Sequence[Union[Definition, AsmDefinition]]) -> Set[str]:
|
||||
edges: Dict[str, Set[str]] = {}
|
||||
for definition in runtime_defs:
|
||||
@@ -3139,6 +3279,9 @@ class Assembler:
|
||||
for defn in definitions:
|
||||
if isinstance(defn, Definition):
|
||||
self._fold_constants_in_definition(defn)
|
||||
for defn in definitions:
|
||||
if isinstance(defn, Definition):
|
||||
self._fold_static_list_literals_definition(defn)
|
||||
stray_forms = [form for form in module.forms if not isinstance(form, valid_defs)]
|
||||
if stray_forms:
|
||||
raise CompileError("top-level literals or word references are not supported yet")
|
||||
@@ -3386,6 +3529,26 @@ class Assembler:
|
||||
builder.emit(" mov [rel list_capture_sp], rax")
|
||||
return
|
||||
|
||||
if kind == "list_literal":
|
||||
values = list(data or [])
|
||||
count = len(values)
|
||||
bytes_needed = (count + 1) * 8
|
||||
builder.comment("list literal")
|
||||
builder.emit(" xor rdi, rdi")
|
||||
builder.emit(f" mov rsi, {bytes_needed}")
|
||||
builder.emit(" mov rdx, 3")
|
||||
builder.emit(" mov r10, 34")
|
||||
builder.emit(" mov r8, -1")
|
||||
builder.emit(" xor r9, r9")
|
||||
builder.emit(" mov rax, 9")
|
||||
builder.emit(" syscall")
|
||||
builder.emit(f" mov qword [rax], {count}")
|
||||
for idx_item, value in enumerate(values):
|
||||
builder.emit(f" mov qword [rax + {8 + idx_item * 8}], {int(value)}")
|
||||
builder.emit(" sub r12, 8")
|
||||
builder.emit(" mov [r12], rax")
|
||||
return
|
||||
|
||||
if kind == "list_end":
|
||||
base = str(data)
|
||||
loop_label = f"{base}_copy_loop"
|
||||
|
||||
Reference in New Issue
Block a user