@@ -90,8 +90,11 @@ func (s *setup) Run(ctx context.Context) (err error) {
9090 closeProgress (err )
9191 }()
9292
93- isDistro , _ := isDebianOrUbuntu ()
94- if ! isDistro {
93+ osid , osversion , err := getOSRelease ()
94+ if err != nil {
95+ return err
96+ }
97+ if osid != "debian" && osid != "ubuntu" {
9598 return errors .Errorf ("NVIDIA setup is currently only supported on Debian/Ubuntu" )
9699 }
97100
@@ -131,7 +134,7 @@ func (s *setup) Run(ctx context.Context) (err error) {
131134 return err
132135 }
133136
134- if err := installPackages (ctx , dv , pw , dgst ); err != nil {
137+ if err := installPackages (ctx , osid , osversion , dv , pw , dgst ); err != nil {
135138 return err
136139 }
137140
@@ -167,8 +170,20 @@ func run(ctx context.Context, args []string, pw progress.Writer, dgst digest.Dig
167170 return cmd .Run ()
168171}
169172
170- func installPackages (ctx context.Context , dv string , pw progress.Writer , dgst digest.Digest ) error {
171- const aptDistro = "ubuntu2404"
173+ func installPackages (ctx context.Context , osid string , osversion string , dv string , pw progress.Writer , dgst digest.Digest ) error {
174+ aptDistro := "ubuntu2404"
175+ switch osid {
176+ case "debian" :
177+ if osversion == "" {
178+ aptDistro = "debian12"
179+ } else {
180+ aptDistro = "debian" + osversion
181+ }
182+ case "ubuntu" :
183+ if osversion != "" {
184+ aptDistro = "ubuntu" + strings .ReplaceAll (osversion , "." , "" )
185+ }
186+ }
172187
173188 var arch string
174189 switch runtime .GOARCH {
@@ -274,36 +289,33 @@ func hasNvidiaDevices() (bool, error) {
274289 return found , nil
275290}
276291
277- func getOSID () (string , error ) {
292+ func getOSRelease () (string , string , error ) {
278293 file , err := os .Open ("/etc/os-release" )
279294 if err != nil {
280- return "" , err
295+ return "" , "" , err
281296 }
282297 defer file .Close ()
283298
299+ var id , versionID string
284300 scanner := bufio .NewScanner (file )
285301 for scanner .Scan () {
286302 line := scanner .Text ()
287303 if strings .HasPrefix (line , "ID=" ) {
288- id := strings .TrimPrefix (line , "ID=" )
289- return strings .Trim (id , `"` ), nil // Remove potential quotes
304+ id = strings .Trim (strings .TrimPrefix (line , "ID=" ), `"` ) // Remove potential quotes
305+ } else if strings .HasPrefix (line , "VERSION_ID=" ) {
306+ versionID = strings .Trim (strings .TrimPrefix (line , "VERSION_ID=" ), `"` )
290307 }
291308 }
292309
293310 if err := scanner .Err (); err != nil {
294- return "" , err
311+ return "" , "" , err
295312 }
296313
297- return "" , errors .Errorf ("ID not found in /etc/os-release" )
298- }
299-
300- func isDebianOrUbuntu () (bool , error ) {
301- id , err := getOSID ()
302- if err != nil {
303- return false , err
314+ if id == "" {
315+ return "" , "" , errors .Errorf ("ID not found in /etc/os-release" )
304316 }
305317
306- return id == "debian" || id == "ubuntu" , nil
318+ return id , versionID , nil
307319}
308320
309321func hasWSLGPU () bool {
0 commit comments