added loop unroling
This commit is contained in:
135
main.py
135
main.py
@@ -1866,16 +1866,24 @@ class _CTHandleTable:
|
||||
|
||||
|
||||
class Assembler:
|
||||
def __init__(self, dictionary: Dictionary, *, enable_constant_folding: bool = True) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
dictionary: Dictionary,
|
||||
*,
|
||||
enable_constant_folding: bool = True,
|
||||
loop_unroll_threshold: int = 8,
|
||||
) -> None:
|
||||
self.dictionary = dictionary
|
||||
self._string_literals: Dict[str, Tuple[str, int]] = {}
|
||||
self._float_literals: Dict[float, str] = {}
|
||||
self._data_section: Optional[List[str]] = None
|
||||
self._inline_stack: List[str] = []
|
||||
self._inline_counter: int = 0
|
||||
self._unroll_counter: int = 0
|
||||
self._emit_stack: List[str] = []
|
||||
self._export_all_defs: bool = False
|
||||
self.enable_constant_folding = enable_constant_folding
|
||||
self.loop_unroll_threshold = loop_unroll_threshold
|
||||
|
||||
def _fold_constants_in_definition(self, definition: Definition) -> None:
|
||||
optimized: List[Op] = []
|
||||
@@ -1906,6 +1914,128 @@ class Assembler:
|
||||
new_loc = operands[0].loc or last.loc
|
||||
nodes[-(arity + 1):] = [Op(op="literal", data=result, loc=new_loc)]
|
||||
|
||||
def _for_pairs(self, nodes: Sequence[Op]) -> Dict[int, int]:
|
||||
stack: List[int] = []
|
||||
pairs: Dict[int, int] = {}
|
||||
for idx, node in enumerate(nodes):
|
||||
if node.op == "for_begin":
|
||||
stack.append(idx)
|
||||
elif node.op == "for_end":
|
||||
if not stack:
|
||||
raise CompileError("'end' without matching 'for'")
|
||||
begin_idx = stack.pop()
|
||||
pairs[begin_idx] = idx
|
||||
pairs[idx] = begin_idx
|
||||
if stack:
|
||||
raise CompileError("'for' without matching 'end'")
|
||||
return pairs
|
||||
|
||||
def _collect_internal_labels(self, nodes: Sequence[Op]) -> Set[str]:
|
||||
labels: Set[str] = set()
|
||||
for node in nodes:
|
||||
kind = node.op
|
||||
data = node.data
|
||||
if kind == "label":
|
||||
labels.add(str(data))
|
||||
elif kind in ("for_begin", "for_end"):
|
||||
labels.add(str(data["loop"]))
|
||||
labels.add(str(data["end"]))
|
||||
elif kind in ("list_begin", "list_end"):
|
||||
labels.add(str(data))
|
||||
return labels
|
||||
|
||||
def _clone_nodes_with_label_remap(
|
||||
self,
|
||||
nodes: Sequence[Op],
|
||||
internal_labels: Set[str],
|
||||
suffix: str,
|
||||
) -> List[Op]:
|
||||
label_map: Dict[str, str] = {}
|
||||
|
||||
def remap(label: str) -> str:
|
||||
if label not in internal_labels:
|
||||
return label
|
||||
if label not in label_map:
|
||||
label_map[label] = f"{label}__unr{suffix}"
|
||||
return label_map[label]
|
||||
|
||||
cloned: List[Op] = []
|
||||
for node in nodes:
|
||||
kind = node.op
|
||||
data = node.data
|
||||
if kind == "label":
|
||||
cloned.append(Op(op="label", data=remap(str(data)), loc=node.loc))
|
||||
continue
|
||||
if kind in ("jump", "branch_zero"):
|
||||
target = str(data)
|
||||
mapped = remap(target) if target in internal_labels else target
|
||||
cloned.append(Op(op=kind, data=mapped, loc=node.loc))
|
||||
continue
|
||||
if kind in ("for_begin", "for_end"):
|
||||
cloned.append(
|
||||
Op(
|
||||
op=kind,
|
||||
data={
|
||||
"loop": remap(str(data["loop"])),
|
||||
"end": remap(str(data["end"])),
|
||||
},
|
||||
loc=node.loc,
|
||||
)
|
||||
)
|
||||
continue
|
||||
if kind in ("list_begin", "list_end"):
|
||||
cloned.append(Op(op=kind, data=remap(str(data)), loc=node.loc))
|
||||
continue
|
||||
cloned.append(Op(op=kind, data=data, loc=node.loc))
|
||||
return cloned
|
||||
|
||||
def _unroll_constant_for_loops(self, definition: Definition) -> None:
|
||||
threshold = self.loop_unroll_threshold
|
||||
if threshold <= 0:
|
||||
return
|
||||
nodes = definition.body
|
||||
pairs = self._for_pairs(nodes)
|
||||
if not pairs:
|
||||
return
|
||||
|
||||
rebuilt: List[Op] = []
|
||||
idx = 0
|
||||
while idx < len(nodes):
|
||||
node = nodes[idx]
|
||||
if node.op == "for_begin" and idx > 0:
|
||||
prev = nodes[idx - 1]
|
||||
if prev.op == "literal" and isinstance(prev.data, int):
|
||||
count = int(prev.data)
|
||||
end_idx = pairs.get(idx)
|
||||
if end_idx is None:
|
||||
raise CompileError("internal loop bookkeeping error")
|
||||
if count <= 0:
|
||||
if rebuilt and rebuilt[-1] is prev:
|
||||
rebuilt.pop()
|
||||
idx = end_idx + 1
|
||||
continue
|
||||
if count <= threshold:
|
||||
if rebuilt and rebuilt[-1] is prev:
|
||||
rebuilt.pop()
|
||||
body = nodes[idx + 1:end_idx]
|
||||
internal_labels = self._collect_internal_labels(body)
|
||||
for copy_idx in range(count):
|
||||
suffix = f"{self._unroll_counter}_{copy_idx}"
|
||||
rebuilt.extend(
|
||||
self._clone_nodes_with_label_remap(
|
||||
body,
|
||||
internal_labels,
|
||||
suffix,
|
||||
)
|
||||
)
|
||||
self._unroll_counter += 1
|
||||
idx = end_idx + 1
|
||||
continue
|
||||
rebuilt.append(node)
|
||||
idx += 1
|
||||
|
||||
definition.body = rebuilt
|
||||
|
||||
def _reachable_runtime_defs(self, runtime_defs: Sequence[Union[Definition, AsmDefinition]]) -> Set[str]:
|
||||
edges: Dict[str, Set[str]] = {}
|
||||
for definition in runtime_defs:
|
||||
@@ -1950,6 +2080,9 @@ class Assembler:
|
||||
valid_defs = (Definition, AsmDefinition)
|
||||
raw_defs = [form for form in module.forms if isinstance(form, valid_defs)]
|
||||
definitions = self._dedup_definitions(raw_defs)
|
||||
for defn in definitions:
|
||||
if isinstance(defn, Definition):
|
||||
self._unroll_constant_for_loops(defn)
|
||||
if self.enable_constant_folding:
|
||||
for defn in definitions:
|
||||
if isinstance(defn, Definition):
|
||||
|
||||
Reference in New Issue
Block a user