small refactor and cleanup

This commit is contained in:
IgorCielniak
2026-01-08 13:15:27 +01:00
parent 963e024108
commit d4dc6ceef5
52 changed files with 418 additions and 372 deletions

295
main.py
View File

@@ -187,30 +187,24 @@ class Reader:
# ---------------------------------------------------------------------------
class ASTNode:
"""Base class for all AST nodes."""
@dataclass
class Op:
"""Flat operation used for both compile-time execution and emission."""
op: str
data: Any = None
@dataclass
class WordRef(ASTNode):
class Definition:
name: str
@dataclass
class Literal(ASTNode):
value: Any
@dataclass
class Definition(ASTNode):
name: str
body: List[ASTNode]
body: List[Op]
immediate: bool = False
compile_only: bool = False
@dataclass
class AsmDefinition(ASTNode):
class AsmDefinition:
name: str
body: str
immediate: bool = False
@@ -218,8 +212,8 @@ class AsmDefinition(ASTNode):
@dataclass
class Module(ASTNode):
forms: List[ASTNode]
class Module:
forms: List[Any]
@dataclass
@@ -236,33 +230,6 @@ class StructField:
size: int
@dataclass
class BranchZero(ASTNode):
target: str
@dataclass
class Jump(ASTNode):
target: str
@dataclass
class Label(ASTNode):
name: str
@dataclass
class ForBegin(ASTNode):
loop_label: str
end_label: str
@dataclass
class ForEnd(ASTNode):
loop_label: str
end_label: str
class MacroContext:
"""Small facade exposed to Python-defined macros."""
@@ -280,12 +247,12 @@ class MacroContext:
return self._parser.peek_token()
def emit_literal(self, value: int) -> None:
self._parser.emit_node(Literal(value=value))
self._parser.emit_node(Op(op="literal", data=value))
def emit_word(self, name: str) -> None:
self._parser.emit_node(WordRef(name=name))
self._parser.emit_node(Op(op="word", data=name))
def emit_node(self, node: ASTNode) -> None:
def emit_node(self, node: Op) -> None:
self._parser.emit_node(node)
def inject_tokens(self, tokens: Sequence[str], template: Optional[Token] = None) -> None:
@@ -316,7 +283,7 @@ class MacroContext:
return self._parser.most_recent_definition()
MacroHandler = Callable[[MacroContext], Optional[List[ASTNode]]]
MacroHandler = Callable[[MacroContext], Optional[List[Op]]]
IntrinsicEmitter = Callable[["FunctionEmitter"], None]
@@ -390,8 +357,8 @@ class Parser:
self._ensure_tokens(self.pos)
return None if self._eof() else self.tokens[self.pos]
def emit_node(self, node: ASTNode) -> None:
self._append_node(node)
def emit_node(self, node: Op) -> None:
self._append_op(node)
def most_recent_definition(self) -> Optional[Word]:
return self.last_defined
@@ -406,18 +373,18 @@ class Parser:
if entry["type"] == "if":
# For if without else
if "false" in entry:
self._append_node(Label(name=entry["false"]))
self._append_op(Op(op="label", data=entry["false"]))
elif entry["type"] == "else":
self._append_node(Label(name=entry["end"]))
self._append_op(Op(op="label", data=entry["end"]))
elif entry["type"] == "while":
self._append_node(Jump(target=entry["begin"]))
self._append_node(Label(name=entry["end"]))
self._append_op(Op(op="jump", data=entry["begin"]))
self._append_op(Op(op="label", data=entry["end"]))
elif entry["type"] == "for":
# Emit ForEnd node for loop decrement
self._append_node(ForEnd(loop_label=entry["loop"], end_label=entry["end"]))
self._append_op(Op(op="for_end", data={"loop": entry["loop"], "end": entry["end"]}))
elif entry["type"] == "begin":
self._append_node(Jump(target=entry["begin"]))
self._append_node(Label(name=entry["end"]))
self._append_op(Op(op="jump", data=entry["begin"]))
self._append_op(Op(op="label", data=entry["end"]))
# Parsing ------------------------------------------------------------------
def parse(self, tokens: Iterable[Token], source: str) -> Module:
@@ -608,12 +575,12 @@ class Parser:
produced = word.macro(MacroContext(self))
if produced:
for node in produced:
self._append_node(node)
self._append_op(node)
else:
self._execute_immediate_word(word)
return
self._append_node(WordRef(name=token.lexeme))
self._append_op(Op(op="word", data=token.lexeme))
def _execute_immediate_word(self, word: Word) -> None:
try:
@@ -721,31 +688,31 @@ class Parser:
def _handle_if_control(self) -> None:
false_label = self._new_label("if_false")
self._append_node(BranchZero(target=false_label))
self._append_op(Op(op="branch_zero", data=false_label))
self._push_control({"type": "if", "false": false_label})
def _handle_else_control(self) -> None:
entry = self._pop_control(("if",))
end_label = self._new_label("if_end")
self._append_node(Jump(target=end_label))
self._append_node(Label(name=entry["false"]))
self._append_op(Op(op="jump", data=end_label))
self._append_op(Op(op="label", data=entry["false"]))
self._push_control({"type": "else", "end": end_label})
def _handle_for_control(self) -> None:
loop_label = self._new_label("for_loop")
end_label = self._new_label("for_end")
self._append_node(ForBegin(loop_label=loop_label, end_label=end_label))
self._append_op(Op(op="for_begin", data={"loop": loop_label, "end": end_label}))
self._push_control({"type": "for", "loop": loop_label, "end": end_label})
def _handle_while_control(self) -> None:
begin_label = self._new_label("begin")
end_label = self._new_label("end")
self._append_node(Label(name=begin_label))
self._append_op(Op(op="label", data=begin_label))
self._push_control({"type": "begin", "begin": begin_label, "end": end_label})
def _handle_do_control(self) -> None:
entry = self._pop_control(("begin",))
self._append_node(BranchZero(target=entry["end"]))
self._append_op(Op(op="branch_zero", data=entry["end"]))
self._push_control(entry)
def _begin_definition(self, token: Token) -> None:
@@ -859,7 +826,7 @@ class Parser:
def _py_exec_namespace(self) -> Dict[str, Any]:
return dict(PY_EXEC_GLOBALS)
def _append_node(self, node: ASTNode) -> None:
def _append_op(self, node: Op) -> None:
target = self.context_stack[-1]
if isinstance(target, Module):
target.forms.append(node)
@@ -868,10 +835,10 @@ class Parser:
else: # pragma: no cover - defensive
raise ParseError("unknown parse context")
def _try_literal(self, token: Token) -> None:
def _try_literal(self, token: Token) -> bool:
try:
value = int(token.lexeme, 0)
self._append_node(Literal(value=value))
self._append_op(Op(op="literal", data=value))
return True
except ValueError:
pass
@@ -880,14 +847,14 @@ class Parser:
try:
if "." in token.lexeme or "e" in token.lexeme.lower():
value = float(token.lexeme)
self._append_node(Literal(value=value))
self._append_op(Op(op="literal", data=value))
return True
except ValueError:
pass
string_value = _parse_string_literal(token)
if string_value is not None:
self._append_node(Literal(value=string_value))
self._append_op(Op(op="literal", data=string_value))
return True
return False
@@ -1229,7 +1196,7 @@ class CompileTimeVM:
raise ParseError(f"unknown word '{name}' during compile-time execution")
self._call_word(word)
def _execute_nodes(self, nodes: Sequence[ASTNode]) -> None:
def _execute_nodes(self, nodes: Sequence[Op]) -> None:
label_positions = self._label_positions(nodes)
loop_pairs = self._for_pairs(nodes)
begin_pairs = self._begin_pairs(nodes)
@@ -1238,12 +1205,16 @@ class CompileTimeVM:
ip = 0
while ip < len(nodes):
node = nodes[ip]
if isinstance(node, Literal):
self.push(node.value)
kind = node.op
data = node.data
if kind == "literal":
self.push(data)
ip += 1
continue
if isinstance(node, WordRef):
name = node.name
if kind == "word":
name = str(data)
if name == "begin":
end_idx = begin_pairs.get(ip)
if end_idx is None:
@@ -1270,9 +1241,9 @@ class CompileTimeVM:
self._call_word_by_name(name)
ip += 1
continue
if isinstance(node, BranchZero):
if kind == "branch_zero":
condition = self.pop()
flag: bool
if isinstance(condition, bool):
flag = condition
elif isinstance(condition, int):
@@ -1280,17 +1251,20 @@ class CompileTimeVM:
else:
raise ParseError("branch expects integer or boolean condition")
if not flag:
ip = self._jump_to_label(label_positions, node.target)
ip = self._jump_to_label(label_positions, str(data))
else:
ip += 1
continue
if isinstance(node, Jump):
ip = self._jump_to_label(label_positions, node.target)
if kind == "jump":
ip = self._jump_to_label(label_positions, str(data))
continue
if isinstance(node, Label):
if kind == "label":
ip += 1
continue
if isinstance(node, ForBegin):
if kind == "for_begin":
count = self.pop_int()
if count <= 0:
match = loop_pairs.get(ip)
@@ -1301,7 +1275,8 @@ class CompileTimeVM:
self.loop_stack.append({"remaining": count, "begin": ip, "initial": count})
ip += 1
continue
if isinstance(node, ForEnd):
if kind == "for_end":
if not self.loop_stack:
raise ParseError("'next' without matching 'for'")
frame = self.loop_stack[-1]
@@ -1312,22 +1287,23 @@ class CompileTimeVM:
self.loop_stack.pop()
ip += 1
continue
raise ParseError(f"unsupported compile-time AST node {node!r}")
def _label_positions(self, nodes: Sequence[ASTNode]) -> Dict[str, int]:
raise ParseError(f"unsupported compile-time op {node!r}")
def _label_positions(self, nodes: Sequence[Op]) -> Dict[str, int]:
positions: Dict[str, int] = {}
for idx, node in enumerate(nodes):
if isinstance(node, Label):
positions[node.name] = idx
if node.op == "label":
positions[str(node.data)] = idx
return positions
def _for_pairs(self, nodes: Sequence[ASTNode]) -> Dict[int, int]:
def _for_pairs(self, nodes: Sequence[Op]) -> Dict[int, int]:
stack: List[int] = []
pairs: Dict[int, int] = {}
for idx, node in enumerate(nodes):
if isinstance(node, ForBegin):
if node.op == "for_begin":
stack.append(idx)
elif isinstance(node, ForEnd):
elif node.op == "for_end":
if not stack:
raise ParseError("'next' without matching 'for'")
begin_idx = stack.pop()
@@ -1337,13 +1313,13 @@ class CompileTimeVM:
raise ParseError("'for' without matching 'next'")
return pairs
def _begin_pairs(self, nodes: Sequence[ASTNode]) -> Dict[int, int]:
def _begin_pairs(self, nodes: Sequence[Op]) -> Dict[int, int]:
stack: List[int] = []
pairs: Dict[int, int] = {}
for idx, node in enumerate(nodes):
if isinstance(node, WordRef) and node.name == "begin":
if node.op == "word" and node.data == "begin":
stack.append(idx)
elif isinstance(node, WordRef) and node.name == "again":
elif node.op == "word" and node.data == "again":
if not stack:
raise ParseError("'again' without matching 'begin'")
begin_idx = stack.pop()
@@ -1611,48 +1587,57 @@ class Assembler:
else:
builder.emit("")
def _emit_node(self, node: ASTNode, builder: FunctionEmitter) -> None:
if isinstance(node, Literal):
if isinstance(node.value, int):
builder.push_literal(node.value)
def _emit_node(self, node: Op, builder: FunctionEmitter) -> None:
kind = node.op
data = node.data
if kind == "literal":
if isinstance(data, int):
builder.push_literal(data)
return
if isinstance(node.value, float):
label = self._intern_float_literal(node.value)
if isinstance(data, float):
label = self._intern_float_literal(data)
builder.push_float(label)
return
if isinstance(node.value, str):
label, length = self._intern_string_literal(node.value)
if isinstance(data, str):
label, length = self._intern_string_literal(data)
builder.push_label(label)
builder.push_literal(length)
return
raise CompileError(f"unsupported literal type {type(node.value)!r}")
return
if isinstance(node, WordRef):
self._emit_wordref(node, builder)
return
if isinstance(node, BranchZero):
self._emit_branch_zero(node, builder)
return
if isinstance(node, Jump):
builder.emit(f" jmp {node.target}")
return
if isinstance(node, Label):
builder.emit(f"{node.name}:")
return
if isinstance(node, ForBegin):
self._emit_for_begin(node, builder)
return
if isinstance(node, ForEnd):
self._emit_for_next(node, builder)
return
raise CompileError(f"unsupported AST node {node!r}")
raise CompileError(f"unsupported literal type {type(data)!r}")
def _emit_wordref(self, ref: WordRef, builder: FunctionEmitter) -> None:
word = self.dictionary.lookup(ref.name)
if kind == "word":
self._emit_wordref(str(data), builder)
return
if kind == "branch_zero":
self._emit_branch_zero(str(data), builder)
return
if kind == "jump":
builder.emit(f" jmp {data}")
return
if kind == "label":
builder.emit(f"{data}:")
return
if kind == "for_begin":
self._emit_for_begin(data, builder)
return
if kind == "for_end":
self._emit_for_next(data, builder)
return
raise CompileError(f"unsupported op {node!r}")
def _emit_wordref(self, name: str, builder: FunctionEmitter) -> None:
word = self.dictionary.lookup(name)
if word is None:
raise CompileError(f"unknown word '{ref.name}'")
raise CompileError(f"unknown word '{name}'")
if word.compile_only:
raise CompileError(f"word '{ref.name}' is compile-time only")
raise CompileError(f"word '{name}' is compile-time only")
if word.intrinsic:
word.intrinsic(builder)
return
@@ -1669,7 +1654,7 @@ class Assembler:
ret_type = signature[1] if signature else None
if len(arg_types) != inputs and signature:
raise CompileError(f"extern '{ref.name}' mismatch: {inputs} inputs vs {len(arg_types)} types")
raise CompileError(f"extern '{name}' mismatch: {inputs} inputs vs {len(arg_types)} types")
int_idx = 0
xmm_idx = 0
@@ -1679,19 +1664,19 @@ class Assembler:
if not arg_types:
# Legacy/Raw mode: assume all ints
if inputs > 6:
raise CompileError(f"extern '{ref.name}' has too many inputs ({inputs} > 6)")
raise CompileError(f"extern '{name}' has too many inputs ({inputs} > 6)")
for i in range(inputs):
mapping.append(("int", regs[i]))
else:
for type_name in arg_types:
if type_name in ("float", "double"):
if xmm_idx >= 8:
raise CompileError(f"extern '{ref.name}' has too many float inputs")
raise CompileError(f"extern '{name}' has too many float inputs")
mapping.append(("float", xmm_regs[xmm_idx]))
xmm_idx += 1
else:
if int_idx >= 6:
raise CompileError(f"extern '{ref.name}' has too many int inputs")
raise CompileError(f"extern '{name}' has too many int inputs")
mapping.append(("int", regs[int_idx]))
int_idx += 1
@@ -1706,7 +1691,7 @@ class Assembler:
builder.emit(" mov rbp, rsp")
builder.emit(" and rsp, -16")
builder.emit(f" mov al, {xmm_idx}")
builder.emit(f" call {ref.name}")
builder.emit(f" call {name}")
builder.emit(" leave")
# Handle Return Value
@@ -1721,30 +1706,34 @@ class Assembler:
raise CompileError("extern only supports 0 or 1 output")
else:
# Emit call to unresolved symbol (let linker resolve it)
builder.emit(f" call {ref.name}")
builder.emit(f" call {name}")
else:
builder.emit(f" call {sanitize_label(ref.name)}")
builder.emit(f" call {sanitize_label(name)}")
def _emit_branch_zero(self, node: BranchZero, builder: FunctionEmitter) -> None:
def _emit_branch_zero(self, target: str, builder: FunctionEmitter) -> None:
builder.pop_to("rax")
builder.emit(" test rax, rax")
builder.emit(f" jz {node.target}")
builder.emit(f" jz {target}")
def _emit_for_begin(self, node: ForBegin, builder: FunctionEmitter) -> None:
def _emit_for_begin(self, data: Dict[str, str], builder: FunctionEmitter) -> None:
loop_label = data["loop"]
end_label = data["end"]
builder.pop_to("rax")
builder.emit(" cmp rax, 0")
builder.emit(f" jle {node.end_label}")
builder.emit(f" jle {end_label}")
builder.emit(" sub r13, 8")
builder.emit(" mov [r13], rax")
builder.emit(f"{node.loop_label}:")
builder.emit(f"{loop_label}:")
def _emit_for_next(self, node: ForEnd, builder: FunctionEmitter) -> None:
def _emit_for_next(self, data: Dict[str, str], builder: FunctionEmitter) -> None:
loop_label = data["loop"]
end_label = data["end"]
builder.emit(" mov rax, [r13]")
builder.emit(" dec rax")
builder.emit(" mov [r13], rax")
builder.emit(f" jg {node.loop_label}")
builder.emit(f" jg {loop_label}")
builder.emit(" add r13, 8")
builder.emit(f"{node.end_label}:")
builder.emit(f"{end_label}:")
def _runtime_prelude(self) -> List[str]:
return [
@@ -1804,7 +1793,7 @@ class Assembler:
# ---------------------------------------------------------------------------
def macro_immediate(ctx: MacroContext) -> Optional[List[ASTNode]]:
def macro_immediate(ctx: MacroContext) -> Optional[List[Op]]:
parser = ctx.parser
word = parser.most_recent_definition()
if word is None:
@@ -1815,7 +1804,7 @@ def macro_immediate(ctx: MacroContext) -> Optional[List[ASTNode]]:
return None
def macro_compile_only(ctx: MacroContext) -> Optional[List[ASTNode]]:
def macro_compile_only(ctx: MacroContext) -> Optional[List[Op]]:
parser = ctx.parser
word = parser.most_recent_definition()
if word is None:
@@ -1826,7 +1815,7 @@ def macro_compile_only(ctx: MacroContext) -> Optional[List[ASTNode]]:
return None
def macro_compile_time(ctx: MacroContext) -> Optional[List[ASTNode]]:
def macro_compile_time(ctx: MacroContext) -> Optional[List[Op]]:
"""Run the next word at compile time and still emit it for runtime."""
parser = ctx.parser
if parser._eof():
@@ -1840,11 +1829,11 @@ def macro_compile_time(ctx: MacroContext) -> Optional[List[ASTNode]]:
raise ParseError(f"word '{name}' is compile-time only")
parser.compile_time_vm.invoke(word)
if isinstance(parser.context_stack[-1], Definition):
parser.emit_node(WordRef(name=name))
parser.emit_node(Op(op="word", data=name))
return None
def macro_begin_text_macro(ctx: MacroContext) -> Optional[List[ASTNode]]:
def macro_begin_text_macro(ctx: MacroContext) -> Optional[List[Op]]:
parser = ctx.parser
if parser._eof():
raise ParseError("macro name missing after 'macro:'")
@@ -1861,7 +1850,7 @@ def macro_begin_text_macro(ctx: MacroContext) -> Optional[List[ASTNode]]:
return None
def macro_end_text_macro(ctx: MacroContext) -> Optional[List[ASTNode]]:
def macro_end_text_macro(ctx: MacroContext) -> Optional[List[Op]]:
parser = ctx.parser
if parser.macro_recording is None:
raise ParseError("';macro' without matching 'macro:'")
@@ -2458,13 +2447,7 @@ def _register_compile_time_primitives(dictionary: Dictionary) -> None:
PY_EXEC_GLOBALS: Dict[str, Any] = {
"MacroContext": MacroContext,
"Token": Token,
"Literal": Literal,
"WordRef": WordRef,
"BranchZero": BranchZero,
"Jump": Jump,
"Label": Label,
"ForBegin": ForBegin,
"ForEnd": ForEnd,
"Op": Op,
"StructField": StructField,
"Definition": Definition,
"Module": Module,
@@ -2474,7 +2457,7 @@ PY_EXEC_GLOBALS: Dict[str, Any] = {
}
def macro_struct_begin(ctx: MacroContext) -> Optional[List[ASTNode]]:
def macro_struct_begin(ctx: MacroContext) -> Optional[List[Op]]:
parser = ctx.parser
if parser._eof():
raise ParseError("struct name missing after 'struct:'")
@@ -2529,7 +2512,7 @@ def macro_struct_begin(ctx: MacroContext) -> Optional[List[ASTNode]]:
return None
def macro_struct_end(ctx: MacroContext) -> Optional[List[ASTNode]]:
def macro_struct_end(ctx: MacroContext) -> Optional[List[Op]]:
raise ParseError("';struct' must follow a 'struct:' block")