added better support for structs for extern functions, added a flag to disable list folding and made extern functions automaticly have priority 1

This commit is contained in:
igor
2026-03-02 10:21:42 +01:00
parent 86d4ffbb9a
commit 13f33d7820
6 changed files with 836 additions and 129 deletions

764
main.py
View File

@@ -292,6 +292,7 @@ class Module:
variables: Dict[str, str] = field(default_factory=dict)
prelude: Optional[List[str]] = None
bss: Optional[List[str]] = None
cstruct_layouts: Dict[str, CStructLayout] = field(default_factory=dict)
@dataclass(slots=True)
@@ -308,6 +309,23 @@ class StructField:
size: int
@dataclass(slots=True)
class CStructField:
name: str
type_name: str
offset: int
size: int
align: int
@dataclass(slots=True)
class CStructLayout:
name: str
size: int
align: int
fields: List[CStructField]
class MacroContext:
"""Small facade exposed to Python-defined macros."""
@@ -460,6 +478,8 @@ Context = Union[Module, Definition]
class Parser:
EXTERN_DEFAULT_PRIORITY = 1
def __init__(self, dictionary: Dictionary, reader: Optional[Reader] = None) -> None:
self.dictionary = dictionary
self.reader = reader or Reader()
@@ -482,6 +502,7 @@ class Parser:
self.compile_time_vm = CompileTimeVM(self)
self.custom_prelude: Optional[List[str]] = None
self.custom_bss: Optional[List[str]] = None
self.cstruct_layouts: Dict[str, CStructLayout] = {}
self._pending_inline_definition: bool = False
self._pending_priority: Optional[int] = None
@@ -588,7 +609,14 @@ class Parser:
self.pos = 0
self.variable_labels = {}
self.variable_words = {}
self.context_stack = [Module(forms=[], variables=self.variable_labels)]
self.cstruct_layouts = {}
self.context_stack = [
Module(
forms=[],
variables=self.variable_labels,
cstruct_layouts=self.cstruct_layouts,
)
]
self.definition_stack.clear()
self.last_defined = None
self.control_stack = []
@@ -693,6 +721,7 @@ class Parser:
module.variables = dict(self.variable_labels)
module.prelude = self.custom_prelude
module.bss = self.custom_bss
module.cstruct_layouts = dict(self.cstruct_layouts)
return module
def _handle_list_begin(self) -> None:
@@ -717,9 +746,9 @@ class Parser:
)
self._pending_priority = value
def _consume_pending_priority(self) -> int:
def _consume_pending_priority(self, *, default: int = 0) -> int:
if self._pending_priority is None:
return 0
return default
value = self._pending_priority
self._pending_priority = None
return value
@@ -734,7 +763,7 @@ class Parser:
if self._eof():
raise ParseError(f"extern missing name at {token.line}:{token.column}")
priority = self._consume_pending_priority()
priority = self._consume_pending_priority(default=self.EXTERN_DEFAULT_PRIORITY)
first_token = self._consume()
if self._try_parse_c_extern(first_token, priority=priority):
return
@@ -1503,6 +1532,7 @@ class CompileTimeVM:
self._dl_handles: List[Any] = [] # ctypes.CDLL handles
self._dl_func_cache: Dict[str, Any] = {} # name → ctypes callable
self._ct_libs: List[str] = [] # library names from -l flags
self._ctypes_struct_cache: Dict[str, Any] = {}
self.current_location: Optional[SourceLocation] = None
def reset(self) -> None:
@@ -1734,11 +1764,21 @@ class CompileTimeVM:
_CTYPE_MAP: Dict[str, Any] = {
"int": ctypes.c_int,
"int8_t": ctypes.c_int8,
"uint8_t": ctypes.c_uint8,
"int16_t": ctypes.c_int16,
"uint16_t": ctypes.c_uint16,
"int32_t": ctypes.c_int32,
"uint32_t": ctypes.c_uint32,
"long": ctypes.c_long,
"long long": ctypes.c_longlong,
"int64_t": ctypes.c_int64,
"unsigned int": ctypes.c_uint,
"unsigned long": ctypes.c_ulong,
"unsigned long long": ctypes.c_ulonglong,
"uint64_t": ctypes.c_uint64,
"size_t": ctypes.c_size_t,
"ssize_t": ctypes.c_ssize_t,
"char": ctypes.c_char,
"char*": ctypes.c_void_p, # use void* so raw integer addrs work
"void*": ctypes.c_void_p,
@@ -1746,14 +1786,30 @@ class CompileTimeVM:
"float": ctypes.c_float,
}
def _resolve_struct_ctype(self, struct_name: str) -> Any:
cached = self._ctypes_struct_cache.get(struct_name)
if cached is not None:
return cached
layout = self.parser.cstruct_layouts.get(struct_name)
if layout is None:
raise ParseError(f"unknown cstruct '{struct_name}' used in extern signature")
fields = []
for field in layout.fields:
fields.append((field.name, self._resolve_ctype(field.type_name)))
struct_cls = type(f"CTStruct_{sanitize_label(struct_name)}", (ctypes.Structure,), {"_fields_": fields})
self._ctypes_struct_cache[struct_name] = struct_cls
return struct_cls
def _resolve_ctype(self, type_name: str) -> Any:
"""Map a C type name string to a ctypes type."""
t = type_name.strip().replace("*", "* ").replace(" ", " ").strip()
if t in self._CTYPE_MAP:
return self._CTYPE_MAP[t]
# Pointer types
t = _canonical_c_type_name(type_name)
if t.endswith("*"):
return ctypes.c_void_p
if t.startswith("struct "):
return self._resolve_struct_ctype(t[len("struct "):].strip())
t = t.replace("*", "* ").replace(" ", " ").strip()
if t in self._CTYPE_MAP:
return self._CTYPE_MAP[t]
# Default to c_long (64-bit on Linux x86-64)
return ctypes.c_long
@@ -1822,21 +1878,34 @@ class CompileTimeVM:
# Convert arguments to proper ctypes values
call_args = []
for i, raw in enumerate(raw_args):
if i < len(arg_types) and arg_types[i] in ("float", "double"):
arg_type = _canonical_c_type_name(arg_types[i]) if i < len(arg_types) else None
if arg_type in ("float", "double"):
# Reinterpret the int64 bits as a double (matching the language's convention)
raw_int = _to_i64(int(raw))
double_val = struct.unpack("d", struct.pack("q", raw_int))[0]
call_args.append(double_val)
elif arg_type is not None and arg_type.startswith("struct ") and not arg_type.endswith("*"):
struct_name = arg_type[len("struct "):].strip()
struct_ctype = self._resolve_struct_ctype(struct_name)
call_args.append(struct_ctype.from_address(int(raw)))
else:
call_args.append(int(raw))
result = func(*call_args)
if outputs > 0 and result is not None:
ret_type = func._ct_signature[1] if func._ct_signature else None
ret_type = _canonical_c_type_name(func._ct_signature[1]) if func._ct_signature else None
if ret_type in ("float", "double"):
int_bits = struct.unpack("q", struct.pack("d", float(result)))[0]
self.push(int_bits)
elif ret_type is not None and ret_type.startswith("struct "):
struct_name = ret_type[len("struct "):].strip()
layout = self.parser.cstruct_layouts.get(struct_name)
if layout is None:
raise ParseError(f"unknown cstruct '{struct_name}' used in extern return type")
out_ptr = self.memory.allocate(layout.size)
ctypes.memmove(out_ptr, ctypes.byref(result), layout.size)
self.push(out_ptr)
else:
self.push(int(result))
@@ -2963,6 +3032,157 @@ _C_TYPE_IGNORED_QUALIFIERS = {
"_Atomic",
}
_C_FIELD_TYPE_ALIASES: Dict[str, str] = {
"i8": "int8_t",
"u8": "uint8_t",
"i16": "int16_t",
"u16": "uint16_t",
"i32": "int32_t",
"u32": "uint32_t",
"i64": "int64_t",
"u64": "uint64_t",
"isize": "long",
"usize": "size_t",
"f32": "float",
"f64": "double",
"ptr": "void*",
}
_C_SCALAR_TYPE_INFO: Dict[str, Tuple[int, int, str]] = {
"char": (1, 1, "INTEGER"),
"signed char": (1, 1, "INTEGER"),
"unsigned char": (1, 1, "INTEGER"),
"short": (2, 2, "INTEGER"),
"short int": (2, 2, "INTEGER"),
"unsigned short": (2, 2, "INTEGER"),
"unsigned short int": (2, 2, "INTEGER"),
"int": (4, 4, "INTEGER"),
"unsigned int": (4, 4, "INTEGER"),
"int32_t": (4, 4, "INTEGER"),
"uint32_t": (4, 4, "INTEGER"),
"long": (8, 8, "INTEGER"),
"unsigned long": (8, 8, "INTEGER"),
"long long": (8, 8, "INTEGER"),
"unsigned long long": (8, 8, "INTEGER"),
"int64_t": (8, 8, "INTEGER"),
"uint64_t": (8, 8, "INTEGER"),
"size_t": (8, 8, "INTEGER"),
"ssize_t": (8, 8, "INTEGER"),
"void": (0, 1, "INTEGER"),
"float": (4, 4, "SSE"),
"double": (8, 8, "SSE"),
}
def _round_up(value: int, align: int) -> int:
if align <= 1:
return value
return ((value + align - 1) // align) * align
def _canonical_c_type_name(type_name: str) -> str:
text = " ".join(type_name.strip().split())
if not text:
return text
text = _C_FIELD_TYPE_ALIASES.get(text, text)
text = text.replace(" *", "*")
return text
def _is_struct_type(type_name: str) -> bool:
return _canonical_c_type_name(type_name).startswith("struct ")
def _c_type_size_align_class(
type_name: str,
cstruct_layouts: Dict[str, CStructLayout],
) -> Tuple[int, int, str, Optional[CStructLayout]]:
t = _canonical_c_type_name(type_name)
if not t:
return 8, 8, "INTEGER", None
if t.endswith("*"):
return 8, 8, "INTEGER", None
if t in _C_SCALAR_TYPE_INFO:
size, align, cls = _C_SCALAR_TYPE_INFO[t]
return size, align, cls, None
if t.startswith("struct "):
struct_name = t[len("struct "):].strip()
layout = cstruct_layouts.get(struct_name)
if layout is None:
raise CompileError(
f"unknown cstruct '{struct_name}' used in extern signature"
)
return layout.size, layout.align, "STRUCT", layout
# Preserve backward compatibility for unknown scalar-ish names.
return 8, 8, "INTEGER", None
def _merge_eightbyte_class(current: str, incoming: str) -> str:
if current == "NO_CLASS":
return incoming
if current == incoming:
return current
if current == "INTEGER" or incoming == "INTEGER":
return "INTEGER"
return incoming
def _classify_struct_eightbytes(
layout: CStructLayout,
cstruct_layouts: Dict[str, CStructLayout],
cache: Optional[Dict[str, Optional[List[str]]]] = None,
) -> Optional[List[str]]:
if cache is None:
cache = {}
cached = cache.get(layout.name)
if cached is not None or layout.name in cache:
return cached
if layout.size <= 0:
cache[layout.name] = []
return []
if layout.size > 16:
cache[layout.name] = None
return None
chunk_count = (layout.size + 7) // 8
classes: List[str] = ["NO_CLASS"] * chunk_count
for field in layout.fields:
f_size, _, f_class, nested = _c_type_size_align_class(field.type_name, cstruct_layouts)
if f_size == 0:
continue
if nested is not None:
nested_classes = _classify_struct_eightbytes(nested, cstruct_layouts, cache)
if nested_classes is None:
cache[layout.name] = None
return None
base_chunk = field.offset // 8
for idx, cls in enumerate(nested_classes):
chunk = base_chunk + idx
if chunk >= len(classes):
cache[layout.name] = None
return None
classes[chunk] = _merge_eightbyte_class(classes[chunk], cls or "INTEGER")
continue
start_chunk = field.offset // 8
end_chunk = (field.offset + f_size - 1) // 8
if end_chunk >= len(classes):
cache[layout.name] = None
return None
if f_class == "SSE" and start_chunk != end_chunk:
cache[layout.name] = None
return None
for chunk in range(start_chunk, end_chunk + 1):
classes[chunk] = _merge_eightbyte_class(classes[chunk], f_class)
for idx, cls in enumerate(classes):
if cls == "NO_CLASS":
classes[idx] = "INTEGER"
cache[layout.name] = classes
return classes
def _split_trailing_identifier(text: str) -> Tuple[str, Optional[str]]:
if not text:
@@ -3083,6 +3303,7 @@ class Assembler:
dictionary: Dictionary,
*,
enable_constant_folding: bool = True,
enable_static_list_folding: bool = True,
enable_peephole_optimization: bool = True,
loop_unroll_threshold: int = 8,
) -> None:
@@ -3094,8 +3315,10 @@ class Assembler:
self._inline_counter: int = 0
self._unroll_counter: int = 0
self._emit_stack: List[str] = []
self._cstruct_layouts: Dict[str, CStructLayout] = {}
self._export_all_defs: bool = False
self.enable_constant_folding = enable_constant_folding
self.enable_static_list_folding = enable_static_list_folding
self.enable_peephole_optimization = enable_peephole_optimization
self.loop_unroll_threshold = loop_unroll_threshold
@@ -3545,6 +3768,7 @@ class Assembler:
self._string_literals = {}
self._float_literals = {}
self._data_section = emission.data
self._cstruct_layouts = dict(module.cstruct_layouts)
valid_defs = (Definition, AsmDefinition)
raw_defs = [form for form in module.forms if isinstance(form, valid_defs)]
@@ -3560,9 +3784,10 @@ class Assembler:
for defn in definitions:
if isinstance(defn, Definition):
self._fold_constants_in_definition(defn)
for defn in definitions:
if isinstance(defn, Definition):
self._fold_static_list_literals_definition(defn)
if self.enable_static_list_folding:
for defn in definitions:
if isinstance(defn, Definition):
self._fold_static_list_literals_definition(defn)
stray_forms = [form for form in module.forms if not isinstance(form, valid_defs)]
if stray_forms:
raise CompileError("top-level literals or word references are not supported yet")
@@ -4029,6 +4254,222 @@ class Assembler:
raise CompileError(f"unsupported op {node!r}{ctx()}")
def _emit_mmap_alloc(self, builder: FunctionEmitter, size: int, target_reg: str = "rax") -> None:
alloc_size = max(1, int(size))
builder.emit(" xor rdi, rdi")
builder.emit(f" mov rsi, {alloc_size}")
builder.emit(" mov rdx, 3")
builder.emit(" mov r10, 34")
builder.emit(" mov r8, -1")
builder.emit(" xor r9, r9")
builder.emit(" mov rax, 9")
builder.emit(" syscall")
if target_reg != "rax":
builder.emit(f" mov {target_reg}, rax")
def _analyze_extern_c_type(self, type_name: str) -> Dict[str, Any]:
size, align, cls, layout = _c_type_size_align_class(type_name, self._cstruct_layouts)
info: Dict[str, Any] = {
"name": _canonical_c_type_name(type_name),
"size": size,
"align": align,
"class": cls,
"kind": "struct" if layout is not None else "scalar",
"layout": layout,
"pass_mode": "scalar",
"eightbytes": [],
}
if layout is not None:
eb = _classify_struct_eightbytes(layout, self._cstruct_layouts)
info["eightbytes"] = eb or []
info["pass_mode"] = "register" if eb is not None else "memory"
return info
def _emit_copy_bytes_from_ptr(
self,
builder: FunctionEmitter,
*,
src_ptr_reg: str,
dst_expr: str,
size: int,
) -> None:
copied = 0
while copied + 8 <= size:
builder.emit(f" mov r11, [{src_ptr_reg} + {copied}]")
builder.emit(f" mov qword [{dst_expr} + {copied}], r11")
copied += 8
while copied < size:
builder.emit(f" mov r11b, byte [{src_ptr_reg} + {copied}]")
builder.emit(f" mov byte [{dst_expr} + {copied}], r11b")
copied += 1
def _emit_extern_wordref(self, name: str, word: Word, builder: FunctionEmitter) -> None:
inputs = getattr(word, "extern_inputs", 0)
outputs = getattr(word, "extern_outputs", 0)
signature = getattr(word, "extern_signature", None)
if signature is None and inputs <= 0 and outputs <= 0:
builder.emit(f" call {name}")
return
arg_types = list(signature[0]) if signature else ["long"] * inputs
ret_type = signature[1] if signature else ("long" if outputs > 0 else "void")
if len(arg_types) != inputs and signature is not None:
suffix = f" while emitting '{self._emit_stack[-1]}'" if self._emit_stack else ""
raise CompileError(f"extern '{name}' mismatch: {inputs} inputs vs {len(arg_types)} types{suffix}")
arg_infos = [self._analyze_extern_c_type(t) for t in arg_types]
ret_info = self._analyze_extern_c_type(ret_type)
regs = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"]
xmm_regs = [f"xmm{i}" for i in range(8)]
ret_uses_sret = ret_info["kind"] == "struct" and ret_info["pass_mode"] == "memory"
int_idx = 1 if ret_uses_sret else 0
xmm_idx = 0
arg_locs: List[Dict[str, Any]] = []
stack_cursor = 0
for info in arg_infos:
if info["kind"] == "struct":
if info["pass_mode"] == "register":
classes: List[str] = list(info["eightbytes"])
need_int = sum(1 for c in classes if c == "INTEGER")
need_xmm = sum(1 for c in classes if c == "SSE")
if int_idx + need_int <= len(regs) and xmm_idx + need_xmm <= len(xmm_regs):
chunks: List[Tuple[str, str, int]] = []
int_off = int_idx
xmm_off = xmm_idx
for chunk_idx, cls in enumerate(classes):
if cls == "SSE":
chunks.append((cls, xmm_regs[xmm_off], chunk_idx * 8))
xmm_off += 1
else:
chunks.append(("INTEGER", regs[int_off], chunk_idx * 8))
int_off += 1
int_idx = int_off
xmm_idx = xmm_off
arg_locs.append({"mode": "struct_reg", "chunks": chunks, "info": info})
else:
stack_size = _round_up(int(info["size"]), 8)
stack_off = stack_cursor
stack_cursor += stack_size
arg_locs.append({"mode": "struct_stack", "stack_off": stack_off, "info": info})
else:
stack_size = _round_up(int(info["size"]), 8)
stack_off = stack_cursor
stack_cursor += stack_size
arg_locs.append({"mode": "struct_stack", "stack_off": stack_off, "info": info})
continue
if info["class"] == "SSE":
if xmm_idx < len(xmm_regs):
arg_locs.append({"mode": "scalar_reg", "reg": xmm_regs[xmm_idx], "class": "SSE"})
xmm_idx += 1
else:
stack_off = stack_cursor
stack_cursor += 8
arg_locs.append({"mode": "scalar_stack", "stack_off": stack_off, "class": "SSE"})
else:
if int_idx < len(regs):
arg_locs.append({"mode": "scalar_reg", "reg": regs[int_idx], "class": "INTEGER"})
int_idx += 1
else:
stack_off = stack_cursor
stack_cursor += 8
arg_locs.append({"mode": "scalar_stack", "stack_off": stack_off, "class": "INTEGER"})
# Preserve and realign RSP for C ABI calls regardless current call depth.
stack_bytes = max(15, int(stack_cursor) + 15)
builder.emit(" mov r14, rsp")
builder.emit(f" sub rsp, {stack_bytes}")
builder.emit(" and rsp, -16")
if ret_info["kind"] == "struct":
self._emit_mmap_alloc(builder, int(ret_info["size"]), target_reg="r15")
if ret_uses_sret:
builder.emit(" mov rdi, r15")
total_args = len(arg_locs)
for idx, loc in enumerate(reversed(arg_locs)):
addr = f"[r12 + {idx * 8}]" if idx > 0 else "[r12]"
mode = str(loc["mode"])
if mode == "scalar_reg":
reg = str(loc["reg"])
cls = str(loc["class"])
if cls == "SSE":
builder.emit(f" mov rax, {addr}")
builder.emit(f" movq {reg}, rax")
else:
builder.emit(f" mov {reg}, {addr}")
continue
if mode == "scalar_stack":
stack_off = int(loc["stack_off"])
builder.emit(f" mov rax, {addr}")
builder.emit(f" mov qword [rsp + {stack_off}], rax")
continue
if mode == "struct_reg":
chunks: List[Tuple[str, str, int]] = list(loc["chunks"])
builder.emit(f" mov rax, {addr}")
for cls, target, off in chunks:
if cls == "SSE":
builder.emit(f" movq {target}, [rax + {off}]")
else:
builder.emit(f" mov {target}, [rax + {off}]")
continue
if mode == "struct_stack":
stack_off = int(loc["stack_off"])
size = int(loc["info"]["size"])
builder.emit(f" mov rax, {addr}")
self._emit_copy_bytes_from_ptr(builder, src_ptr_reg="rax", dst_expr=f"rsp + {stack_off}", size=size)
continue
raise CompileError(f"internal extern lowering error for '{name}': unknown arg mode {mode!r}")
if total_args:
builder.emit(f" add r12, {total_args * 8}")
builder.emit(f" mov al, {xmm_idx}")
builder.emit(f" call {name}")
builder.emit(" mov rsp, r14")
if ret_info["kind"] == "struct":
if not ret_uses_sret:
ret_classes: List[str] = list(ret_info["eightbytes"])
int_ret_regs = ["rax", "rdx"]
xmm_ret_regs = ["xmm0", "xmm1"]
int_ret_idx = 0
xmm_ret_idx = 0
for chunk_idx, cls in enumerate(ret_classes):
off = chunk_idx * 8
if cls == "SSE":
src = xmm_ret_regs[xmm_ret_idx]
xmm_ret_idx += 1
builder.emit(f" movq [r15 + {off}], {src}")
else:
src = int_ret_regs[int_ret_idx]
int_ret_idx += 1
builder.emit(f" mov [r15 + {off}], {src}")
builder.emit(" sub r12, 8")
builder.emit(" mov [r12], r15")
return
if _ctype_uses_sse(ret_type):
builder.emit(" sub r12, 8")
builder.emit(" movq rax, xmm0")
builder.emit(" mov [r12], rax")
elif outputs == 1:
builder.push_from("rax")
elif outputs > 1:
raise CompileError("extern only supports 0 or 1 scalar output")
def _emit_wordref(self, name: str, builder: FunctionEmitter) -> None:
word = self.dictionary.lookup(name)
if word is None:
@@ -4048,114 +4489,7 @@ class Assembler:
word.intrinsic(builder)
return
if getattr(word, "is_extern", False):
inputs = getattr(word, "extern_inputs", 0)
outputs = getattr(word, "extern_outputs", 0)
signature = getattr(word, "extern_signature", None)
if signature is not None or inputs > 0 or outputs > 0:
regs = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"]
xmm_regs = [f"xmm{i}" for i in range(8)]
arg_types = signature[0] if signature else []
ret_type = signature[1] if signature else None
if len(arg_types) != inputs and signature:
suffix = f" while emitting '{self._emit_stack[-1]}'" if self._emit_stack else ""
raise CompileError(f"extern '{name}' mismatch: {inputs} inputs vs {len(arg_types)} types{suffix}")
int_idx = 0
xmm_idx = 0
mapping: List[Tuple[str, str]] = [] # (type, target)
# Assign registers for first args; overflow goes to stack
if not arg_types:
# Legacy/Raw mode: assume all ints
for i in range(inputs):
if int_idx < len(regs):
mapping.append(("int", regs[int_idx]))
int_idx += 1
else:
mapping.append(("int", "stack"))
else:
for type_name in arg_types:
if type_name in ("float", "double"):
if xmm_idx < len(xmm_regs):
mapping.append(("float", xmm_regs[xmm_idx]))
xmm_idx += 1
else:
mapping.append(("float", "stack"))
else:
if int_idx < len(regs):
mapping.append(("int", regs[int_idx]))
int_idx += 1
else:
mapping.append(("int", "stack"))
# Count stack slots required
stack_slots = sum(1 for t, target in mapping if target == "stack")
# stack allocation in bytes; make it a multiple of 16 for alignment
stack_bytes = ((stack_slots * 8 + 15) // 16) * 16 if stack_slots > 0 else 0
# Prepare stack-passed arguments: allocate space (16-byte multiple)
if stack_bytes:
builder.emit(f" sub rsp, {stack_bytes}")
# Read all arguments from the CT stack by indexed addressing
# (without advancing r12) and write them to registers or the
# prepared spill area. After all reads are emitted we advance
# r12 once by the total number of arguments to pop them.
total_args = len(mapping)
if stack_slots:
stack_write_idx = stack_slots - 1
else:
stack_write_idx = 0
# Iterate over reversed mapping (right-to-left) but use an
# index to address the CT stack without modifying r12.
for idx, (typ, target) in enumerate(reversed(mapping)):
addr = f"[r12 + {idx * 8}]" if idx > 0 else "[r12]"
if target == "stack":
# Read spilled arg from indexed CT stack slot and store
# it into the caller's spill area at the computed offset.
builder.emit(f" mov rax, {addr}")
offset = stack_write_idx * 8
builder.emit(f" mov [rsp + {offset}], rax")
stack_write_idx -= 1
else:
if typ == "float":
builder.emit(f" mov rax, {addr}")
builder.emit(f" movq {target}, rax")
else:
builder.emit(f" mov {target}, {addr}")
# Advance the CT stack pointer once to pop all arguments.
if total_args:
builder.emit(f" add r12, {total_args * 8}")
# Call the external function. We allocated a multiple-of-16
# area for spilled args above so `rsp` is already aligned
# for the call; set `al` (SSE count) then call directly.
builder.emit(f" mov al, {xmm_idx}")
builder.emit(f" call {name}")
# Restore stack after the call
if stack_bytes:
builder.emit(f" add rsp, {stack_bytes}")
# Handle Return Value
if _ctype_uses_sse(ret_type):
# Result in xmm0, move to stack
builder.emit(" sub r12, 8")
builder.emit(" movq rax, xmm0")
builder.emit(" mov [r12], rax")
elif outputs == 1:
builder.push_from("rax")
elif outputs > 1:
raise CompileError("extern only supports 0 or 1 output")
else:
# Emit call to unresolved symbol (let linker resolve it)
builder.emit(f" call {name}")
self._emit_extern_wordref(name, word, builder)
else:
builder.emit(f" call {sanitize_label(name)}")
@@ -5346,6 +5680,27 @@ PY_EXEC_GLOBALS: Dict[str, Any] = {
}
def _parse_cfield_type(parser: Parser, struct_name: str) -> str:
if parser._eof():
raise ParseError(f"field type missing in cstruct '{struct_name}'")
tok = parser.next_token().lexeme
if tok == "struct":
if parser._eof():
raise ParseError(f"struct field type missing name in cstruct '{struct_name}'")
name_tok = parser.next_token().lexeme
type_name = f"struct {name_tok}"
if not parser._eof():
peek = parser.peek_token()
if peek is not None and set(peek.lexeme) == {"*"}:
type_name += peek.lexeme
parser.next_token()
return _canonical_c_type_name(type_name)
canonical = _canonical_c_type_name(tok)
return _canonical_c_type_name(_C_FIELD_TYPE_ALIASES.get(canonical, canonical))
def macro_struct_begin(ctx: MacroContext) -> Optional[List[Op]]:
parser = ctx.parser
if parser._eof():
@@ -5402,6 +5757,84 @@ def macro_struct_begin(ctx: MacroContext) -> Optional[List[Op]]:
parser.tokens[parser.pos:parser.pos] = generated
return None
def macro_cstruct_begin(ctx: MacroContext) -> Optional[List[Op]]:
parser = ctx.parser
if parser._eof():
raise ParseError("cstruct name missing after 'cstruct'")
name_token = parser.next_token()
struct_name = name_token.lexeme
fields: List[CStructField] = []
current_offset = 0
max_align = 1
while True:
if parser._eof():
raise ParseError("unterminated cstruct definition (missing 'end')")
token = parser.next_token()
if token.lexeme == "end":
break
if token.lexeme != "cfield":
raise ParseError(
f"expected 'cfield' or 'end' in cstruct '{struct_name}' definition"
)
if parser._eof():
raise ParseError("field name missing in cstruct definition")
field_name_token = parser.next_token()
type_name = _parse_cfield_type(parser, struct_name)
field_size, field_align, _, _ = _c_type_size_align_class(type_name, parser.cstruct_layouts)
if field_size <= 0:
raise ParseError(
f"invalid cfield type '{type_name}' for '{field_name_token.lexeme}' in cstruct '{struct_name}'"
)
current_offset = _round_up(current_offset, field_align)
fields.append(
CStructField(
name=field_name_token.lexeme,
type_name=type_name,
offset=current_offset,
size=field_size,
align=field_align,
)
)
current_offset += field_size
if field_align > max_align:
max_align = field_align
total_size = _round_up(current_offset, max_align)
parser.cstruct_layouts[struct_name] = CStructLayout(
name=struct_name,
size=total_size,
align=max_align,
fields=fields,
)
generated: List[Token] = []
_struct_emit_definition(generated, name_token, f"{struct_name}.size", [str(total_size)])
_struct_emit_definition(generated, name_token, f"{struct_name}.align", [str(max_align)])
for field in fields:
size_word = f"{struct_name}.{field.name}.size"
offset_word = f"{struct_name}.{field.name}.offset"
_struct_emit_definition(generated, name_token, size_word, [str(field.size)])
_struct_emit_definition(generated, name_token, offset_word, [str(field.offset)])
if field.size == 8:
_struct_emit_definition(
generated,
name_token,
f"{struct_name}.{field.name}@",
[offset_word, "+", "@"],
)
_struct_emit_definition(
generated,
name_token,
f"{struct_name}.{field.name}!",
["swap", offset_word, "+", "swap", "!"],
)
parser.tokens[parser.pos:parser.pos] = generated
return None
def macro_here(ctx: MacroContext) -> Optional[List[Op]]:
tok = ctx.parser._last_token
if tok is None:
@@ -5422,6 +5855,7 @@ def bootstrap_dictionary() -> Dictionary:
dictionary.register(Word(name="with", immediate=True, macro=macro_with))
dictionary.register(Word(name="macro", immediate=True, macro=macro_begin_text_macro))
dictionary.register(Word(name="struct", immediate=True, macro=macro_struct_begin))
dictionary.register(Word(name="cstruct", immediate=True, macro=macro_cstruct_begin))
_register_compile_time_primitives(dictionary)
_register_runtime_intrinsics(dictionary)
return dictionary
@@ -5694,9 +6128,17 @@ class BuildCache:
key = self._hash_str(str(source.resolve()))
return self.cache_dir / f"{key}.json"
def flags_hash(self, debug: bool, folding: bool, peephole: bool, entry_mode: str) -> str:
def flags_hash(
self,
debug: bool,
folding: bool,
static_list_folding: bool,
peephole: bool,
entry_mode: str,
) -> str:
return self._hash_str(
f"debug={debug},folding={folding},peephole={peephole},entry_mode={entry_mode}"
f"debug={debug},folding={folding},static_list_folding={static_list_folding},"
f"peephole={peephole},entry_mode={entry_mode}"
)
def _file_info(self, path: Path) -> dict:
@@ -5857,6 +6299,38 @@ def build_static_library(obj_path: Path, archive_path: Path) -> None:
subprocess.run(["ar", "rcs", str(archive_path), str(obj_path)], check=True)
def _load_sidecar_meta_libs(source: Path) -> List[str]:
"""Return additional linker libs from sibling <source>.meta.json."""
meta_path = source.with_suffix(".meta.json")
if not meta_path.exists():
return []
try:
payload = json.loads(meta_path.read_text())
except Exception as exc:
print(f"[warn] failed to read {meta_path}: {exc}")
return []
libs = payload.get("libs")
if not isinstance(libs, list):
return []
out: List[str] = []
for item in libs:
if isinstance(item, str) and item:
out.append(item)
return out
def _build_ct_sidecar_shared(source: Path, temp_dir: Path) -> Optional[Path]:
"""Build sibling <source>.c into a shared object for --ct-run-main externs."""
c_path = source.with_suffix(".c")
if not c_path.exists():
return None
temp_dir.mkdir(parents=True, exist_ok=True)
so_path = temp_dir / f"{source.stem}.ctlib.so"
cmd = ["cc", "-shared", "-fPIC", str(c_path), "-o", str(so_path)]
subprocess.run(cmd, check=True)
return so_path
def run_repl(
compiler: Compiler,
temp_dir: Path,
@@ -7150,6 +7624,11 @@ def cli(argv: Sequence[str]) -> int:
parser.add_argument("--repl", action="store_true", help="interactive REPL; source file is optional")
parser.add_argument("-l", dest="libs", action="append", default=[], help="pass library to linker (e.g. -l m or -l libc.so.6)")
parser.add_argument("--no-folding", action="store_true", help="disable constant folding optimization")
parser.add_argument(
"--no-static-list-folding",
action="store_true",
help="disable static list-literal folding (lists stay runtime-allocated)",
)
parser.add_argument("--no-peephole", action="store_true", help="disable peephole optimizations")
parser.add_argument("--no-cache", action="store_true", help="disable incremental build cache")
parser.add_argument("--ct-run-main", action="store_true", help="execute 'main' via the compile-time VM after parsing")
@@ -7203,6 +7682,7 @@ def cli(argv: Sequence[str]) -> int:
artifact_kind = args.artifact
folding_enabled = not args.no_folding
static_list_folding_enabled = not args.no_static_list_folding
peephole_enabled = not args.no_peephole
if args.ct_run_main and artifact_kind != "exe":
@@ -7253,8 +7733,28 @@ def cli(argv: Sequence[str]) -> int:
if not args.repl and artifact_kind in {"static", "obj"} and args.libs:
print("[warn] --libs ignored for static/object outputs")
ct_run_libs: List[str] = list(args.libs)
if args.source is not None:
for lib in _load_sidecar_meta_libs(args.source):
if lib not in args.libs:
args.libs.append(lib)
if lib not in ct_run_libs:
ct_run_libs.append(lib)
if args.ct_run_main and args.source is not None:
try:
ct_sidecar = _build_ct_sidecar_shared(args.source, args.temp_dir)
except subprocess.CalledProcessError as exc:
print(f"[error] failed to build compile-time sidecar library: {exc}")
return 1
if ct_sidecar is not None:
so_lib = str(ct_sidecar.resolve())
if so_lib not in ct_run_libs:
ct_run_libs.append(so_lib)
compiler = Compiler(include_paths=[Path("."), Path("./stdlib"), *args.include_paths])
compiler.assembler.enable_constant_folding = folding_enabled
compiler.assembler.enable_static_list_folding = static_list_folding_enabled
compiler.assembler.enable_peephole_optimization = peephole_enabled
cache: Optional[BuildCache] = None
@@ -7271,7 +7771,13 @@ def cli(argv: Sequence[str]) -> int:
asm_text: Optional[str] = None
fhash = ""
if cache and not args.ct_run_main:
fhash = cache.flags_hash(args.debug, folding_enabled, peephole_enabled, entry_mode)
fhash = cache.flags_hash(
args.debug,
folding_enabled,
static_list_folding_enabled,
peephole_enabled,
entry_mode,
)
manifest = cache.load_manifest(args.source)
if manifest and cache.check_fresh(manifest, fhash):
cached = cache.get_cached_asm(manifest)
@@ -7287,13 +7793,19 @@ def cli(argv: Sequence[str]) -> int:
if cache and not args.ct_run_main:
if not fhash:
fhash = cache.flags_hash(args.debug, folding_enabled, peephole_enabled, entry_mode)
fhash = cache.flags_hash(
args.debug,
folding_enabled,
static_list_folding_enabled,
peephole_enabled,
entry_mode,
)
has_ct = bool(compiler.parser.compile_time_vm._ct_executed)
cache.save(args.source, compiler._loaded_files, fhash, asm_text, has_ct_effects=has_ct)
if args.ct_run_main:
try:
compiler.run_compile_time_word("main", libs=args.libs)
compiler.run_compile_time_word("main", libs=ct_run_libs)
except CompileTimeError as exc:
print(f"[error] compile-time execution of 'main' failed: {exc}")
return 1