This commit is contained in:
IgorCielniak
2025-12-14 00:38:19 +01:00
parent 6910e05be6
commit 6574222280
23 changed files with 1473 additions and 277 deletions

567
main.py
View File

@@ -11,6 +11,8 @@ This file now contains working scaffolding for:
from __future__ import annotations
import argparse
import ctypes
import mmap
import subprocess
import sys
import textwrap
@@ -18,6 +20,13 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Union, Tuple
try: # lazy optional import; required for compile-time :asm execution
from keystone import Ks, KsError, KS_ARCH_X86, KS_MODE_64
except Exception: # pragma: no cover - optional dependency
Ks = None
KsError = Exception
KS_ARCH_X86 = KS_MODE_64 = None
class ParseError(Exception):
"""Raised when the source stream cannot be parsed."""
@@ -764,11 +773,13 @@ class CompileTimeVM:
self.stack: List[Any] = []
self.return_stack: List[Any] = []
self.loop_stack: List[Dict[str, Any]] = []
self._handles = _CTHandleTable()
def reset(self) -> None:
self.stack.clear()
self.return_stack.clear()
self.loop_stack.clear()
self._handles.clear()
def push(self, value: Any) -> None:
self.stack.append(value)
@@ -778,6 +789,22 @@ class CompileTimeVM:
raise ParseError("compile-time stack underflow")
return self.stack.pop()
def _resolve_handle(self, value: Any) -> Any:
if isinstance(value, int):
for delta in (0, -1, 1):
candidate = value + delta
if candidate in self._handles.objects:
obj = self._handles.objects[candidate]
self._handles.objects[value] = obj
return obj
# Occasionally a raw object id can appear on the stack; recover it if we still
# hold the object reference.
for obj in self._handles.objects.values():
if id(obj) == value:
self._handles.objects[value] = obj
return obj
return value
def peek(self) -> Any:
if not self.stack:
raise ParseError("compile-time stack underflow")
@@ -790,19 +817,24 @@ class CompileTimeVM:
return value
def pop_str(self) -> str:
value = self.pop()
value = self._resolve_handle(self.pop())
if not isinstance(value, str):
raise ParseError("expected string on compile-time stack")
return value
def pop_list(self) -> List[Any]:
value = self.pop()
value = self._resolve_handle(self.pop())
if not isinstance(value, list):
raise ParseError("expected list on compile-time stack")
known = value in self._handles.objects if isinstance(value, int) else False
handles_size = len(self._handles.objects)
handle_keys = list(self._handles.objects.keys())
raise ParseError(
f"expected list on compile-time stack, got {type(value).__name__} value={value!r} known_handle={known} handles={handles_size}:{handle_keys!r} stack={self.stack!r}"
)
return value
def pop_token(self) -> Token:
value = self.pop()
value = self._resolve_handle(self.pop())
if not isinstance(value, Token):
raise ParseError("expected token on compile-time stack")
return value
@@ -826,9 +858,210 @@ class CompileTimeVM:
if definition is None:
raise ParseError(f"word '{word.name}' has no compile-time definition")
if isinstance(definition, AsmDefinition):
raise ParseError(f"word '{word.name}' cannot run at compile time")
self._run_asm_definition(word)
return
self._execute_nodes(definition.body)
def _run_asm_definition(self, word: Word) -> None:
definition = word.definition
if Ks is None:
raise ParseError("keystone is required for compile-time :asm execution; install keystone-engine")
if not isinstance(definition, AsmDefinition): # pragma: no cover - defensive
raise ParseError(f"word '{word.name}' has no asm body")
asm_body = definition.body.strip("\n")
string_mode = word.name == "puts"
handles = self._handles
non_int_data = any(not isinstance(v, int) for v in self.stack)
non_int_return = any(not isinstance(v, int) for v in self.return_stack)
# Collect all strings present on data and return stacks so we can point
# puts() at a real buffer and pass its range check (data_start..data_end).
strings: List[str] = []
if string_mode:
for v in self.stack + self.return_stack:
if isinstance(v, str):
strings.append(v)
data_blob = b""
string_addrs: Dict[str, Tuple[int, int]] = {}
if strings:
offset = 0
parts: List[bytes] = []
seen: Dict[str, Tuple[int, int]] = {}
for s in strings:
if s in seen:
string_addrs[s] = seen[s]
continue
encoded = s.encode("utf-8") + b"\x00"
parts.append(encoded)
addr = offset
length = len(encoded) - 1
seen[s] = (addr, length)
string_addrs[s] = (addr, length)
offset += len(encoded)
data_blob = b"".join(parts)
string_buffer: Optional[ctypes.Array[Any]] = None
data_start = 0
data_end = 0
if data_blob:
string_buffer = ctypes.create_string_buffer(data_blob)
data_start = ctypes.addressof(string_buffer)
data_end = data_start + len(data_blob)
handles.refs.append(string_buffer)
for s, (off, _len) in string_addrs.items():
handles.objects[data_start + off] = s
PRINT_BUF_BYTES = 128
print_buffer = ctypes.create_string_buffer(PRINT_BUF_BYTES)
handles.refs.append(print_buffer)
print_buf = ctypes.addressof(print_buffer)
wrapper_lines = []
wrapper_lines.extend([
"_ct_entry:",
" push rbx",
" push r12",
" push r13",
" push r14",
" push r15",
" mov r12, rdi", # data stack pointer
" mov r13, rsi", # return stack pointer
" mov r14, rdx", # out ptr for r12
" mov r15, rcx", # out ptr for r13
])
if asm_body:
patched_body = []
for line in asm_body.splitlines():
line = line.strip()
if line == "ret":
line = "jmp _ct_save"
if "lea r8, [rel data_start]" in line:
line = line.replace("lea r8, [rel data_start]", f"mov r8, {data_start}")
if "lea r9, [rel data_end]" in line:
line = line.replace("lea r9, [rel data_end]", f"mov r9, {data_end}")
if "mov byte [rel print_buf]" in line or "mov byte ptr [rel print_buf]" in line:
patched_body.append(f"mov rax, {print_buf}")
patched_body.append("mov byte ptr [rax], 10")
continue
if "lea rsi, [rel print_buf_end]" in line:
line = f"mov rsi, {print_buf + PRINT_BUF_BYTES}"
if "lea rsi, [rel print_buf]" in line:
line = f"mov rsi, {print_buf}"
patched_body.append(line)
wrapper_lines.extend(patched_body)
wrapper_lines.extend([
"_ct_save:",
" mov [r14], r12",
" mov [r15], r13",
" pop r15",
" pop r14",
" pop r13",
" pop r12",
" pop rbx",
" ret",
])
def _normalize_sizes(line: str) -> str:
for size in ("qword", "dword", "word", "byte"):
line = line.replace(f"{size} [", f"{size} ptr [")
return line
def _strip_comment(line: str) -> str:
return line.split(";", 1)[0].rstrip()
normalized_lines = []
for raw in wrapper_lines:
stripped = _strip_comment(raw)
if not stripped.strip():
continue
normalized_lines.append(_normalize_sizes(stripped))
ks = Ks(KS_ARCH_X86, KS_MODE_64)
try:
encoding, _ = ks.asm("\n".join(normalized_lines))
except KsError as exc:
debug_lines = "\n".join(normalized_lines)
raise ParseError(
f"keystone failed for word '{word.name}': {exc}\n--- asm ---\n{debug_lines}\n--- end asm ---"
) from exc
if encoding is None:
raise ParseError(
f"keystone produced no code for word '{word.name}' (lines: {len(wrapper_lines)})"
)
code = bytes(encoding)
code_buf = mmap.mmap(-1, len(code), prot=mmap.PROT_READ | mmap.PROT_WRITE | mmap.PROT_EXEC)
code_buf.write(code)
code_ptr = ctypes.addressof(ctypes.c_char.from_buffer(code_buf))
func_type = ctypes.CFUNCTYPE(None, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64)
func = func_type(code_ptr)
handles = self._handles
def _marshal_stack(py_stack: List[Any]) -> Tuple[int, int, int, Any]:
capacity = len(py_stack) + 16
buffer = (ctypes.c_int64 * capacity)()
base = ctypes.addressof(buffer)
top = base + capacity * 8
sp = top
for value in py_stack:
sp -= 8
if isinstance(value, int):
ctypes.c_int64.from_address(sp).value = value
elif isinstance(value, str):
if string_mode:
offset, strlen = string_addrs.get(value, (0, 0))
addr = data_start + offset if data_start else handles.store(value)
# puts expects (len, addr) with len on top
ctypes.c_int64.from_address(sp).value = addr
sp -= 8
ctypes.c_int64.from_address(sp).value = strlen
else:
ctypes.c_int64.from_address(sp).value = handles.store(value)
else:
ctypes.c_int64.from_address(sp).value = handles.store(value)
return sp, top, base, buffer
# r12/r13 must point at the top element (or top of buffer if empty)
buffers: List[Any] = []
d_sp, d_top, d_base, d_buf = _marshal_stack(self.stack)
buffers.append(d_buf)
r_sp, r_top, r_base, r_buf = _marshal_stack(self.return_stack)
buffers.append(r_buf)
out_d = ctypes.c_uint64(0)
out_r = ctypes.c_uint64(0)
func(d_sp, r_sp, ctypes.addressof(out_d), ctypes.addressof(out_r))
new_d = out_d.value
new_r = out_r.value
if not (d_base <= new_d <= d_top):
raise ParseError(f"compile-time asm '{word.name}' corrupted data stack pointer")
if not (r_base <= new_r <= r_top):
raise ParseError(f"compile-time asm '{word.name}' corrupted return stack pointer")
def _unmarshal_stack(sp: int, top: int, table: _CTHandleTable) -> List[Any]:
if sp == top:
return []
values: List[Any] = []
addr = top - 8
while addr >= sp:
raw = ctypes.c_int64.from_address(addr).value
if raw in table.objects:
obj = table.objects[raw]
if isinstance(obj, str) and values and isinstance(values[-1], int):
# collapse (len, addr) pairs back into the original string
values.pop()
values.append(obj)
else:
values.append(obj)
else:
values.append(raw)
addr -= 8
return values
self.stack = _unmarshal_stack(new_d, d_top, handles)
self.return_stack = _unmarshal_stack(new_r, r_top, handles)
def _call_word_by_name(self, name: str) -> None:
word = self.dictionary.lookup(name)
if word is None:
@@ -1085,6 +1318,27 @@ def _parse_string_literal(token: Token) -> Optional[str]:
return "".join(result)
class _CTHandleTable:
"""Keeps Python object references stable across compile-time asm calls."""
def __init__(self) -> None:
self.objects: Dict[int, Any] = {}
self.refs: List[Any] = []
self.string_buffers: List[ctypes.Array[Any]] = []
def clear(self) -> None:
self.objects.clear()
self.refs.clear()
self.string_buffers.clear()
def store(self, value: Any) -> int:
addr = id(value)
self.refs.append(value)
self.objects[addr] = value
return addr
class Assembler:
def __init__(self, dictionary: Dictionary) -> None:
self.dictionary = dictionary
@@ -1298,6 +1552,24 @@ def macro_compile_only(ctx: MacroContext) -> Optional[List[ASTNode]]:
return None
def macro_compile_time(ctx: MacroContext) -> Optional[List[ASTNode]]:
"""Run the next word at compile time and still emit it for runtime."""
parser = ctx.parser
if parser._eof():
raise ParseError("word name missing after 'compile-time'")
tok = parser.next_token()
name = tok.lexeme
word = parser.dictionary.lookup(name)
if word is None:
raise ParseError(f"unknown word '{name}' for compile-time")
if word.compile_only:
raise ParseError(f"word '{name}' is compile-time only")
parser.compile_time_vm.invoke(word)
if isinstance(parser.context_stack[-1], Definition):
parser.emit_node(WordRef(name=name))
return None
def macro_begin_text_macro(ctx: MacroContext) -> Optional[List[ASTNode]]:
parser = ctx.parser
if parser._eof():
@@ -1447,14 +1719,6 @@ def _ensure_lexer(value: Any) -> SplitLexer:
return value
def _truthy(value: Any) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, int):
return value != 0
return value is not None
def _coerce_str(value: Any) -> str:
if isinstance(value, str):
return value
@@ -1473,217 +1737,21 @@ def _default_template(template: Optional[Token]) -> Token:
return template
def _trunc_divmod(a: int, b: int) -> Tuple[int, int]:
if b == 0:
raise ParseError("division by zero")
quot = abs(a) // abs(b)
if (a < 0) ^ (b < 0):
quot = -quot
rem = a - quot * b
return quot, rem
def _ct_dup(vm: CompileTimeVM) -> None:
vm.push(vm.peek())
def _ct_drop(vm: CompileTimeVM) -> None:
vm.pop()
def _ct_swap(vm: CompileTimeVM) -> None:
a = vm.pop()
b = vm.pop()
vm.push(a)
vm.push(b)
def _ct_over(vm: CompileTimeVM) -> None:
if len(vm.stack) < 2:
raise ParseError("over requires two stack values")
vm.push(vm.stack[-2])
def _ct_rot(vm: CompileTimeVM) -> None:
if len(vm.stack) < 3:
raise ParseError("rot requires three stack values")
vm.stack[-3], vm.stack[-2], vm.stack[-1] = vm.stack[-2], vm.stack[-1], vm.stack[-3]
def _ct_nip(vm: CompileTimeVM) -> None:
if len(vm.stack) < 2:
raise ParseError("nip requires two stack values")
top = vm.pop()
vm.pop()
vm.push(top)
def _ct_tuck(vm: CompileTimeVM) -> None:
if len(vm.stack) < 2:
raise ParseError("tuck requires two stack values")
first = vm.pop()
second = vm.pop()
vm.push(first)
vm.push(second)
vm.push(first)
def _ct_2dup(vm: CompileTimeVM) -> None:
if len(vm.stack) < 2:
raise ParseError("2dup requires two stack values")
second = vm.pop()
first = vm.pop()
vm.push(first)
vm.push(second)
vm.push(first)
vm.push(second)
def _ct_2drop(vm: CompileTimeVM) -> None:
if len(vm.stack) < 2:
raise ParseError("2drop requires two stack values")
vm.pop()
vm.pop()
def _ct_2swap(vm: CompileTimeVM) -> None:
if len(vm.stack) < 4:
raise ParseError("2swap requires four stack values")
a = vm.pop()
b = vm.pop()
c = vm.pop()
d = vm.pop()
vm.push(a)
vm.push(b)
vm.push(c)
vm.push(d)
def _ct_2over(vm: CompileTimeVM) -> None:
if len(vm.stack) < 4:
raise ParseError("2over requires four stack values")
vm.push(vm.stack[-4])
vm.push(vm.stack[-3])
def _ct_minus_rot(vm: CompileTimeVM) -> None:
if len(vm.stack) < 3:
raise ParseError("-rot requires three stack values")
vm.stack[-3], vm.stack[-2], vm.stack[-1] = vm.stack[-1], vm.stack[-3], vm.stack[-2]
def _ct_binary_int(vm: CompileTimeVM, func: Callable[[int, int], int]) -> None:
b = vm.pop_int()
a = vm.pop_int()
vm.push(func(a, b))
def _ct_add(vm: CompileTimeVM) -> None:
_ct_binary_int(vm, lambda a, b: a + b)
def _ct_sub(vm: CompileTimeVM) -> None:
_ct_binary_int(vm, lambda a, b: a - b)
def _ct_mul(vm: CompileTimeVM) -> None:
_ct_binary_int(vm, lambda a, b: a * b)
def _ct_div(vm: CompileTimeVM) -> None:
divisor = vm.pop_int()
dividend = vm.pop_int()
quot, _ = _trunc_divmod(dividend, divisor)
vm.push(quot)
def _ct_mod(vm: CompileTimeVM) -> None:
divisor = vm.pop_int()
dividend = vm.pop_int()
_, rem = _trunc_divmod(dividend, divisor)
vm.push(rem)
def _ct_compare(vm: CompileTimeVM, predicate: Callable[[Any, Any], bool]) -> None:
b = vm.pop()
a = vm.pop()
vm.push(1 if predicate(a, b) else 0)
def _ct_eq(vm: CompileTimeVM) -> None:
_ct_compare(vm, lambda a, b: a == b)
def _ct_ne(vm: CompileTimeVM) -> None:
_ct_compare(vm, lambda a, b: a != b)
def _ct_lt(vm: CompileTimeVM) -> None:
_ct_compare(vm, lambda a, b: a < b)
def _ct_le(vm: CompileTimeVM) -> None:
_ct_compare(vm, lambda a, b: a <= b)
def _ct_gt(vm: CompileTimeVM) -> None:
_ct_compare(vm, lambda a, b: a > b)
def _ct_ge(vm: CompileTimeVM) -> None:
_ct_compare(vm, lambda a, b: a >= b)
def _ct_and(vm: CompileTimeVM) -> None:
b = _truthy(vm.pop())
a = _truthy(vm.pop())
vm.push(1 if (a and b) else 0)
def _ct_or(vm: CompileTimeVM) -> None:
b = _truthy(vm.pop())
a = _truthy(vm.pop())
vm.push(1 if (a or b) else 0)
def _ct_not(vm: CompileTimeVM) -> None:
vm.push(1 if not _truthy(vm.pop()) else 0)
def _ct_to_r(vm: CompileTimeVM) -> None:
vm.return_stack.append(vm.pop())
def _ct_r_from(vm: CompileTimeVM) -> None:
if not vm.return_stack:
raise ParseError("return stack underflow")
vm.push(vm.return_stack.pop())
def _ct_rdrop(vm: CompileTimeVM) -> None:
if not vm.return_stack:
raise ParseError("return stack underflow")
vm.return_stack.pop()
def _ct_rpick(vm: CompileTimeVM) -> None:
index = vm.pop_int()
if index < 0 or index >= len(vm.return_stack):
raise ParseError("rpick index out of range")
vm.push(vm.return_stack[-1 - index])
def _ct_pick(vm: CompileTimeVM) -> None:
index = vm.pop_int()
if index < 0 or index >= len(vm.stack):
raise ParseError("pick index out of range")
vm.push(vm.stack[-1 - index])
def _ct_nil(vm: CompileTimeVM) -> None:
vm.push(None)
def _ct_puts(vm: CompileTimeVM) -> None:
value = vm.pop()
if isinstance(value, str):
print(value)
return
if isinstance(value, int):
print(value)
return
raise ParseError("puts expects string or integer at compile time")
def _ct_nil_p(vm: CompileTimeVM) -> None:
vm.push(1 if vm.pop() is None else 0)
@@ -1704,6 +1772,12 @@ def _ct_list_append(vm: CompileTimeVM) -> None:
vm.push(lst)
def _ct_drop(vm: CompileTimeVM) -> None:
if not vm.stack:
return
vm.pop()
def _ct_list_pop(vm: CompileTimeVM) -> None:
lst = _ensure_list(vm.pop())
if not lst:
@@ -1723,7 +1797,7 @@ def _ct_list_pop_front(vm: CompileTimeVM) -> None:
def _ct_list_length(vm: CompileTimeVM) -> None:
lst = _ensure_list(vm.pop())
lst = vm.pop_list()
vm.push(len(lst))
@@ -1955,13 +2029,24 @@ def _ct_int_to_string(vm: CompileTimeVM) -> None:
def _ct_identifier_p(vm: CompileTimeVM) -> None:
value = vm.pop_str()
value = vm._resolve_handle(vm.pop())
if isinstance(value, Token):
value = value.lexeme
if not isinstance(value, str):
vm.push(0)
return
vm.push(1 if _is_identifier(value) else 0)
def _ct_token_lexeme(vm: CompileTimeVM) -> None:
token = vm.pop_token()
vm.push(token.lexeme)
value = vm._resolve_handle(vm.pop())
if isinstance(value, Token):
vm.push(value.lexeme)
return
if isinstance(value, str):
vm.push(value)
return
raise ParseError("expected token or string on compile-time stack")
def _ct_token_from_lexeme(vm: CompileTimeVM) -> None:
@@ -2068,43 +2153,12 @@ def _register_compile_time_primitives(dictionary: Dictionary) -> None:
if compile_only:
word.compile_only = True
register("dup", _ct_dup)
register("drop", _ct_drop)
register("swap", _ct_swap)
register("over", _ct_over)
register("rot", _ct_rot)
register("nip", _ct_nip)
register("tuck", _ct_tuck)
register("2dup", _ct_2dup)
register("2drop", _ct_2drop)
register("2swap", _ct_2swap)
register("2over", _ct_2over)
register("-rot", _ct_minus_rot)
register("+", _ct_add)
register("-", _ct_sub)
register("*", _ct_mul)
register("/", _ct_div)
register("%", _ct_mod)
register("==", _ct_eq)
register("!=", _ct_ne)
register("<", _ct_lt)
register("<=", _ct_le)
register(">", _ct_gt)
register(">=", _ct_ge)
register("and", _ct_and)
register("or", _ct_or)
register("not", _ct_not)
register(">r", _ct_to_r)
register("r>", _ct_r_from)
register("rdrop", _ct_rdrop)
register("rpick", _ct_rpick)
register("pick", _ct_pick)
register("nil", _ct_nil, compile_only=True)
register("nil?", _ct_nil_p, compile_only=True)
register("list-new", _ct_list_new, compile_only=True)
register("list-clone", _ct_list_clone, compile_only=True)
register("list-append", _ct_list_append, compile_only=True)
register("drop", _ct_drop)
register("list-pop", _ct_list_pop, compile_only=True)
register("list-pop-front", _ct_list_pop_front, compile_only=True)
register("list-length", _ct_list_length, compile_only=True)
@@ -2239,6 +2293,7 @@ def bootstrap_dictionary() -> Dictionary:
dictionary = Dictionary()
dictionary.register(Word(name="immediate", immediate=True, macro=macro_immediate))
dictionary.register(Word(name="compile-only", immediate=True, macro=macro_compile_only))
dictionary.register(Word(name="compile-time", immediate=True, macro=macro_compile_time))
dictionary.register(Word(name="macro:", immediate=True, macro=macro_begin_text_macro))
dictionary.register(Word(name=";macro", immediate=True, macro=macro_end_text_macro))
dictionary.register(Word(name="struct:", immediate=True, macro=macro_struct_begin))