small update to linux.sl and the syscall instruction, including some small optimizations

This commit is contained in:
IgorCielniak
2026-03-13 19:42:15 +01:00
parent ecf90feab9
commit fd115f31dc
7 changed files with 5501 additions and 15111 deletions

View File

@@ -48,7 +48,6 @@ word sh
!
syscall.fork
syscall
dup 0 < if
>r
1 rpick
@@ -67,11 +66,9 @@ word sh
dup
32 +
syscall.execve
syscall
drop
127
syscall.exit
syscall
else
mem
40 +
@@ -79,7 +76,6 @@ word sh
0
0
syscall.wait4
syscall
dup 0 < if
>r
rdrop

62
main.py
View File

@@ -8427,6 +8427,68 @@ class Compiler:
word.intrinsic = self._emit_syscall_intrinsic
def _emit_syscall_intrinsic(self, builder: FunctionEmitter) -> None:
def _try_pop_known_syscall_setup() -> Optional[Tuple[int, int]]:
"""Recognize and remove literal setup for known-argc syscalls.
Supported forms right before `syscall`:
1) <argc> <nr>
2) <nr> <argc> ___linux_swap
Returns (argc, nr) when recognized.
"""
# Form 1: ... push argc ; push nr ; syscall
nr = Assembler._pop_preceding_literal(builder)
if nr is not None:
argc = Assembler._pop_preceding_literal(builder)
if argc is not None and 0 <= argc <= 6:
return argc, nr
# rollback if second literal wasn't argc
builder.push_literal(nr)
# Form 2: ... push nr ; push argc ; ___linux_swap ; syscall
text = builder.text
swap_tail = [
"mov rax, [r12]",
"mov rbx, [r12 + 8]",
"mov [r12], rbx",
"mov [r12 + 8], rax",
]
if len(text) >= 4 and [s.strip() for s in text[-4:]] == swap_tail:
del text[-4:]
argc2 = Assembler._pop_preceding_literal(builder)
nr2 = Assembler._pop_preceding_literal(builder)
if argc2 is not None and nr2 is not None and 0 <= argc2 <= 6:
return argc2, nr2
# rollback conservatively if match fails
if nr2 is not None:
builder.push_literal(nr2)
if argc2 is not None:
builder.push_literal(argc2)
text.extend(swap_tail)
return None
known = _try_pop_known_syscall_setup()
if known is not None:
argc, nr = known
builder.push_literal(nr)
builder.pop_to("rax")
if argc >= 6:
builder.pop_to("r9")
if argc >= 5:
builder.pop_to("r8")
if argc >= 4:
builder.pop_to("r10")
if argc >= 3:
builder.pop_to("rdx")
if argc >= 2:
builder.pop_to("rsi")
if argc >= 1:
builder.pop_to("rdi")
builder.emit(" syscall")
builder.push_from("rax")
return
label_id = self._syscall_label_counter
self._syscall_label_counter += 1

File diff suppressed because it is too large Load Diff

39
test.py
View File

@@ -224,6 +224,7 @@ class TestCase:
expected_stdout: Path
expected_stderr: Path
compile_expected: Path
asm_forbid: Path
stdin_path: Path
args_path: Path
meta_path: Path
@@ -324,6 +325,7 @@ class TestRunner:
expected_stdout=source.with_suffix(".expected"),
expected_stderr=source.with_suffix(".stderr"),
compile_expected=source.with_suffix(".compile.expected"),
asm_forbid=source.with_suffix(".asm.forbid"),
stdin_path=source.with_suffix(".stdin"),
args_path=source.with_suffix(".args"),
meta_path=meta_path,
@@ -391,6 +393,10 @@ class TestRunner:
return CaseResult(case, compile_status, "compile", compile_note, compile_details, duration)
if compile_status == "updated" and compile_note:
updated_notes.append(compile_note)
asm_status, asm_note, asm_details = self._check_asm_forbidden_patterns(case)
if asm_status == "failed":
duration = time.perf_counter() - start
return CaseResult(case, asm_status, "asm", asm_note, asm_details, duration)
if case.config.compile_only:
duration = time.perf_counter() - start
if updated_notes:
@@ -633,6 +639,39 @@ class TestRunner:
parts.append(proc.stderr)
return "".join(parts)
def _check_asm_forbidden_patterns(self, case: TestCase) -> Tuple[str, str, Optional[str]]:
"""Fail test if generated asm contains forbidden markers listed in *.asm.forbid."""
if not case.asm_forbid.exists():
return "passed", "", None
asm_path = case.build_dir / f"{case.binary_stub}.asm"
if not asm_path.exists():
return "failed", f"missing generated asm file {asm_path.name}", None
asm_text = asm_path.read_text(encoding="utf-8")
patterns: List[str] = []
for raw in case.asm_forbid.read_text(encoding="utf-8").splitlines():
line = raw.strip()
if not line or line.startswith("#"):
continue
patterns.append(line)
hits: List[str] = []
for pattern in patterns:
if pattern.startswith("re:"):
expr = pattern[3:]
if re.search(expr, asm_text, re.MULTILINE):
hits.append(pattern)
continue
if pattern in asm_text:
hits.append(pattern)
if not hits:
return "passed", "", None
detail = "forbidden asm pattern(s) matched:\n" + "\n".join(f"- {p}" for p in hits)
return "failed", "assembly contains forbidden patterns", detail
def _compare_nob_test_stdout(
self,
case: TestCase,

View File

@@ -0,0 +1,4 @@
# Ensure known-argc syscall lowering avoids generic dynamic syscall boilerplate.
clamp arg count to [0, 6]
re:syscall_\d+_count_
re:syscall_\d+_skip_

View File

@@ -6,7 +6,6 @@ word main
1
"hello"
syscall.write
syscall
#drop
1

220
tools/gen_linux_sl.py Normal file
View File

@@ -0,0 +1,220 @@
#!/usr/bin/env python3
"""Generate stdlib/linux.sl from syscall_64.tbl metadata."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import re
ROOT = Path(__file__).resolve().parent.parent
SRC = ROOT / "syscall_64.tbl"
DST = ROOT / "stdlib" / "linux.sl"
def _sanitize_alias(alias: str) -> str:
name = alias.strip()
if not name:
return ""
if name.startswith("__x64_sys_"):
name = name[len("__x64_sys_") :]
elif name.startswith("sys_"):
name = name[len("sys_") :]
name = re.sub(r"[^A-Za-z0-9_]", "_", name)
name = re.sub(r"_+", "_", name).strip("_")
if not name:
return ""
if name[0].isdigit():
name = "n_" + name
return name
@dataclass(frozen=True)
class SyscallEntry:
argc: int
num: int
aliases: tuple[str, ...]
def _parse_table(path: Path) -> list[SyscallEntry]:
entries: list[SyscallEntry] = []
for raw in path.read_text(encoding="utf-8").splitlines():
line = raw.strip()
if not line or line.startswith("#"):
continue
parts = line.split(maxsplit=2)
if len(parts) < 3:
continue
try:
argc = int(parts[0])
num = int(parts[1])
except ValueError:
continue
aliases = tuple(a for a in parts[2].split("/") if a)
if not aliases:
continue
entries.append(SyscallEntry(argc=argc, num=num, aliases=aliases))
return entries
def _emit_header(lines: list[str]) -> None:
lines.extend(
[
"# Autogenerated from syscall_64.tbl",
"# Generated by tools/gen_linux_sl.py",
"# Linux syscall constants + convenience wrappers for L2",
"",
"# File descriptor constants",
"macro fd_stdin 0 0 ;",
"macro fd_stdout 0 1 ;",
"macro fd_stderr 0 2 ;",
"",
"# Common open(2) flags",
"macro O_RDONLY 0 0 ;",
"macro O_WRONLY 0 1 ;",
"macro O_RDWR 0 2 ;",
"macro O_CREAT 0 64 ;",
"macro O_EXCL 0 128 ;",
"macro O_NOCTTY 0 256 ;",
"macro O_TRUNC 0 512 ;",
"macro O_APPEND 0 1024 ;",
"macro O_NONBLOCK 0 2048 ;",
"macro O_CLOEXEC 0 524288 ;",
"",
"# lseek(2)",
"macro SEEK_SET 0 0 ;",
"macro SEEK_CUR 0 1 ;",
"macro SEEK_END 0 2 ;",
"",
"# mmap(2)",
"macro PROT_NONE 0 0 ;",
"macro PROT_READ 0 1 ;",
"macro PROT_WRITE 0 2 ;",
"macro PROT_EXEC 0 4 ;",
"macro MAP_PRIVATE 0 2 ;",
"macro MAP_ANONYMOUS 0 32 ;",
"macro MAP_SHARED 0 1 ;",
"",
"# Socket constants",
"macro AF_UNIX 0 1 ;",
"macro AF_INET 0 2 ;",
"macro AF_INET6 0 10 ;",
"macro SOCK_STREAM 0 1 ;",
"macro SOCK_DGRAM 0 2 ;",
"macro SOCK_NONBLOCK 0 2048 ;",
"macro SOCK_CLOEXEC 0 524288 ;",
"",
"macro INADDR_ANY 0 0 ;",
"",
"# Generic syscall helpers with explicit argument count",
"# Stack form:",
"# syscall -> <argN> ... <arg0> <argc> <nr> syscall",
"# syscallN -> <argN-1> ... <arg0> <nr> syscallN",
"",
"# swap impl is provided so this can be used without stdlib",
"# ___linux_swap [*, x1 | x2] -> [*, x2 | x1]",
":asm ___linux_swap {",
" mov rax, [r12]",
" mov rbx, [r12 + 8]",
" mov [r12], rbx",
" mov [r12 + 8], rax",
"}",
";",
"",
"macro syscall0 0",
" 0",
" ___linux_swap",
" syscall",
";",
"",
"macro syscall1 0",
" 1",
" ___linux_swap",
" syscall",
";",
"",
"macro syscall2 0",
" 2",
" ___linux_swap",
" syscall",
";",
"",
"macro syscall3 0",
" 3",
" ___linux_swap",
" syscall",
";",
"",
"macro syscall4 0",
" 4",
" ___linux_swap",
" syscall",
";",
"",
"macro syscall5 0",
" 5",
" ___linux_swap",
" syscall",
";",
"",
"macro syscall6 0",
" 6",
" ___linux_swap",
" syscall",
";",
"",
]
)
def _emit_entry(lines: list[str], alias: str, argc: int, num: int) -> None:
safe_argc = max(0, min(argc, 6))
lines.extend(
[
f"macro syscall.{alias} 0",
f" {num}",
f" syscall{safe_argc}",
";",
"",
f"macro syscall.{alias}.num 0",
f" {num}",
";",
"",
f"macro syscall.{alias}.argc 0",
f" {safe_argc}",
";",
"",
]
)
def generate() -> str:
entries = _parse_table(SRC)
lines: list[str] = []
_emit_header(lines)
emitted: set[str] = set()
for entry in sorted(entries, key=lambda e: (e.num, e.aliases[0])):
for alias in entry.aliases:
name = _sanitize_alias(alias)
if not name:
continue
key = f"syscall.{name}"
if key in emitted:
continue
_emit_entry(lines, name, entry.argc, entry.num)
emitted.add(key)
return "\n".join(lines).rstrip() + "\n"
def main() -> None:
output = generate()
DST.parent.mkdir(parents=True, exist_ok=True)
DST.write_text(output, encoding="utf-8")
print(f"wrote {DST}")
if __name__ == "__main__":
main()