added loop unroling

This commit is contained in:
IgorCielniak
2026-02-10 15:18:30 +01:00
parent 2f5b1f40b1
commit 2b48e2d7e6

135
main.py
View File

@@ -1866,16 +1866,24 @@ class _CTHandleTable:
class Assembler: class Assembler:
def __init__(self, dictionary: Dictionary, *, enable_constant_folding: bool = True) -> None: def __init__(
self,
dictionary: Dictionary,
*,
enable_constant_folding: bool = True,
loop_unroll_threshold: int = 8,
) -> None:
self.dictionary = dictionary self.dictionary = dictionary
self._string_literals: Dict[str, Tuple[str, int]] = {} self._string_literals: Dict[str, Tuple[str, int]] = {}
self._float_literals: Dict[float, str] = {} self._float_literals: Dict[float, str] = {}
self._data_section: Optional[List[str]] = None self._data_section: Optional[List[str]] = None
self._inline_stack: List[str] = [] self._inline_stack: List[str] = []
self._inline_counter: int = 0 self._inline_counter: int = 0
self._unroll_counter: int = 0
self._emit_stack: List[str] = [] self._emit_stack: List[str] = []
self._export_all_defs: bool = False self._export_all_defs: bool = False
self.enable_constant_folding = enable_constant_folding self.enable_constant_folding = enable_constant_folding
self.loop_unroll_threshold = loop_unroll_threshold
def _fold_constants_in_definition(self, definition: Definition) -> None: def _fold_constants_in_definition(self, definition: Definition) -> None:
optimized: List[Op] = [] optimized: List[Op] = []
@@ -1906,6 +1914,128 @@ class Assembler:
new_loc = operands[0].loc or last.loc new_loc = operands[0].loc or last.loc
nodes[-(arity + 1):] = [Op(op="literal", data=result, loc=new_loc)] nodes[-(arity + 1):] = [Op(op="literal", data=result, loc=new_loc)]
def _for_pairs(self, nodes: Sequence[Op]) -> Dict[int, int]:
stack: List[int] = []
pairs: Dict[int, int] = {}
for idx, node in enumerate(nodes):
if node.op == "for_begin":
stack.append(idx)
elif node.op == "for_end":
if not stack:
raise CompileError("'end' without matching 'for'")
begin_idx = stack.pop()
pairs[begin_idx] = idx
pairs[idx] = begin_idx
if stack:
raise CompileError("'for' without matching 'end'")
return pairs
def _collect_internal_labels(self, nodes: Sequence[Op]) -> Set[str]:
labels: Set[str] = set()
for node in nodes:
kind = node.op
data = node.data
if kind == "label":
labels.add(str(data))
elif kind in ("for_begin", "for_end"):
labels.add(str(data["loop"]))
labels.add(str(data["end"]))
elif kind in ("list_begin", "list_end"):
labels.add(str(data))
return labels
def _clone_nodes_with_label_remap(
self,
nodes: Sequence[Op],
internal_labels: Set[str],
suffix: str,
) -> List[Op]:
label_map: Dict[str, str] = {}
def remap(label: str) -> str:
if label not in internal_labels:
return label
if label not in label_map:
label_map[label] = f"{label}__unr{suffix}"
return label_map[label]
cloned: List[Op] = []
for node in nodes:
kind = node.op
data = node.data
if kind == "label":
cloned.append(Op(op="label", data=remap(str(data)), loc=node.loc))
continue
if kind in ("jump", "branch_zero"):
target = str(data)
mapped = remap(target) if target in internal_labels else target
cloned.append(Op(op=kind, data=mapped, loc=node.loc))
continue
if kind in ("for_begin", "for_end"):
cloned.append(
Op(
op=kind,
data={
"loop": remap(str(data["loop"])),
"end": remap(str(data["end"])),
},
loc=node.loc,
)
)
continue
if kind in ("list_begin", "list_end"):
cloned.append(Op(op=kind, data=remap(str(data)), loc=node.loc))
continue
cloned.append(Op(op=kind, data=data, loc=node.loc))
return cloned
def _unroll_constant_for_loops(self, definition: Definition) -> None:
threshold = self.loop_unroll_threshold
if threshold <= 0:
return
nodes = definition.body
pairs = self._for_pairs(nodes)
if not pairs:
return
rebuilt: List[Op] = []
idx = 0
while idx < len(nodes):
node = nodes[idx]
if node.op == "for_begin" and idx > 0:
prev = nodes[idx - 1]
if prev.op == "literal" and isinstance(prev.data, int):
count = int(prev.data)
end_idx = pairs.get(idx)
if end_idx is None:
raise CompileError("internal loop bookkeeping error")
if count <= 0:
if rebuilt and rebuilt[-1] is prev:
rebuilt.pop()
idx = end_idx + 1
continue
if count <= threshold:
if rebuilt and rebuilt[-1] is prev:
rebuilt.pop()
body = nodes[idx + 1:end_idx]
internal_labels = self._collect_internal_labels(body)
for copy_idx in range(count):
suffix = f"{self._unroll_counter}_{copy_idx}"
rebuilt.extend(
self._clone_nodes_with_label_remap(
body,
internal_labels,
suffix,
)
)
self._unroll_counter += 1
idx = end_idx + 1
continue
rebuilt.append(node)
idx += 1
definition.body = rebuilt
def _reachable_runtime_defs(self, runtime_defs: Sequence[Union[Definition, AsmDefinition]]) -> Set[str]: def _reachable_runtime_defs(self, runtime_defs: Sequence[Union[Definition, AsmDefinition]]) -> Set[str]:
edges: Dict[str, Set[str]] = {} edges: Dict[str, Set[str]] = {}
for definition in runtime_defs: for definition in runtime_defs:
@@ -1950,6 +2080,9 @@ class Assembler:
valid_defs = (Definition, AsmDefinition) valid_defs = (Definition, AsmDefinition)
raw_defs = [form for form in module.forms if isinstance(form, valid_defs)] raw_defs = [form for form in module.forms if isinstance(form, valid_defs)]
definitions = self._dedup_definitions(raw_defs) definitions = self._dedup_definitions(raw_defs)
for defn in definitions:
if isinstance(defn, Definition):
self._unroll_constant_for_loops(defn)
if self.enable_constant_folding: if self.enable_constant_folding:
for defn in definitions: for defn in definitions:
if isinstance(defn, Definition): if isinstance(defn, Definition):