diff --git a/extra_tests/ct_test.compile.expected b/extra_tests/ct_test.compile.expected index 084d950..2e54142 100644 --- a/extra_tests/ct_test.compile.expected +++ b/extra_tests/ct_test.compile.expected @@ -1,2 +1,3 @@ hello world -[info] built /home/igor/programming/IgorCielniak/l2/build/ct_test +[info] built /ct_test +[warn] redefining word syscall diff --git a/main.py b/main.py index 226822d..d622190 100644 --- a/main.py +++ b/main.py @@ -207,6 +207,7 @@ class Op: op: str data: Any = None loc: Optional[SourceLocation] = None + _word_ref: Optional["Word"] = field(default=None, repr=False, compare=False) @dataclass @@ -217,6 +218,13 @@ class Definition: compile_only: bool = False terminator: str = "end" inline: bool = False + # Cached analysis (populated lazily by CT VM) + _label_positions: Optional[Dict[str, int]] = field(default=None, repr=False, compare=False) + _for_pairs: Optional[Dict[int, int]] = field(default=None, repr=False, compare=False) + _begin_pairs: Optional[Dict[int, int]] = field(default=None, repr=False, compare=False) + _words_resolved: bool = field(default=False, repr=False, compare=False) + # Merged JIT runs: maps start_ip → (end_ip_exclusive, cache_key) + _merged_runs: Optional[Dict[int, Tuple[int, str]]] = field(default=None, repr=False, compare=False) @dataclass @@ -331,6 +339,7 @@ class Word: macro_expansion: Optional[List[str]] = None macro_params: int = 0 compile_time_intrinsic: Optional[Callable[["CompileTimeVM"], None]] = None + runtime_intrinsic: Optional[Callable[["CompileTimeVM"], None]] = None compile_only: bool = False compile_time_override: bool = False is_extern: bool = False @@ -432,6 +441,18 @@ class Parser: builder.push_label(target) word.intrinsic = _intrinsic + + # CT intrinsic: allocate a qword in CTMemory for this variable. + # The address is lazily created on first use and cached. + _ct_var_addrs: Dict[str, int] = {} + + def _ct_intrinsic(vm: CompileTimeVM, var_name: str = name) -> None: + if var_name not in _ct_var_addrs: + _ct_var_addrs[var_name] = vm.memory.allocate(8) + vm.push(_ct_var_addrs[var_name]) + + word.compile_time_intrinsic = _ct_intrinsic + word.runtime_intrinsic = _ct_intrinsic self.dictionary.register(word) return label, hidden_word @@ -1133,7 +1154,160 @@ class Parser: self.tokens.append(next_tok) +# --------------------------------------------------------------------------- +# Compile-time VM helpers +# --------------------------------------------------------------------------- + + +def _to_i64(v: int) -> int: + """Truncate to signed 64-bit integer (matching x86-64 register semantics).""" + v = v & 0xFFFFFFFFFFFFFFFF + if v >= 0x8000000000000000: + v -= 0x10000000000000000 + return v + + +class _CTVMJump(Exception): + """Raised by the ``jmp`` intrinsic to transfer control in _execute_nodes.""" + + def __init__(self, target_ip: int) -> None: + self.target_ip = target_ip + + +class _CTVMExit(Exception): + """Raised by the ``exit`` intrinsic to stop compile-time execution.""" + + def __init__(self, code: int = 0) -> None: + self.code = code + + +class CTMemory: + """Managed memory for the compile-time VM. + + Uses ctypes buffers with real process addresses so that ``c@``, ``c!``, + ``@``, ``!`` can operate on them directly via ``ctypes.from_address``. + + String literals are slab-allocated from a contiguous data section so that + ``data_start``/``data_end`` bracket them correctly for ``print``'s range + check. + """ + + PERSISTENT_SIZE = 64 # matches default BSS ``persistent: resb 64`` + PRINT_BUF_SIZE = 128 # matches ``PRINT_BUF_BYTES`` + DATA_SECTION_SIZE = 4 * 1024 * 1024 # 4 MB slab for string literals + + def __init__(self, persistent_size: int = 0) -> None: + self._buffers: List[Any] = [] # prevent GC of ctypes objects + self._string_cache: Dict[str, Tuple[int, int]] = {} # cache string literals + + # Persistent BSS region (for ``mem`` word) + actual_persistent = persistent_size if persistent_size > 0 else self.PERSISTENT_SIZE + self._persistent = ctypes.create_string_buffer(actual_persistent) + self._persistent_size = actual_persistent + self._buffers.append(self._persistent) + self.persistent_addr: int = ctypes.addressof(self._persistent) + + # print_buf region (for words that use ``[rel print_buf]``) + self._print_buf = ctypes.create_string_buffer(self.PRINT_BUF_SIZE) + self._buffers.append(self._print_buf) + self.print_buf_addr: int = ctypes.addressof(self._print_buf) + + # Data section – contiguous slab for string literals so that + # data_start..data_end consistently brackets all of them. + self._data_section = ctypes.create_string_buffer(self.DATA_SECTION_SIZE) + self._buffers.append(self._data_section) + self.data_start: int = ctypes.addressof(self._data_section) + self.data_end: int = self.data_start + self.DATA_SECTION_SIZE + self._data_offset: int = 0 + + # sys_argc / sys_argv – populated by invoke() + self._sys_argc = ctypes.c_int64(0) + self._buffers.append(self._sys_argc) + self.sys_argc_addr: int = ctypes.addressof(self._sys_argc) + + self._sys_argv_ptrs: Optional[ctypes.Array[Any]] = None + self._sys_argv = ctypes.c_int64(0) # qword holding pointer to argv array + self._buffers.append(self._sys_argv) + self.sys_argv_addr: int = ctypes.addressof(self._sys_argv) + + # -- argv helpers ------------------------------------------------------ + + def setup_argv(self, args: List[str]) -> None: + """Populate sys_argc / sys_argv from *args*.""" + self._sys_argc.value = len(args) + # Build null-terminated C string array + argv_bufs: List[Any] = [] + for arg in args: + encoded = arg.encode("utf-8") + b"\x00" + buf = ctypes.create_string_buffer(encoded, len(encoded)) + self._buffers.append(buf) + argv_bufs.append(buf) + # pointer array (+ NULL sentinel) + arr_type = ctypes.c_int64 * (len(args) + 1) + self._sys_argv_ptrs = arr_type() + for i, buf in enumerate(argv_bufs): + self._sys_argv_ptrs[i] = ctypes.addressof(buf) + self._sys_argv_ptrs[len(args)] = 0 + self._buffers.append(self._sys_argv_ptrs) + self._sys_argv.value = ctypes.addressof(self._sys_argv_ptrs) + + # -- allocation -------------------------------------------------------- + + def allocate(self, size: int) -> int: + """Allocate a zero-filled region, return its real address. + Adds padding to mimic real mmap which always gives full pages.""" + if size <= 0: + size = 1 + buf = ctypes.create_string_buffer(size + 16) # padding for null terminators + addr = ctypes.addressof(buf) + self._buffers.append(buf) + return addr + + def store_string(self, s: str) -> Tuple[int, int]: + """Store a UTF-8 string in the data section slab. Returns ``(addr, length)``. + Caches immutable string literals to avoid redundant allocations.""" + cached = self._string_cache.get(s) + if cached is not None: + return cached + encoded = s.encode("utf-8") + needed = len(encoded) + 1 # null terminator + aligned = (needed + 7) & ~7 # 8-byte align + if self._data_offset + aligned > self.DATA_SECTION_SIZE: + raise RuntimeError("CT data section overflow") + addr = self.data_start + self._data_offset + ctypes.memmove(addr, encoded, len(encoded)) + ctypes.c_uint8.from_address(addr + len(encoded)).value = 0 # null terminator + self._data_offset += aligned + result = (addr, len(encoded)) + self._string_cache[s] = result + return result + + # -- low-level access -------------------------------------------------- + + @staticmethod + def read_byte(addr: int) -> int: + return ctypes.c_uint8.from_address(addr).value + + @staticmethod + def write_byte(addr: int, value: int) -> None: + ctypes.c_uint8.from_address(addr).value = value & 0xFF + + @staticmethod + def read_qword(addr: int) -> int: + return ctypes.c_int64.from_address(addr).value + + @staticmethod + def write_qword(addr: int, value: int) -> None: + ctypes.c_int64.from_address(addr).value = _to_i64(value) + + @staticmethod + def read_bytes(addr: int, length: int) -> bytes: + return ctypes.string_at(addr, length) + + class CompileTimeVM: + NATIVE_STACK_SIZE = 8 * 1024 * 1024 # 8 MB per native stack + def __init__(self, parser: Parser) -> None: self.parser = parser self.dictionary = parser.dictionary @@ -1142,6 +1316,32 @@ class CompileTimeVM: self.loop_stack: List[Dict[str, Any]] = [] self._handles = _CTHandleTable() self.call_stack: List[str] = [] + # Runtime-faithful execution state + self.memory = CTMemory() + self.runtime_mode: bool = False + self._list_capture_stack: List[Any] = [] # for list_begin/list_end (int depth or native r12 addr) + self._ct_executed: Set[str] = set() # words already executed at CT + # Native stack state (used only in runtime_mode) + self.r12: int = 0 # data stack pointer (grows downward) + self.r13: int = 0 # return stack pointer (grows downward) + self._native_data_stack: Optional[Any] = None # ctypes buffer + self._native_data_top: int = 0 + self._native_return_stack: Optional[Any] = None # ctypes buffer + self._native_return_top: int = 0 + # JIT cache: word name → ctypes callable + self._jit_cache: Dict[str, Any] = {} + self._jit_code_pages: List[Any] = [] # keep mmap pages alive + # Pre-allocated output structs for JIT calls (avoid per-call allocation) + self._jit_out2 = (ctypes.c_int64 * 2)() + self._jit_out2_addr = ctypes.addressof(self._jit_out2) + self._jit_out4 = (ctypes.c_int64 * 4)() + self._jit_out4_addr = ctypes.addressof(self._jit_out4) + # BSS symbol table for JIT patching + self._bss_symbols: Dict[str, int] = {} + # dlopen handles for C extern support + self._dl_handles: List[Any] = [] # ctypes.CDLL handles + self._dl_func_cache: Dict[str, Any] = {} # name → ctypes callable + self._ct_libs: List[str] = [] # library names from -l flags def reset(self) -> None: self.stack.clear() @@ -1149,11 +1349,104 @@ class CompileTimeVM: self.loop_stack.clear() self._handles.clear() self.call_stack.clear() + self._list_capture_stack.clear() + self.r12 = 0 + self.r13 = 0 + + def invoke(self, word: Word, *, runtime_mode: bool = False, libs: Optional[List[str]] = None) -> None: + self.reset() + prev_mode = self.runtime_mode + self.runtime_mode = runtime_mode + if runtime_mode: + # Determine persistent size from BSS overrides if available. + persistent_size = 0 + if self.parser.custom_bss: + import re as _re_bss + for bss_line in self.parser.custom_bss: + m = _re_bss.search(r'persistent:\s*resb\s+(\d+)', bss_line) + if m: + persistent_size = int(m.group(1)) + self.memory = CTMemory(persistent_size) # fresh memory per invocation + self.memory.setup_argv(sys.argv) + + # Allocate native stacks + self._native_data_stack = ctypes.create_string_buffer(self.NATIVE_STACK_SIZE) + self._native_data_top = ctypes.addressof(self._native_data_stack) + self.NATIVE_STACK_SIZE + self.r12 = self._native_data_top # empty, grows downward + + self._native_return_stack = ctypes.create_string_buffer(self.NATIVE_STACK_SIZE) + self._native_return_top = ctypes.addressof(self._native_return_stack) + self.NATIVE_STACK_SIZE + self.r13 = self._native_return_top # empty, grows downward + + # BSS symbol table for JIT [rel SYMBOL] patching + self._bss_symbols = { + "data_start": self.memory.data_start, + "data_end": self.memory.data_start + self.memory._data_offset if self.memory._data_offset else self.memory.data_end, + "print_buf": self.memory.print_buf_addr, + "print_buf_end": self.memory.print_buf_addr + CTMemory.PRINT_BUF_SIZE, + "persistent": self.memory.persistent_addr, + "persistent_end": self.memory.persistent_addr + self.memory._persistent_size, + "sys_argc": self.memory.sys_argc_addr, + "sys_argv": self.memory.sys_argv_addr, + } + + # JIT cache is per-invocation (addresses change) + self._jit_cache = {} + self._jit_code_pages = [] + + # dlopen libraries for C extern support + self._dl_handles = [] + self._dl_func_cache = {} + all_libs = list(self._ct_libs) + if libs: + for lib in libs: + if lib not in all_libs: + all_libs.append(lib) + for lib_name in all_libs: + self._dlopen(lib_name) + + # Deep word chains need extra Python stack depth. + old_limit = sys.getrecursionlimit() + if old_limit < 10000: + sys.setrecursionlimit(10000) + try: + self._call_word(word) + except _CTVMExit: + pass # graceful exit from CT execution + finally: + self.runtime_mode = prev_mode + # Clear JIT cache; code pages are libc mmap'd and we intentionally + # leak them — the OS reclaims them at process exit. + self._jit_cache.clear() + self._jit_code_pages.clear() + self._dl_func_cache.clear() + self._dl_handles.clear() + + def invoke_with_args(self, word: Word, args: Sequence[Any]) -> None: + self.reset() + for value in args: + self.push(value) + self._call_word(word) def push(self, value: Any) -> None: - self.stack.append(value) + if self.runtime_mode: + self.r12 -= 8 + if isinstance(value, float): + import struct as _struct + bits = _struct.unpack("q", _struct.pack("d", value))[0] + CTMemory.write_qword(self.r12, bits) + else: + CTMemory.write_qword(self.r12, _to_i64(int(value))) + else: + self.stack.append(value) def pop(self) -> Any: + if self.runtime_mode: + if self.r12 >= self._native_data_top: + raise ParseError("compile-time stack underflow") + val = CTMemory.read_qword(self.r12) + self.r12 += 8 + return val if not self.stack: raise ParseError("compile-time stack underflow") return self.stack.pop() @@ -1175,16 +1468,63 @@ class CompileTimeVM: return value def peek(self) -> Any: + if self.runtime_mode: + if self.r12 >= self._native_data_top: + raise ParseError("compile-time stack underflow") + return CTMemory.read_qword(self.r12) if not self.stack: raise ParseError("compile-time stack underflow") return self.stack[-1] def pop_int(self) -> int: + if self.runtime_mode: + return self.pop() # already returns int from native stack value = self.pop() + if isinstance(value, bool): + return int(value) if not isinstance(value, int): - raise ParseError("expected integer on compile-time stack") + raise ParseError(f"expected integer on compile-time stack, got {type(value).__name__}: {value!r}") return value + # -- return stack helpers (native r13 in runtime_mode) ----------------- + + def push_return(self, value: int) -> None: + if self.runtime_mode: + self.r13 -= 8 + CTMemory.write_qword(self.r13, _to_i64(value)) + else: + self.return_stack.append(value) + + def pop_return(self) -> int: + if self.runtime_mode: + val = CTMemory.read_qword(self.r13) + self.r13 += 8 + return val + return self.return_stack.pop() + + def peek_return(self) -> int: + if self.runtime_mode: + return CTMemory.read_qword(self.r13) + return self.return_stack[-1] + + def poke_return(self, value: int) -> None: + """Overwrite top of return stack.""" + if self.runtime_mode: + CTMemory.write_qword(self.r13, _to_i64(value)) + else: + self.return_stack[-1] = value + + def return_stack_empty(self) -> bool: + if self.runtime_mode: + return self.r13 >= self._native_return_top + return len(self.return_stack) == 0 + + # -- native stack depth ------------------------------------------------ + + def native_stack_depth(self) -> int: + """Number of items on data stack (runtime_mode only).""" + return (self._native_data_top - self.r12) // 8 + def pop_str(self) -> str: value = self._resolve_handle(self.pop()) if not isinstance(value, str): @@ -1208,32 +1548,168 @@ class CompileTimeVM: raise ParseError("expected token on compile-time stack") return value - def invoke(self, word: Word) -> None: - self.reset() - self._call_word(word) + # -- dlopen / C extern support ----------------------------------------- - def invoke_with_args(self, word: Word, args: Sequence[Any]) -> None: - self.reset() - for value in args: - self.push(value) - self._call_word(word) + def _dlopen(self, lib_name: str) -> None: + """Open a shared library and append to _dl_handles.""" + import ctypes.util + # Try as given first (handles absolute paths, "libc.so.6", etc.) + candidates = [lib_name] + # Try lib.so + if not lib_name.startswith("lib") and "." not in lib_name: + candidates.append(f"lib{lib_name}.so") + # Use ctypes.util.find_library for short names like "m", "c" + found = ctypes.util.find_library(lib_name) + if found: + candidates.append(found) + for candidate in candidates: + try: + handle = ctypes.CDLL(candidate, use_errno=True) + self._dl_handles.append(handle) + return + except OSError: + continue + # Not fatal — the library may not be needed at CT + + _CTYPE_MAP: Dict[str, Any] = { + "int": ctypes.c_int, + "long": ctypes.c_long, + "long long": ctypes.c_longlong, + "unsigned int": ctypes.c_uint, + "unsigned long": ctypes.c_ulong, + "size_t": ctypes.c_size_t, + "char": ctypes.c_char, + "char*": ctypes.c_void_p, # use void* so raw integer addrs work + "void*": ctypes.c_void_p, + "double": ctypes.c_double, + "float": ctypes.c_float, + } + + def _resolve_ctype(self, type_name: str) -> Any: + """Map a C type name string to a ctypes type.""" + t = type_name.strip().replace("*", "* ").replace(" ", " ").strip() + if t in self._CTYPE_MAP: + return self._CTYPE_MAP[t] + # Pointer types + if t.endswith("*"): + return ctypes.c_void_p + # Default to c_long (64-bit on Linux x86-64) + return ctypes.c_long + + def _dlsym(self, name: str) -> Any: + """Look up a symbol across all dl handles, return a raw function pointer or None.""" + for handle in self._dl_handles: + try: + return getattr(handle, name) + except AttributeError: + continue + return None + + def _call_extern_ct(self, word: Word) -> None: + """Call an extern C function via dlsym/ctypes on the native stacks.""" + name = word.name + + # Special handling for exit — intercept it before doing anything + if name == "exit": + raise _CTVMExit() + + func = self._dl_func_cache.get(name) + if func is None: + raw = self._dlsym(name) + if raw is None: + raise ParseError(f"extern '{name}' not found in any loaded library") + + signature = word.extern_signature + inputs = word.extern_inputs + outputs = word.extern_outputs + + if signature: + arg_types, ret_type = signature + c_arg_types = [self._resolve_ctype(t) for t in arg_types] + if ret_type == "void": + c_ret_type = None + else: + c_ret_type = self._resolve_ctype(ret_type) + else: + # Legacy mode: assume all int64 args + arg_types = [] + c_arg_types = [ctypes.c_int64] * inputs + c_ret_type = ctypes.c_int64 if outputs > 0 else None + + # Configure the ctypes function object directly + raw.restype = c_ret_type + raw.argtypes = c_arg_types + # Stash metadata for calling + raw._ct_inputs = inputs + raw._ct_outputs = outputs + raw._ct_arg_types = c_arg_types + raw._ct_ret_type = c_ret_type + raw._ct_signature = signature + func = raw + self._dl_func_cache[name] = func + + inputs = func._ct_inputs + outputs = func._ct_outputs + arg_types = func._ct_signature[0] if func._ct_signature else [] + + # Pop arguments off the native data stack (right-to-left / reverse order) + raw_args = [] + for i in range(inputs): + raw_args.append(self.pop()) + raw_args.reverse() + + # Convert arguments to proper ctypes values + import struct as _struct + call_args = [] + for i, raw in enumerate(raw_args): + if i < len(arg_types) and arg_types[i] in ("float", "double"): + # Reinterpret the int64 bits as a double (matching the language's convention) + raw_int = _to_i64(int(raw)) + double_val = _struct.unpack("d", _struct.pack("q", raw_int))[0] + call_args.append(double_val) + else: + call_args.append(int(raw)) + + result = func(*call_args) + + if outputs > 0 and result is not None: + ret_type = func._ct_signature[1] if func._ct_signature else None + if ret_type in ("float", "double"): + int_bits = _struct.unpack("q", _struct.pack("d", float(result)))[0] + self.push(int_bits) + else: + self.push(int(result)) def _call_word(self, word: Word) -> None: self.call_stack.append(word.name) try: definition = word.definition + # In runtime_mode, prefer runtime_intrinsic (for exit/jmp/syscall + # and __with_* variables). All other :asm words run as native JIT. + if self.runtime_mode and word.runtime_intrinsic is not None: + word.runtime_intrinsic(self) + return prefer_definition = word.compile_time_override or (isinstance(definition, Definition) and (word.immediate or word.compile_only)) if not prefer_definition and word.compile_time_intrinsic is not None: word.compile_time_intrinsic(self) return + # C extern words: call via dlopen/dlsym in runtime_mode + if self.runtime_mode and getattr(word, "is_extern", False): + self._call_extern_ct(word) + return if definition is None: raise ParseError(f"word '{word.name}' has no compile-time definition") if isinstance(definition, AsmDefinition): - self._run_asm_definition(word) + if self.runtime_mode: + self._run_jit(word) + else: + self._run_asm_definition(word) return - self._execute_nodes(definition.body) + self._execute_nodes(definition.body, _defn=definition) except CompileTimeError: raise + except (_CTVMJump, _CTVMExit): + raise except ParseError as exc: raise CompileTimeError(f"{exc}\ncompile-time stack: {' -> '.join(self.call_stack)}") from None except Exception as exc: @@ -1243,6 +1719,137 @@ class CompileTimeVM: finally: self.call_stack.pop() + # -- Native JIT execution (runtime_mode) -------------------------------- + + _JIT_FUNC_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_int64, ctypes.c_int64, ctypes.c_void_p) + + def _run_jit(self, word: Word) -> None: + """JIT-compile (once) and execute an :asm word on the native r12/r13 stacks.""" + func = self._jit_cache.get(word.name) + if func is None: + func = self._compile_jit(word) + self._jit_cache[word.name] = func + + out = self._jit_out2 + func(self.r12, self.r13, self._jit_out2_addr) + self.r12 = out[0] + self.r13 = out[1] + + def _compile_jit(self, word: Word) -> Any: + """Assemble an :asm word into executable memory and return a ctypes callable.""" + if Ks is None: + raise ParseError("keystone-engine is required for JIT execution") + definition = word.definition + if not isinstance(definition, AsmDefinition): + raise ParseError(f"word '{word.name}' has no asm body") + asm_body = definition.body.strip("\n") + + import re as _re + _rel_pat = _re.compile(r'\[rel\s+(\w+)\]') + bss = self._bss_symbols + + # Build wrapper + lines: List[str] = [] + # Entry: save callee-saved regs, set r12/r13, stash output ptr at [rsp] + lines.extend([ + "_ct_entry:", + " push rbx", + " push r12", + " push r13", + " push r14", + " push r15", + " sub rsp, 16", # align + room for output ptr + " mov [rsp], rdx", # save output-struct pointer + " mov r12, rdi", # data stack + " mov r13, rsi", # return stack + ]) + + # Patch asm body + for raw_line in asm_body.splitlines(): + line = raw_line.strip() + if not line or line.startswith(";"): + continue + if line.startswith("extern"): + continue # strip extern declarations + if line == "ret": + line = "jmp _ct_save" + + # Patch [rel SYMBOL] → concrete address + m = _rel_pat.search(line) + if m and m.group(1) in bss: + sym = m.group(1) + addr = bss[sym] + if line.lstrip().startswith("lea"): + # lea REG, [rel X] → mov REG, addr + line = _rel_pat.sub(str(addr), line).replace("lea", "mov", 1) + else: + # e.g. mov rax, [rel X] or mov byte [rel X], val + # Replace with push/mov-rax/substitute/pop trampoline + lines.append(" push rax") + lines.append(f" mov rax, {addr}") + new_line = _rel_pat.sub("[rax]", line) + lines.append(f" {new_line}") + lines.append(" pop rax") + continue + lines.append(f" {line}") + + # Save: restore output ptr from [rsp], write r12/r13 out, restore regs + lines.extend([ + "_ct_save:", + " mov rax, [rsp]", # output-struct pointer + " mov [rax], r12", + " mov [rax + 8], r13", + " add rsp, 16", + " pop r15", + " pop r14", + " pop r13", + " pop r12", + " pop rbx", + " ret", + ]) + + # Normalize for Keystone + def _norm(l: str) -> str: + l = l.split(";", 1)[0].rstrip() + for sz in ("qword", "dword", "word", "byte"): + l = l.replace(f"{sz} [", f"{sz} ptr [") + return l + normalized = [_norm(l) for l in lines if _norm(l).strip()] + + ks = Ks(KS_ARCH_X86, KS_MODE_64) + try: + encoding, _ = ks.asm("\n".join(normalized)) + except KsError as exc: + debug_txt = "\n".join(normalized) + raise ParseError( + f"JIT assembly failed for '{word.name}': {exc}\n--- asm ---\n{debug_txt}\n--- end ---" + ) from exc + if encoding is None: + raise ParseError(f"JIT produced no code for '{word.name}'") + + code = bytes(encoding) + # Allocate RWX memory via libc mmap (not Python's mmap module) so + # Python's GC never tries to finalize the mapping. + page_size = max(len(code), 4096) + _libc = ctypes.CDLL(None, use_errno=True) + _libc.mmap.restype = ctypes.c_void_p + _libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, + ctypes.c_int, ctypes.c_int, ctypes.c_long] + PROT_RWX = 0x1 | 0x2 | 0x4 # READ | WRITE | EXEC + MAP_PRIVATE = 0x02 + MAP_ANONYMOUS = 0x20 + ptr = _libc.mmap(None, page_size, PROT_RWX, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0) + if ptr == ctypes.c_void_p(-1).value or ptr is None: + raise RuntimeError(f"mmap failed for JIT code ({page_size} bytes)") + ctypes.memmove(ptr, code, len(code)) + # Store (ptr, size) so we can munmap later + self._jit_code_pages.append((ptr, page_size)) + func = self._JIT_FUNC_TYPE(ptr) + return func + + # -- Old non-runtime asm execution (kept for non-runtime CT mode) ------- + def _run_asm_definition(self, word: Word) -> None: definition = word.definition if Ks is None: @@ -1315,22 +1922,44 @@ class CompileTimeVM: ]) if asm_body: patched_body = [] + # Build BSS symbol table for [rel X] → concrete address substitution + _bss_symbols: Dict[str, int] = { + "data_start": data_start, + "data_end": data_end, + "print_buf": print_buf, + "print_buf_end": print_buf + PRINT_BUF_BYTES, + } + if self.memory is not None: + _bss_symbols.update({ + "persistent": self.memory.persistent_addr, + "persistent_end": self.memory.persistent_addr + self.memory._persistent_size, + }) + import re as _re + _rel_pat = _re.compile(r'\[rel\s+(\w+)\]') + for line in asm_body.splitlines(): line = line.strip() if line == "ret": line = "jmp _ct_save" - if "lea r8, [rel data_start]" in line: - line = line.replace("lea r8, [rel data_start]", f"mov r8, {data_start}") - if "lea r9, [rel data_end]" in line: - line = line.replace("lea r9, [rel data_end]", f"mov r9, {data_end}") - if "mov byte [rel print_buf]" in line or "mov byte ptr [rel print_buf]" in line: - patched_body.append(f"mov rax, {print_buf}") - patched_body.append("mov byte ptr [rax], 10") - continue - if "lea rsi, [rel print_buf_end]" in line: - line = f"mov rsi, {print_buf + PRINT_BUF_BYTES}" - if "lea rsi, [rel print_buf]" in line: - line = f"mov rsi, {print_buf}" + # Replace [rel SYMBOL] with concrete addresses + m = _rel_pat.search(line) + if m and m.group(1) in _bss_symbols: + sym = m.group(1) + addr = _bss_symbols[sym] + # lea REG, [rel X] → mov REG, addr + if line.lstrip().startswith("lea"): + line = _rel_pat.sub(str(addr), line).replace("lea", "mov", 1) + else: + # For memory operands like mov byte [rel X], val + # replace [rel X] with [] + tmp_reg = "rax" + # Use a scratch register to hold the address + patched_body.append(f"push rax") + patched_body.append(f"mov rax, {addr}") + new_line = _rel_pat.sub("[rax]", line) + patched_body.append(new_line) + patched_body.append(f"pop rax") + continue patched_body.append(line) wrapper_lines.extend(patched_body) wrapper_lines.extend([ @@ -1450,27 +2079,297 @@ class CompileTimeVM: raise ParseError(f"unknown word '{name}' during compile-time execution") self._call_word(word) - def _execute_nodes(self, nodes: Sequence[Op]) -> None: - label_positions = self._label_positions(nodes) - loop_pairs = self._for_pairs(nodes) - begin_pairs = self._begin_pairs(nodes) + def _resolve_words_in_body(self, defn: Definition) -> None: + """Pre-resolve word name → Word objects on Op nodes (once per Definition).""" + if defn._words_resolved: + return + lookup = self.dictionary.lookup + for node in defn.body: + if node.op == "word" and node._word_ref is None: + name = str(node.data) + # Skip structural keywords that _execute_nodes handles inline + if name not in ("begin", "again", "continue", "exit", "get_addr"): + ref = lookup(name) + if ref is not None: + node._word_ref = ref + defn._words_resolved = True + + def _prepare_definition(self, defn: Definition) -> Tuple[Dict[str, int], Dict[int, int], Dict[int, int]]: + """Return (label_positions, for_pairs, begin_pairs), cached on the Definition.""" + if defn._label_positions is None: + defn._label_positions = self._label_positions(defn.body) + if defn._for_pairs is None: + defn._for_pairs = self._for_pairs(defn.body) + if defn._begin_pairs is None: + defn._begin_pairs = self._begin_pairs(defn.body) + self._resolve_words_in_body(defn) + if self.runtime_mode and defn._merged_runs is None: + defn._merged_runs = self._find_mergeable_runs(defn) + return defn._label_positions, defn._for_pairs, defn._begin_pairs + + def _find_mergeable_runs(self, defn: Definition) -> Dict[int, Tuple[int, str]]: + """Find consecutive runs of JIT-able asm word ops (length >= 2).""" + runs: Dict[int, Tuple[int, str]] = {} + body = defn.body + n = len(body) + i = 0 + while i < n: + # Start of a potential run + if body[i].op == "word" and body[i]._word_ref is not None: + w = body[i]._word_ref + if (w.runtime_intrinsic is None and isinstance(w.definition, AsmDefinition) + and not w.compile_time_override): + run_start = i + run_words = [w.name] + i += 1 + while i < n and body[i].op == "word" and body[i]._word_ref is not None: + w2 = body[i]._word_ref + if (w2.runtime_intrinsic is None and isinstance(w2.definition, AsmDefinition) + and not w2.compile_time_override): + run_words.append(w2.name) + i += 1 + else: + break + if len(run_words) >= 2: + key = f"__merged_{defn.name}_{run_start}_{i}" + runs[run_start] = (i, key) + continue + i += 1 + return runs + + def _compile_merged_jit(self, words: List[Word], cache_key: str) -> Any: + """Compile multiple asm word bodies into a single JIT function.""" + if Ks is None: + raise ParseError("keystone-engine is required for JIT execution") + + import re as _re + _rel_pat = _re.compile(r'\[rel\s+(\w+)\]') + _label_pat = _re.compile(r'^(\.\w+|\w+):') + bss = self._bss_symbols + + lines: List[str] = [] + # Entry wrapper (same as _compile_jit) + lines.extend([ + "_ct_entry:", + " push rbx", + " push r12", + " push r13", + " push r14", + " push r15", + " sub rsp, 16", + " mov [rsp], rdx", + " mov r12, rdi", + " mov r13, rsi", + ]) + + # Append each word's asm body, with labels uniquified + for word_idx, word in enumerate(words): + defn = word.definition + asm_body = defn.body.strip("\n") + prefix = f"_m{word_idx}_" + + # Collect all labels in this asm body first + local_labels: Set[str] = set() + for raw_line in asm_body.splitlines(): + line = raw_line.strip() + lm = _label_pat.match(line) + if lm: + local_labels.add(lm.group(1)) + + for raw_line in asm_body.splitlines(): + line = raw_line.strip() + if not line or line.startswith(";"): + continue + if line.startswith("extern"): + continue + if line == "ret": + # Last word: jmp to save; others: fall through + if word_idx < len(words) - 1: + continue # just skip ret → fall through + else: + line = "jmp _ct_save" + + # Replace all references to local labels with prefixed versions + for label in local_labels: + # Use word-boundary replacement to avoid partial matches + line = _re.sub(rf'(? str: + l = l.split(";", 1)[0].rstrip() + for sz in ("qword", "dword", "word", "byte"): + l = l.replace(f"{sz} [", f"{sz} ptr [") + return l + normalized = [_norm(l) for l in lines if _norm(l).strip()] + + ks = Ks(KS_ARCH_X86, KS_MODE_64) + try: + encoding, _ = ks.asm("\n".join(normalized)) + except KsError as exc: + debug_txt = "\n".join(normalized) + raise ParseError( + f"JIT merged assembly failed for '{cache_key}': {exc}\n--- asm ---\n{debug_txt}\n--- end ---" + ) from exc + if encoding is None: + raise ParseError(f"JIT merged produced no code for '{cache_key}'") + + code = bytes(encoding) + page_size = max(len(code), 4096) + _libc = ctypes.CDLL(None, use_errno=True) + _libc.mmap.restype = ctypes.c_void_p + _libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, + ctypes.c_int, ctypes.c_int, ctypes.c_long] + PROT_RWX = 0x1 | 0x2 | 0x4 + MAP_PRIVATE = 0x02 + MAP_ANONYMOUS = 0x20 + ptr = _libc.mmap(None, page_size, PROT_RWX, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0) + if ptr == ctypes.c_void_p(-1).value or ptr is None: + raise RuntimeError(f"mmap failed for merged JIT code ({page_size} bytes)") + ctypes.memmove(ptr, code, len(code)) + self._jit_code_pages.append((ptr, page_size)) + return self._JIT_FUNC_TYPE(ptr) + + def _execute_nodes(self, nodes: Sequence[Op], *, _defn: Optional[Definition] = None) -> None: + # Use cached analysis if we have one, else compute fresh + if _defn is not None: + label_positions, loop_pairs, begin_pairs = self._prepare_definition(_defn) + else: + label_positions = self._label_positions(nodes) + loop_pairs = self._for_pairs(nodes) + begin_pairs = self._begin_pairs(nodes) prev_loop_stack = self.loop_stack self.loop_stack = [] begin_stack: List[Dict[str, int]] = [] + + # Local variable aliases for hot-path speedup + _runtime_mode = self.runtime_mode + _push = self.push + _pop = self.pop + _pop_int = self.pop_int + _push_return = self.push_return + _pop_return = self.pop_return + _peek_return = self.peek_return + _poke_return = self.poke_return + _call_word = self._call_word + _dict_lookup = self.dictionary.lookup + + # Hot JIT-call locals (avoid repeated attribute access) + _jit_cache = self._jit_cache if _runtime_mode else None + _jit_out2 = self._jit_out2 if _runtime_mode else None + _jit_out2_addr = self._jit_out2_addr if _runtime_mode else 0 + _compile_jit = self._compile_jit if _runtime_mode else None + _compile_merged = self._compile_merged_jit if _runtime_mode else None + _AsmDef = AsmDefinition + _merged_runs = (_defn._merged_runs if _defn is not None and _defn._merged_runs else None) if _runtime_mode else None + + n_nodes = len(nodes) ip = 0 try: - while ip < len(nodes): + while ip < n_nodes: node = nodes[ip] kind = node.op - data = node.data - - if kind == "literal": - self.push(data) - ip += 1 - continue if kind == "word": - name = str(data) + # Merged JIT run: call one combined function for N words + if _merged_runs is not None: + run_info = _merged_runs.get(ip) + if run_info is not None: + end_ip, cache_key = run_info + func = _jit_cache.get(cache_key) + if func is None: + # Warmup: only compile merged function after seen 2+ times + hit_key = cache_key + "_hits" + hits = _jit_cache.get(hit_key, 0) + 1 + _jit_cache[hit_key] = hits + if hits < 2: + # Fall through to individual JIT calls + pass + else: + run_words = [nodes[j]._word_ref for j in range(ip, end_ip)] + func = _compile_merged(run_words, cache_key) + _jit_cache[cache_key] = func + if func is not None: + func(self.r12, self.r13, _jit_out2_addr) + self.r12 = _jit_out2[0] + self.r13 = _jit_out2[1] + ip = end_ip + continue + + # Fast path: pre-resolved word reference + word = node._word_ref + if word is not None: + # Inlined _call_word for common cases (JIT asm words) + if _runtime_mode: + ri = word.runtime_intrinsic + if ri is not None: + self.call_stack.append(word.name) + try: + ri(self) + except _CTVMJump as jmp: + self.call_stack.pop() + ip = jmp.target_ip + continue + finally: + if self.call_stack and self.call_stack[-1] == word.name: + self.call_stack.pop() + ip += 1 + continue + defn = word.definition + if isinstance(defn, _AsmDef): + # Ultra-hot path: inline JIT call, skip call_stack + wn = word.name + func = _jit_cache.get(wn) + if func is None: + func = _compile_jit(word) + _jit_cache[wn] = func + func(self.r12, self.r13, _jit_out2_addr) + self.r12 = _jit_out2[0] + self.r13 = _jit_out2[1] + ip += 1 + continue + # Fall through to full _call_word for other cases + try: + _call_word(word) + except _CTVMJump as jmp: + ip = jmp.target_ip + continue + ip += 1 + continue + + # Structural keywords or unresolved words + name = str(node.data) if name == "begin": end_idx = begin_pairs.get(ip) if end_idx is None: @@ -1494,12 +2393,61 @@ class CompileTimeVM: ip = frame["end"] + 1 continue return - self._call_word_by_name(name) + if _runtime_mode and name == "get_addr": + _push(ip + 1) + ip += 1 + continue + # Lookup at runtime (rare: word was defined after body was compiled) + w = _dict_lookup(name) + if w is None: + raise ParseError(f"unknown word '{name}' during compile-time execution") + try: + _call_word(w) + except _CTVMJump as jmp: + ip = jmp.target_ip + continue + ip += 1 + continue + + if kind == "literal": + data = node.data + if _runtime_mode and isinstance(data, str): + addr, length = self.memory.store_string(data) + _push(addr) + _push(length) + else: + _push(data) + ip += 1 + continue + + if kind == "for_end": + if not self.loop_stack: + raise ParseError("'next' without matching 'for'") + val = _peek_return() - 1 + _poke_return(val) + if val > 0: + ip = self.loop_stack[-1]["begin"] + 1 + continue + _pop_return() + self.loop_stack.pop() + ip += 1 + continue + + if kind == "for_begin": + count = _pop_int() + if count <= 0: + match = loop_pairs.get(ip) + if match is None: + raise ParseError("internal loop bookkeeping error") + ip = match + 1 + continue + _push_return(count) + self.loop_stack.append({"begin": ip}) ip += 1 continue if kind == "branch_zero": - condition = self.pop() + condition = _pop() if isinstance(condition, bool): flag = condition elif isinstance(condition, int): @@ -1507,40 +2455,52 @@ class CompileTimeVM: else: raise ParseError("branch expects integer or boolean condition") if not flag: - ip = self._jump_to_label(label_positions, str(data)) + ip = label_positions.get(str(node.data), -1) + if ip == -1: + raise ParseError(f"unknown label '{node.data}' during compile-time execution") else: ip += 1 continue if kind == "jump": - ip = self._jump_to_label(label_positions, str(data)) + ip = label_positions.get(str(node.data), -1) + if ip == -1: + raise ParseError(f"unknown label '{node.data}' during compile-time execution") continue if kind == "label": ip += 1 continue - if kind == "for_begin": - count = self.pop_int() - if count <= 0: - match = loop_pairs.get(ip) - if match is None: - raise ParseError("internal loop bookkeeping error") - ip = match + 1 - continue - self.loop_stack.append({"remaining": count, "begin": ip, "initial": count}) + if kind == "list_begin": + if _runtime_mode: + self._list_capture_stack.append(self.r12) + else: + self._list_capture_stack.append(len(self.stack)) ip += 1 continue - if kind == "for_end": - if not self.loop_stack: - raise ParseError("'next' without matching 'for'") - frame = self.loop_stack[-1] - frame["remaining"] -= 1 - if frame["remaining"] > 0: - ip = frame["begin"] + 1 - continue - self.loop_stack.pop() + if kind == "list_end": + if not self._list_capture_stack: + raise ParseError("']' without matching '['") + saved = self._list_capture_stack.pop() + if _runtime_mode: + items: List[int] = [] + ptr = saved - 8 + while ptr >= self.r12: + items.append(CTMemory.read_qword(ptr)) + ptr -= 8 + self.r12 = saved + else: + items = self.stack[saved:] + del self.stack[saved:] + count = len(items) + buf_size = (count + 1) * 8 + addr = self.memory.allocate(buf_size) + CTMemory.write_qword(addr, count) + for idx_item, val in enumerate(items): + CTMemory.write_qword(addr + 8 + idx_item * 8, val) + _push(addr) ip += 1 continue @@ -2660,6 +3620,7 @@ def macro_compile_time(ctx: MacroContext) -> Optional[List[Op]]: if word.compile_only: raise ParseError(f"word '{name}' is compile-time only") parser.compile_time_vm.invoke(word) + parser.compile_time_vm._ct_executed.add(name) if isinstance(parser.context_stack[-1], Definition): parser.emit_node(Op(op="word", data=name)) return None @@ -3308,6 +4269,197 @@ def _ct_lexer_push_back(vm: CompileTimeVM) -> None: vm.push(lexer) +# --------------------------------------------------------------------------- +# Runtime intrinsics that cannot run as native JIT (for --ct-run-main) +# --------------------------------------------------------------------------- + +def _rt_exit(vm: CompileTimeVM) -> None: + code = vm.pop_int() + raise _CTVMExit(code) + + +def _rt_jmp(vm: CompileTimeVM) -> None: + target_ip = vm.pop_int() + raise _CTVMJump(target_ip) + + +def _rt_syscall(vm: CompileTimeVM) -> None: + """Execute a real Linux syscall via a JIT stub, intercepting exit/exit_group.""" + # Lazily compile the syscall JIT stub + stub = vm._jit_cache.get("__syscall_stub") + if stub is None: + stub = _compile_syscall_stub(vm) + vm._jit_cache["__syscall_stub"] = stub + + # out[0] = final r12, out[1] = final r13, out[2] = flag (0=normal, 1=exit, code in out[3]) + out = vm._jit_out4 + stub(vm.r12, vm.r13, vm._jit_out4_addr) + vm.r12 = out[0] + vm.r13 = out[1] + if out[2] == 1: + raise _CTVMExit(out[3]) + + +def _compile_syscall_stub(vm: CompileTimeVM) -> Any: + """JIT-compile a native syscall stub that intercepts exit/exit_group.""" + if Ks is None: + raise ParseError("keystone-engine is required for JIT syscall execution") + + # The stub uses the same wrapper convention as _compile_jit: + # rdi = r12 (data stack ptr), rsi = r13 (return stack ptr), rdx = output ptr + # Output struct: [r12, r13, exit_flag, exit_code] + # + # Stack protocol (matching _emit_syscall_intrinsic): + # TOS: syscall number → rax + # TOS-1: arg count → rcx + # then up to 6 args → rdi, rsi, rdx, r10, r8, r9 + lines = [ + "_stub_entry:", + " push rbx", + " push r12", + " push r13", + " push r14", + " push r15", + " sub rsp, 16", + " mov [rsp], rdx", # save output-struct pointer + " mov r12, rdi", # data stack + " mov r13, rsi", # return stack + # Pop syscall number + " mov rax, [r12]", + " add r12, 8", + # Pop arg count + " mov rcx, [r12]", + " add r12, 8", + # Clamp to [0,6] + " cmp rcx, 0", + " jge _count_nonneg", + " xor rcx, rcx", + "_count_nonneg:", + " cmp rcx, 6", + " jle _count_clamped", + " mov rcx, 6", + "_count_clamped:", + # Save arg count in r14 and syscall num in r15 + " mov r14, rcx", + " mov r15, rax", + # Pop args into scratch area on machine stack (up to 6 qwords) + # We pop them into rbx, r8-r11 area, then assign to syscall regs after + # Pop all args onto machine stack (reverse order) + " sub rsp, 48", # 6 * 8 bytes for args + " xor rbx, rbx", # index + "_pop_args:", + " cmp rbx, r14", + " jge _pop_done", + " mov rax, [r12]", + " add r12, 8", + " mov [rsp + rbx*8], rax", + " inc rbx", + " jmp _pop_args", + "_pop_done:", + # Check for exit (60) / exit_group (231) + " cmp r15, 60", + " je _do_exit", + " cmp r15, 231", + " je _do_exit", + # Assign args to syscall registers from the scratch area + # arg0 → rdi, arg1 → rsi, arg2 → rdx, arg3 → r10, arg4 → r8, arg5 → r9 + " mov rdi, [rsp]", + " mov rsi, [rsp+8]", + " mov rdx, [rsp+16]", + " mov r10, [rsp+24]", + " mov r8, [rsp+32]", + " mov r9, [rsp+40]", + " mov rax, r15", # syscall number + " syscall", + # Push result + " sub r12, 8", + " mov [r12], rax", + # Normal return: flag=0 + " add rsp, 48", + " mov rax, [rsp]", # output-struct pointer + " mov qword [rax], r12", + " mov qword [rax+8], r13", + " mov qword [rax+16], 0", # exit_flag = 0 + " mov qword [rax+24], 0", # exit_code = 0 + " jmp _stub_epilogue", + # Exit path: don't actually call syscall, just report it + "_do_exit:", + " mov rbx, [rsp]", # arg0 = exit code + " add rsp, 48", + " mov rax, [rsp]", # output-struct pointer + " mov qword [rax], r12", + " mov qword [rax+8], r13", + " mov qword [rax+16], 1", # exit_flag = 1 + " mov [rax+24], rbx", # exit_code + "_stub_epilogue:", + " add rsp, 16", + " pop r15", + " pop r14", + " pop r13", + " pop r12", + " pop rbx", + " ret", + ] + + def _norm(l: str) -> str: + l = l.split(";", 1)[0].rstrip() + for sz in ("qword", "dword", "word", "byte"): + l = l.replace(f"{sz} [", f"{sz} ptr [") + return l + normalized = [_norm(l) for l in lines if _norm(l).strip()] + + ks = Ks(KS_ARCH_X86, KS_MODE_64) + try: + encoding, _ = ks.asm("\n".join(normalized)) + except KsError as exc: + debug_txt = "\n".join(normalized) + raise ParseError(f"JIT syscall stub assembly failed: {exc}\n--- asm ---\n{debug_txt}\n--- end ---") from exc + if encoding is None: + raise ParseError("JIT syscall stub produced no code") + + code = bytes(encoding) + page_size = max(len(code), 4096) + _libc = ctypes.CDLL(None, use_errno=True) + _libc.mmap.restype = ctypes.c_void_p + _libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, + ctypes.c_int, ctypes.c_int, ctypes.c_long] + PROT_RWX = 0x1 | 0x2 | 0x4 + MAP_PRIVATE = 0x02 + MAP_ANONYMOUS = 0x20 + ptr = _libc.mmap(None, page_size, PROT_RWX, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0) + if ptr == ctypes.c_void_p(-1).value or ptr is None: + raise RuntimeError("mmap failed for JIT syscall stub") + ctypes.memmove(ptr, code, len(code)) + vm._jit_code_pages.append((ptr, page_size)) + # Same signature: (r12, r13, out_ptr) → void + func = CompileTimeVM._JIT_FUNC_TYPE(ptr) + return func + + +def _register_runtime_intrinsics(dictionary: Dictionary) -> None: + """Register runtime intrinsics only for words that cannot run as native JIT. + + Most :asm words now run as native JIT-compiled machine code on real + memory stacks. Only a handful need Python-level interception: + - exit : must not actually call sys_exit (would kill the compiler) + - jmp : needs interpreter-level IP manipulation + - syscall : the ``syscall`` word is compiler-generated (no asm body); + intercept to block sys_exit and handle safely + Note: get_addr is handled inline in _execute_nodes before _call_word. + """ + _RT_MAP: Dict[str, Callable[[CompileTimeVM], None]] = { + "exit": _rt_exit, + "jmp": _rt_jmp, + "syscall": _rt_syscall, + } + for name, func in _RT_MAP.items(): + word = dictionary.lookup(name) + if word is None: + word = Word(name=name) + dictionary.register(word) + word.runtime_intrinsic = func + + def _register_compile_time_primitives(dictionary: Dictionary) -> None: def register(name: str, func: Callable[[CompileTimeVM], None], *, compile_only: bool = False) -> None: word = dictionary.lookup(name) @@ -3473,6 +4625,7 @@ def bootstrap_dictionary() -> Dictionary: dictionary.register(Word(name="macro", immediate=True, macro=macro_begin_text_macro)) dictionary.register(Word(name="struct", immediate=True, macro=macro_struct_begin)) _register_compile_time_primitives(dictionary) + _register_runtime_intrinsics(dictionary) return dictionary @@ -3518,11 +4671,14 @@ class Compiler: source, spans = self._load_with_imports(path.resolve()) return self.compile_source(source, spans=spans, debug=debug, entry_mode=entry_mode) - def run_compile_time_word(self, name: str) -> None: + def run_compile_time_word(self, name: str, *, libs: Optional[List[str]] = None) -> None: word = self.dictionary.lookup(name) if word is None: raise CompileTimeError(f"word '{name}' not defined; cannot run at compile time") - self.parser.compile_time_vm.invoke(word) + # Skip if already executed via a ``compile-time `` directive. + if name in self.parser.compile_time_vm._ct_executed: + return + self.parser.compile_time_vm.invoke(word, runtime_mode=True, libs=libs) def _resolve_import_target(self, importing_file: Path, target: str) -> Path: raw = Path(target) @@ -4190,7 +5346,7 @@ def cli(argv: Sequence[str]) -> int: if args.ct_run_main: try: - compiler.run_compile_time_word("main") + compiler.run_compile_time_word("main", libs=args.libs) except CompileTimeError as exc: print(f"[error] compile-time execution of 'main' failed: {exc}") return 1 @@ -4241,7 +5397,12 @@ def cli(argv: Sequence[str]) -> int: def main() -> None: - sys.exit(cli(sys.argv[1:])) + code = cli(sys.argv[1:]) + # Flush all output then use os._exit to avoid SIGSEGV from ctypes/native + # memory finalization during Python's shutdown sequence. + sys.stdout.flush() + sys.stderr.flush() + os._exit(code) if __name__ == "__main__": diff --git a/test.py b/test.py index 74a89fb..c5e3dd1 100644 --- a/test.py +++ b/test.py @@ -432,11 +432,17 @@ class TestRunner: cmd.append("--ct-run-main") if self.args.verbose: print(f"\n{format_status('CMD', 'blue')} {quote_cmd(cmd)}") + # When --ct-run-main is used, the compiler executes main at compile time, + # so it may need stdin data that would normally go to the binary. + compile_input = None + if self.args.ct_run_main and case.stdin_data() is not None: + compile_input = case.stdin_data() return subprocess.run( cmd, cwd=self.root, capture_output=True, text=True, + input=compile_input, env=self._env_for(case), )