From 2b48e2d7e6e43e5ebc053f866067392ff385353d Mon Sep 17 00:00:00 2001 From: IgorCielniak Date: Tue, 10 Feb 2026 15:18:30 +0100 Subject: [PATCH] added loop unroling --- main.py | 135 +++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 134 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 9f29a6b..85d1c6b 100644 --- a/main.py +++ b/main.py @@ -1866,16 +1866,24 @@ class _CTHandleTable: class Assembler: - def __init__(self, dictionary: Dictionary, *, enable_constant_folding: bool = True) -> None: + def __init__( + self, + dictionary: Dictionary, + *, + enable_constant_folding: bool = True, + loop_unroll_threshold: int = 8, + ) -> None: self.dictionary = dictionary self._string_literals: Dict[str, Tuple[str, int]] = {} self._float_literals: Dict[float, str] = {} self._data_section: Optional[List[str]] = None self._inline_stack: List[str] = [] self._inline_counter: int = 0 + self._unroll_counter: int = 0 self._emit_stack: List[str] = [] self._export_all_defs: bool = False self.enable_constant_folding = enable_constant_folding + self.loop_unroll_threshold = loop_unroll_threshold def _fold_constants_in_definition(self, definition: Definition) -> None: optimized: List[Op] = [] @@ -1906,6 +1914,128 @@ class Assembler: new_loc = operands[0].loc or last.loc nodes[-(arity + 1):] = [Op(op="literal", data=result, loc=new_loc)] + def _for_pairs(self, nodes: Sequence[Op]) -> Dict[int, int]: + stack: List[int] = [] + pairs: Dict[int, int] = {} + for idx, node in enumerate(nodes): + if node.op == "for_begin": + stack.append(idx) + elif node.op == "for_end": + if not stack: + raise CompileError("'end' without matching 'for'") + begin_idx = stack.pop() + pairs[begin_idx] = idx + pairs[idx] = begin_idx + if stack: + raise CompileError("'for' without matching 'end'") + return pairs + + def _collect_internal_labels(self, nodes: Sequence[Op]) -> Set[str]: + labels: Set[str] = set() + for node in nodes: + kind = node.op + data = node.data + if kind == "label": + labels.add(str(data)) + elif kind in ("for_begin", "for_end"): + labels.add(str(data["loop"])) + labels.add(str(data["end"])) + elif kind in ("list_begin", "list_end"): + labels.add(str(data)) + return labels + + def _clone_nodes_with_label_remap( + self, + nodes: Sequence[Op], + internal_labels: Set[str], + suffix: str, + ) -> List[Op]: + label_map: Dict[str, str] = {} + + def remap(label: str) -> str: + if label not in internal_labels: + return label + if label not in label_map: + label_map[label] = f"{label}__unr{suffix}" + return label_map[label] + + cloned: List[Op] = [] + for node in nodes: + kind = node.op + data = node.data + if kind == "label": + cloned.append(Op(op="label", data=remap(str(data)), loc=node.loc)) + continue + if kind in ("jump", "branch_zero"): + target = str(data) + mapped = remap(target) if target in internal_labels else target + cloned.append(Op(op=kind, data=mapped, loc=node.loc)) + continue + if kind in ("for_begin", "for_end"): + cloned.append( + Op( + op=kind, + data={ + "loop": remap(str(data["loop"])), + "end": remap(str(data["end"])), + }, + loc=node.loc, + ) + ) + continue + if kind in ("list_begin", "list_end"): + cloned.append(Op(op=kind, data=remap(str(data)), loc=node.loc)) + continue + cloned.append(Op(op=kind, data=data, loc=node.loc)) + return cloned + + def _unroll_constant_for_loops(self, definition: Definition) -> None: + threshold = self.loop_unroll_threshold + if threshold <= 0: + return + nodes = definition.body + pairs = self._for_pairs(nodes) + if not pairs: + return + + rebuilt: List[Op] = [] + idx = 0 + while idx < len(nodes): + node = nodes[idx] + if node.op == "for_begin" and idx > 0: + prev = nodes[idx - 1] + if prev.op == "literal" and isinstance(prev.data, int): + count = int(prev.data) + end_idx = pairs.get(idx) + if end_idx is None: + raise CompileError("internal loop bookkeeping error") + if count <= 0: + if rebuilt and rebuilt[-1] is prev: + rebuilt.pop() + idx = end_idx + 1 + continue + if count <= threshold: + if rebuilt and rebuilt[-1] is prev: + rebuilt.pop() + body = nodes[idx + 1:end_idx] + internal_labels = self._collect_internal_labels(body) + for copy_idx in range(count): + suffix = f"{self._unroll_counter}_{copy_idx}" + rebuilt.extend( + self._clone_nodes_with_label_remap( + body, + internal_labels, + suffix, + ) + ) + self._unroll_counter += 1 + idx = end_idx + 1 + continue + rebuilt.append(node) + idx += 1 + + definition.body = rebuilt + def _reachable_runtime_defs(self, runtime_defs: Sequence[Union[Definition, AsmDefinition]]) -> Set[str]: edges: Dict[str, Set[str]] = {} for definition in runtime_defs: @@ -1950,6 +2080,9 @@ class Assembler: valid_defs = (Definition, AsmDefinition) raw_defs = [form for form in module.forms if isinstance(form, valid_defs)] definitions = self._dedup_definitions(raw_defs) + for defn in definitions: + if isinstance(defn, Definition): + self._unroll_constant_for_loops(defn) if self.enable_constant_folding: for defn in definitions: if isinstance(defn, Definition):