Skip to content

Commit 1d7a4e4

Browse files
authored
Merge pull request #35 from Living-Technologies/urlclassloader-fix
Removed the URLClassLoader for ClassGraph to check for native gpu.
2 parents 299045c + 3248d30 commit 1d7a4e4

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ Wisconsin-Madison and Google, Inc.</license.copyrightOwners>
119119
<releaseProfiles>sign,deploy-to-scijava</releaseProfiles>
120120

121121
<imagej-updater.version>2.0.0</imagej-updater.version>
122+
<classgraph.version>4.8.172</classgraph.version>
122123
</properties>
123124

124125
<repositories>
@@ -183,6 +184,12 @@ Wisconsin-Madison and Google, Inc.</license.copyrightOwners>
183184
<groupId>org.apache.commons</groupId>
184185
<artifactId>commons-compress</artifactId>
185186
</dependency>
187+
<dependency>
188+
<groupId>io.github.classgraph</groupId>
189+
<artifactId>classgraph</artifactId>
190+
<version>${classgraph.version}</version>
191+
</dependency>
192+
186193

187194
<!-- Test dependencies -->
188195
<dependency>

src/main/java/net/imagej/tensorflow/util/TensorFlowUtil.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
package net.imagej.tensorflow.util;
3232

33+
import io.github.classgraph.ClassGraph;
3334
import net.imagej.tensorflow.TensorFlowVersion;
3435
import net.imagej.updater.util.Platforms;
3536
import org.scijava.log.Logger;
@@ -75,9 +76,12 @@ public static TensorFlowVersion getTensorFlowJARVersion(URL jar) {
7576
if(matcher.find()) {
7677
// guess GPU support by looking for tensorflow_jni_gpu in the class path
7778
boolean supportsGPU = false;
78-
ClassLoader cl = ClassLoader.getSystemClassLoader();
79-
for(URL url: ((URLClassLoader)cl).getURLs()){
80-
if(url.getFile().contains("libtensorflow_jni_gpu")) {
79+
ClassGraph cg = new ClassGraph();
80+
String cp = cg.getClasspath();
81+
String[] jars = cp.split(File.pathSeparator);
82+
83+
for(String j: jars){
84+
if(j.contains("libtensorflow_jni_gpu")) {
8185
supportsGPU = true;
8286
break;
8387
}

0 commit comments

Comments
 (0)