Skip to content

Commit b3921c4

Browse files
parse context from first function argument to local symbol table
1 parent adf3256 commit b3921c4

File tree

8 files changed

+64
-8
lines changed

8 files changed

+64
-8
lines changed

pythonbpf/allocation_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab):
5151
return
5252

5353
# When allocating a variable, check if it's a vmlinux struct type
54-
if isinstance(stmt.value, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct(
54+
if isinstance(rval, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct(
5555
stmt.value.id
5656
):
5757
# Handle vmlinux struct allocation

pythonbpf/codegen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def finalize_module(original_str):
3636
replacement = r'\1 "btf_ama"'
3737
return re.sub(pattern, replacement, original_str)
3838

39+
3940
def bpf_passthrough_gen(module):
4041
i32_ty = ir.IntType(32)
4142
ptr_ty = ir.PointerType(ir.IntType(8))

pythonbpf/expr/vmlinux_registry.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import ast
22

3+
from pythonbpf.vmlinux_parser.vmlinux_exports_handler import VmlinuxHandler
4+
35

46
class VmlinuxHandlerRegistry:
57
"""Registry for vmlinux handler operations"""
68

79
_handler = None
810

911
@classmethod
10-
def set_handler(cls, handler):
12+
def set_handler(cls, handler: VmlinuxHandler):
1113
"""Set the vmlinux handler"""
1214
cls._handler = handler
1315

@@ -43,3 +45,10 @@ def is_vmlinux_struct(cls, name):
4345
if cls._handler is None:
4446
return False
4547
return cls._handler.is_vmlinux_struct(name)
48+
49+
@classmethod
50+
def get_struct_type(cls, name):
51+
"""Try to handle a struct name as vmlinux struct"""
52+
if cls._handler is None:
53+
return None
54+
return cls._handler.get_vmlinux_struct_type(name)

pythonbpf/functions/functions_pass.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
reset_scratch_pool,
88
)
99
from pythonbpf.type_deducer import ctypes_to_ir
10-
from pythonbpf.expr import eval_expr, handle_expr, convert_to_bool
10+
from pythonbpf.expr import (
11+
eval_expr,
12+
handle_expr,
13+
convert_to_bool,
14+
VmlinuxHandlerRegistry,
15+
)
1116
from pythonbpf.assign_pass import (
1217
handle_variable_assignment,
1318
handle_struct_field_assignment,
@@ -337,6 +342,35 @@ def process_func_body(
337342
structs_sym_tab,
338343
)
339344

345+
# Add the context parameter (first function argument) to the local symbol table
346+
if func_node.args.args and len(func_node.args.args) > 0:
347+
context_arg = func_node.args.args[0]
348+
context_name = context_arg.arg
349+
350+
if hasattr(context_arg, "annotation") and context_arg.annotation:
351+
if isinstance(context_arg.annotation, ast.Name):
352+
context_type_name = context_arg.annotation.id
353+
elif isinstance(context_arg.annotation, ast.Attribute):
354+
context_type_name = context_arg.annotation.attr
355+
else:
356+
raise TypeError(
357+
f"Unsupported annotation type: {ast.dump(context_arg.annotation)}"
358+
)
359+
if VmlinuxHandlerRegistry.is_vmlinux_struct(context_type_name):
360+
resolved_type = VmlinuxHandlerRegistry.get_struct_type(
361+
context_type_name
362+
)
363+
context_type = {"type": ir.PointerType(resolved_type), "ptr": True}
364+
else:
365+
try:
366+
resolved_type = ctypes_to_ir(context_type_name)
367+
context_type = {"type": ir.PointerType(resolved_type), "ptr": True}
368+
except Exception:
369+
raise TypeError(f"Type '{context_type_name}' not declared")
370+
371+
local_sym_tab[context_name] = context_type
372+
logger.info(f"Added argument '{context_name}' to local symbol table")
373+
340374
logger.info(f"Local symbol table: {local_sym_tab.keys()}")
341375

342376
for stmt in func_node.body:

pythonbpf/helper/bpf_helper_handler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ def bpf_map_lookup_elem_emitter(
7373
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
7474

7575
# TODO: I have changed the return type to i64*, as we are
76-
# allocating space for that type in allocate_mem. This is
77-
# temporary, and we will honour other widths later. But this
78-
# allows us to have cool binary ops on the returned value.
76+
# allocating space for that type in allocate_mem. This is
77+
# temporary, and we will honour other widths later. But this
78+
# allows us to have cool binary ops on the returned value.
7979
fn_type = ir.FunctionType(
8080
ir.PointerType(ir.IntType(64)), # Return type: void*
8181
[ir.PointerType(), ir.PointerType()], # Args: (void*, void*)

pythonbpf/vmlinux_parser/import_detector.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import logging
33
import importlib
44
import inspect
5-
import llvmlite.ir as ir
65

76
from .assignment_info import AssignmentInfo, AssignmentType
87
from .dependency_handler import DependencyHandler

pythonbpf/vmlinux_parser/vmlinux_exports_handler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ def is_vmlinux_enum(self, name):
3939
and self.vmlinux_symtab[name]["value_type"] == AssignmentType.CONSTANT
4040
)
4141

42+
def get_vmlinux_struct_type(self, name):
43+
"""Check if name is a vmlinux struct type"""
44+
if (
45+
name in self.vmlinux_symtab
46+
and self.vmlinux_symtab[name]["value_type"] == AssignmentType.STRUCT
47+
):
48+
return self.vmlinux_symtab[name]["python_type"]
49+
else:
50+
raise ValueError(f"{name} is not a vmlinux struct type")
51+
4252
def is_vmlinux_struct(self, name):
4353
"""Check if name is a vmlinux struct"""
4454
return (

tests/failing_tests/vmlinux/struct_field_access.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from pythonbpf import compile # noqa: F401
55
from vmlinux import TASK_COMM_LEN # noqa: F401
66
from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401
7-
from ctypes import c_int64
7+
from ctypes import c_int64, c_void_p # noqa: F401
8+
89

910
# from vmlinux import struct_uinput_device
1011
# from vmlinux import struct_blk_integrity_iter
@@ -14,7 +15,9 @@
1415
@section("tracepoint/syscalls/sys_enter_execve")
1516
def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64:
1617
a = 2 + TASK_COMM_LEN + TASK_COMM_LEN
18+
# b = ctx
1719
print(f"Hello, World{TASK_COMM_LEN} and {a}")
20+
# print(f"This is context field {b}")
1821
return c_int64(TASK_COMM_LEN + 2)
1922

2023

0 commit comments

Comments
 (0)