# Example code showing how to run distributed training across multiple GPUs or nodes
# using horovod and MPI
import tflon
import tensorflow as tf
import pandas as pd
from pkg_resources import resource_filename
class NeuralNet(tflon.model.Model):
def _model(self):
I = self.add_input('desc', shape=[None, 210])
T = self.add_target('targ', shape=[None, 1])
net = tflon.toolkit.WindowInput() |\
tflon.toolkit.Dense(20, activation=tf.tanh) |\
tflon.toolkit.Dense(5, activation=tf.tanh) |\
tflon.toolkit.Dense(1)
L = net(I)
self.add_output( "pred", tf.nn.sigmoid(L) )
self.add_loss( "xent", tflon.toolkit.xent_uniform_sum(T, L) )
self.add_loss( "l2", tflon.toolkit.l2_penalty(self.weights) )
self.add_metric( 'auc', tflon.toolkit.auc(T, L) )
if __name__=='__main__':
tflon.data.TensorQueue.DEFAULT_TIMEOUT=1000
# Initialize horovod and setup gpu resources
config = tflon.distributed.init_distributed_resources()
graph = tf.Graph()
with graph.as_default():
# Add a model instance
NN = NeuralNet(use_gpu=True)
# Create the distributed trainer
trainer = tflon.distributed.DistributedTrainer( tf.train.AdamOptimizer(1e-3), iterations=1000 )
# Create the data feed, use the same feed for all process instances
# tflon.distributed.DistributedTable adds MPI synchronization to the Table API min and max ops
# Usually, different data would be loaded on each process (see tflon.distributed.make_distributed_table_feed)
tsv_reader = lambda fpath: pd.read_csv(fpath, sep='\t', dtype={'ID':str}).set_index('ID')
schema = NN.schema.map(desc=('descriptors.tsv', tsv_reader), targ=('targets.tsv', tsv_reader))
# Look at tflon_test/data/distributed to see how shards are organized on disk
feed = tflon.distributed.make_distributed_table_feed( resource_filename('tflon_test.data', 'distributed'), schema, master_table='desc',partition_strategy='all' )
with tf.Session(graph=graph, config=config):
# Train with minibatch size 100
NN.fit( feed.shuffle(batch_size=100), trainer, restarts=2, source_tables=feed )
# Perform inference on the master process
if tflon.distributed.is_master():
auc = NN.evaluate( feed, query='auc' )
print "AUC:", auc
assert auc > 0.8