Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 87 additions & 11 deletions src/irx/builders/llvmliteir.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
from irx.builders.base import Builder, BuilderVisitor
from irx.tools.typing import typechecked

DATE_PARTS = 3
MIN_MONTH = 1
MAX_MONTH = 12
MIN_DAY = 1
MAX_DAY = 31
MIN_YEAR = 1
MAX_YEAR = 9999


@typechecked
def safe_pop(lst: list[ir.Value | ir.Function]) -> ir.Value | ir.Function:
Expand All @@ -45,6 +53,7 @@ class VariablesLLVM:
STRING_TYPE: ir.types.Type
ASCII_STRING_TYPE: ir.types.Type
UTF8_STRING_TYPE: ir.types.Type
DATE_TYPE: ir.types.Type

context: ir.context.Context
module: ir.module.Module
Expand Down Expand Up @@ -89,6 +98,8 @@ def get_data_type(self, type_name: str) -> ir.types.Type:
return self.UTF8_STRING_TYPE
elif type_name == "nonetype":
return self.VOID_TYPE
elif type_name == "date":
return self.DATE_TYPE

raise Exception(f"[EE]: Type name {type_name} not valid.")

Expand Down Expand Up @@ -129,23 +140,22 @@ def translate(self, node: astx.AST) -> str:
return str(self._llvm.module)

def initialize(self) -> None:
"""Initialize self."""
# self._llvm.context = ir.context.Context()
"""Initialize LLVM module and types safely."""
self._llvm = VariablesLLVM()
self._llvm.module = ir.module.Module("Arx")

# initialize the target registry etc.
llvm.initialize()
llvm.initialize_all_asmprinters()
llvm.initialize_all_targets()
llvm.initialize_native_target()
llvm.initialize_native_asmparser()
llvm.initialize_native_asmprinter()
# (llvmlite handles most automatically now)
try:
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
except (RuntimeError, AttributeError):
# These may already be initialized — safe to ignore
pass

# Create a new builder for the module.
# Create a new builder for the module
self._llvm.ir_builder = ir.IRBuilder()

# Data Types
# Define basic data types
self._llvm.FLOAT_TYPE = ir.FloatType()
self._llvm.FLOAT16_TYPE = ir.HalfType()
self._llvm.DOUBLE_TYPE = ir.DoubleType()
Expand All @@ -160,6 +170,9 @@ def initialize(self) -> None:
)
self._llvm.ASCII_STRING_TYPE = ir.IntType(8).as_pointer()
self._llvm.UTF8_STRING_TYPE = self._llvm.STRING_TYPE
self._llvm.DATE_TYPE = ir.LiteralStructType(
[ir.IntType(32), ir.IntType(32), ir.IntType(32)]
)

def _add_builtins(self) -> None:
# The C++ tutorial adds putchard() simply by defining it in the host
Expand Down Expand Up @@ -657,6 +670,69 @@ def visit(self, node: astx.IfStmt) -> None:

self.result_stack.append(phi)

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.LiteralDate) -> None:
"""
Lower a LiteralDate to LLVM IR.

Representation:
{ i32 year, i32 month, i32 day }
-- emitted as a constant struct.

Expected format: YYYY-MM-DD (ISO),
but also accepts single-digit month/day.
"""
s = node.value.strip()

# Split by "-"
parts = s.split("-")
if len(parts) != DATE_PARTS:
raise Exception(
"LiteralDate: invalid date format "
f"'{node.value}'. Expected 'YYYY-MM-DD'."
)

try:
# Convert to integers even if month/day are single-digit
year = int(parts[0])
month = int(parts[1])
day = int(parts[2])
except Exception:
raise Exception(
f"LiteralDate: invalid year/month/day in '{node.value}'."
)

# Basic range checks
if not (1 <= month <= MAX_MONTH):
raise Exception(
"LiteralDate: month out of range in "
f"'{node.value}'. Expected 1-12."
)
if not (1 <= day <= MAX_DAY):
raise Exception(
"LiteralDate: day out of range in "
f"'{node.value}'. Expected 1-31."
)
if not (1 <= year <= MAX_YEAR):
raise Exception(
"LiteralDate: year out of range in "
f"'{node.value}'. Expected 1-9999."
)

# Build constant struct { i32, i32, i32 }
i32 = self._llvm.INT32_TYPE
date_ty = ir.LiteralStructType([i32, i32, i32])
const_date = ir.Constant(
date_ty,
[
ir.Constant(i32, year),
ir.Constant(i32, month),
ir.Constant(i32, day),
],
)

self.result_stack.append(const_date)

@dispatch # type: ignore[no-redef]
def visit(self, expr: astx.WhileStmt) -> None:
"""Translate ASTx While Loop to LLVM-IR."""
Expand Down
183 changes: 183 additions & 0 deletions tests/test_literal_date.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""Tests for LiteralDate support."""

from typing import Type

import astx
import pytest

from irx.builders.base import Builder
from irx.builders.llvmliteir import LLVMLiteIR


@pytest.mark.parametrize(
"date_str,expected_year,expected_month,expected_day",
[
("2024-01-15", 2024, 1, 15),
("2000-12-31", 2000, 12, 31),
("1970-01-01", 1970, 1, 1),
("2023-06-15", 2023, 6, 15),
("1999-02-28", 1999, 2, 28),
("2024-11-30", 2024, 11, 30),
],
)
@pytest.mark.parametrize("builder_class", [LLVMLiteIR])
def test_literal_date_basic(
builder_class: Type[Builder],
date_str: str,
expected_year: int,
expected_month: int,
expected_day: int,
) -> None:
"""Test basic LiteralDate parsing and IR generation."""
builder = builder_class()
module = builder.module()

# Create date literal
date_literal = astx.LiteralDate(date_str)

# Create a function that stores the date
proto = astx.FunctionPrototype(
name="main", args=astx.Arguments(), return_type=astx.Int32()
)
block = astx.Block()

# Store date in variable
date_decl = astx.VariableDeclaration(
name="d", type_=astx.Date(), value=date_literal
)
block.append(date_decl)

# Return 0 (just testing that it compiles)
block.append(astx.FunctionReturn(astx.LiteralInt32(0)))

fn = astx.FunctionDef(prototype=proto, body=block)
module.block.append(fn)

# Check it translates without error
ir_code = builder.translate(module)
assert "i32" in ir_code
# Verify the struct contains our values
assert str(expected_year) in ir_code
assert str(expected_month) in ir_code
assert str(expected_day) in ir_code


@pytest.mark.parametrize(
"invalid_date,error_msg",
[
("2024-13-01", "month out of range"), # Invalid month
("2024-00-01", "month out of range"), # Month = 0
("2024-01-32", "day out of range"), # Invalid day
("2024-01-00", "day out of range"), # Day = 0
("10000-01-01", "year out of range"), # Year too large
("0-01-01", "year out of range"), # Year = 0
("2024/01/01", "invalid date format"), # Wrong separator
("2024-Jan-01", "invalid year/month/day"), # Text month
("2024-01", "invalid date format"), # Missing day
],
)
@pytest.mark.parametrize("builder_class", [LLVMLiteIR])
def test_literal_date_invalid(
builder_class: Type[Builder],
invalid_date: str,
error_msg: str,
) -> None:
"""Test that invalid date formats raise appropriate exceptions."""
builder = builder_class()
module = builder.module()

date_literal = astx.LiteralDate(invalid_date)

proto = astx.FunctionPrototype(
name="main", args=astx.Arguments(), return_type=astx.Int32()
)
block = astx.Block()

date_decl = astx.VariableDeclaration(
name="d", type_=astx.Date(), value=date_literal
)
block.append(date_decl)
block.append(astx.FunctionReturn(astx.LiteralInt32(0)))

fn = astx.FunctionDef(prototype=proto, body=block)
module.block.append(fn)

# Should raise an exception with expected message
with pytest.raises(Exception) as exc_info:
builder.translate(module)

assert error_msg in str(exc_info.value).lower()


@pytest.mark.parametrize(
"date_str",
[
"2024-02-29", # Leap year - valid
"2024-12-31", # End of year
"2024-01-01", # Start of year
],
)
@pytest.mark.parametrize("builder_class", [LLVMLiteIR])
def test_literal_date_edge_cases(
builder_class: Type[Builder],
date_str: str,
) -> None:
"""Test edge cases for valid dates."""
builder = builder_class()
module = builder.module()

date_literal = astx.LiteralDate(date_str)

proto = astx.FunctionPrototype(
name="main", args=astx.Arguments(), return_type=astx.Int32()
)
block = astx.Block()

date_decl = astx.VariableDeclaration(
name="d", type_=astx.Date(), value=date_literal
)
block.append(date_decl)
block.append(astx.FunctionReturn(astx.LiteralInt32(0)))

fn = astx.FunctionDef(prototype=proto, body=block)
module.block.append(fn)

# Should compile successfully
ir_code = builder.translate(module)
assert "i32" in ir_code


@pytest.mark.parametrize("builder_class", [LLVMLiteIR])
def test_literal_date_multiple_variables(
builder_class: Type[Builder],
) -> None:
"""Test multiple date variables in the same function."""
builder = builder_class()
module = builder.module()

date1 = astx.LiteralDate("2024-01-15")
date2 = astx.LiteralDate("2023-12-25")

proto = astx.FunctionPrototype(
name="main", args=astx.Arguments(), return_type=astx.Int32()
)
block = astx.Block()

# Store multiple dates
date_decl1 = astx.VariableDeclaration(
name="d1", type_=astx.Date(), value=date1
)
date_decl2 = astx.VariableDeclaration(
name="d2", type_=astx.Date(), value=date2
)
block.append(date_decl1)
block.append(date_decl2)
block.append(astx.FunctionReturn(astx.LiteralInt32(0)))

fn = astx.FunctionDef(prototype=proto, body=block)
module.block.append(fn)

# Should compile successfully
ir_code = builder.translate(module)
assert "2024" in ir_code
assert "2023" in ir_code
Loading