From c4b15e066679c03cf989012fcd8f8d12754d9b7c Mon Sep 17 00:00:00 2001 From: igor Date: Fri, 20 Feb 2026 15:03:49 +0100 Subject: [PATCH] implemented quick sort, ported the repl to run on the ct vm and added some more optimizations --- main.py | 614 +++++++++++++++++++++++--------------- tests/quick_sort.expected | 31 ++ tests/quick_sort.sl | 102 +++++++ 3 files changed, 506 insertions(+), 241 deletions(-) create mode 100644 tests/quick_sort.expected create mode 100644 tests/quick_sort.sl diff --git a/main.py b/main.py index 2203cd4..d685f6e 100644 --- a/main.py +++ b/main.py @@ -11,6 +11,7 @@ This file now contains working scaffolding for: from __future__ import annotations import argparse +import bisect import ctypes import hashlib import json @@ -18,6 +19,7 @@ import mmap import os import re import shlex +import struct import subprocess import sys import shutil @@ -33,6 +35,11 @@ except Exception: # pragma: no cover - optional dependency KsError = Exception KS_ARCH_X86 = KS_MODE_64 = None +# Pre-compiled regex patterns used by JIT and BSS code +_RE_REL_PAT = re.compile(r'\[rel\s+(\w+)\]') +_RE_LABEL_PAT = re.compile(r'^(\.\w+|\w+):') +_RE_BSS_PERSISTENT = re.compile(r'persistent:\s*resb\s+(\d+)') + class ParseError(Exception): """Raised when the source stream cannot be parsed.""" @@ -51,7 +58,7 @@ class CompileTimeError(ParseError): # --------------------------------------------------------------------------- -@dataclass +@dataclass(slots=True) class Token: lexeme: str line: int @@ -63,7 +70,7 @@ class Token: return f"Token({self.lexeme!r}@{self.line}:{self.column})" -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class SourceLocation: path: Path line: int @@ -78,6 +85,8 @@ class Reader: self.column = 0 self.custom_tokens: Set[str] = {"(", ")", "{", "}", ";", ",", "[", "]"} self._token_order: List[str] = sorted(self.custom_tokens, key=len, reverse=True) + self._single_char_tokens: Set[str] = {t for t in self.custom_tokens if len(t) == 1} + self._multi_char_tokens: List[str] = [t for t in self._token_order if len(t) > 1] def add_tokens(self, tokens: Iterable[str]) -> None: updated = False @@ -89,6 +98,8 @@ class Reader: updated = True if updated: self._token_order = sorted(self.custom_tokens, key=len, reverse=True) + self._single_char_tokens = {t for t in self.custom_tokens if len(t) == 1} + self._multi_char_tokens = [t for t in self._token_order if len(t) > 1] def add_token_chars(self, chars: str) -> None: self.add_tokens(chars) @@ -155,10 +166,13 @@ class Reader: self.column += 1 continue matched_token: Optional[str] = None - for tok in self._token_order: - if source.startswith(tok, index): - matched_token = tok - break + if char in self._single_char_tokens: + matched_token = char + elif self._multi_char_tokens: + for tok in self._multi_char_tokens: + if source.startswith(tok, index): + matched_token = tok + break if matched_token is not None: if lexeme: yield Token("".join(lexeme), token_line, token_column, token_start, index) @@ -203,7 +217,36 @@ class Reader: # --------------------------------------------------------------------------- -@dataclass +# Integer opcode constants for hot-path dispatch +OP_WORD = 0 +OP_LITERAL = 1 +OP_WORD_PTR = 2 +OP_FOR_BEGIN = 3 +OP_FOR_END = 4 +OP_BRANCH_ZERO = 5 +OP_JUMP = 6 +OP_LABEL = 7 +OP_LIST_BEGIN = 8 +OP_LIST_END = 9 +OP_LIST_LITERAL = 10 +OP_OTHER = 11 + +_OP_STR_TO_INT = { + "word": OP_WORD, + "literal": OP_LITERAL, + "word_ptr": OP_WORD_PTR, + "for_begin": OP_FOR_BEGIN, + "for_end": OP_FOR_END, + "branch_zero": OP_BRANCH_ZERO, + "jump": OP_JUMP, + "label": OP_LABEL, + "list_begin": OP_LIST_BEGIN, + "list_end": OP_LIST_END, + "list_literal": OP_LIST_LITERAL, +} + + +@dataclass(slots=True) class Op: """Flat operation used for both compile-time execution and emission.""" @@ -211,9 +254,13 @@ class Op: data: Any = None loc: Optional[SourceLocation] = None _word_ref: Optional["Word"] = field(default=None, repr=False, compare=False) + _opcode: int = field(default=OP_OTHER, repr=False, compare=False) + + def __post_init__(self) -> None: + self._opcode = _OP_STR_TO_INT.get(self.op, OP_OTHER) -@dataclass +@dataclass(slots=True) class Definition: name: str body: List[Op] @@ -230,7 +277,7 @@ class Definition: _merged_runs: Optional[Dict[int, Tuple[int, str]]] = field(default=None, repr=False, compare=False) -@dataclass +@dataclass(slots=True) class AsmDefinition: name: str body: str @@ -239,7 +286,7 @@ class AsmDefinition: effects: Set[str] = field(default_factory=set) -@dataclass +@dataclass(slots=True) class Module: forms: List[Any] variables: Dict[str, str] = field(default_factory=dict) @@ -247,14 +294,14 @@ class Module: bss: Optional[List[str]] = None -@dataclass +@dataclass(slots=True) class MacroDefinition: name: str tokens: List[str] param_count: int = 0 -@dataclass +@dataclass(slots=True) class StructField: name: str offset: int @@ -332,7 +379,7 @@ _WORD_EFFECT_ALIASES: Dict[str, str] = { } -@dataclass +@dataclass(slots=True) class Word: name: str priority: int = 0 @@ -356,7 +403,12 @@ class Word: _suppress_redefine_warnings = False -@dataclass +def _suppress_redefine_warnings_set(value: bool) -> None: + global _suppress_redefine_warnings + _suppress_redefine_warnings = value + + +@dataclass(slots=True) class Dictionary: words: Dict[str, Word] = field(default_factory=dict) @@ -433,9 +485,17 @@ class Parser: self._pending_inline_definition: bool = False self._pending_priority: Optional[int] = None + def _rebuild_span_index(self) -> None: + """Rebuild bisect index after file_spans changes.""" + self._span_starts: List[int] = [s.start_line for s in self.file_spans] + def location_for_token(self, token: Token) -> SourceLocation: - for span in self.file_spans: - if span.start_line <= token.line < span.end_line: + if not hasattr(self, '_span_starts') or len(self._span_starts) != len(self.file_spans): + self._rebuild_span_index() + idx = bisect.bisect_right(self._span_starts, token.line) - 1 + if idx >= 0: + span = self.file_spans[idx] + if token.line < span.end_line: local_line = span.local_start_line + (token.line - span.start_line) return SourceLocation(span.path, local_line, token.column) return SourceLocation(Path(""), token.line, token.column) @@ -1197,21 +1257,25 @@ class Parser: raise ParseError("unknown parse context") def _try_literal(self, token: Token) -> bool: - try: - value = int(token.lexeme, 0) - self._append_op(Op(op="literal", data=value)) - return True - except ValueError: - pass - - # Try float - try: - if "." in token.lexeme or "e" in token.lexeme.lower(): - value = float(token.lexeme) + lexeme = token.lexeme + first = lexeme[0] if lexeme else '\0' + if first.isdigit() or first == '-' or first == '+': + try: + value = int(lexeme, 0) self._append_op(Op(op="literal", data=value)) return True - except ValueError: - pass + except ValueError: + pass + + # Try float + if first.isdigit() or first == '-' or first == '+' or first == '.': + try: + if "." in lexeme or "e" in lexeme.lower(): + value = float(lexeme) + self._append_op(Op(op="literal", data=value)) + return True + except ValueError: + pass string_value = _parse_string_literal(token) if string_value is not None: @@ -1460,9 +1524,8 @@ class CompileTimeVM: # Determine persistent size from BSS overrides if available. persistent_size = 0 if self.parser.custom_bss: - import re as _re_bss for bss_line in self.parser.custom_bss: - m = _re_bss.search(r'persistent:\s*resb\s+(\d+)', bss_line) + m = _RE_BSS_PERSISTENT.search(bss_line) if m: persistent_size = int(m.group(1)) self.memory = CTMemory(persistent_size) # fresh memory per invocation @@ -1531,8 +1594,7 @@ class CompileTimeVM: if self.runtime_mode: self.r12 -= 8 if isinstance(value, float): - import struct as _struct - bits = _struct.unpack("q", _struct.pack("d", value))[0] + bits = struct.unpack("q", struct.pack("d", value))[0] CTMemory.write_qword(self.r12, bits) else: CTMemory.write_qword(self.r12, _to_i64(int(value))) @@ -1758,13 +1820,12 @@ class CompileTimeVM: raw_args.reverse() # Convert arguments to proper ctypes values - import struct as _struct call_args = [] for i, raw in enumerate(raw_args): if i < len(arg_types) and arg_types[i] in ("float", "double"): # Reinterpret the int64 bits as a double (matching the language's convention) raw_int = _to_i64(int(raw)) - double_val = _struct.unpack("d", _struct.pack("q", raw_int))[0] + double_val = struct.unpack("d", struct.pack("q", raw_int))[0] call_args.append(double_val) else: call_args.append(int(raw)) @@ -1774,7 +1835,7 @@ class CompileTimeVM: if outputs > 0 and result is not None: ret_type = func._ct_signature[1] if func._ct_signature else None if ret_type in ("float", "double"): - int_bits = _struct.unpack("q", _struct.pack("d", float(result)))[0] + int_bits = struct.unpack("q", struct.pack("d", float(result)))[0] self.push(int_bits) else: self.push(int(result)) @@ -1843,8 +1904,6 @@ class CompileTimeVM: raise ParseError(f"word '{word.name}' has no asm body") asm_body = definition.body.strip("\n") - import re as _re - _rel_pat = _re.compile(r'\[rel\s+(\w+)\]') bss = self._bss_symbols # Build wrapper @@ -1874,19 +1933,19 @@ class CompileTimeVM: line = "jmp _ct_save" # Patch [rel SYMBOL] → concrete address - m = _rel_pat.search(line) + m = _RE_REL_PAT.search(line) if m and m.group(1) in bss: sym = m.group(1) addr = bss[sym] if line.lstrip().startswith("lea"): # lea REG, [rel X] → mov REG, addr - line = _rel_pat.sub(str(addr), line).replace("lea", "mov", 1) + line = _RE_REL_PAT.sub(str(addr), line).replace("lea", "mov", 1) else: # e.g. mov rax, [rel X] or mov byte [rel X], val # Replace with push/mov-rax/substitute/pop trampoline lines.append(" push rax") lines.append(f" mov rax, {addr}") - new_line = _rel_pat.sub("[rax]", line) + new_line = _RE_REL_PAT.sub("[rax]", line) lines.append(f" {new_line}") lines.append(" pop rax") continue @@ -2033,21 +2092,18 @@ class CompileTimeVM: "persistent": self.memory.persistent_addr, "persistent_end": self.memory.persistent_addr + self.memory._persistent_size, }) - import re as _re - _rel_pat = _re.compile(r'\[rel\s+(\w+)\]') - for line in asm_body.splitlines(): line = line.strip() if line == "ret": line = "jmp _ct_save" # Replace [rel SYMBOL] with concrete addresses - m = _rel_pat.search(line) + m = _RE_REL_PAT.search(line) if m and m.group(1) in _bss_symbols: sym = m.group(1) addr = _bss_symbols[sym] # lea REG, [rel X] → mov REG, addr if line.lstrip().startswith("lea"): - line = _rel_pat.sub(str(addr), line).replace("lea", "mov", 1) + line = _RE_REL_PAT.sub(str(addr), line).replace("lea", "mov", 1) else: # For memory operands like mov byte [rel X], val # replace [rel X] with [] @@ -2055,7 +2111,7 @@ class CompileTimeVM: # Use a scratch register to hold the address patched_body.append(f"push rax") patched_body.append(f"mov rax, {addr}") - new_line = _rel_pat.sub("[rax]", line) + new_line = _RE_REL_PAT.sub("[rax]", line) patched_body.append(new_line) patched_body.append(f"pop rax") continue @@ -2184,7 +2240,7 @@ class CompileTimeVM: return lookup = self.dictionary.lookup for node in defn.body: - if node.op == "word" and node._word_ref is None: + if node._opcode == OP_WORD and node._word_ref is None: name = str(node.data) # Skip structural keywords that _execute_nodes handles inline if name not in ("begin", "again", "continue", "exit", "get_addr"): @@ -2196,11 +2252,10 @@ class CompileTimeVM: def _prepare_definition(self, defn: Definition) -> Tuple[Dict[str, int], Dict[int, int], Dict[int, int]]: """Return (label_positions, for_pairs, begin_pairs), cached on the Definition.""" if defn._label_positions is None: - defn._label_positions = self._label_positions(defn.body) - if defn._for_pairs is None: - defn._for_pairs = self._for_pairs(defn.body) - if defn._begin_pairs is None: - defn._begin_pairs = self._begin_pairs(defn.body) + lp, fp, bp = self._analyze_nodes(defn.body) + defn._label_positions = lp + defn._for_pairs = fp + defn._begin_pairs = bp self._resolve_words_in_body(defn) if self.runtime_mode: # Merged JIT runs are a performance optimization, but have shown @@ -2220,14 +2275,14 @@ class CompileTimeVM: i = 0 while i < n: # Start of a potential run - if body[i].op == "word" and body[i]._word_ref is not None: + if body[i]._opcode == OP_WORD and body[i]._word_ref is not None: w = body[i]._word_ref if (w.runtime_intrinsic is None and isinstance(w.definition, AsmDefinition) and not w.compile_time_override): run_start = i run_words = [w.name] i += 1 - while i < n and body[i].op == "word" and body[i]._word_ref is not None: + while i < n and body[i]._opcode == OP_WORD and body[i]._word_ref is not None: w2 = body[i]._word_ref if (w2.runtime_intrinsic is None and isinstance(w2.definition, AsmDefinition) and not w2.compile_time_override): @@ -2247,9 +2302,6 @@ class CompileTimeVM: if Ks is None: raise ParseError("keystone-engine is required for JIT execution") - import re as _re - _rel_pat = _re.compile(r'\[rel\s+(\w+)\]') - _label_pat = _re.compile(r'^(\.\w+|\w+):') bss = self._bss_symbols lines: List[str] = [] @@ -2277,7 +2329,7 @@ class CompileTimeVM: local_labels: Set[str] = set() for raw_line in asm_body.splitlines(): line = raw_line.strip() - lm = _label_pat.match(line) + lm = _RE_LABEL_PAT.match(line) if lm: local_labels.add(lm.group(1)) @@ -2300,16 +2352,16 @@ class CompileTimeVM: line = _re.sub(rf'(? Tuple[Dict[str, int], Dict[int, int], Dict[int, int]]: + """Single-pass analysis: returns (label_positions, for_pairs, begin_pairs).""" + label_positions: Dict[str, int] = {} + for_pairs: Dict[int, int] = {} + begin_pairs: Dict[int, int] = {} + for_stack: List[int] = [] + begin_stack: List[int] = [] + for idx, node in enumerate(nodes): + opc = node._opcode + if opc == OP_LABEL: + label_positions[str(node.data)] = idx + elif opc == OP_FOR_BEGIN: + for_stack.append(idx) + elif opc == OP_FOR_END: + if not for_stack: + raise ParseError("'next' without matching 'for'") + begin_idx = for_stack.pop() + for_pairs[begin_idx] = idx + for_pairs[idx] = begin_idx + elif opc == OP_WORD: + d = node.data + if d == "begin": + begin_stack.append(idx) + elif d == "again": + if not begin_stack: + raise ParseError("'again' without matching 'begin'") + begin_idx = begin_stack.pop() + begin_pairs[begin_idx] = idx + begin_pairs[idx] = begin_idx + if for_stack: + raise ParseError("'for' without matching 'next'") + if begin_stack: + raise ParseError("'begin' without matching 'again'") + return label_positions, for_pairs, begin_pairs + def _label_positions(self, nodes: Sequence[Op]) -> Dict[str, int]: positions: Dict[str, int] = {} for idx, node in enumerate(nodes): - if node.op == "label": + if node._opcode == OP_LABEL: positions[str(node.data)] = idx return positions @@ -2657,9 +2754,9 @@ class CompileTimeVM: stack: List[int] = [] pairs: Dict[int, int] = {} for idx, node in enumerate(nodes): - if node.op == "for_begin": + if node._opcode == OP_FOR_BEGIN: stack.append(idx) - elif node.op == "for_end": + elif node._opcode == OP_FOR_END: if not stack: raise ParseError("'next' without matching 'for'") begin_idx = stack.pop() @@ -2673,9 +2770,9 @@ class CompileTimeVM: stack: List[int] = [] pairs: Dict[int, int] = {} for idx, node in enumerate(nodes): - if node.op == "word" and node.data == "begin": + if node._opcode == OP_WORD and node.data == "begin": stack.append(idx) - elif node.op == "word" and node.data == "again": + elif node._opcode == OP_WORD and node.data == "again": if not stack: raise ParseError("'again' without matching 'begin'") begin_idx = stack.pop() @@ -2696,7 +2793,7 @@ class CompileTimeVM: # --------------------------------------------------------------------------- -@dataclass +@dataclass(slots=True) class Emission: text: List[str] = field(default_factory=list) data: List[str] = field(default_factory=list) @@ -2817,7 +2914,13 @@ _FOLDABLE_WORDS: Dict[str, Tuple[int, Callable[..., int]]] = { } +_sanitize_label_cache: Dict[str, str] = {} + + def sanitize_label(name: str) -> str: + cached = _sanitize_label_cache.get(name) + if cached is not None: + return cached parts: List[str] = [] for ch in name: if ch.isalnum() or ch == "_": @@ -2827,7 +2930,8 @@ def sanitize_label(name: str) -> str: safe = "".join(parts) or "anon" if safe[0].isdigit(): safe = "_" + safe - return f"{safe}" + _sanitize_label_cache[name] = safe + return safe def _is_identifier(text: str) -> bool: @@ -2987,7 +3091,6 @@ class Assembler: def _peephole_optimize_definition(self, definition: Definition) -> None: # Rewrite short stack-manipulation sequences into canonical forms. - # Rules are ordered longest-first by matching logic below. rules: List[Tuple[Tuple[str, ...], Tuple[str, ...]]] = [ (("swap", "drop"), ("nip",)), # Stack no-ops @@ -3007,6 +3110,11 @@ class Assembler: max_pat_len = max(len(pattern) for pattern, _ in rules) + # Build index: first word -> list of (pattern, replacement) + rule_index: Dict[str, List[Tuple[Tuple[str, ...], Tuple[str, ...]]]] = {} + for pattern, repl in rules: + rule_index.setdefault(pattern[0], []).append((pattern, repl)) + nodes = definition.body changed = True while changed: @@ -3014,26 +3122,30 @@ class Assembler: optimized: List[Op] = [] idx = 0 while idx < len(nodes): + node = nodes[idx] matched = False - for window in range(min(max_pat_len, len(nodes) - idx), 1, -1): - segment = nodes[idx:idx + window] - if any(node.op != "word" for node in segment): - continue - names = tuple(str(node.data) for node in segment) - replacement: Optional[Tuple[str, ...]] = None - for pattern, repl in rules: - if names == pattern: - replacement = repl + if node._opcode == OP_WORD: + candidates = rule_index.get(str(node.data)) + if candidates: + for window in range(min(max_pat_len, len(nodes) - idx), 1, -1): + segment = nodes[idx:idx + window] + if any(n._opcode != OP_WORD for n in segment): + continue + names = tuple(str(n.data) for n in segment) + replacement: Optional[Tuple[str, ...]] = None + for pattern, repl in candidates: + if names == pattern: + replacement = repl + break + if replacement is None: + continue + base_loc = segment[0].loc + for repl_name in replacement: + optimized.append(Op(op="word", data=repl_name, loc=base_loc)) + idx += window + changed = True + matched = True break - if replacement is None: - continue - base_loc = segment[0].loc - for repl_name in replacement: - optimized.append(Op(op="word", data=repl_name, loc=base_loc)) - idx += window - changed = True - matched = True - break if matched: continue optimized.append(nodes[idx]) @@ -3052,7 +3164,7 @@ class Assembler: if idx + 1 < len(nodes): a = nodes[idx] b = nodes[idx + 1] - if a.op == "word" and b.op == "word": + if a._opcode == OP_WORD and b._opcode == OP_WORD: wa = str(a.data) wb = str(b.data) if (wa, wb) in { @@ -3067,7 +3179,7 @@ class Assembler: if idx + 1 < len(nodes): lit = nodes[idx] op = nodes[idx + 1] - if lit.op == "literal" and isinstance(lit.data, int) and op.op == "word": + if lit._opcode == OP_LITERAL and isinstance(lit.data, int) and op._opcode == OP_WORD: k = int(lit.data) w = str(op.data) base_loc = lit.loc or op.loc @@ -3136,7 +3248,7 @@ class Assembler: if len(nodes) < arity + 1: return operands = nodes[-(arity + 1):-1] - if any(op.op != "literal" or not isinstance(op.data, int) for op in operands): + if any(op._opcode != OP_LITERAL or not isinstance(op.data, int) for op in operands): return values = [int(op.data) for op in operands] try: @@ -3150,9 +3262,9 @@ class Assembler: stack: List[int] = [] pairs: Dict[int, int] = {} for idx, node in enumerate(nodes): - if node.op == "for_begin": + if node._opcode == OP_FOR_BEGIN: stack.append(idx) - elif node.op == "for_end": + elif node._opcode == OP_FOR_END: if not stack: raise CompileError("'end' without matching 'for'") begin_idx = stack.pop() @@ -3165,14 +3277,14 @@ class Assembler: def _collect_internal_labels(self, nodes: Sequence[Op]) -> Set[str]: labels: Set[str] = set() for node in nodes: - kind = node.op + kind = node._opcode data = node.data - if kind == "label": + if kind == OP_LABEL: labels.add(str(data)) - elif kind in ("for_begin", "for_end"): + elif kind == OP_FOR_BEGIN or kind == OP_FOR_END: labels.add(str(data["loop"])) labels.add(str(data["end"])) - elif kind in ("list_begin", "list_end"): + elif kind == OP_LIST_BEGIN or kind == OP_LIST_END: labels.add(str(data)) return labels @@ -3193,20 +3305,20 @@ class Assembler: cloned: List[Op] = [] for node in nodes: - kind = node.op + kind = node._opcode data = node.data - if kind == "label": + if kind == OP_LABEL: cloned.append(Op(op="label", data=remap(str(data)), loc=node.loc)) continue - if kind in ("jump", "branch_zero"): + if kind == OP_JUMP or kind == OP_BRANCH_ZERO: target = str(data) mapped = remap(target) if target in internal_labels else target - cloned.append(Op(op=kind, data=mapped, loc=node.loc)) + cloned.append(Op(op=node.op, data=mapped, loc=node.loc)) continue - if kind in ("for_begin", "for_end"): + if kind == OP_FOR_BEGIN or kind == OP_FOR_END: cloned.append( Op( - op=kind, + op=node.op, data={ "loop": remap(str(data["loop"])), "end": remap(str(data["end"])), @@ -3215,10 +3327,10 @@ class Assembler: ) ) continue - if kind in ("list_begin", "list_end"): - cloned.append(Op(op=kind, data=remap(str(data)), loc=node.loc)) + if kind == OP_LIST_BEGIN or kind == OP_LIST_END: + cloned.append(Op(op=node.op, data=remap(str(data)), loc=node.loc)) continue - cloned.append(Op(op=kind, data=data, loc=node.loc)) + cloned.append(Op(op=node.op, data=data, loc=node.loc)) return cloned def _unroll_constant_for_loops(self, definition: Definition) -> None: @@ -3234,9 +3346,9 @@ class Assembler: idx = 0 while idx < len(nodes): node = nodes[idx] - if node.op == "for_begin" and idx > 0: + if node._opcode == OP_FOR_BEGIN and idx > 0: prev = nodes[idx - 1] - if prev.op == "literal" and isinstance(prev.data, int): + if prev._opcode == OP_LITERAL and isinstance(prev.data, int): count = int(prev.data) end_idx = pairs.get(idx) if end_idx is None: @@ -3286,12 +3398,12 @@ class Assembler: while j < len(nodes): cur = nodes[j] - if cur.op == "list_begin": + if cur._opcode == OP_LIST_BEGIN: depth += 1 is_static = False j += 1 continue - if cur.op == "list_end": + if cur._opcode == OP_LIST_END: depth -= 1 if depth == 0: break @@ -3299,7 +3411,7 @@ class Assembler: continue if depth == 1: - if cur.op == "literal" and isinstance(cur.data, int): + if cur._opcode == OP_LITERAL and isinstance(cur.data, int): static_values.append(int(cur.data)) else: is_static = False @@ -3519,37 +3631,37 @@ class Assembler: return label_map[label] for node in definition.body: - kind = node.op + kind = node._opcode data = node.data - if kind == "label": + if kind == OP_LABEL: mapped = remap(str(data)) self._emit_node(Op(op="label", data=mapped), builder) continue - if kind == "jump": + if kind == OP_JUMP: mapped = remap(str(data)) self._emit_node(Op(op="jump", data=mapped), builder) continue - if kind == "branch_zero": + if kind == OP_BRANCH_ZERO: mapped = remap(str(data)) self._emit_node(Op(op="branch_zero", data=mapped), builder) continue - if kind == "for_begin": + if kind == OP_FOR_BEGIN: mapped = { "loop": remap(data["loop"]), "end": remap(data["end"]), } self._emit_node(Op(op="for_begin", data=mapped), builder) continue - if kind == "for_end": + if kind == OP_FOR_END: mapped = { "loop": remap(data["loop"]), "end": remap(data["end"]), } self._emit_node(Op(op="for_end", data=mapped), builder) continue - if kind in ("list_begin", "list_end"): + if kind == OP_LIST_BEGIN or kind == OP_LIST_END: mapped = remap(str(data)) - self._emit_node(Op(op=kind, data=mapped), builder) + self._emit_node(Op(op=node.op, data=mapped), builder) continue self._emit_node(node, builder) @@ -3566,14 +3678,14 @@ class Assembler: builder.emit("") def _emit_node(self, node: Op, builder: FunctionEmitter) -> None: - kind = node.op + kind = node._opcode data = node.data builder.set_location(node.loc) def ctx() -> str: return f" while emitting '{self._emit_stack[-1]}'" if self._emit_stack else "" - if kind == "literal": + if kind == OP_LITERAL: if isinstance(data, int): builder.push_literal(data) return @@ -3588,35 +3700,35 @@ class Assembler: return raise CompileError(f"unsupported literal type {type(data)!r}{ctx()}") - if kind == "word": + if kind == OP_WORD: self._emit_wordref(str(data), builder) return - if kind == "word_ptr": + if kind == OP_WORD_PTR: self._emit_wordptr(str(data), builder) return - if kind == "branch_zero": + if kind == OP_BRANCH_ZERO: self._emit_branch_zero(str(data), builder) return - if kind == "jump": + if kind == OP_JUMP: builder.emit(f" jmp {data}") return - if kind == "label": + if kind == OP_LABEL: builder.emit(f"{data}:") return - if kind == "for_begin": + if kind == OP_FOR_BEGIN: self._emit_for_begin(data, builder) return - if kind == "for_end": + if kind == OP_FOR_END: self._emit_for_next(data, builder) return - if kind == "list_begin": + if kind == OP_LIST_BEGIN: builder.comment("list begin") builder.emit(" mov rax, [rel list_capture_sp]") builder.emit(" lea rdx, [rel list_capture_stack]") @@ -3625,7 +3737,7 @@ class Assembler: builder.emit(" mov [rel list_capture_sp], rax") return - if kind == "list_literal": + if kind == OP_LIST_LITERAL: values = list(data or []) count = len(values) bytes_needed = (count + 1) * 8 @@ -3645,7 +3757,7 @@ class Assembler: builder.emit(" mov [r12], rax") return - if kind == "list_end": + if kind == OP_LIST_END: base = str(data) loop_label = f"{base}_copy_loop" done_label = f"{base}_copy_done" @@ -3710,8 +3822,7 @@ class Assembler: suffix = f" while emitting '{self._emit_stack[-1]}'" if self._emit_stack else "" raise CompileError(f"unknown word '{name}'{suffix}") if word.compile_only: - suffix = f" while emitting '{self._emit_stack[-1]}'" if self._emit_stack else "" - raise CompileError(f"word '{name}' is compile-time only{suffix}") + return # silently skip compile-time-only words during emission if getattr(word, "inline", False) and isinstance(word.definition, Definition): if word.name in self._inline_stack: suffix = f" while emitting '{self._emit_stack[-1]}'" if self._emit_stack else "" @@ -3953,7 +4064,7 @@ def macro_label(ctx: MacroContext) -> Optional[List[Op]]: if not _is_identifier(name): raise ParseError(f"invalid label name '{name}'") definition = _require_definition_context(parser, "label") - if any(node.op == "label" and node.data == name for node in definition.body): + if any(node._opcode == OP_LABEL and node.data == name for node in definition.body): raise ParseError(f"duplicate label '{name}' in definition '{definition.name}'") parser.emit_node(Op(op="label", data=name)) return None @@ -4651,6 +4762,48 @@ def _ct_lexer_push_back(vm: CompileTimeVM) -> None: vm.push(lexer) +def _ct_eval(vm: CompileTimeVM) -> None: + """Pop a string from TOS and execute it in the compile-time VM.""" + if vm.runtime_mode: + length = vm.pop_int() + addr = vm.pop_int() + source = ctypes.string_at(addr, length).decode("utf-8") + else: + source = vm.pop_str() + tokens = list(vm.parser.reader.tokenize(source)) + # Parse as if inside a definition body to get Op nodes + parser = vm.parser + # Save parser state + old_tokens = parser.tokens + old_pos = parser.pos + old_iter = parser._token_iter + old_exhausted = parser._token_iter_exhausted + old_source = parser.source + # Set up temporary token stream + parser.tokens = list(tokens) + parser.pos = 0 + parser._token_iter = iter([]) + parser._token_iter_exhausted = True + parser.source = "" + # Collect ops by capturing what _handle_token appends + temp_defn = Definition(name="__eval__", body=[]) + parser.context_stack.append(temp_defn) + try: + while not parser._eof(): + token = parser._consume() + parser._handle_token(token) + finally: + parser.context_stack.pop() + # Restore parser state + parser.tokens = old_tokens + parser.pos = old_pos + parser._token_iter = old_iter + parser._token_iter_exhausted = old_exhausted + parser.source = old_source + # Execute collected ops in the VM + if temp_defn.body: + vm._execute_nodes(temp_defn.body) + # --------------------------------------------------------------------------- # Runtime intrinsics that cannot run as native JIT (for --ct-run-main) @@ -4945,6 +5098,7 @@ def _register_compile_time_primitives(dictionary: Dictionary) -> None: register("lexer-expect", _ct_lexer_expect, compile_only=True) register("lexer-collect-brace", _ct_lexer_collect_brace, compile_only=True) register("lexer-push-back", _ct_lexer_push_back, compile_only=True) + register("eval", _ct_eval, compile_only=True) @@ -5082,6 +5236,13 @@ class Compiler: module = self.parser.parse(tokens, source) return self.assembler.emit(module, debug=debug, entry_mode=entry_mode) + def parse_file(self, path: Path) -> None: + """Parse a source file to populate the dictionary without emitting assembly.""" + source, spans = self._load_with_imports(path.resolve()) + self.parser.file_spans = spans or [] + tokens = self.reader.tokenize(source) + self.parser.parse(tokens, source) + def compile_file(self, path: Path, *, debug: bool = False, entry_mode: str = "program") -> Emission: source, spans = self._load_with_imports(path.resolve()) return self.compile_source(source, spans=spans, debug=debug, entry_mode=entry_mode) @@ -5473,6 +5634,8 @@ def run_repl( debug: bool = False, initial_source: Optional[Path] = None, ) -> int: + """REPL backed by the compile-time VM for instant execution.""" + def _block_defines_main(block: str) -> bool: stripped_lines = [ln.strip() for ln in block.splitlines() if ln.strip() and not ln.strip().startswith("#")] for idx, stripped in enumerate(stripped_lines): @@ -5487,10 +5650,6 @@ def run_repl( return False temp_dir.mkdir(parents=True, exist_ok=True) - global _suppress_redefine_warnings - asm_path = temp_dir / "repl.asm" - obj_path = temp_dir / "repl.o" - exe_path = temp_dir / "repl.out" src_path = temp_dir / "repl.sl" editor_cmd = os.environ.get("EDITOR") or "vim" @@ -5501,6 +5660,8 @@ def run_repl( main_body: List[str] = [] has_user_main = False + include_paths = list(compiler.include_paths) + if initial_source is not None: try: initial_text = initial_source.read_text() @@ -5512,27 +5673,62 @@ def run_repl( except Exception as exc: print(f"[repl] failed to load {initial_source}: {exc}") + def _run_on_ct_vm(source: str, word_name: str = "main") -> bool: + """Parse source and execute word_name via the compile-time VM. + + Returns True on success, False on error (already printed). + """ + nonlocal compiler + src_path.write_text(source) + try: + _suppress_redefine_warnings_set(True) + compiler._loaded_files.clear() + compiler.parse_file(src_path) + except (ParseError, CompileError, CompileTimeError) as exc: + print(f"[error] {exc}") + return False + except Exception as exc: + print(f"[error] parse failed: {exc}") + return False + finally: + _suppress_redefine_warnings_set(False) + + try: + compiler.run_compile_time_word(word_name, libs=list(libs)) + except (CompileTimeError, _CTVMExit) as exc: + if isinstance(exc, _CTVMExit): + code = exc.args[0] if exc.args else 0 + if code != 0: + print(f"[warn] program exited with code {code}") + else: + print(f"[error] {exc}") + return False + except Exception as exc: + print(f"[error] execution failed: {exc}") + return False + return True + def _print_help() -> None: print("[repl] commands:") print(" :help show this help") print(" :show display current session source (with synthetic main if pending snippet)") print(" :reset clear session imports/defs") print(" :load load a source file into the session") - print(" :call compile and run a program that calls ") + print(" :call execute a word via the compile-time VM") print(" :edit [file] open session file or given file in editor") print(" :seteditor [cmd] show/set editor command (default from $EDITOR or vim)") print(" :quit | :q exit the REPL") print("[repl] free-form input:") print(" definitions (word/:asm/:py/extern/macro/struct) extend the session") print(" imports add to session imports") - print(" other lines run immediately in an isolated temp program (not saved)") + print(" other lines run immediately via the compile-time VM (not saved)") print(" multiline: end lines with \\ to continue; finish with a non-\\ line") print("[repl] type L2 code; :help for commands; :quit to exit") + print("[repl] execution via compile-time VM (instant, no nasm/ld)") print("[repl] enter multiline with trailing \\; finish with a line without \\") pending_block: List[str] = [] - snippet_counter = 0 while True: try: @@ -5554,6 +5750,8 @@ def run_repl( main_body.clear() has_user_main = False pending_block.clear() + # Re-create compiler for a clean dictionary state + compiler = Compiler(include_paths=include_paths) print("[repl] session cleared") continue if stripped.startswith(":seteditor"): @@ -5633,32 +5831,17 @@ def run_repl( if not word_name: print("[repl] usage: :call ") continue - try: - if word_name == "main" and not has_user_main: - print("[repl] cannot call main; no user-defined main present") - continue - if word_name == "main" and has_user_main: - builder_source = _repl_build_source(imports, user_defs_files, user_defs_repl, [], True, force_synthetic=False) - else: - # Override entrypoint with a tiny wrapper that calls the target word. - temp_defs_repl = [*user_defs_repl, f"word main\n {word_name}\nend"] - builder_source = _repl_build_source(imports, user_defs_files, temp_defs_repl, [], True, force_synthetic=False) - src_path.write_text(builder_source) - _suppress_redefine_warnings = True - try: - emission = compiler.compile_file(src_path, debug=debug) - finally: - _suppress_redefine_warnings = False - compiler.assembler.write_asm(emission, asm_path) - run_nasm(asm_path, obj_path, debug=debug) - run_linker(obj_path, exe_path, debug=debug, libs=list(libs)) - result = subprocess.run([str(exe_path)]) - if result.returncode != 0: - print(f"[warn] program exited with code {result.returncode}") - except (ParseError, CompileError, CompileTimeError) as exc: - print(f"[error] {exc}") - except Exception as exc: - print(f"[error] build failed: {exc}") + if word_name == "main" and not has_user_main: + print("[repl] cannot call main; no user-defined main present") + continue + if word_name == "main" and has_user_main: + source = _repl_build_source(imports, user_defs_files, user_defs_repl, [], True, force_synthetic=False) + else: + temp_defs = [*user_defs_repl, f"word __repl_call__\n {word_name}\nend"] + source = _repl_build_source(imports, user_defs_files, temp_defs, [], True, force_synthetic=False) + _run_on_ct_vm(source, "__repl_call__") + continue + _run_on_ct_vm(source, word_name) continue if not stripped: continue @@ -5689,15 +5872,8 @@ def run_repl( main_body.clear() user_defs_repl.append(block) else: - # Run arbitrary snippet in an isolated temp program without touching session files. - snippet_counter += 1 - snippet_id = snippet_counter - snippet_src = temp_dir / f"repl_snippet_{snippet_id}.sl" - snippet_asm = temp_dir / f"repl_snippet_{snippet_id}.asm" - snippet_obj = temp_dir / f"repl_snippet_{snippet_id}.o" - snippet_exe = temp_dir / f"repl_snippet_{snippet_id}.out" - - snippet_source = _repl_build_source( + # Execute snippet immediately via the compile-time VM. + source = _repl_build_source( imports, user_defs_files, user_defs_repl, @@ -5705,66 +5881,22 @@ def run_repl( has_user_main, force_synthetic=True, ) - try: - snippet_src.write_text(snippet_source) - _suppress_redefine_warnings = True - try: - emission = compiler.compile_file(snippet_src, debug=debug) - finally: - _suppress_redefine_warnings = False - compiler.assembler.write_asm(emission, snippet_asm) - run_nasm(snippet_asm, snippet_obj, debug=debug) - run_linker(snippet_obj, snippet_exe, debug=debug, libs=list(libs)) - except (ParseError, CompileError, CompileTimeError) as exc: - print(f"[error] {exc}") - for p in (snippet_src, snippet_asm, snippet_obj, snippet_exe): - try: - p.unlink(missing_ok=True) - except Exception: - pass - continue - except Exception as exc: - print(f"[error] build failed: {exc}") - for p in (snippet_src, snippet_asm, snippet_obj, snippet_exe): - try: - p.unlink(missing_ok=True) - except Exception: - pass - continue - - try: - result = subprocess.run([str(snippet_exe)]) - if result.returncode != 0: - print(f"[warn] program exited with code {result.returncode}") - except Exception as exc: - print(f"[error] execution failed: {exc}") - finally: - for p in (snippet_src, snippet_asm, snippet_obj, snippet_exe): - try: - p.unlink(missing_ok=True) - except Exception: - pass + _run_on_ct_vm(source) continue + # Validate definitions by parsing (no execution needed). source = _repl_build_source(imports, user_defs_files, user_defs_repl, main_body, has_user_main, force_synthetic=bool(main_body)) - try: src_path.write_text(source) - _suppress_redefine_warnings = True + _suppress_redefine_warnings_set(True) try: - emission = compiler.compile_file(src_path, debug=debug) + compiler._loaded_files.clear() + compiler.parse_file(src_path) finally: - _suppress_redefine_warnings = False + _suppress_redefine_warnings_set(False) except (ParseError, CompileError, CompileTimeError) as exc: print(f"[error] {exc}") continue - try: - compiler.assembler.write_asm(emission, asm_path) - run_nasm(asm_path, obj_path, debug=debug) - run_linker(obj_path, exe_path, debug=debug, libs=list(libs)) - except Exception as exc: - print(f"[error] build failed: {exc}") - continue return 0 diff --git a/tests/quick_sort.expected b/tests/quick_sort.expected new file mode 100644 index 0000000..d74af6e --- /dev/null +++ b/tests/quick_sort.expected @@ -0,0 +1,31 @@ +1 +2 +3 +4 +5 +6 +7 +8 +1 +2 +3 +4 +5 +1 +3 +5 +7 +9 +42 +1 +1 +2 +3 +3 +4 +5 +5 +6 +9 +10 +20 diff --git a/tests/quick_sort.sl b/tests/quick_sort.sl new file mode 100644 index 0000000..a5223e7 --- /dev/null +++ b/tests/quick_sort.sl @@ -0,0 +1,102 @@ +import ../stdlib/stdlib.sl +import ../stdlib/io.sl +import ../stdlib/arr.sl + +# Get element from static array, preserving the array pointer +# [*, arr | i] -> [*, arr | value] +word aget + over swap arr_get_static +end + +# Set element in static array, preserving the array pointer +# [*, arr, value | i] -> [* | arr] +word aset + rot dup >r -rot arr_set_static r> +end + +# Swap elements at indices i and j in a static array +# [*, arr, i | j] -> [* | arr] +word arr_swap + >r >r + 0 rpick aget + swap + 1 rpick aget + 0 rpick aset + swap + 1 rpick aset + rdrop rdrop +end + +# Lomuto partition (ascending, signed comparison) +# [*, arr, lo | hi] -> [*, arr | pivot_index] +word partition + >r >r + 1 rpick aget + >r + 1 rpick dec + 1 rpick + while dup 2 rpick < do + 2 pick over aget nip + 0 rpick <= + if + swap inc swap + 2 pick 2 pick 2 pick arr_swap drop + end + inc + end + drop inc + over over 2 rpick arr_swap drop + rdrop rdrop rdrop +end + +# Recursive quicksort +# [*, arr, lo | hi] -> [* | arr] +word qsort_rec + over over >= if + drop drop + else + >r >r + 0 rpick 1 rpick + partition + over 0 rpick + 2 pick dec + qsort_rec + drop + over swap inc + 1 rpick + qsort_rec + drop + rdrop rdrop + end +end + +# Quicksort for static arrays (in-place, ascending) +# [* | arr] -> [* | arr] +word arr_qsort + dup @ dec + dup 0 < if + drop + else + >r dup 0 r> + qsort_rec + end +end + +# Print all elements of a static array, one per line +word print_arr + dup @ 0 + while 2dup > do + 2 pick over arr_get_static puti cr + 1 + + end + 2drop drop +end + +word main + [ 5 3 8 1 7 2 6 4 ] arr_qsort print_arr + [ 1 2 3 4 5 ] arr_qsort print_arr + [ 9 7 5 3 1 ] arr_qsort print_arr + [ 42 ] arr_qsort print_arr + [ 3 1 4 1 5 9 2 6 5 3 ] arr_qsort print_arr + [ 20 10 ] arr_qsort print_arr +end