Skip to content

Commit f79413f

Browse files
committed
Add support for program target in AST and related code generation
This update introduces a new optional field, `prog_target`, to the program definition structure, allowing for better handling of kprobe and tracepoint programs. The field is initialized to `None` in various program creation functions and is utilized in the IR generation and code generation processes to ensure proper target information is passed through. Additionally, tests have been updated to reflect this new field.
1 parent 91811a2 commit f79413f

File tree

10 files changed

+105
-32
lines changed

10 files changed

+105
-32
lines changed

src/ast.ml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ type program_def = {
276276
prog_functions: function_def list;
277277
prog_maps: map_declaration list; (* Maps local to this program *)
278278
prog_structs: struct_def list; (* Structs local to this program *)
279+
prog_target: string option; (* Target for kprobe/tracepoint programs *)
279280
prog_pos: position;
280281
}
281282

@@ -402,6 +403,7 @@ let make_program name prog_type functions pos = {
402403
prog_functions = functions;
403404
prog_maps = [];
404405
prog_structs = [];
406+
prog_target = None;
405407
prog_pos = pos;
406408
}
407409

@@ -411,6 +413,7 @@ let make_program_with_maps name prog_type functions maps pos = {
411413
prog_functions = functions;
412414
prog_maps = maps;
413415
prog_structs = [];
416+
prog_target = None;
414417
prog_pos = pos;
415418
}
416419

@@ -420,6 +423,7 @@ let make_program_with_all name prog_type functions maps structs pos = {
420423
prog_functions = functions;
421424
prog_maps = maps;
422425
prog_structs = structs;
426+
prog_target = None;
423427
prog_pos = pos;
424428
}
425429

src/ebpf_c_codegen.ml

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3352,24 +3352,16 @@ let generate_c_function ctx ir_func =
33523352
| Some target ->
33533353
(* If we have the target, convert KernelScript format to raw tracepoint format *)
33543354
if String.contains target '/' then
3355-
let event_name = String.map (function '/' -> '_' | c -> c) target in
3356-
sprintf "SEC(\"raw_tracepoint/%s\")" event_name
3355+
(* For raw tracepoints, extract just the event name (part after slash) *)
3356+
let parts = String.split_on_char '/' target in
3357+
(match parts with
3358+
| [_category; event] -> sprintf "SEC(\"raw_tracepoint/%s\")" event
3359+
| _ -> sprintf "SEC(\"raw_tracepoint/%s\")" target)
33573360
else
33583361
sprintf "SEC(\"raw_tracepoint/%s\")" target
33593362
| None ->
3360-
(* Fallback: try to extract from function name *)
3361-
let func_name = ir_func.func_name in
3362-
if String.contains func_name '_' then
3363-
(* Function name like "sched_sched_switch_handler" -> extract "sched_switch" *)
3364-
let parts = String.split_on_char '_' func_name in
3365-
(match parts with
3366-
| category :: event :: "handler" :: _ ->
3367-
sprintf "SEC(\"raw_tracepoint/%s_%s\")" category event
3368-
| category :: event :: _ when List.length parts >= 2 ->
3369-
sprintf "SEC(\"raw_tracepoint/%s_%s\")" category event
3370-
| _ -> "SEC(\"raw_tracepoint\")")
3371-
else
3372-
"SEC(\"raw_tracepoint\")")
3363+
(* This should not happen now that we properly pass targets through *)
3364+
failwith "Tracepoint function is missing target information")
33733365
| _ ->
33743366
(* For non-struct_ops, non-kprobe, and non-tracepoint functions, only generate SEC if it's a main function *)
33753367
if ir_func.is_main then

src/ir_generator.ml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2731,12 +2731,8 @@ let lower_single_program ctx prog_def _global_ir_maps _kernel_shared_functions =
27312731
(* For attributed functions (single function programs), the function IS the entry function *)
27322732
(* But struct_ops functions should NOT be marked as main functions *)
27332733
let is_attributed_entry = (List.length prog_def.prog_functions = 1 && index = 0 && prog_def.prog_type <> Ast.StructOps) in
2734-
(* Extract target from attributed function if available *)
2735-
let func_target =
2736-
(* This is a hack to find the original attributed function and extract the target *)
2737-
(* For now, we'll pass None and fix this in the eBPF codegen *)
2738-
None
2739-
in
2734+
(* Extract target from program definition *)
2735+
let func_target = prog_def.prog_target in
27402736
let temp_func = lower_function ctx prog_def.prog_name ~program_type:(Some prog_def.prog_type) ~func_target func in
27412737
if is_attributed_entry then
27422738
(* Mark the attributed function as entry by updating the is_main field *)
@@ -2830,9 +2826,10 @@ let lower_multi_program ast symbol_table source_name =
28302826
prog_functions = [attr_func.attr_function];
28312827
prog_maps = [];
28322828
prog_structs = [];
2829+
prog_target = None;
28332830
prog_pos = attr_func.attr_pos;
28342831
})
2835-
| AttributeWithArg (attr_name, _target_func) :: _ ->
2832+
| AttributeWithArg (attr_name, target_func) :: _ ->
28362833
(* Handle attributes with arguments like @kprobe("sys_read") *)
28372834
(match attr_name with
28382835
| "kprobe" ->
@@ -2842,6 +2839,7 @@ let lower_multi_program ast symbol_table source_name =
28422839
prog_functions = [attr_func.attr_function];
28432840
prog_maps = [];
28442841
prog_structs = [];
2842+
prog_target = Some target_func;
28452843
prog_pos = attr_func.attr_pos;
28462844
}
28472845
| "tracepoint" ->
@@ -2851,6 +2849,7 @@ let lower_multi_program ast symbol_table source_name =
28512849
prog_functions = [attr_func.attr_function];
28522850
prog_maps = [];
28532851
prog_structs = [];
2852+
prog_target = Some target_func;
28542853
prog_pos = attr_func.attr_pos;
28552854
}
28562855
| _ -> None)
@@ -2871,6 +2870,7 @@ let lower_multi_program ast symbol_table source_name =
28712870
prog_functions = [func];
28722871
prog_maps = [];
28732872
prog_structs = [];
2873+
prog_target = None; (* struct_ops don't have targets *)
28742874
prog_pos = func.func_pos;
28752875
}
28762876
| Ast.ImplStaticField (_, _) -> None (* Static fields are not programs *)

src/main.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,8 @@ let compile_source input_file output_dir _verbose generate_makefile btf_vmlinux_
814814

815815
(* Create a program structure for safety analysis *)
816816
let safety_program = {
817-
Ast.prog_name = base_name;
817+
Ast.prog_name = base_name;
818+
prog_target = None;
818819
prog_type = Xdp; (* Default - not used by safety checker *)
819820
prog_functions = all_functions;
820821
prog_maps = all_maps;

src/multi_program_analyzer.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ let extract_programs (ast: declaration list) : program_def list =
122122
prog_functions = [attr_func.attr_function];
123123
prog_maps = [];
124124
prog_structs = [];
125+
prog_target = None;
125126
prog_pos = attr_func.attr_pos;
126127
})
127128
| _ -> None)

src/type_checker.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2567,6 +2567,7 @@ let typed_program_to_program tprog original_prog =
25672567
prog_functions = List.map typed_function_to_function tprog.tprog_functions;
25682568
prog_maps = original_prog.prog_maps; (* Preserve original map declarations *)
25692569
prog_structs = original_prog.prog_structs; (* Preserve original struct declarations *)
2570+
prog_target = original_prog.prog_target; (* Preserve original target *)
25702571
prog_pos = tprog.tprog_pos }
25712572

25722573
(** Convert typed AST back to annotated AST declarations *)

tests/test_function_generation.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ let test_function_parameters _ =
5151
(* Create program containing this function *)
5252
let prog_def = {
5353
prog_name = "test_prog";
54+
prog_target = None;
5455
prog_type = Xdp;
5556
prog_maps = [];
5657
prog_structs = [];
@@ -127,6 +128,7 @@ let test_program_function_calls _ =
127128
(* Create program with both functions *)
128129
let prog_def = {
129130
prog_name = "test_prog";
131+
prog_target = None;
130132
prog_type = Xdp;
131133
prog_maps = [];
132134
prog_structs = [];
@@ -195,6 +197,7 @@ let test_multiple_parameters _ =
195197

196198
let prog_def = {
197199
prog_name = "test_prog";
200+
prog_target = None;
198201
prog_type = Xdp;
199202
prog_maps = [];
200203
prog_structs = [];

tests/test_symbol_table.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ let create_test_function name params return_type =
8888

8989
let create_test_program name functions =
9090
{
91-
prog_name = name;
91+
prog_name = name;
92+
prog_target = None;
9293
prog_type = Xdp;
9394
prog_functions = functions;
9495
prog_maps = [];

tests/test_tracepoint.ml

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,75 @@ fn sched_switch_handler(ctx: *trace_event_raw_sched_switch) -> i32 {
201201
check bool "Function should be marked as main" true main_func.is_main;
202202
check string "Function name should match" "sched_switch_handler" main_func.func_name
203203

204+
(* NEW: Target Propagation Tests *)
205+
let test_tracepoint_target_propagation _ =
206+
let source = "@tracepoint(\"sched/sched_switch\")
207+
fn sched_switch_handler(ctx: *trace_event_raw_sched_switch) -> i32 {
208+
return 0
209+
}" in
210+
let ast = parse_string source in
211+
let typed_ast = type_check_ast ast in
212+
let symbol_table = Kernelscript.Symbol_table.build_symbol_table typed_ast in
213+
let ir_multi_prog = generate_ir typed_ast symbol_table "test_tracepoint" in
214+
let program = List.hd ir_multi_prog.programs in
215+
let main_func = program.entry_function in
216+
217+
(* Test that the target is properly propagated through IR generation *)
218+
check (option string) "Function should have correct target" (Some "sched/sched_switch") main_func.func_target
219+
220+
let test_multiple_tracepoint_targets _ =
221+
(* Test various tracepoint targets to ensure they all work correctly *)
222+
let test_cases = [
223+
("sched/sched_switch", "SEC(\"raw_tracepoint/sched_switch\")");
224+
("net/netif_rx", "SEC(\"raw_tracepoint/netif_rx\")");
225+
("syscalls/sys_enter_read", "SEC(\"raw_tracepoint/sys_enter_read\")");
226+
("syscalls/sys_exit_write", "SEC(\"raw_tracepoint/sys_exit_write\")");
227+
("irq/irq_handler_entry", "SEC(\"raw_tracepoint/irq_handler_entry\")");
228+
] in
229+
230+
List.iter (fun (target, expected_sec) ->
231+
let source = Printf.sprintf "@tracepoint(\"%s\")
232+
fn handler(ctx: *trace_event_raw_context) -> i32 {
233+
return 0
234+
}" target in
235+
let ast = parse_string source in
236+
let typed_ast = type_check_ast ast in
237+
let symbol_table = Kernelscript.Symbol_table.build_symbol_table typed_ast in
238+
let ir_multi_prog = generate_ir typed_ast symbol_table "test" in
239+
let c_code = generate_c_multi_program ir_multi_prog in
240+
241+
check bool (Printf.sprintf "Should generate %s for target %s" expected_sec target) true
242+
(Str.search_forward (Str.regexp_string expected_sec) c_code 0 >= 0)
243+
) test_cases
244+
245+
let test_sched_switch_bug_regression _ =
246+
(* Regression test: Ensure we don't generate the buggy SEC("raw_tracepoint/sched_sched") *)
247+
let source = "@tracepoint(\"sched/sched_switch\")
248+
fn sched_switch_handler(ctx: *trace_event_raw_sched_switch) -> i32 {
249+
return 0
250+
}" in
251+
let ast = parse_string source in
252+
let typed_ast = type_check_ast ast in
253+
let symbol_table = Kernelscript.Symbol_table.build_symbol_table typed_ast in
254+
let ir_multi_prog = generate_ir typed_ast symbol_table "test_regression" in
255+
let c_code = generate_c_multi_program ir_multi_prog in
256+
257+
(* Ensure correct SEC() is generated *)
258+
check bool "Should generate correct SEC(raw_tracepoint/sched_switch)" true
259+
(Str.search_forward (Str.regexp_string "SEC(\"raw_tracepoint/sched_switch\")") c_code 0 >= 0);
260+
261+
(* Ensure buggy SEC() is NOT generated *)
262+
check bool "Should NOT generate buggy SEC(raw_tracepoint/sched_sched)" true
263+
(try
264+
let _ = Str.search_forward (Str.regexp_string "SEC(\"raw_tracepoint/sched_sched\")") c_code 0 in
265+
false (* Found the buggy pattern - test should fail *)
266+
with Not_found ->
267+
true (* Didn't find the buggy pattern - test should pass *)
268+
)
269+
204270
(* 4. Code Generation Tests *)
205271
let test_raw_tracepoint_section_name_generation _ =
206-
(* Test minimal raw tracepoint section name conversion logic *)
272+
(* Test correct raw tracepoint section name generation *)
207273
let source = "@tracepoint(\"sched/sched_switch\")
208274
fn sched_switch_handler(ctx: *trace_event_raw_sched_switch) -> i32 {
209275
return 0
@@ -214,9 +280,9 @@ fn sched_switch_handler(ctx: *trace_event_raw_sched_switch) -> i32 {
214280
let ir_multi_prog = generate_ir typed_ast symbol_table "test_raw_tracepoint" in
215281
let c_code = generate_c_multi_program ir_multi_prog in
216282

217-
(* Check that forward slash is converted to underscore in section name *)
218-
check bool "Should contain raw_tracepoint section with underscore" true
219-
(String.contains c_code (String.get "SEC(\"raw_tracepoint/sched_sched_switch\")" 0))
283+
(* Check that the correct SEC() is generated with just the event name *)
284+
check bool "Should contain correct raw_tracepoint/sched_switch section" true
285+
(Str.search_forward (Str.regexp_string "SEC(\"raw_tracepoint/sched_switch\")") c_code 0 >= 0)
220286

221287
let test_tracepoint_ebpf_codegen _ =
222288
let source = "@tracepoint(\"sched/sched_switch\")
@@ -230,8 +296,8 @@ fn sched_switch_handler(ctx: *trace_event_raw_sched_switch) -> i32 {
230296
let c_code = generate_c_multi_program ir_multi_prog in
231297

232298
(* Check for tracepoint-specific C code elements *)
233-
check bool "Should contain SEC(\"tracepoint\")" true
234-
(String.contains c_code (String.get "SEC(\"tracepoint\")" 0));
299+
check bool "Should contain correct raw_tracepoint SEC" true
300+
(Str.search_forward (Str.regexp_string "SEC(\"raw_tracepoint/sched_switch\")") c_code 0 >= 0);
235301
check bool "Should contain function definition" true
236302
(String.contains c_code (String.get "sched_switch_handler" 0));
237303
check bool "Should contain struct parameter" true
@@ -342,8 +408,8 @@ fn sys_enter_open_handler(ctx: *trace_event_raw_sys_enter) -> i32 {
342408
let c_code = generate_c_multi_program ir_multi_prog in
343409

344410
(* Comprehensive end-to-end validation *)
345-
check bool "Contains tracepoint section" true
346-
(String.contains c_code (String.get "SEC(\"tracepoint\")" 0));
411+
check bool "Contains correct raw_tracepoint section" true
412+
(Str.search_forward (Str.regexp_string "SEC(\"raw_tracepoint/sys_enter_open\")") c_code 0 >= 0);
347413
check bool "Contains function name" true
348414
(String.contains c_code (String.get "sys_enter_open_handler" 0));
349415
check bool "Contains context struct" true
@@ -381,6 +447,9 @@ let type_checking_tests = [
381447
let ir_generation_tests = [
382448
"tracepoint IR generation", `Quick, test_tracepoint_ir_generation;
383449
"tracepoint function signature validation", `Quick, test_tracepoint_function_signature_validation;
450+
"tracepoint target propagation", `Quick, test_tracepoint_target_propagation;
451+
"multiple tracepoint targets", `Quick, test_multiple_tracepoint_targets;
452+
"sched_switch bug regression", `Quick, test_sched_switch_bug_regression;
384453
]
385454

386455
let code_generation_tests = [

tests/test_utils.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ module Helpers = struct
185185
let make_test_program name prog_type main_func =
186186
{
187187
prog_name = name;
188+
prog_target = None;
188189
prog_type = prog_type;
189190
prog_functions = [main_func];
190191
prog_maps = [];

0 commit comments

Comments
 (0)