argwhere (tensorflow)

code example:

Get arguments with a Tensor and delete them

import tensorflow as tf

a = tf.constant([3, 4, 4, 5, 5, 3])                 # matrix
b = tf.constant([0.2, 0.9, 0.9, 0.1, 0.8, 0.9])     # matrix

indices1 = tf.where(tf.less(a, 5))                  # [[0][5]]
indices2 = tf.where(tf.less_equal(b, .3))           # [[0][3]]

result_a = tf.gather_nd(a, indices1)
result_b = tf.gather_nd(b, indices2)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    init.run()
    print(sess.run(indices1).T)       # [[0 1 2 5]]
    print(sess.run(indices2).T)       # [[0 3]]
    print(sess.run(result_a))         # [3 4 4 3]
    print(sess.run(result_b))         # [0.2 0.1]
    
Manuel Cuevas

Manuel Cuevas

Hello, I'm Manuel Cuevas a Software Engineer with background in machine learning and artificial intelligence.

Recent post