33// variable in order to cache the compiled artifacts and avoid recompiling too often.
44use anyhow:: { Context , Result } ;
55use rayon:: prelude:: * ;
6+ use std:: fs;
67use std:: path:: PathBuf ;
78use std:: str:: FromStr ;
89
9- const KERNEL_FILES : [ & str ; 4 ] = [
10- "flash_api.cu" ,
11- "fmha_fwd_hdim32.cu" ,
12- "fmha_fwd_hdim64.cu" ,
13- "fmha_fwd_hdim128.cu" ,
14- ] ;
10+ // const KERNEL_FILES: [&str; 4] = [
11+ // "flash_api.cu",
12+ // "fmha_fwd_hdim32.cu",
13+ // "fmha_fwd_hdim64.cu",
14+ // "fmha_fwd_hdim128.cu",
15+ // ];
16+
17+ /// Recursively reads the filenames in a directory and stores them in a Vec.
18+ fn _read_dir_recursively ( dir_path : & PathBuf , paths : & mut Vec < PathBuf > ) -> std:: io:: Result < ( ) > {
19+ for entry in fs:: read_dir ( dir_path) ? {
20+ let entry = entry?;
21+ let path = entry. path ( ) ;
22+
23+ if path. is_dir ( ) {
24+ _read_dir_recursively ( & path, paths) ?;
25+ } else {
26+ paths. push ( path) ;
27+ }
28+ }
29+
30+ Ok ( ( ) )
31+ }
32+
33+ /// Recursively reads the filenames in a directory and stores them in a Vec.
34+ fn read_dir_recursively ( dir_path : & PathBuf ) -> std:: io:: Result < Vec < PathBuf > > {
35+ let mut paths = Vec :: new ( ) ;
36+ _read_dir_recursively ( dir_path, & mut paths) ?;
37+ Ok ( paths)
38+ }
1539
1640fn main ( ) -> Result < ( ) > {
1741 let num_cpus = std:: env:: var ( "RAYON_NUM_THREADS" ) . map_or_else (
@@ -25,12 +49,11 @@ fn main() -> Result<()> {
2549 . unwrap ( ) ;
2650
2751 println ! ( "cargo:rerun-if-changed=build.rs" ) ;
28- for kernel_file in KERNEL_FILES . iter ( ) {
29- println ! ( "cargo:rerun-if-changed=kernels/{kernel_file}" ) ;
52+
53+ let paths = read_dir_recursively ( & PathBuf :: from_str ( "kernels" ) ?) ?;
54+ for file in paths. iter ( ) {
55+ println ! ( "cargo:rerun-if-changed={}" , file. display( ) ) ;
3056 }
31- println ! ( "cargo:rerun-if-changed=kernels/**.h" ) ;
32- println ! ( "cargo:rerun-if-changed=kernels/**.cuh" ) ;
33- println ! ( "cargo:rerun-if-changed=kernels/fmha/**.h" ) ;
3457 let out_dir = PathBuf :: from ( std:: env:: var ( "OUT_DIR" ) . context ( "OUT_DIR not set" ) ?) ;
3558 let build_dir = match std:: env:: var ( "CANDLE_FLASH_ATTN_BUILD_DIR" ) {
3659 Err ( _) =>
@@ -57,12 +80,17 @@ fn main() -> Result<()> {
5780 let out_file = build_dir. join ( "libflashattentionv1.a" ) ;
5881
5982 let kernel_dir = PathBuf :: from ( "kernels" ) ;
60- let cu_files: Vec < _ > = KERNEL_FILES
83+ let kernels: Vec < _ > = paths
84+ . iter ( )
85+ . filter ( |f| f. extension ( ) . map ( |ext| ext == "cu" ) . unwrap_or_default ( ) )
86+ . collect ( ) ;
87+ let cu_files: Vec < _ > = kernels
6188 . iter ( )
6289 . map ( |f| {
6390 let mut obj_file = out_dir. join ( f) ;
91+ fs:: create_dir_all ( obj_file. parent ( ) . unwrap ( ) ) . unwrap ( ) ;
6492 obj_file. set_extension ( "o" ) ;
65- ( kernel_dir . join ( f ) , obj_file)
93+ ( f , obj_file)
6694 } )
6795 . collect ( ) ;
6896 let out_modified: Result < _ , _ > = out_file. metadata ( ) . and_then ( |m| m. modified ( ) ) ;
0 commit comments