diff --git a/build.sbt b/build.sbt index a016809f6..53959b59e 100644 --- a/build.sbt +++ b/build.sbt @@ -26,7 +26,7 @@ organization in ThisBuild := "org.platanios" autoCompilerPlugins in ThisBuild := true -val tensorFlowVersion = "1.11.0" +val tensorFlowVersion = "1.12.0" val circeVersion = "0.10.1" // Use for working with JSON. // addCompilerPlugin(MetalsPlugin.semanticdbScalac) @@ -157,7 +157,12 @@ lazy val jni = (project in file("./modules/jni")) "Sparse" -> Seq("SparseToDense"), "Text" -> Seq( "StringJoin", "StringSplit", "EncodeBase64", "DecodeBase64", "StringToHashBucket", "StringToHashBucketFast", - "StringToHashBucketStrong") + "StringToHashBucketStrong"), + "Linalg" -> Seq( + "Cholesky", "CholeskyGrad", "LogMatrixDeterminant", "MatrixDeterminant", + "MatrixInverse", "MatrixSolve", "MatrixSolveLs", + /* "MatrixSquareRoot", */ "MatrixTriangularSolve", "Qr", + "SelfAdjointEigV2", "Svd") ), scalaPackage in generateTensorOps := "tensors", // Native bindings compilation settings diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/core/types/package.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/core/types/package.scala index 8546282c4..c6646427d 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/core/types/package.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/core/types/package.scala @@ -165,6 +165,7 @@ package object types { type Quantized = Union[QByte]#or[QShort]#or[QInt]#or[QUByte]#or[QUShort]#create type Numeric = Union[TruncatedHalf]#or[Half]#or[Float]#or[Double]#or[Byte]#or[Short]#or[Int]#or[Long]#or[UByte]#or[UShort]#or[UInt]#or[ULong]#or[ComplexFloat]#or[ComplexDouble]#or[QByte]#or[QShort]#or[QInt]#or[QUByte]#or[QUShort]#create type BooleanOrNumeric = Union[Boolean]#or[Half]#or[Float]#or[Double]#or[Byte]#or[Short]#or[Int]#or[Long]#or[UByte]#or[UShort]#or[UInt]#or[ULong]#or[ComplexFloat]#or[ComplexDouble]#or[QByte]#or[QShort]#or[QInt]#or[QUByte]#or[QUShort]#create + type RealOrComplex = Union[Float]#or[Double]#or[ComplexFloat]#or[ComplexDouble]#create type IsFloatOrDouble[T] = Contains[T, FloatOrDouble] type IsHalfOrFloat[T] = Contains[T, HalfOrFloat] @@ -186,6 +187,7 @@ package object types { type IsQuantized[T] = Contains[T, Quantized] type IsNumeric[T] = Contains[T, Numeric] type IsBooleanOrNumeric[T] = Contains[T, BooleanOrNumeric] + type IsRealOrComplex[T] = Contains[T, RealOrComplex] object IsFloatOrDouble { def apply[T: IsFloatOrDouble]: IsFloatOrDouble[T] = implicitly[IsFloatOrDouble[T]] @@ -266,4 +268,8 @@ package object types { object IsBooleanOrNumeric { def apply[T: IsBooleanOrNumeric]: IsBooleanOrNumeric[T] = implicitly[IsBooleanOrNumeric[T]] } + + object IsRealOrComplex { + def apply[T: IsRealOrComplex]: IsRealOrComplex[T] = implicitly[IsRealOrComplex[T]] + } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala new file mode 100644 index 000000000..0a5afd60a --- /dev/null +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala @@ -0,0 +1,296 @@ +/* Copyright 2019, T.AI Labs. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.platanios.tensorflow.api.ops + +import org.platanios.tensorflow.api.core.Shape +import org.platanios.tensorflow.api.core.exception.InvalidArgumentException +import org.platanios.tensorflow.api.core.types._ +import org.platanios.tensorflow.api.implicits.Implicits._ +import org.platanios.tensorflow.api.tensors +import org.platanios.tensorflow.api.tensors.Tensor +import org.platanios.tensorflow.api.utilities.DefaultsTo.IntDefault +//import com.google.protobuf.ByteString.Output + +import org.tensorflow.framework.AttrValue + +import scala.language.postfixOps +import com.google.protobuf.Descriptors.FieldDescriptor + +/** + * Defines linear algebra ops similar to the + * ones defined in tf.linalg package of the Python TF API + * + * + */ +trait Linalg { + + /** + * Performs cholesky decomposition of one or more self-adjoint matrices. + * + * The input is a tensor of shape [..., M, M] whose inner-most 2 + * dimensions form square matrices. + * + * The input has to be symmetric and positive definite. Only the lower-triangular + * part of the input will be used for this operation. The upper-triangular part + * will not be read. The output is a tensor of the same shape as the input + * containing the Cholesky decompositions for all input submatrices `[..., :, :]`. + * **Note**: The gradient computation on GPU is faster for large matrices but + * not for large batch dimensions when the submatrices are small. In this + * case it might be faster to use the CPU. + * + * Returns: + * Output of shape [M, M] + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The input. + * + * @param name An optional name to assign to the op. + * + */ + def cholesky[T: TF: IsRealOrComplex](matrix: Output[T], name: String = "Cholesky"): Output[T] = + Op.Builder[Output[T], Output[T]]( + opType = "Cholesky", + name = name, + input = matrix + ).setGradientFn(choleskyGrad(_, _)(TF[T], IsRealOrComplex[T])).build().output + + protected def choleskyGrad[T: TF: IsRealOrComplex]( + l: Op[Output[T], Output[T]], + outputGradient: Output[T] + ): Output[T] = + Op.Builder[(Output[T], Output[T]), Output[T]]( + opType = "CholeskyGrad", + name = "CholeskyGrad", + input = (l.output, outputGradient) + ).build().output + + def matrixDeterminant[T: TF: IsRealOrComplex](matrix: Output[T], name: String = "MatrixDeterminant"): Output[T] = { + Op.Builder[Output[T], Output[T]]( + opType = "MatrixDeterminant", + name = name, + input = matrix + ).build().output + } + + /** + * Computes (sign(det(x)) log(|det(x)|)) for an input x. + * + * @tparam T The underlying scala type of the matrix elements. + * + * @param matrix A matrix of shape [N, M, M] + * @param name An optional name to assign to the op. + * + * @return A tuple having the results. + * + */ + def logMatrixDeterminant[T: TF: IsRealOrComplex]( + matrix: Output[T], + name: String = "LogMatrixDeterminant" + ): (Output[T], Output[T]) = { + Op.Builder[Output[T], (Output[T], Output[T])]( + opType = "LogMatrixDeterminant", + name = name, + input = matrix + ).build().output + } + + /** + * Computes inv(A), assuming matrix A is invertible and of shape [..., M, M] + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The matrix to invert. + * @param adjoint If set to true, returns the adjoint, defaults to false. + * @param name An optional name to assign to the op. + * + */ + def matrixInverse[T: TF: IsRealOrComplex]( + matrix: Output[T], + adjoint: Boolean = false, + name: String = "MatrixInverse" + ): Output[T] = + Op.Builder[Output[T], Output[T]]( + opType = "MatrixInverse", + name = name, + input = matrix + ).setAttribute("adjoint", adjoint).build().output + + /** + * Solves systems of linear equations Ax = b. + * The matrix M must be of shape [..., M, M] whose inner-most 2 dimensions + * form square matrices. + * + * The right hand side b is a tensor of shape [..., M, K]. + * The output x is a tensor shape [..., M, K] + * + * If `adjoint` is `True` then each output matrix satisfies + * adjoint(A[..., :, :]) * x[..., :, :] = b[..., :, :]. + * + * If `adjoint` is `False` then each output matrix satisfies + * A[..., :, :] * x[..., :, :] = b[..., :, :]. + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The matrix (A) on the left hand side. + * @param rhs The right hand side (b). + * @param adjoint Defaults to false. + * @param name An optional name to assign to the op. + * + */ + def matrixSolve[T: TF: IsRealOrComplex]( + matrix: Output[T], + rhs: Output[T], + adjoint: Boolean = false, + name: String = "MatrixSolve" + ): Output[T] = + Op.Builder[(Output[T], Output[T]), Output[T]]( + opType = "MatrixSolve", + name = name, + input = (matrix, rhs) + ).setAttribute("adjoint", adjoint).build().output + + /** + * Solves systems of linear equations Ax = b, in the regularised + * least squares sense. + * + * The matrix A must be of shape [..., M, N] whose inner-most 2 dimensions + * form square matrices. + * + * The right hand side b is a tensor of shape [..., M, K]. + * The output x is a tensor shape [..., N, K] + * + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The matrix (A) on the left hand side. + * @param rhs The right hand side (b). + * @param reg The L2 regularisation constant. + * @param fast Defaults to true. + * @param name An optional name to assign to the op. + * + */ + def matrixSolveLS[T: TF: IsRealOrComplex]( + matrix: Output[T], + rhs: Output[T], + reg: Output[T], + fast: Boolean = true, + name: String = "MatrixSolveLs" + ): Output[T] = + Op.Builder[(Output[T], Output[T], Output[T]), Output[T]]( + opType = "MatrixSolveLs", + name = name, + input = (matrix, rhs, reg) + ).setAttribute("fast", fast).build().output + + /* def matrixSquareRoot[T: TF: IsRealOrComplex](matrix: Output[T], name: String = "MatrixSquareRoot"): Output[T] = { + Op.Builder[Output[T], Output[T]]( + opType = "MatrixSquareRoot", + name = name, + input = matrix + ).build().output + } */ + + def matrixTriangularSolve[T: TF: IsRealOrComplex]( + matrix: Output[T], + rhs: Output[T], + lower: Boolean = true, + adjoint: Boolean = false, + name: String = "MatrixTriangularSolve" + ): Output[T] = + Op.Builder[(Output[T], Output[T]), Output[T]]( + opType = "MatrixTriangularSolve", + name = name, + input = (matrix, rhs) + ).setAttribute("lower", lower).setAttribute("adjoint", adjoint).build().output + + /** + * Performs QR decomposition of a matrix. + * + * The matrix must be of [..., M, N] whose inner-most 2 dimensions + * form matrices of size [M, N]. Let P be the minimum of M and N. + * + * Returns: + * q: Orthonormal basis for range of the input matrix. If + * full_matrices is `False` then shape is [..., M, P]; + * if full_matrices is `True` then shape is [..., M, M]. + * r: Triangular factor. If full_matrices is `False` then shape is + * [..., P, N]. If full_matrices is `True` then shape is [..., M, N]. + * + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The input. + * @param full_matrices If true, compute full-sized q and r. + * If false (the default), compute only the + * leading P columns of q. + * @param name An optional name to assign to the op. + * + */ + def qr[T: TF: IsRealOrComplex]( + matrix: Output[T], + full_matrices: Boolean = false, + name: String = "Qr" + ): (Output[T], Output[T]) = + Op.Builder[Output[T], (Output[T], Output[T])]( + opType = "Qr", + name = name, + input = matrix + ).setAttribute("full_matrices", full_matrices).build().output + + def selfAdjointEig[T: TF: IsRealOrComplex]( + matrix: Output[T], + compute_v: Boolean = true, + name: String = "SelfAdjointEigV2" + ): (Output[T], Output[T]) = + Op.Builder[Output[T], (Output[T], Output[T])]( + opType = "SelfAdjointEigV2", + name = name, + input = matrix + ).setAttribute("compute_v", compute_v).build().output + + /** + * Performs singular value decomposition of a matrix. + * + * The matrix must be of [..., M, N] whose inner-most 2 dimensions + * form matrices of size [M, N]. Let P be the minimum of M and N. + * + * Returns: + * s: Singular values. Shape is [..., P]. + * u: Left singular vectors. If full_matrices is False then shape is + * [..., M, P]; if full_matrices is True then shape is + * [..., M, M]. Undefined if compute_uv is False. + * v: Left singular vectors. If full_matrices is False then shape is + * [..., N, P]. If full_matrices is True then shape is [..., N, N]. + * Undefined if compute_uv is false. + * + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The matrix to decompose. + * @param full_matrices If true, compute full-sized u and v. + * If false (the default), compute only the + * leading P singular vectors. + * @param compute_uv If true, left and right singular vectors will be + * computed and returned in u and v, respectively. + * @param name An optional name to assign to the op. + * + */ + def svd[T: TF: IsRealOrComplex]( + matrix: Output[T], + compute_uv: Boolean = true, + full_matrices: Boolean = false, + name: String = "Svd" + ): (Output[T], Output[T], Output[T]) = + Op.Builder[Output[T], (Output[T], Output[T], Output[T])]( + opType = "Svd", + name = name, + input = matrix + ).setAttribute("compute_uv", compute_uv).setAttribute("full_matrices", full_matrices).build().output +} diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/package.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/package.scala index 48041f186..cbb5b98c5 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/package.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/package.scala @@ -67,6 +67,7 @@ package object ops { with Logging with Math with NN + with Linalg with Parsing with Random with Resources diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/package.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/package.scala index 7badc279c..dcf9f0cd8 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/package.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/package.scala @@ -72,76 +72,76 @@ package object api extends implicits.Implicits with Documentation { val Shape: core.Shape.type = core.Shape type Indexer = core.Indexer - type Index = core.Index - type Slice = core.Slice + type Index = core.Index + type Slice = core.Slice val --- : Indexer = core.Ellipsis val NewAxis: Indexer = core.NewAxis val :: : Slice = core.Slice.:: - type TensorLike[T] = tensors.TensorLike[T] - type Tensor[T] = tensors.Tensor[T] + type TensorLike[T] = tensors.TensorLike[T] + type Tensor[T] = tensors.Tensor[T] type TensorIndexedSlices[T] = tensors.TensorIndexedSlices[T] - type SparseTensor[T] = tensors.SparseTensor[T] + type SparseTensor[T] = tensors.SparseTensor[T] - val Tensor : tensors.Tensor.type = tensors.Tensor + val Tensor: tensors.Tensor.type = tensors.Tensor val TensorIndexedSlices: tensors.TensorIndexedSlices.type = tensors.TensorIndexedSlices - val SparseTensor : tensors.SparseTensor.type = tensors.SparseTensor + val SparseTensor: tensors.SparseTensor.type = tensors.SparseTensor - type Op[I, O] = ops.Op[I, O] + type Op[I, O] = ops.Op[I, O] type UntypedOp = ops.UntypedOp - type OutputLike[T] = ops.OutputLike[T] - type Output[T] = ops.Output[T] + type OutputLike[T] = ops.OutputLike[T] + type Output[T] = ops.Output[T] type OutputIndexedSlices[T] = ops.OutputIndexedSlices[T] - type SparseOutput[T] = ops.SparseOutput[T] + type SparseOutput[T] = ops.SparseOutput[T] type TensorArray[T] = ops.TensorArray[T] - val Op : ops.Op.type = ops.Op - val Output : ops.Output.type = ops.Output + val Op: ops.Op.type = ops.Op + val Output: ops.Output.type = ops.Output val OutputIndexedSlices: ops.OutputIndexedSlices.type = ops.OutputIndexedSlices - val SparseOutput : ops.SparseOutput.type = ops.SparseOutput - val TensorArray : ops.TensorArray.type = ops.TensorArray + val SparseOutput: ops.SparseOutput.type = ops.SparseOutput + val TensorArray: ops.TensorArray.type = ops.TensorArray type VariableLike[T] = ops.variables.VariableLike[T] - type Variable[T] = ops.variables.Variable[T] + type Variable[T] = ops.variables.Variable[T] //region Types //region Value Classes - type Half = core.types.Half + type Half = core.types.Half type TruncatedHalf = core.types.TruncatedHalf - type ComplexFloat = core.types.ComplexFloat + type ComplexFloat = core.types.ComplexFloat type ComplexDouble = core.types.ComplexDouble - type UByte = core.types.UByte - type UShort = core.types.UShort - type UInt = core.types.UInt - type ULong = core.types.ULong - type QByte = core.types.QByte - type QShort = core.types.QShort - type QInt = core.types.QInt - type QUByte = core.types.QUByte - type QUShort = core.types.QUShort - type Resource = core.types.Resource - type Variant = core.types.Variant - - val Half : core.types.Half.type = core.types.Half + type UByte = core.types.UByte + type UShort = core.types.UShort + type UInt = core.types.UInt + type ULong = core.types.ULong + type QByte = core.types.QByte + type QShort = core.types.QShort + type QInt = core.types.QInt + type QUByte = core.types.QUByte + type QUShort = core.types.QUShort + type Resource = core.types.Resource + type Variant = core.types.Variant + + val Half: core.types.Half.type = core.types.Half val TruncatedHalf: core.types.TruncatedHalf.type = core.types.TruncatedHalf - val ComplexFloat : core.types.ComplexFloat.type = core.types.ComplexFloat + val ComplexFloat: core.types.ComplexFloat.type = core.types.ComplexFloat val ComplexDouble: core.types.ComplexDouble.type = core.types.ComplexDouble - val UByte : core.types.UByte.type = core.types.UByte - val UShort : core.types.UShort.type = core.types.UShort - val UInt : core.types.UInt.type = core.types.UInt - val ULong : core.types.ULong.type = core.types.ULong - val QByte : core.types.QByte.type = core.types.QByte - val QShort : core.types.QShort.type = core.types.QShort - val QInt : core.types.QInt.type = core.types.QInt - val QUByte : core.types.QUByte.type = core.types.QUByte - val QUShort : core.types.QUShort.type = core.types.QUShort - val Resource : core.types.Resource.type = core.types.Resource - val Variant : core.types.Variant.type = core.types.Variant + val UByte: core.types.UByte.type = core.types.UByte + val UShort: core.types.UShort.type = core.types.UShort + val UInt: core.types.UInt.type = core.types.UInt + val ULong: core.types.ULong.type = core.types.ULong + val QByte: core.types.QByte.type = core.types.QByte + val QShort: core.types.QShort.type = core.types.QShort + val QInt: core.types.QInt.type = core.types.QInt + val QUByte: core.types.QUByte.type = core.types.QUByte + val QUShort: core.types.QUShort.type = core.types.QUShort + val Resource: core.types.Resource.type = core.types.Resource + val Variant: core.types.Variant.type = core.types.Variant //endregion Value Classes @@ -149,96 +149,100 @@ package object api extends implicits.Implicits with Documentation { type DataType[T] = core.types.DataType[T] - type STRING = core.types.DataType[String] - type BOOLEAN = core.types.DataType[Boolean] - type FLOAT16 = core.types.DataType[core.types.Half] - type FLOAT32 = core.types.DataType[Float] - type FLOAT64 = core.types.DataType[Double] - type BFLOAT16 = core.types.DataType[core.types.TruncatedHalf] - type COMPLEX64 = core.types.DataType[core.types.ComplexFloat] + type STRING = core.types.DataType[String] + type BOOLEAN = core.types.DataType[Boolean] + type FLOAT16 = core.types.DataType[core.types.Half] + type FLOAT32 = core.types.DataType[Float] + type FLOAT64 = core.types.DataType[Double] + type BFLOAT16 = core.types.DataType[core.types.TruncatedHalf] + type COMPLEX64 = core.types.DataType[core.types.ComplexFloat] type COMPLEX128 = core.types.DataType[core.types.ComplexDouble] - type INT8 = core.types.DataType[Byte] - type INT16 = core.types.DataType[Short] - type INT32 = core.types.DataType[Int] - type INT64 = core.types.DataType[Long] - type UINT8 = core.types.DataType[core.types.UByte] - type UINT16 = core.types.DataType[core.types.UShort] - type UINT32 = core.types.DataType[core.types.UInt] - type UINT64 = core.types.DataType[core.types.ULong] - type QINT8 = core.types.DataType[core.types.QByte] - type QINT16 = core.types.DataType[core.types.QShort] - type QINT32 = core.types.DataType[core.types.QInt] - type QUINT8 = core.types.DataType[core.types.QUByte] - type QUINT16 = core.types.DataType[core.types.QUShort] - type RESOURCE = core.types.DataType[core.types.Resource] - type VARIANT = core.types.DataType[core.types.Variant] - - val STRING : STRING = core.types.STRING - val BOOLEAN : BOOLEAN = core.types.BOOLEAN - val FLOAT16 : FLOAT16 = core.types.FLOAT16 - val FLOAT32 : FLOAT32 = core.types.FLOAT32 - val FLOAT64 : FLOAT64 = core.types.FLOAT64 - val BFLOAT16 : BFLOAT16 = core.types.BFLOAT16 - val COMPLEX64 : COMPLEX64 = core.types.COMPLEX64 + type INT8 = core.types.DataType[Byte] + type INT16 = core.types.DataType[Short] + type INT32 = core.types.DataType[Int] + type INT64 = core.types.DataType[Long] + type UINT8 = core.types.DataType[core.types.UByte] + type UINT16 = core.types.DataType[core.types.UShort] + type UINT32 = core.types.DataType[core.types.UInt] + type UINT64 = core.types.DataType[core.types.ULong] + type QINT8 = core.types.DataType[core.types.QByte] + type QINT16 = core.types.DataType[core.types.QShort] + type QINT32 = core.types.DataType[core.types.QInt] + type QUINT8 = core.types.DataType[core.types.QUByte] + type QUINT16 = core.types.DataType[core.types.QUShort] + type RESOURCE = core.types.DataType[core.types.Resource] + type VARIANT = core.types.DataType[core.types.Variant] + + val STRING: STRING = core.types.STRING + val BOOLEAN: BOOLEAN = core.types.BOOLEAN + val FLOAT16: FLOAT16 = core.types.FLOAT16 + val FLOAT32: FLOAT32 = core.types.FLOAT32 + val FLOAT64: FLOAT64 = core.types.FLOAT64 + val BFLOAT16: BFLOAT16 = core.types.BFLOAT16 + val COMPLEX64: COMPLEX64 = core.types.COMPLEX64 val COMPLEX128: COMPLEX128 = core.types.COMPLEX128 - val INT8 : INT8 = core.types.INT8 - val INT16 : INT16 = core.types.INT16 - val INT32 : INT32 = core.types.INT32 - val INT64 : INT64 = core.types.INT64 - val UINT8 : UINT8 = core.types.UINT8 - val UINT16 : UINT16 = core.types.UINT16 - val UINT32 : UINT32 = core.types.UINT32 - val UINT64 : UINT64 = core.types.UINT64 - val QINT8 : QINT8 = core.types.QINT8 - val QINT16 : QINT16 = core.types.QINT16 - val QINT32 : QINT32 = core.types.QINT32 - val QUINT8 : QUINT8 = core.types.QUINT8 - val QUINT16 : QUINT16 = core.types.QUINT16 - val RESOURCE : RESOURCE = core.types.RESOURCE - val VARIANT : VARIANT = core.types.VARIANT + val INT8: INT8 = core.types.INT8 + val INT16: INT16 = core.types.INT16 + val INT32: INT32 = core.types.INT32 + val INT64: INT64 = core.types.INT64 + val UINT8: UINT8 = core.types.UINT8 + val UINT16: UINT16 = core.types.UINT16 + val UINT32: UINT32 = core.types.UINT32 + val UINT64: UINT64 = core.types.UINT64 + val QINT8: QINT8 = core.types.QINT8 + val QINT16: QINT16 = core.types.QINT16 + val QINT32: QINT32 = core.types.QINT32 + val QUINT8: QUINT8 = core.types.QUINT8 + val QUINT16: QUINT16 = core.types.QUINT16 + val RESOURCE: RESOURCE = core.types.RESOURCE + val VARIANT: VARIANT = core.types.VARIANT //endregion Data Type Instances //region Type Traits - type TF[T] = core.types.TF[T] - type IsFloatOrDouble[T] = core.types.IsFloatOrDouble[T] - type IsHalfOrFloatOrDouble[T] = core.types.IsHalfOrFloatOrDouble[T] - type IsTruncatedHalfOrFloatOrDouble[T] = core.types.IsTruncatedHalfOrFloatOrDouble[T] - type IsTruncatedHalfOrHalfOrFloat[T] = core.types.IsTruncatedHalfOrHalfOrFloat[T] - type IsDecimal[T] = core.types.IsDecimal[T] - type IsIntOrLong[T] = core.types.IsIntOrLong[T] - type IsIntOrLongOrFloatOrDouble[T] = core.types.IsIntOrLongOrFloatOrDouble[T] + type TF[T] = core.types.TF[T] + type IsFloatOrDouble[T] = core.types.IsFloatOrDouble[T] + type IsHalfOrFloatOrDouble[T] = core.types.IsHalfOrFloatOrDouble[T] + type IsTruncatedHalfOrFloatOrDouble[T] = core.types.IsTruncatedHalfOrFloatOrDouble[T] + type IsTruncatedHalfOrHalfOrFloat[T] = core.types.IsTruncatedHalfOrHalfOrFloat[T] + type IsDecimal[T] = core.types.IsDecimal[T] + type IsIntOrLong[T] = core.types.IsIntOrLong[T] + type IsIntOrLongOrFloatOrDouble[T] = core.types.IsIntOrLongOrFloatOrDouble[T] type IsIntOrLongOrHalfOrFloatOrDouble[T] = core.types.IsIntOrLongOrHalfOrFloatOrDouble[T] - type IsIntOrLongOrUByte[T] = core.types.IsIntOrLongOrUByte[T] - type IsIntOrUInt[T] = core.types.IsIntOrUInt[T] - type IsStringOrInteger[T] = core.types.IsStringOrInteger[T] - type IsStringOrFloatOrLong[T] = core.types.IsStringOrFloatOrLong[T] - type IsReal[T] = core.types.IsReal[T] - type IsComplex[T] = core.types.IsComplex[T] - type IsNotQuantized[T] = core.types.IsNotQuantized[T] - type IsQuantized[T] = core.types.IsQuantized[T] - type IsNumeric[T] = core.types.IsNumeric[T] - type IsBooleanOrNumeric[T] = core.types.IsBooleanOrNumeric[T] - - val TF : core.types.TF.type = core.types.TF - val IsFloatOrDouble : core.types.IsFloatOrDouble.type = core.types.IsFloatOrDouble - val IsHalfOrFloatOrDouble : core.types.IsHalfOrFloatOrDouble.type = core.types.IsHalfOrFloatOrDouble - val IsTruncatedHalfOrFloatOrDouble : core.types.IsTruncatedHalfOrFloatOrDouble.type = core.types.IsTruncatedHalfOrFloatOrDouble - val IsTruncatedHalfOrHalfOrFloat : core.types.IsTruncatedHalfOrHalfOrFloat.type = core.types.IsTruncatedHalfOrHalfOrFloat - val IsDecimal : core.types.IsDecimal.type = core.types.IsDecimal - val IsIntOrLong : core.types.IsIntOrLong.type = core.types.IsIntOrLong - val IsIntOrLongOrFloatOrDouble : core.types.IsIntOrLongOrFloatOrDouble.type = core.types.IsIntOrLongOrFloatOrDouble - val IsIntOrLongOrHalfOrFloatOrDouble: core.types.IsIntOrLongOrHalfOrFloatOrDouble.type = core.types.IsIntOrLongOrHalfOrFloatOrDouble - val IsIntOrLongOrUByte : core.types.IsIntOrLongOrUByte.type = core.types.IsIntOrLongOrUByte - val IsIntOrUInt : core.types.IsIntOrUInt.type = core.types.IsIntOrUInt - val IsStringOrFloatOrLong : core.types.IsStringOrFloatOrLong.type = core.types.IsStringOrFloatOrLong - val IsReal : core.types.IsReal.type = core.types.IsReal - val IsComplex : core.types.IsComplex.type = core.types.IsComplex - val IsNotQuantized : core.types.IsNotQuantized.type = core.types.IsNotQuantized - val IsQuantized : core.types.IsQuantized.type = core.types.IsQuantized - val IsNumeric : core.types.IsNumeric.type = core.types.IsNumeric - val IsBooleanOrNumeric : core.types.IsBooleanOrNumeric.type = core.types.IsBooleanOrNumeric + type IsIntOrLongOrUByte[T] = core.types.IsIntOrLongOrUByte[T] + type IsIntOrUInt[T] = core.types.IsIntOrUInt[T] + type IsStringOrInteger[T] = core.types.IsStringOrInteger[T] + type IsStringOrFloatOrLong[T] = core.types.IsStringOrFloatOrLong[T] + type IsReal[T] = core.types.IsReal[T] + type IsComplex[T] = core.types.IsComplex[T] + type IsNotQuantized[T] = core.types.IsNotQuantized[T] + type IsQuantized[T] = core.types.IsQuantized[T] + type IsNumeric[T] = core.types.IsNumeric[T] + type IsBooleanOrNumeric[T] = core.types.IsBooleanOrNumeric[T] + + val TF: core.types.TF.type = core.types.TF + val IsFloatOrDouble: core.types.IsFloatOrDouble.type = core.types.IsFloatOrDouble + val IsHalfOrFloatOrDouble: core.types.IsHalfOrFloatOrDouble.type = core.types.IsHalfOrFloatOrDouble + val IsTruncatedHalfOrFloatOrDouble: core.types.IsTruncatedHalfOrFloatOrDouble.type = + core.types.IsTruncatedHalfOrFloatOrDouble + val IsTruncatedHalfOrHalfOrFloat: core.types.IsTruncatedHalfOrHalfOrFloat.type = + core.types.IsTruncatedHalfOrHalfOrFloat + val IsDecimal: core.types.IsDecimal.type = core.types.IsDecimal + val IsIntOrLong: core.types.IsIntOrLong.type = core.types.IsIntOrLong + val IsIntOrLongOrFloatOrDouble: core.types.IsIntOrLongOrFloatOrDouble.type = core.types.IsIntOrLongOrFloatOrDouble + val IsIntOrLongOrHalfOrFloatOrDouble: core.types.IsIntOrLongOrHalfOrFloatOrDouble.type = + core.types.IsIntOrLongOrHalfOrFloatOrDouble + val IsIntOrLongOrUByte: core.types.IsIntOrLongOrUByte.type = core.types.IsIntOrLongOrUByte + val IsIntOrUInt: core.types.IsIntOrUInt.type = core.types.IsIntOrUInt + val IsStringOrFloatOrLong: core.types.IsStringOrFloatOrLong.type = core.types.IsStringOrFloatOrLong + val IsReal: core.types.IsReal.type = core.types.IsReal + val IsComplex: core.types.IsComplex.type = core.types.IsComplex + val IsNotQuantized: core.types.IsNotQuantized.type = core.types.IsNotQuantized + val IsQuantized: core.types.IsQuantized.type = core.types.IsQuantized + val IsNumeric: core.types.IsNumeric.type = core.types.IsNumeric + val IsBooleanOrNumeric: core.types.IsBooleanOrNumeric.type = core.types.IsBooleanOrNumeric + val IsRealOrComplex: core.types.IsRealOrComplex.type = core.types.IsRealOrComplex //endregion Type Traits @@ -291,9 +295,7 @@ package object api extends implicits.Implicits with Documentation { * @groupname CallbackOps Ops / Callback * @groupprio CallbackOps 280 */ - object tf - extends core.API - with ops.API { + object tf extends core.API with ops.API { object learn extends api.learn.API } @@ -336,7 +338,5 @@ package object api extends implicits.Implicits with Documentation { * @groupname CallbackOps Ops / Callback * @groupprio CallbackOps 280 */ - object tfi - extends core.API - with tensors.API + object tfi extends core.API with tensors.API } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala new file mode 100644 index 000000000..a17990553 --- /dev/null +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala @@ -0,0 +1,168 @@ +/* Copyright 2019, T.AI Labs. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.platanios.tensorflow.api.tensors.ops + +import org.platanios.tensorflow.api.core.Shape +import org.platanios.tensorflow.api.core.exception.InvalidShapeException +import org.platanios.tensorflow.api.core.types._ +import org.platanios.tensorflow.api.implicits.Implicits._ +import org.platanios.tensorflow.api.tensors._ +import org.platanios.tensorflow.api.utilities.DefaultsTo.IntDefault +import org.platanios.tensorflow.jni.generated.tensors.{Linalg => NativeTensorOpsLinAlg} + +import scala.language.postfixOps + +/** + * Defines linear algebra ops similar to the + * ones defined in tf.linalg package of the Python TF API + * + * + */ +trait Linalg { + + def cholesky[T: TF: IsRealOrComplex](matrix: Tensor[T]): Tensor[T] = + Tensor.fromNativeHandle[T]( + NativeTensorOpsLinAlg.cholesky(executionContext.value.nativeHandle, matrix.nativeHandle) + ) + + def matrixDeterminant[T: TF: IsRealOrComplex](matrix: Tensor[T]): Tensor[T] = { + Tensor.fromNativeHandle[T]( + NativeTensorOpsLinAlg.matrixDeterminant(executionContext.value.nativeHandle, matrix.nativeHandle) + ) + } + + /** + * Computes (sign(det(x)) log(|det(x)|)) for an input x. + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix A matrix of shape [N, M, M] + * + * @return A tuple having the results. + * + */ + def logMatrixDeterminant[T: TF: IsRealOrComplex](matrix: Tensor[T]): (Tensor[T], Tensor[T]) = { + val results = NativeTensorOpsLinAlg + .logMatrixDeterminant(executionContext.value.nativeHandle, matrix.nativeHandle).map( + h => Tensor.fromNativeHandle[T](h) + ) + (results.head, results.last) + } + + /** + * Computes inv(A), assuming matrix A is invertible and of shape [..., M, M] + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The matrix to invert. + * @param adjoint If set to true, returns the adjoint, defaults to false. + * + * + */ + def matrixInverse[T: TF: IsRealOrComplex](matrix: Tensor[T], adjoint: Boolean = false): Tensor[T] = { + Tensor.fromNativeHandle[T]( + NativeTensorOpsLinAlg.matrixInverse(executionContext.value.nativeHandle, matrix.nativeHandle, adjoint) + ) + } + + /** + * Solves systems of linear equations Ax = b. + * The matrix M must be of shape [..., M, M] whose inner-most 2 dimensions + * form square matrices. + * + * The right hand side b is a tensor of shape [..., M, K]. + * The output x is a tensor shape [..., M, K] + * + * If `adjoint` is `True` then each output matrix satisfies + * adjoint(A[..., :, :]) * x[..., :, :] = b[..., :, :]. + * + * If `adjoint` is `False` then each output matrix satisfies + * A[..., :, :] * x[..., :, :] = b[..., :, :]. + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The matrix (A) on the left hand side. + * @param rhs The right hand side (b). + * @param adjoint Defaults to false. + * + */ + def matrixSolve[T: TF: IsRealOrComplex](matrix: Tensor[T], rhs: Tensor[T], adjoint: Boolean = false): Tensor[T] = { + Tensor.fromNativeHandle[T]( + NativeTensorOpsLinAlg + .matrixSolve(executionContext.value.nativeHandle, matrix.nativeHandle, rhs.nativeHandle, adjoint) + ) + } + + def matrixSolveLS[T: TF: IsRealOrComplex]( + matrix: Tensor[T], + rhs: Tensor[T], + reg: Tensor[T], + fast: Boolean = true + ): Tensor[T] = { + Tensor.fromNativeHandle[T]( + NativeTensorOpsLinAlg.matrixSolveLs( + executionContext.value.nativeHandle, + matrix.nativeHandle, + rhs.nativeHandle, + reg.nativeHandle, + fast + ) + ) + } + + def matrixTriangularSolve[T: TF: IsRealOrComplex]( + matrix: Tensor[T], + rhs: Tensor[T], + lower: Boolean = true, + adjoint: Boolean = false + ): Tensor[T] = { + Tensor.fromNativeHandle[T]( + NativeTensorOpsLinAlg.matrixTriangularSolve( + executionContext.value.nativeHandle, + matrix.nativeHandle, + rhs.nativeHandle, + lower, + adjoint + ) + ) + } + + def qr[T: TF: IsRealOrComplex](matrix: Tensor[T], full_matrices: Boolean = false): (Tensor[T], Tensor[T]) = { + + val results = NativeTensorOpsLinAlg + .qr(executionContext.value.nativeHandle, matrix.nativeHandle, full_matrices).map( + h => Tensor.fromNativeHandle[T](h) + ) + (results.head, results.last) + } + + def selfAdjointEig[T: TF: IsRealOrComplex](matrix: Tensor[T], compute_v: Boolean = true): (Tensor[T], Tensor[T]) = { + val results = NativeTensorOpsLinAlg + .selfAdjointEigV2(executionContext.value.nativeHandle, matrix.nativeHandle, compute_v).map( + h => Tensor.fromNativeHandle[T](h) + ) + (results.head, results.last) + } + + def svd[T: TF: IsRealOrComplex]( + matrix: Tensor[T], + compute_uv: Boolean = true, + full_matrices: Boolean = false + ): (Tensor[T], Tensor[T], Tensor[T]) = { + + val results = NativeTensorOpsLinAlg + .svd(executionContext.value.nativeHandle, matrix.nativeHandle, compute_uv, full_matrices).map( + h => Tensor.fromNativeHandle[T](h) + ) + (results.head, results(1), results.last) + } +} diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/package.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/package.scala index 2c6bb2cc9..418f2e301 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/package.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/package.scala @@ -25,4 +25,5 @@ package object ops { with Math with NN with Random + with Linalg } diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc new file mode 100644 index 000000000..e7da7f08b --- /dev/null +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc @@ -0,0 +1,450 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ + +/* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +#include "tensor_linalg_ops.h" +#include "exception.h" +#include "utilities.h" + +#include +#include +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" + +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_cholesky( + JNIEnv* env, jobject object, jlong context_handle, jlong input) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "Cholesky", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), 0); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(input_handle, TFE_TensorHandle, input, 0); + TFE_OpAddInput(op.get(), input_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(attr_T_input_handle, TFE_TensorHandle, input, 0); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_input_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + const int num_outputs = 1; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), 0); + + return reinterpret_cast(outputs[0]); +} + +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_choleskyGrad( + JNIEnv* env, jobject object, jlong context_handle, jlong l, jlong grad) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "CholeskyGrad", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), 0); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(l_handle, TFE_TensorHandle, l, 0); + TFE_OpAddInput(op.get(), l_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(grad_handle, TFE_TensorHandle, grad, 0); + TFE_OpAddInput(op.get(), grad_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(attr_T_l_handle, TFE_TensorHandle, l, 0); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_l_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + REQUIRE_HANDLE(attr_T_grad_handle, TFE_TensorHandle, grad, 0); + const TF_DataType attr_T_grad = TFE_TensorHandleDataType(attr_T_grad_handle); + if (attr_T != attr_T_grad) { + std::stringstream error_msg; + error_msg + << "Argument 'grad' of 'choleskyGrad' op with data type '" + << attr_T_grad + << "' must match data type '" + << attr_T + << "' of argument 'l'"; + throw_exception(env, tf_invalid_argument_exception, error_msg.str().c_str()); + } + + const int num_outputs = 1; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), 0); + + return reinterpret_cast(outputs[0]); +} + +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_logMatrixDeterminant( + JNIEnv* env, jobject object, jlong context_handle, jlong input) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, nullptr); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "LogMatrixDeterminant", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), nullptr); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(input_handle, TFE_TensorHandle, input, nullptr); + TFE_OpAddInput(op.get(), input_handle, status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(attr_T_input_handle, TFE_TensorHandle, input, nullptr); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_input_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + const int num_outputs = 2; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + jlongArray outputs_array = env->NewLongArray(static_cast(num_outputs)); + jlong* output_elems = env->GetLongArrayElements(outputs_array, nullptr); + for (int i = 0; i < num_outputs; ++i) { + output_elems[i] = reinterpret_cast(outputs[i]); + } + env->ReleaseLongArrayElements(outputs_array, output_elems, 0); + return outputs_array; +} + +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixDeterminant( + JNIEnv* env, jobject object, jlong context_handle, jlong input) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "MatrixDeterminant", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), 0); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(input_handle, TFE_TensorHandle, input, 0); + TFE_OpAddInput(op.get(), input_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(attr_T_input_handle, TFE_TensorHandle, input, 0); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_input_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + const int num_outputs = 1; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), 0); + + return reinterpret_cast(outputs[0]); +} + +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixInverse( + JNIEnv* env, jobject object, jlong context_handle, jlong input, jboolean adjoint) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "MatrixInverse", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), 0); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(input_handle, TFE_TensorHandle, input, 0); + TFE_OpAddInput(op.get(), input_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(attr_T_input_handle, TFE_TensorHandle, input, 0); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_input_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + TFE_OpSetAttrBool(op.get(), "adjoint", static_cast(adjoint)); + + const int num_outputs = 1; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), 0); + + return reinterpret_cast(outputs[0]); +} + +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixSolve( + JNIEnv* env, jobject object, jlong context_handle, jlong matrix, jlong rhs, jboolean adjoint) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "MatrixSolve", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), 0); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(matrix_handle, TFE_TensorHandle, matrix, 0); + TFE_OpAddInput(op.get(), matrix_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(rhs_handle, TFE_TensorHandle, rhs, 0); + TFE_OpAddInput(op.get(), rhs_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(attr_T_matrix_handle, TFE_TensorHandle, matrix, 0); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_matrix_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + REQUIRE_HANDLE(attr_T_rhs_handle, TFE_TensorHandle, rhs, 0); + const TF_DataType attr_T_rhs = TFE_TensorHandleDataType(attr_T_rhs_handle); + if (attr_T != attr_T_rhs) { + std::stringstream error_msg; + error_msg + << "Argument 'rhs' of 'matrixSolve' op with data type '" + << attr_T_rhs + << "' must match data type '" + << attr_T + << "' of argument 'matrix'"; + throw_exception(env, tf_invalid_argument_exception, error_msg.str().c_str()); + } + + TFE_OpSetAttrBool(op.get(), "adjoint", static_cast(adjoint)); + + const int num_outputs = 1; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), 0); + + return reinterpret_cast(outputs[0]); +} + +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixSolveLs( + JNIEnv* env, jobject object, jlong context_handle, jlong matrix, jlong rhs, jlong l2_regularizer, jboolean fast) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "MatrixSolveLs", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), 0); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(matrix_handle, TFE_TensorHandle, matrix, 0); + TFE_OpAddInput(op.get(), matrix_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(rhs_handle, TFE_TensorHandle, rhs, 0); + TFE_OpAddInput(op.get(), rhs_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(l2_regularizer_handle, TFE_TensorHandle, l2_regularizer, 0); + TFE_OpAddInput(op.get(), l2_regularizer_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(attr_T_matrix_handle, TFE_TensorHandle, matrix, 0); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_matrix_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + REQUIRE_HANDLE(attr_T_rhs_handle, TFE_TensorHandle, rhs, 0); + const TF_DataType attr_T_rhs = TFE_TensorHandleDataType(attr_T_rhs_handle); + if (attr_T != attr_T_rhs) { + std::stringstream error_msg; + error_msg + << "Argument 'rhs' of 'matrixSolveLs' op with data type '" + << attr_T_rhs + << "' must match data type '" + << attr_T + << "' of argument 'matrix'"; + throw_exception(env, tf_invalid_argument_exception, error_msg.str().c_str()); + } + + TFE_OpSetAttrBool(op.get(), "fast", static_cast(fast)); + + const int num_outputs = 1; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), 0); + + return reinterpret_cast(outputs[0]); +} + +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixTriangularSolve( + JNIEnv* env, jobject object, jlong context_handle, jlong matrix, jlong rhs, jboolean lower, jboolean adjoint) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "MatrixTriangularSolve", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), 0); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(matrix_handle, TFE_TensorHandle, matrix, 0); + TFE_OpAddInput(op.get(), matrix_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(rhs_handle, TFE_TensorHandle, rhs, 0); + TFE_OpAddInput(op.get(), rhs_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(attr_T_matrix_handle, TFE_TensorHandle, matrix, 0); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_matrix_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + REQUIRE_HANDLE(attr_T_rhs_handle, TFE_TensorHandle, rhs, 0); + const TF_DataType attr_T_rhs = TFE_TensorHandleDataType(attr_T_rhs_handle); + if (attr_T != attr_T_rhs) { + std::stringstream error_msg; + error_msg + << "Argument 'rhs' of 'matrixTriangularSolve' op with data type '" + << attr_T_rhs + << "' must match data type '" + << attr_T + << "' of argument 'matrix'"; + throw_exception(env, tf_invalid_argument_exception, error_msg.str().c_str()); + } + + TFE_OpSetAttrBool(op.get(), "lower", static_cast(lower)); + + TFE_OpSetAttrBool(op.get(), "adjoint", static_cast(adjoint)); + + const int num_outputs = 1; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), 0); + + return reinterpret_cast(outputs[0]); +} + +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_qr( + JNIEnv* env, jobject object, jlong context_handle, jlong input, jboolean full_matrices) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, nullptr); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "Qr", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), nullptr); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(input_handle, TFE_TensorHandle, input, nullptr); + TFE_OpAddInput(op.get(), input_handle, status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(attr_T_input_handle, TFE_TensorHandle, input, nullptr); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_input_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + TFE_OpSetAttrBool(op.get(), "full_matrices", static_cast(full_matrices)); + + const int num_outputs = 2; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + jlongArray outputs_array = env->NewLongArray(static_cast(num_outputs)); + jlong* output_elems = env->GetLongArrayElements(outputs_array, nullptr); + for (int i = 0; i < num_outputs; ++i) { + output_elems[i] = reinterpret_cast(outputs[i]); + } + env->ReleaseLongArrayElements(outputs_array, output_elems, 0); + return outputs_array; +} + +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_selfAdjointEigV2( + JNIEnv* env, jobject object, jlong context_handle, jlong input, jboolean compute_v) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, nullptr); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "SelfAdjointEigV2", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), nullptr); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(input_handle, TFE_TensorHandle, input, nullptr); + TFE_OpAddInput(op.get(), input_handle, status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(attr_T_input_handle, TFE_TensorHandle, input, nullptr); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_input_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + TFE_OpSetAttrBool(op.get(), "compute_v", static_cast(compute_v)); + + const int num_outputs = 2; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + jlongArray outputs_array = env->NewLongArray(static_cast(num_outputs)); + jlong* output_elems = env->GetLongArrayElements(outputs_array, nullptr); + for (int i = 0; i < num_outputs; ++i) { + output_elems[i] = reinterpret_cast(outputs[i]); + } + env->ReleaseLongArrayElements(outputs_array, output_elems, 0); + return outputs_array; +} + +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_svd( + JNIEnv* env, jobject object, jlong context_handle, jlong input, jboolean compute_uv, jboolean full_matrices) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, nullptr); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "Svd", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), nullptr); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(input_handle, TFE_TensorHandle, input, nullptr); + TFE_OpAddInput(op.get(), input_handle, status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(attr_T_input_handle, TFE_TensorHandle, input, nullptr); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_input_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + TFE_OpSetAttrBool(op.get(), "compute_uv", static_cast(compute_uv)); + + TFE_OpSetAttrBool(op.get(), "full_matrices", static_cast(full_matrices)); + + const int num_outputs = 3; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + jlongArray outputs_array = env->NewLongArray(static_cast(num_outputs)); + jlong* output_elems = env->GetLongArrayElements(outputs_array, nullptr); + for (int i = 0; i < num_outputs; ++i) { + output_elems[i] = reinterpret_cast(outputs[i]); + } + env->ReleaseLongArrayElements(outputs_array, output_elems, 0); + return outputs_array; +} diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.h b/modules/jni/src/main/native/generated/tensor_linalg_ops.h new file mode 100644 index 000000000..0ebd73aac --- /dev/null +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.h @@ -0,0 +1,101 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_platanios_tensorflow_jni_generated_tensors_Linalg__ */ + +#ifndef _Included_org_platanios_tensorflow_jni_generated_tensors_Linalg__ +#define _Included_org_platanios_tensorflow_jni_generated_tensors_Linalg__ +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: cholesky + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_cholesky + (JNIEnv *, jobject, jlong, jlong); + +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: choleskyGrad + * Signature: (JJJ)J + */ +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_choleskyGrad + (JNIEnv *, jobject, jlong, jlong, jlong); + +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: logMatrixDeterminant + * Signature: (JJ)[J + */ +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_logMatrixDeterminant + (JNIEnv *, jobject, jlong, jlong); + +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: matrixDeterminant + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixDeterminant + (JNIEnv *, jobject, jlong, jlong); + +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: matrixInverse + * Signature: (JJZ)J + */ +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixInverse + (JNIEnv *, jobject, jlong, jlong, jboolean); + +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: matrixSolve + * Signature: (JJJZ)J + */ +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixSolve + (JNIEnv *, jobject, jlong, jlong, jlong, jboolean); + +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: matrixSolveLs + * Signature: (JJJJZ)J + */ +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixSolveLs + (JNIEnv *, jobject, jlong, jlong, jlong, jlong, jboolean); + +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: matrixTriangularSolve + * Signature: (JJJZZ)J + */ +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixTriangularSolve + (JNIEnv *, jobject, jlong, jlong, jlong, jboolean, jboolean); + +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: qr + * Signature: (JJZ)[J + */ +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_qr + (JNIEnv *, jobject, jlong, jlong, jboolean); + +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: selfAdjointEigV2 + * Signature: (JJZ)[J + */ +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_selfAdjointEigV2 + (JNIEnv *, jobject, jlong, jlong, jboolean); + +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: svd + * Signature: (JJZZ)[J + */ +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_svd + (JNIEnv *, jobject, jlong, jlong, jboolean, jboolean); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala new file mode 100644 index 000000000..123b20223 --- /dev/null +++ b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala @@ -0,0 +1,36 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ + +/* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package org.platanios.tensorflow.jni.generated.tensors + +import org.platanios.tensorflow.jni.TensorFlow + +object Linalg { + TensorFlow.load() + + @native def cholesky(contextHandle: Long, input: Long): Long + @native def choleskyGrad(contextHandle: Long, l: Long, grad: Long): Long + @native def logMatrixDeterminant(contextHandle: Long, input: Long): Array[Long] + @native def matrixDeterminant(contextHandle: Long, input: Long): Long + @native def matrixInverse(contextHandle: Long, input: Long, adjoint: Boolean): Long + @native def matrixSolve(contextHandle: Long, matrix: Long, rhs: Long, adjoint: Boolean): Long + @native def matrixSolveLs(contextHandle: Long, matrix: Long, rhs: Long, l2_regularizer: Long, fast: Boolean): Long + @native def matrixTriangularSolve(contextHandle: Long, matrix: Long, rhs: Long, lower: Boolean, adjoint: Boolean): Long + @native def qr(contextHandle: Long, input: Long, full_matrices: Boolean): Array[Long] + @native def selfAdjointEigV2(contextHandle: Long, input: Long, compute_v: Boolean): Array[Long] + @native def svd(contextHandle: Long, input: Long, compute_uv: Boolean, full_matrices: Boolean): Array[Long] +}