diff --git a/main.py b/main.py index c3217c5..0b54f58 100644 --- a/main.py +++ b/main.py @@ -3631,143 +3631,344 @@ class Assembler: return "\n".join(lines) def _peephole_optimize_definition(self, definition: Definition) -> None: - # Rewrite short stack-manipulation sequences into canonical forms. - rules: List[Tuple[Tuple[str, ...], Tuple[str, ...]]] = [ + # Word-only rewrite rules: pattern → replacement (both tuples of + # word names). The engine scans left-to-right and applies the + # longest matching rule at each position. + word_rules: List[Tuple[Tuple[str, ...], Tuple[str, ...]]] = [ + # --- stack no-ops (cancellation) --- + (("dup", "drop"), ()), + (("swap", "swap"), ()), + (("over", "drop"), ()), + (("dup", "nip"), ()), + (("2dup", "2drop"), ()), + (("2swap", "2swap"), ()), + (("rot", "rot", "rot"), ()), + (("rot", "-rot"), ()), + (("-rot", "rot"), ()), + (("drop", "drop"), ("2drop",)), + (("over", "over"), ("2dup",)), + (("inc", "dec"), ()), + (("dec", "inc"), ()), + (("neg", "neg"), ()), + (("not", "not"), ()), + (("bitnot", "bitnot"), ()), + (("bnot", "bnot"), ()), + (("abs", "abs"), ("abs",)), + # --- canonicalizations that merge into single ops --- (("swap", "drop"), ("nip",)), - # Stack no-ops - (("dup", "drop"), tuple()), - (("swap", "swap"), tuple()), - (("over", "drop"), tuple()), - (("dup", "nip"), tuple()), - (("2dup", "2drop"), tuple()), - (("2swap", "2swap"), tuple()), - (("rot", "rot", "rot"), tuple()), - # Canonicalizations (("swap", "over"), ("tuck",)), (("swap", "nip"), ("drop",)), (("nip", "drop"), ("2drop",)), (("tuck", "drop"), ("swap",)), + # --- commutative ops: swap before them is a no-op --- + (("swap", "+"), ("+",)), + (("swap", "*"), ("*",)), + (("swap", "=="), ("==",)), + (("swap", "!="), ("!=",)), + (("swap", "band"), ("band",)), + (("swap", "bor"), ("bor",)), + (("swap", "bxor"), ("bxor",)), + (("swap", "and"), ("and",)), + (("swap", "or"), ("or",)), + (("swap", "min"), ("min",)), + (("swap", "max"), ("max",)), + # --- dup + self-idempotent binary → identity --- + (("dup", "bor"), ()), # x | x == x + (("dup", "band"), ()), # x & x == x + (("dup", "bxor"), ("drop", "literal_0")), # x ^ x == 0 + (("dup", "=="), ("drop", "literal_1")), # x == x always true + (("dup", "-"), ("drop", "literal_0")), # x - x == 0 ] - max_pat_len = max(len(pattern) for pattern, _ in rules) + # Filter out placeholder rules whose replacements contain + # pseudo-words; expand them into proper Op sequences later. + _PLACEHOLDER_RULES: Dict[Tuple[str, ...], Tuple[str, ...]] = {} + clean_rules: List[Tuple[Tuple[str, ...], Tuple[str, ...]]] = [] + for pat, repl in word_rules: + if any(r.startswith("literal_") for r in repl): + _PLACEHOLDER_RULES[pat] = repl + else: + clean_rules.append((pat, repl)) + word_rules = clean_rules - # Build index: first word -> list of (pattern, replacement) + max_pat_len = max(len(p) for p, _ in word_rules) if word_rules else 0 rule_index: Dict[str, List[Tuple[Tuple[str, ...], Tuple[str, ...]]]] = {} - for pattern, repl in rules: + for pattern, repl in word_rules: rule_index.setdefault(pattern[0], []).append((pattern, repl)) nodes = definition.body - changed = True - while changed: - changed = False - optimized: List[Op] = [] - idx = 0 - while idx < len(nodes): - node = nodes[idx] - matched = False - if node._opcode == OP_WORD: - candidates = rule_index.get(str(node.data)) - if candidates: - for window in range(min(max_pat_len, len(nodes) - idx), 1, -1): - segment = nodes[idx:idx + window] - if any(n._opcode != OP_WORD for n in segment): + + # Outer loop: keeps re-running all passes until nothing changes. + any_changed = True + while any_changed: + any_changed = False + + # ---------- Pass 1: word-only pattern rewriting ---------- + changed = True + while changed: + changed = False + optimized: List[Op] = [] + idx = 0 + while idx < len(nodes): + node = nodes[idx] + matched = False + if node._opcode == OP_WORD: + word_name = str(node.data) + + # --- placeholder rules (produce literals) --- + for pat, repl in _PLACEHOLDER_RULES.items(): + plen = len(pat) + if pat[0] != word_name: continue - names = tuple(str(n.data) for n in segment) - replacement: Optional[Tuple[str, ...]] = None - for pattern, repl in candidates: - if names == pattern: - replacement = repl - break - if replacement is None: + if idx + plen > len(nodes): continue - base_loc = segment[0].loc - for repl_name in replacement: - optimized.append(Op(op="word", data=repl_name, loc=base_loc)) - idx += window + seg = nodes[idx:idx + plen] + if any(n._opcode != OP_WORD for n in seg): + continue + if tuple(str(n.data) for n in seg) == pat: + base_loc = seg[0].loc + for r in repl: + if r.startswith("literal_"): + val = int(r[len("literal_"):]) + optimized.append(Op(op="literal", data=val, loc=base_loc)) + else: + optimized.append(Op(op="word", data=r, loc=base_loc)) + idx += plen + changed = True + matched = True + break + if matched: + continue + + # --- normal word-only rules --- + candidates = rule_index.get(word_name) + if candidates: + for window in range(min(max_pat_len, len(nodes) - idx), 1, -1): + segment = nodes[idx:idx + window] + if any(n._opcode != OP_WORD for n in segment): + continue + names = tuple(str(n.data) for n in segment) + replacement: Optional[Tuple[str, ...]] = None + for pattern, repl in candidates: + if names == pattern: + replacement = repl + break + if replacement is None: + continue + base_loc = segment[0].loc + for repl_name in replacement: + optimized.append(Op(op="word", data=repl_name, loc=base_loc)) + idx += window + changed = True + matched = True + break + if matched: + continue + optimized.append(nodes[idx]) + idx += 1 + if changed: + any_changed = True + nodes = optimized + + # ---------- Pass 2: literal + word algebraic identities ---------- + # String literals push TWO values (pointer + length), so + # most literal-aware rewrites must be restricted to scalars. + def _is_scalar_literal(node: Op) -> bool: + return node._opcode == OP_LITERAL and not isinstance(node.data, str) + + changed = True + while changed: + changed = False + optimized = [] + idx = 0 + while idx < len(nodes): + # -- Redundant unary pairs (word word) -- + if idx + 1 < len(nodes): + a, b = nodes[idx], nodes[idx + 1] + if a._opcode == OP_WORD and b._opcode == OP_WORD: + wa, wb = str(a.data), str(b.data) + if (wa, wb) in { + ("not", "not"), ("neg", "neg"), + ("bitnot", "bitnot"), ("bnot", "bnot"), + ("inc", "dec"), ("dec", "inc"), + }: + idx += 2 + changed = True + continue + # abs is idempotent + if wa == "abs" and wb == "abs": + optimized.append(a) + idx += 2 + changed = True + continue + + # -- scalar literal + dup → literal literal -- + if idx + 1 < len(nodes): + a, b = nodes[idx], nodes[idx + 1] + if _is_scalar_literal(a) and b._opcode == OP_WORD and str(b.data) == "dup": + optimized.append(a) + optimized.append(Op(op="literal", data=a.data, loc=a.loc)) + idx += 2 changed = True - matched = True - break - if matched: + continue + + # -- scalar literal + drop → (nothing) -- + if idx + 1 < len(nodes): + a, b = nodes[idx], nodes[idx + 1] + if _is_scalar_literal(a) and b._opcode == OP_WORD and str(b.data) == "drop": + idx += 2 + changed = True + continue + + # -- scalar literal scalar literal + 2drop → (nothing) -- + if idx + 2 < len(nodes): + a, b, c = nodes[idx], nodes[idx + 1], nodes[idx + 2] + if _is_scalar_literal(a) and _is_scalar_literal(b) and c._opcode == OP_WORD and str(c.data) == "2drop": + idx += 3 + changed = True + continue + + # -- scalar literal scalar literal + swap → swapped -- + if idx + 2 < len(nodes): + a, b, c = nodes[idx], nodes[idx + 1], nodes[idx + 2] + if _is_scalar_literal(a) and _is_scalar_literal(b) and c._opcode == OP_WORD and str(c.data) == "swap": + optimized.append(Op(op="literal", data=b.data, loc=b.loc)) + optimized.append(Op(op="literal", data=a.data, loc=a.loc)) + idx += 3 + changed = True + continue + + # -- Binary op identities: literal K + word -- + if idx + 1 < len(nodes): + lit, op = nodes[idx], nodes[idx + 1] + if lit._opcode == OP_LITERAL and isinstance(lit.data, int) and op._opcode == OP_WORD: + k = int(lit.data) + w = str(op.data) + base_loc = lit.loc or op.loc + + # Identity elements + if (w == "+" and k == 0) or (w == "-" and k == 0) or (w == "*" and k == 1) or (w == "/" and k == 1): + idx += 2 + changed = True + continue + + # Absorbing elements + if w == "*" and k == 0: + optimized.append(Op(op="word", data="drop", loc=base_loc)) + optimized.append(Op(op="literal", data=0, loc=base_loc)) + idx += 2 + changed = True + continue + + if w == "band" and k == 0: + optimized.append(Op(op="word", data="drop", loc=base_loc)) + optimized.append(Op(op="literal", data=0, loc=base_loc)) + idx += 2 + changed = True + continue + + if w == "bor" and k == -1: + optimized.append(Op(op="word", data="drop", loc=base_loc)) + optimized.append(Op(op="literal", data=-1, loc=base_loc)) + idx += 2 + changed = True + continue + + # Negate + if w == "*" and k == -1: + optimized.append(Op(op="word", data="neg", loc=base_loc)) + idx += 2 + changed = True + continue + + # Modulo 1 → always 0 + if w == "%" and k == 1: + optimized.append(Op(op="word", data="drop", loc=base_loc)) + optimized.append(Op(op="literal", data=0, loc=base_loc)) + idx += 2 + changed = True + continue + + # 0 == → not + if w == "==" and k == 0: + optimized.append(Op(op="word", data="not", loc=base_loc)) + idx += 2 + changed = True + continue + + # No-op bitwise + if (w == "bor" and k == 0) or (w == "bxor" and k == 0): + idx += 2 + changed = True + continue + if w == "band" and k == -1: + idx += 2 + changed = True + continue + if w in {"shl", "shr", "sar"} and k == 0: + idx += 2 + changed = True + continue + + # Strength reduction: multiply by power of 2 → shl + if w == "*" and k > 1 and (k & (k - 1)) == 0: + shift = k.bit_length() - 1 + optimized.append(Op(op="literal", data=shift, loc=base_loc)) + optimized.append(Op(op="word", data="shl", loc=base_loc)) + idx += 2 + changed = True + continue + + # +1 → inc, -1 → dec, +(-1) → dec, -(−1) → inc + if w == "+" and k == 1: + optimized.append(Op(op="word", data="inc", loc=base_loc)) + idx += 2 + changed = True + continue + if w == "+" and k == -1: + optimized.append(Op(op="word", data="dec", loc=base_loc)) + idx += 2 + changed = True + continue + if w == "-" and k == 1: + optimized.append(Op(op="word", data="dec", loc=base_loc)) + idx += 2 + changed = True + continue + if w == "-" and k == -1: + optimized.append(Op(op="word", data="inc", loc=base_loc)) + idx += 2 + changed = True + continue + + optimized.append(nodes[idx]) + idx += 1 + if changed: + any_changed = True + nodes = optimized + + # ---------- Pass 3: dead-code after unconditional jump/end ---------- + # Opcodes that prevent fall-through. + _TERMINATORS = {OP_JUMP} + new_nodes: List[Op] = [] + dead = False + for node in nodes: + kind = node._opcode + if dead: + # A label ends the dead region. + if kind == OP_LABEL: + dead = False + new_nodes.append(node) + else: + any_changed = True continue - optimized.append(nodes[idx]) - idx += 1 - nodes = optimized + new_nodes.append(node) + if kind in _TERMINATORS: + dead = True + if len(new_nodes) != len(nodes): + any_changed = True + nodes = new_nodes - # Literal-aware algebraic identities and redundant unary chains. - changed = True - while changed: - changed = False - optimized = [] - idx = 0 - - while idx < len(nodes): - # Redundant unary pairs. - if idx + 1 < len(nodes): - a = nodes[idx] - b = nodes[idx + 1] - if a._opcode == OP_WORD and b._opcode == OP_WORD: - wa = str(a.data) - wb = str(b.data) - if (wa, wb) in { - ("not", "not"), - ("neg", "neg"), - }: - idx += 2 - changed = True - continue - - # Binary op identities where right operand is a literal. - if idx + 1 < len(nodes): - lit = nodes[idx] - op = nodes[idx + 1] - if lit._opcode == OP_LITERAL and isinstance(lit.data, int) and op._opcode == OP_WORD: - k = int(lit.data) - w = str(op.data) - base_loc = lit.loc or op.loc - - if (w == "+" and k == 0) or (w == "-" and k == 0) or (w == "*" and k == 1) or (w == "/" and k == 1): - idx += 2 - changed = True - continue - - if w == "*" and k == -1: - optimized.append(Op(op="word", data="neg", loc=base_loc)) - idx += 2 - changed = True - continue - - if w == "%" and k == 1: - optimized.append(Op(op="word", data="drop", loc=base_loc)) - optimized.append(Op(op="literal", data=0, loc=base_loc)) - idx += 2 - changed = True - continue - - if w == "==" and k == 0: - optimized.append(Op(op="word", data="not", loc=base_loc)) - idx += 2 - changed = True - continue - - if (w == "bor" and k == 0) or (w == "bxor" and k == 0): - idx += 2 - changed = True - continue - - if w == "band" and k == -1: - idx += 2 - changed = True - continue - - if w in {"shl", "shr", "sar"} and k == 0: - idx += 2 - changed = True - continue - - optimized.append(nodes[idx]) - idx += 1 - - nodes = optimized definition.body = nodes def _fold_constants_in_definition(self, definition: Definition) -> None: