ProductPromotion
Logo

Scala

made by https://0x3d.site

GitHub - eaplatanios/tensorflow_scala: TensorFlow API for the Scala Programming Language
TensorFlow API for the Scala Programming Language. Contribute to eaplatanios/tensorflow_scala development by creating an account on GitHub.
Visit Site

GitHub - eaplatanios/tensorflow_scala: TensorFlow API for the Scala Programming Language

GitHub - eaplatanios/tensorflow_scala: TensorFlow API for the Scala Programming Language


CircleCI Codacy Badge License API Docs JNI Docs Data Docs Examples Docs

This library is a Scala API for https://www.tensorflow.org. It attempts to provide most of the functionality provided by the official Python API, while at the same type being strongly-typed and adding some new features. It is a work in progress and a project I started working on for my personal research purposes. Much of the API should be relatively stable by now, but things are still likely to change.

Chat Room

Please refer to the main website for documentation and tutorials. Here are a few useful links:

Citation

It would be greatly appreciated if you could cite this project using the following BibTex entry, if you end up using it in your work:

@misc{Platanios:2018:tensorflow-scala,
  title        = {{TensorFlow Scala}},
  author       = {Platanios, Emmanouil Antonios},
  howpublished = {\url{https://github.com/eaplatanios/tensorflow_scala}},
  year         = {2018}
}

Main Features

  • Easy manipulation of tensors and computations involving tensors (similar to NumPy in Python):

    val t1 = Tensor(1.2, 4.5)
    val t2 = Tensor(-0.2, 1.1)
    t1 + t2 == Tensor(1.0, 5.6)
    
  • Low-level graph construction API, similar to that of the Python API, but strongly typed wherever possible:

    val inputs      = tf.placeholder[Float](Shape(-1, 10))
    val outputs     = tf.placeholder[Float](Shape(-1, 10))
    val predictions = tf.nameScope("Linear") {
      val weights = tf.variable[Float]("weights", Shape(10, 1), tf.ZerosInitializer)
      tf.matmul(inputs, weights)
    }
    val loss        = tf.sum(tf.square(predictions - outputs))
    val optimizer   = tf.train.AdaGrad(1.0f)
    val trainOp     = optimizer.minimize(loss)
    
  • Numpy-like indexing/slicing for tensors. For example:

    tensor(2 :: 5, ---, 1) // is equivalent to numpy's 'tensor[2:5, ..., 1]'
    
  • High-level API for creating, training, and using neural networks. For example, the following code shows how simple it is to train a multi-layer perceptron for MNIST using TensorFlow for Scala. Here we omit a lot of very powerful features such as summary and checkpoint savers, for simplicity, but these are also very simple to use.

    // Load and batch data using pre-fetching.
    val dataset = MNISTLoader.load(Paths.get("/tmp"))
    val trainImages = tf.data.datasetFromTensorSlices(dataset.trainImages.toFloat)
    val trainLabels = tf.data.datasetFromTensorSlices(dataset.trainLabels.toLong)
    val trainData =
      trainImages.zip(trainLabels)
          .repeat()
          .shuffle(10000)
          .batch(256)
          .prefetch(10)
    
    // Create the MLP model.
    val input = Input(FLOAT32, Shape(-1, 28, 28))
    val trainInput = Input(INT64, Shape(-1))
    val layer = Flatten[Float]("Input/Flatten") >>
        Linear[Float]("Layer_0", 128) >> ReLU[Float]("Layer_0/Activation", 0.1f) >>
        Linear[Float]("Layer_1", 64) >> ReLU[Float]("Layer_1/Activation", 0.1f) >>
        Linear[Float]("Layer_2", 32) >> ReLU[Float]("Layer_2/Activation", 0.1f) >>
        Linear[Float]("OutputLayer", 10)
    val loss = SparseSoftmaxCrossEntropy[Float, Long, Float]("Loss") >>
        Mean("Loss/Mean")
    val optimizer = tf.train.GradientDescent(1e-6f)
    val model = Model.simpleSupervised(input, trainInput, layer, loss, optimizer)
    
    // Create an estimator and train the model.
    val estimator = InMemoryEstimator(model)
    estimator.train(() => trainData, StopCriteria(maxSteps = Some(1000000)))
    

    And by changing a few lines to the following code, you can get checkpoint capability, summaries, and seamless integration with TensorBoard:

    val loss = SparseSoftmaxCrossEntropy[Float, Long, Float]("Loss") >>
        Mean("Loss/Mean") >>
        ScalarSummary(name = "Loss", tag = "Loss")
    val summariesDir = Paths.get("/tmp/summaries")
    val estimator = InMemoryEstimator(
      modelFunction = model,
      configurationBase = Configuration(Some(summariesDir)),
      trainHooks = Set(
        SummarySaver(summariesDir, StepHookTrigger(100)),
        CheckpointSaver(summariesDir, StepHookTrigger(1000))),
      tensorBoardConfig = TensorBoardConfig(summariesDir))
    estimator.train(() => trainData, StopCriteria(maxSteps = Some(100000)))
    

    If you now browse to https://127.0.0.1:6006 while training, you can see the training progress:

  • Efficient interaction with the native library that avoids unnecessary copying of data. All tensors are created and managed by the native TensorFlow library. When they are passed to the Scala API (e.g., fetched from a TensorFlow session), we use a combination of weak references and a disposing thread running in the background. Please refer to tensorflow/src/main/scala/org/platanios/tensorflow/api/utilities/Disposer.scala, for the implementation.

Compiling from Source

Note that in order to compile TensorFlow Scala on your machine you will need to first install the TensorFlow Python API. You also need to make sure that you have a python3 alias for your python binary. This is used by CMake to find the TensorFlow header files in your installation.

Tutorials

Funding

Funding for the development of this library has been generously provided by the following sponsors:

CMU Presidential Fellowship National Science Foundation Air Force Office of Scientific Research
awarded to Emmanouil Antonios Platanios Grant #: IIS1250956 Grant #: FA95501710218

TensorFlow, the TensorFlow logo, and any related marks are trademarks of Google Inc.

More Resources
to explore the angular.

mail [email protected] to add your project or resources here 🔥.

Related Articles
to learn about angular.

FAQ's
to learn more about Angular JS.

mail [email protected] to add more queries here 🔍.

More Sites
to check out once you're finished browsing here.

0x3d
https://www.0x3d.site/
0x3d is designed for aggregating information.
NodeJS
https://nodejs.0x3d.site/
NodeJS Online Directory
Cross Platform
https://cross-platform.0x3d.site/
Cross Platform Online Directory
Open Source
https://open-source.0x3d.site/
Open Source Online Directory
Analytics
https://analytics.0x3d.site/
Analytics Online Directory
JavaScript
https://javascript.0x3d.site/
JavaScript Online Directory
GoLang
https://golang.0x3d.site/
GoLang Online Directory
Python
https://python.0x3d.site/
Python Online Directory
Swift
https://swift.0x3d.site/
Swift Online Directory
Rust
https://rust.0x3d.site/
Rust Online Directory
Scala
https://scala.0x3d.site/
Scala Online Directory
Ruby
https://ruby.0x3d.site/
Ruby Online Directory
Clojure
https://clojure.0x3d.site/
Clojure Online Directory
Elixir
https://elixir.0x3d.site/
Elixir Online Directory
Elm
https://elm.0x3d.site/
Elm Online Directory
Lua
https://lua.0x3d.site/
Lua Online Directory
C Programming
https://c-programming.0x3d.site/
C Programming Online Directory
C++ Programming
https://cpp-programming.0x3d.site/
C++ Programming Online Directory
R Programming
https://r-programming.0x3d.site/
R Programming Online Directory
Perl
https://perl.0x3d.site/
Perl Online Directory
Java
https://java.0x3d.site/
Java Online Directory
Kotlin
https://kotlin.0x3d.site/
Kotlin Online Directory
PHP
https://php.0x3d.site/
PHP Online Directory
React JS
https://react.0x3d.site/
React JS Online Directory
Angular
https://angular.0x3d.site/
Angular JS Online Directory