added more extensive peephole optimizations

This commit is contained in:
igor
2026-03-02 14:44:49 +01:00
parent 4a5dd89932
commit 269055be5f

441
main.py
View File

@@ -3631,143 +3631,344 @@ class Assembler:
return "\n".join(lines)
def _peephole_optimize_definition(self, definition: Definition) -> None:
# Rewrite short stack-manipulation sequences into canonical forms.
rules: List[Tuple[Tuple[str, ...], Tuple[str, ...]]] = [
# Word-only rewrite rules: pattern → replacement (both tuples of
# word names). The engine scans left-to-right and applies the
# longest matching rule at each position.
word_rules: List[Tuple[Tuple[str, ...], Tuple[str, ...]]] = [
# --- stack no-ops (cancellation) ---
(("dup", "drop"), ()),
(("swap", "swap"), ()),
(("over", "drop"), ()),
(("dup", "nip"), ()),
(("2dup", "2drop"), ()),
(("2swap", "2swap"), ()),
(("rot", "rot", "rot"), ()),
(("rot", "-rot"), ()),
(("-rot", "rot"), ()),
(("drop", "drop"), ("2drop",)),
(("over", "over"), ("2dup",)),
(("inc", "dec"), ()),
(("dec", "inc"), ()),
(("neg", "neg"), ()),
(("not", "not"), ()),
(("bitnot", "bitnot"), ()),
(("bnot", "bnot"), ()),
(("abs", "abs"), ("abs",)),
# --- canonicalizations that merge into single ops ---
(("swap", "drop"), ("nip",)),
# Stack no-ops
(("dup", "drop"), tuple()),
(("swap", "swap"), tuple()),
(("over", "drop"), tuple()),
(("dup", "nip"), tuple()),
(("2dup", "2drop"), tuple()),
(("2swap", "2swap"), tuple()),
(("rot", "rot", "rot"), tuple()),
# Canonicalizations
(("swap", "over"), ("tuck",)),
(("swap", "nip"), ("drop",)),
(("nip", "drop"), ("2drop",)),
(("tuck", "drop"), ("swap",)),
# --- commutative ops: swap before them is a no-op ---
(("swap", "+"), ("+",)),
(("swap", "*"), ("*",)),
(("swap", "=="), ("==",)),
(("swap", "!="), ("!=",)),
(("swap", "band"), ("band",)),
(("swap", "bor"), ("bor",)),
(("swap", "bxor"), ("bxor",)),
(("swap", "and"), ("and",)),
(("swap", "or"), ("or",)),
(("swap", "min"), ("min",)),
(("swap", "max"), ("max",)),
# --- dup + self-idempotent binary → identity ---
(("dup", "bor"), ()), # x | x == x
(("dup", "band"), ()), # x & x == x
(("dup", "bxor"), ("drop", "literal_0")), # x ^ x == 0
(("dup", "=="), ("drop", "literal_1")), # x == x always true
(("dup", "-"), ("drop", "literal_0")), # x - x == 0
]
max_pat_len = max(len(pattern) for pattern, _ in rules)
# Filter out placeholder rules whose replacements contain
# pseudo-words; expand them into proper Op sequences later.
_PLACEHOLDER_RULES: Dict[Tuple[str, ...], Tuple[str, ...]] = {}
clean_rules: List[Tuple[Tuple[str, ...], Tuple[str, ...]]] = []
for pat, repl in word_rules:
if any(r.startswith("literal_") for r in repl):
_PLACEHOLDER_RULES[pat] = repl
else:
clean_rules.append((pat, repl))
word_rules = clean_rules
# Build index: first word -> list of (pattern, replacement)
max_pat_len = max(len(p) for p, _ in word_rules) if word_rules else 0
rule_index: Dict[str, List[Tuple[Tuple[str, ...], Tuple[str, ...]]]] = {}
for pattern, repl in rules:
for pattern, repl in word_rules:
rule_index.setdefault(pattern[0], []).append((pattern, repl))
nodes = definition.body
changed = True
while changed:
changed = False
optimized: List[Op] = []
idx = 0
while idx < len(nodes):
node = nodes[idx]
matched = False
if node._opcode == OP_WORD:
candidates = rule_index.get(str(node.data))
if candidates:
for window in range(min(max_pat_len, len(nodes) - idx), 1, -1):
segment = nodes[idx:idx + window]
if any(n._opcode != OP_WORD for n in segment):
# Outer loop: keeps re-running all passes until nothing changes.
any_changed = True
while any_changed:
any_changed = False
# ---------- Pass 1: word-only pattern rewriting ----------
changed = True
while changed:
changed = False
optimized: List[Op] = []
idx = 0
while idx < len(nodes):
node = nodes[idx]
matched = False
if node._opcode == OP_WORD:
word_name = str(node.data)
# --- placeholder rules (produce literals) ---
for pat, repl in _PLACEHOLDER_RULES.items():
plen = len(pat)
if pat[0] != word_name:
continue
names = tuple(str(n.data) for n in segment)
replacement: Optional[Tuple[str, ...]] = None
for pattern, repl in candidates:
if names == pattern:
replacement = repl
break
if replacement is None:
if idx + plen > len(nodes):
continue
base_loc = segment[0].loc
for repl_name in replacement:
optimized.append(Op(op="word", data=repl_name, loc=base_loc))
idx += window
seg = nodes[idx:idx + plen]
if any(n._opcode != OP_WORD for n in seg):
continue
if tuple(str(n.data) for n in seg) == pat:
base_loc = seg[0].loc
for r in repl:
if r.startswith("literal_"):
val = int(r[len("literal_"):])
optimized.append(Op(op="literal", data=val, loc=base_loc))
else:
optimized.append(Op(op="word", data=r, loc=base_loc))
idx += plen
changed = True
matched = True
break
if matched:
continue
# --- normal word-only rules ---
candidates = rule_index.get(word_name)
if candidates:
for window in range(min(max_pat_len, len(nodes) - idx), 1, -1):
segment = nodes[idx:idx + window]
if any(n._opcode != OP_WORD for n in segment):
continue
names = tuple(str(n.data) for n in segment)
replacement: Optional[Tuple[str, ...]] = None
for pattern, repl in candidates:
if names == pattern:
replacement = repl
break
if replacement is None:
continue
base_loc = segment[0].loc
for repl_name in replacement:
optimized.append(Op(op="word", data=repl_name, loc=base_loc))
idx += window
changed = True
matched = True
break
if matched:
continue
optimized.append(nodes[idx])
idx += 1
if changed:
any_changed = True
nodes = optimized
# ---------- Pass 2: literal + word algebraic identities ----------
# String literals push TWO values (pointer + length), so
# most literal-aware rewrites must be restricted to scalars.
def _is_scalar_literal(node: Op) -> bool:
return node._opcode == OP_LITERAL and not isinstance(node.data, str)
changed = True
while changed:
changed = False
optimized = []
idx = 0
while idx < len(nodes):
# -- Redundant unary pairs (word word) --
if idx + 1 < len(nodes):
a, b = nodes[idx], nodes[idx + 1]
if a._opcode == OP_WORD and b._opcode == OP_WORD:
wa, wb = str(a.data), str(b.data)
if (wa, wb) in {
("not", "not"), ("neg", "neg"),
("bitnot", "bitnot"), ("bnot", "bnot"),
("inc", "dec"), ("dec", "inc"),
}:
idx += 2
changed = True
continue
# abs is idempotent
if wa == "abs" and wb == "abs":
optimized.append(a)
idx += 2
changed = True
continue
# -- scalar literal + dup → literal literal --
if idx + 1 < len(nodes):
a, b = nodes[idx], nodes[idx + 1]
if _is_scalar_literal(a) and b._opcode == OP_WORD and str(b.data) == "dup":
optimized.append(a)
optimized.append(Op(op="literal", data=a.data, loc=a.loc))
idx += 2
changed = True
matched = True
break
if matched:
continue
# -- scalar literal + drop → (nothing) --
if idx + 1 < len(nodes):
a, b = nodes[idx], nodes[idx + 1]
if _is_scalar_literal(a) and b._opcode == OP_WORD and str(b.data) == "drop":
idx += 2
changed = True
continue
# -- scalar literal scalar literal + 2drop → (nothing) --
if idx + 2 < len(nodes):
a, b, c = nodes[idx], nodes[idx + 1], nodes[idx + 2]
if _is_scalar_literal(a) and _is_scalar_literal(b) and c._opcode == OP_WORD and str(c.data) == "2drop":
idx += 3
changed = True
continue
# -- scalar literal scalar literal + swap → swapped --
if idx + 2 < len(nodes):
a, b, c = nodes[idx], nodes[idx + 1], nodes[idx + 2]
if _is_scalar_literal(a) and _is_scalar_literal(b) and c._opcode == OP_WORD and str(c.data) == "swap":
optimized.append(Op(op="literal", data=b.data, loc=b.loc))
optimized.append(Op(op="literal", data=a.data, loc=a.loc))
idx += 3
changed = True
continue
# -- Binary op identities: literal K + word --
if idx + 1 < len(nodes):
lit, op = nodes[idx], nodes[idx + 1]
if lit._opcode == OP_LITERAL and isinstance(lit.data, int) and op._opcode == OP_WORD:
k = int(lit.data)
w = str(op.data)
base_loc = lit.loc or op.loc
# Identity elements
if (w == "+" and k == 0) or (w == "-" and k == 0) or (w == "*" and k == 1) or (w == "/" and k == 1):
idx += 2
changed = True
continue
# Absorbing elements
if w == "*" and k == 0:
optimized.append(Op(op="word", data="drop", loc=base_loc))
optimized.append(Op(op="literal", data=0, loc=base_loc))
idx += 2
changed = True
continue
if w == "band" and k == 0:
optimized.append(Op(op="word", data="drop", loc=base_loc))
optimized.append(Op(op="literal", data=0, loc=base_loc))
idx += 2
changed = True
continue
if w == "bor" and k == -1:
optimized.append(Op(op="word", data="drop", loc=base_loc))
optimized.append(Op(op="literal", data=-1, loc=base_loc))
idx += 2
changed = True
continue
# Negate
if w == "*" and k == -1:
optimized.append(Op(op="word", data="neg", loc=base_loc))
idx += 2
changed = True
continue
# Modulo 1 → always 0
if w == "%" and k == 1:
optimized.append(Op(op="word", data="drop", loc=base_loc))
optimized.append(Op(op="literal", data=0, loc=base_loc))
idx += 2
changed = True
continue
# 0 == → not
if w == "==" and k == 0:
optimized.append(Op(op="word", data="not", loc=base_loc))
idx += 2
changed = True
continue
# No-op bitwise
if (w == "bor" and k == 0) or (w == "bxor" and k == 0):
idx += 2
changed = True
continue
if w == "band" and k == -1:
idx += 2
changed = True
continue
if w in {"shl", "shr", "sar"} and k == 0:
idx += 2
changed = True
continue
# Strength reduction: multiply by power of 2 → shl
if w == "*" and k > 1 and (k & (k - 1)) == 0:
shift = k.bit_length() - 1
optimized.append(Op(op="literal", data=shift, loc=base_loc))
optimized.append(Op(op="word", data="shl", loc=base_loc))
idx += 2
changed = True
continue
# +1 → inc, -1 → dec, +(-1) → dec, -(1) → inc
if w == "+" and k == 1:
optimized.append(Op(op="word", data="inc", loc=base_loc))
idx += 2
changed = True
continue
if w == "+" and k == -1:
optimized.append(Op(op="word", data="dec", loc=base_loc))
idx += 2
changed = True
continue
if w == "-" and k == 1:
optimized.append(Op(op="word", data="dec", loc=base_loc))
idx += 2
changed = True
continue
if w == "-" and k == -1:
optimized.append(Op(op="word", data="inc", loc=base_loc))
idx += 2
changed = True
continue
optimized.append(nodes[idx])
idx += 1
if changed:
any_changed = True
nodes = optimized
# ---------- Pass 3: dead-code after unconditional jump/end ----------
# Opcodes that prevent fall-through.
_TERMINATORS = {OP_JUMP}
new_nodes: List[Op] = []
dead = False
for node in nodes:
kind = node._opcode
if dead:
# A label ends the dead region.
if kind == OP_LABEL:
dead = False
new_nodes.append(node)
else:
any_changed = True
continue
optimized.append(nodes[idx])
idx += 1
nodes = optimized
new_nodes.append(node)
if kind in _TERMINATORS:
dead = True
if len(new_nodes) != len(nodes):
any_changed = True
nodes = new_nodes
# Literal-aware algebraic identities and redundant unary chains.
changed = True
while changed:
changed = False
optimized = []
idx = 0
while idx < len(nodes):
# Redundant unary pairs.
if idx + 1 < len(nodes):
a = nodes[idx]
b = nodes[idx + 1]
if a._opcode == OP_WORD and b._opcode == OP_WORD:
wa = str(a.data)
wb = str(b.data)
if (wa, wb) in {
("not", "not"),
("neg", "neg"),
}:
idx += 2
changed = True
continue
# Binary op identities where right operand is a literal.
if idx + 1 < len(nodes):
lit = nodes[idx]
op = nodes[idx + 1]
if lit._opcode == OP_LITERAL and isinstance(lit.data, int) and op._opcode == OP_WORD:
k = int(lit.data)
w = str(op.data)
base_loc = lit.loc or op.loc
if (w == "+" and k == 0) or (w == "-" and k == 0) or (w == "*" and k == 1) or (w == "/" and k == 1):
idx += 2
changed = True
continue
if w == "*" and k == -1:
optimized.append(Op(op="word", data="neg", loc=base_loc))
idx += 2
changed = True
continue
if w == "%" and k == 1:
optimized.append(Op(op="word", data="drop", loc=base_loc))
optimized.append(Op(op="literal", data=0, loc=base_loc))
idx += 2
changed = True
continue
if w == "==" and k == 0:
optimized.append(Op(op="word", data="not", loc=base_loc))
idx += 2
changed = True
continue
if (w == "bor" and k == 0) or (w == "bxor" and k == 0):
idx += 2
changed = True
continue
if w == "band" and k == -1:
idx += 2
changed = True
continue
if w in {"shl", "shr", "sar"} and k == 0:
idx += 2
changed = True
continue
optimized.append(nodes[idx])
idx += 1
nodes = optimized
definition.body = nodes
def _fold_constants_in_definition(self, definition: Definition) -> None: