A few experiments using the mixtures of experts technique.
Data trained on Cifar10 dataset. Ten experts were pretrained using binary classification on similar classes (cat vs dog, deer vs deer, etc.).
Gating network was trained using the following methods:
All gate layers used relu activation. Best accuracy was with the standard gating, at roughly 84%.
from keras.losses import mse, binary_crossentropy, mae
import numpy as np
import tensorflow as tf
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Input, Dense, Dropout, Activation, Flatten, MaxPooling2D, Conv2D, Reshape,Conv2DTranspose
from keras.models import Model
from keras import backend as K
import os
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoard
from keras.layers import concatenate,BatchNormalization, Input
from keras.layers import Concatenate
import numpy as np
from keras.layers.core import Dense,Dropout,Activation,Flatten,Lambda
from keras import models
from keras.layers import multiply,add
from keras import regularizers
from keras.callbacks import History
import os
import math
from keras.optimizers import Adam
batch_size = 50
num_classes = 10
weight_decay = 1e-4
datagen = ImageDataGenerator(
featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
zca_epsilon=1e-06, # epsilon for ZCA whitening
rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
# randomly shift images horizontally (fraction of total width)
width_shift_range=0.1,
# randomly shift images vertically (fraction of total height)
height_shift_range=0.1,
shear_range=0., # set range for random shear
zoom_range=0., # set range for random zoom
channel_shift_range=0., # set range for random channel shifts
# set mode for filling points outside the input boundaries
fill_mode='nearest',
cval=0., # value used for fill_mode = "constant"
horizontal_flip=True, # randomly flip images
vertical_flip=False, # randomly flip images
# set rescaling factor (applied before any other transformation)
rescale=None,
)
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
labels=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
inputs = Input(shape=x_train.shape[1:])
def base_model(filters, name):
c1=Conv2D(filters, (3,3), padding='same',name='base1_'+name, kernel_regularizer=regularizers.l2(weight_decay), input_shape=x_train.shape[1:])(inputs)
c2=Activation('elu',name='base2_'+name)(c1)
c3=BatchNormalization(name='base3_'+name)(c2)
c4=Conv2D(filters, (3,3),name='base4_'+name, padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c3)
c5=Activation('elu',name='base5_'+name)(c4)
c6=BatchNormalization(name='base6_'+name)(c5)
c7=MaxPooling2D(pool_size=(2,2),name='base7_'+name)(c6)
c8=Dropout(0.2,name='base8_'+name)(c7)
c9=Conv2D(filters*2, (3,3),name='base9_'+name, padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c8)
c10=Activation('elu',name='base10_'+name)(c9)
c11=BatchNormalization(name='base11_'+name)(c10)
c12=Conv2D(filters*2, (3,3),name='base12_'+name, padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c11)
c13=Activation('elu',name='base13_'+name)(c12)
c14=BatchNormalization(name='base14_'+name)(c13)
c15=MaxPooling2D(pool_size=(2,2),name='base15_'+name)(c14)
c16=Dropout(0.3,name='base16_'+name)(c15)
c17=Conv2D(filters*4, (3,3),name='base17_'+name, padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c16)
c18=Activation('elu',name='base18_'+name)(c17)
c19=BatchNormalization(name='base19_'+name)(c18)
c20=Conv2D(filters*4, (3,3),name='base20_'+name, padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c19)
c21=Activation('elu',name='base21_'+name)(c20)
c22=BatchNormalization(name='base22_'+name)(c21)
c23=MaxPooling2D(pool_size=(2,2),name='base23_'+name)(c22)
c24=Dropout(0.4,name='base24_'+name)(c23)
c25=Flatten(name='base25_'+name)(c24)
c26=Dense(num_classes,name='base26_'+name)(c25)
c27=Activation('softmax',name='base27_'+name)(c26)
return Model(inputs=inputs, outputs=c27)
def gatingNetwork():
c1=Conv2D(32, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay), input_shape=x_train.shape[1:],name='gate1')(inputs)
c2=Activation('elu',name='gate2')(c1)
c3=BatchNormalization(name='gate3')(c2)
c4=Conv2D(32, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay),name='gate4')(c3)
c5=Activation('elu',name='gate5')(c4)
c6=BatchNormalization(name='gate6')(c5)
c7=MaxPooling2D(pool_size=(2,2),name='gate7')(c6)
c8=Dropout(0.2,name='gate26')(c7)
c9=Conv2D(32*2, (3,3),name='gate8', padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c8)
c10=Activation('elu',name='gate9')(c9)
c11=BatchNormalization(name='gate25')(c10)
c12=Conv2D(32*2, (3,3),name='gate10', padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c11)
c13=Activation('elu',name='gate11')(c12)
c14=BatchNormalization(name='gate12')(c13)
c15=MaxPooling2D(pool_size=(2,2),name='gate13')(c14)
c16=Dropout(0.3,name='gate14')(c15)
c17=Conv2D(32*4, (3,3), padding='same',name='gate15', kernel_regularizer=regularizers.l2(weight_decay))(c16)
c18=Activation('elu',name='gate16')(c17)
c19=BatchNormalization(name='gate17')(c18)
c20=Conv2D(32*4, (3,3),name='gate18', padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c19)
c21=Activation('elu',name='gate19')(c20)
c22=BatchNormalization(name='gate20')(c21)
c23=MaxPooling2D(pool_size=(2,2),name='gate21')(c22)
c24=Dropout(0.4,name='gate22')(c23)
c25=Flatten(name='gate23')(c24)
c26=Dense(10,name='gate24',activation='elu')(c25)
model=Model(inputs=inputs, outputs=c26)
return model
#labels=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
def train_base_models(weights_file_in):
for i in range(5):
if(i==0):
second=8
if(i==1):
second=9
if(i==2):
second=6
if(i==3):
second=5
if(i==4):
second=7
model=models[i]
weights_file=weights_file_in +str(i)
checkpointer=ModelCheckpoint(weights_file+'.h5', monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=True, mode='auto')
reduce_lr = ReduceLROnPlateau(monitor='val_acc', factor=0.5, patience=2, min_lr=0.000001, verbose=1)
earlystopper=EarlyStopping(monitor='val_acc', min_delta=0.00001, patience=10, verbose=1, mode='auto')
model.compile(loss='categorical_crossentropy',
optimizer=Adam(),
metrics=['accuracy'])
callbacks_list = [reduce_lr,earlystopper,checkpointer]
y=[j for j in range(len(y_train)) if y_train[j][i]==1 or y_train[j][second]==1]
x=x_train[y]
y=y_train[y]
y_val=[j for j in range(len(y_test)) if y_test[j][i]==1 or y_test[j][second]==1]
x_val=x_test[y_val]
y_val=y_test[y_val]
model.fit_generator(datagen.flow(x, y, batch_size=batch_size),
epochs=100,
steps_per_epoch=len(x) / batch_size,
validation_data=(x_val, y_val),callbacks=callbacks_list,
workers=4, verbose=2)
#labels=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
def train_base_models2(weights_file_in):
for i in range(5):
if(i==0):
second=9
if(i==1):
second=8
if(i==2):
second=7
if(i==3):
second=6
if(i==4):
second=5
for i in range(5):
model=models[i+5]
weights_file=weights_file_in+str(i+5)
checkpointer=ModelCheckpoint(weights_file+'.h5', monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=True, mode='auto')
reduce_lr = ReduceLROnPlateau(monitor='val_acc', factor=0.5, patience=2, min_lr=0.000001, verbose=1)
earlystopper=EarlyStopping(monitor='val_acc', min_delta=0.00001, patience=10, verbose=1, mode='auto')
model.compile(loss='categorical_crossentropy',
optimizer=Adam(),
metrics=['accuracy'])
callbacks_list = [reduce_lr,earlystopper,checkpointer]
y=[j for j in range(len(y_train)) if y_train[j][i]==1 or y_train[j][second]==1]
x=x_train[y]
y=y_train[y]
y_val=[j for j in range(len(y_test)) if y_test[j][i]==1 or y_test[j][second]==1]
x_val=x_test[y_val]
y_val=y_test[y_val]
model.fit_generator(datagen.flow(x, y, batch_size=batch_size),
epochs=100,
steps_per_epoch=len(x) / batch_size,
validation_data=(x_val, y_val),callbacks=callbacks_list,
workers=4, verbose=2)
models=[base_model(32,"1"),base_model(32,"2"),base_model(32,"3"),base_model(32,"4"),base_model(32,"5"),
base_model(32,"6"),base_model(32,"7"),base_model(32,"8"),base_model(32,"9"),base_model(32,"10")]
base_weights_file='weights/base_model_'
train_base=False
if(train_base):
train_base_models(base_weights_file)
train_base_models2(base_weights_file)
# Loading base model weights
def load_weights(model, weights_file):
for a in range(5):
m=models[a]
file=weights_file+str(a)+'.h5'
m.load_weights(file,by_name=True)
for b in m.layers:
for l in model.layers:
if(l.name==b.name):
l.set_weights(b.get_weights())
#3print("loaded")
for l in model.layers:
if('gate' in l.name or 'lambda' in l.name or 'encoder'in l.name or 'decoder' in l.name):
l.trainable=True
#print("training gate ")
else:
l.trainable=False
def gating_multiplier(gate,branches):
forLambda=[gate]
forLambda.extend(branches)
add= Lambda(lambda x:K.tf.transpose(
sum(K.tf.transpose(forLambda[i]) *
forLambda[0][:, i-1] for i in range(1,len(forLambda))
)
))(forLambda)
return add
def merge_two_gates(gates,branches):
o1, o2 ,o3,o4,o5,o6,o7,o8,o9,o10= branches
add= Lambda(lambda x:K.tf.transpose
(K.tf.transpose(o1) * gates[0][:, 0]+K.tf.transpose(o1)* gates[1][:, 0]
+ K.tf.transpose(o2) * gates[0][:, 1]+K.tf.transpose(o2)* gates[1][:, 1]
+ K.tf.transpose(o3) * gates[0][:, 2]+K.tf.transpose(o3)* gates[1][:, 2]
+ K.tf.transpose(o4) * gates[0][:, 3]+K.tf.transpose(o4) * gates[1][:, 3]
+ K.tf.transpose(o5) * gates[0][:, 4]+K.tf.transpose(o5)* gates[1][:, 4]
+ K.tf.transpose(o6) * gates[0][:, 5]+K.tf.transpose(o6)* gates[1][:, 5]
+ K.tf.transpose(o7) * gates[0][:, 6]+K.tf.transpose(o7)* gates[1][:, 6]
+ K.tf.transpose(o8) * gates[0][:, 7]+K.tf.transpose(o8)* gates[1][:, 7]
+ K.tf.transpose(o9) * gates[0][:, 8]+K.tf.transpose(o9) * gates[1][:, 8]
+ K.tf.transpose(o10) * gates[0][:, 9]+K.tf.transpose(o10)* gates[1][:, 9]
))([gates[0],gates[1],o1,o2,o3,o4,o5,o6,o7,o8,o9,o10])
return add
def slices_to_dims(slice_indices):
"""
Args:
slice_indices: An [N, k] Tensor mapping to column indices.
Returns:
An index Tensor with shape [N * k, 2], corresponding to indices suitable for
passing to SparseTensor.
"""
slice_indices = tf.cast(slice_indices, tf.int64)
num_rows = tf.shape(slice_indices, out_type=tf.int64)[0]
row_range = tf.range(num_rows)
item_numbers = slice_indices * num_rows + tf.expand_dims(row_range, axis=1)
item_numbers_flat = tf.reshape(item_numbers, [-1])
return tf.stack([item_numbers_flat % num_rows,
item_numbers_flat // num_rows], axis=1)
def sparseGating(inputs_,gates=2):
indi=tf.cast(tf.math.top_k(inputs_,gates, sorted=False).indices,dtype=tf.int64)
v=tf.math.top_k(gate_network.layers[-1].output,gates, sorted=False).values
sparse_indices = slices_to_dims(indi)
sparse = tf.sparse_reorder(tf.SparseTensor(
indices=sparse_indices, values=tf.reshape(v, [-1]),
dense_shape=tf.cast(tf.shape(gate_network.layers[-1].output),dtype=tf.int64)))
c=tf.zeros_like(gate_network.layers[-1].output)
d=tf.sparse_add(c, sparse)
z =tf.ones_like(gate_network.layers[-1].output)*-np.inf
mask = tf.less_equal(d, tf.zeros_like(d))
new_tensor = tf.multiply(z, tf.cast(mask, dtype=tf.float32))
g=tf.where(tf.is_nan(new_tensor), tf.zeros_like(new_tensor), new_tensor)
g=tf.sparse_add(g,sparse)
b=Lambda(lambda a:g)(gate_network.layers[-1].output)
return b
def stack(a,b,gates=1):
b=Lambda(lambda z:tf.stack([a,b]))([a,b])
return b
def sort(stacked,gates=1):
x=tf.math.reduce_max(stacked,axis=0)
b=Lambda(lambda a:x)(stacked)
return b
def train_lambda_model(model, weights_file):
history=History()
highest_acc=0.8
iterationsWithoutImprovement=0
lr=.001
for i in range(100):
if(os.path.isfile(weights_file+'.hdf5')):
model.load_weights(weights_file+'.hdf5')
#load_weights()
hist = model.fit_generator(datagen.flow(x_train, y_train, batch_size=50),
epochs=1,
steps_per_epoch=len(x_train) / 50,
validation_data=(x_test, y_test),callbacks=[history],
workers=4, verbose=1)
val_acc=history.history['val_acc'][-1]
if(val_acc>highest_acc):
model.save_weights(weights_file+'.hdf5')
print("Saving weights, new highest accuracy: "+str(val_acc))
highest_acc=val_acc
iterationsWithoutImprovement=0
else:
iterationsWithoutImprovement+=1
if(iterationsWithoutImprovement>3):
lr*=.5
K.set_value(model.optimizer.lr,lr )
print("Learning rate reduced to: "+str(lr))
iterationsWithoutImprovement=0
which_model_to_run='full_multiple_gates'
if(which_model_to_run=='sparse'):
weights_file='weights/sparse'
layer=sparseGating(gate_network.layers[-1].output,gates=2)
b=Activation('relu',name='sparse_relu')(layer)
merged=gating_multiplier(b,[m.layers[-1].output for m in models])
b=Activation('softmax',name='gatex')(merged)
model = Model(inputs=inputs, outputs=b)
load_weights(model,base_weights_file)
model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=0.0001),metrics=['accuracy'])
elif(which_model_to_run=='twoGates_takeMax'):
gate_network=gatingNetwork_2outputs()
weights_file='weights/twoGates_takeMax'
merged1=gating_multiplier(gate_network.output[0],[m.layers[-1].output for m in models])
merged2=gating_multiplier(gate_network.output[1],[m.layers[-1].output for m in models])
b=stack(merged1,merged2)
b=sort(b)
b=Activation('softmax',name='gatex')(b)
model = Model(inputs=inputs, outputs=b)
load_weights(model,base_weights_file)
model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=.0005),metrics=['accuracy'])
elif(which_model_to_run=='full_multiple_gates'):
gate_network=gatingNetwork_2outputs()
weights_file='weights/full_multiple_gates'
merged1=merge_two_gates(gate_network.output,[m.layers[-1].output for m in models])
b=Activation('softmax',name='gatex')(merged1)
model = Model(inputs=inputs, outputs=b)
load_weights(model,base_weights_file)
load_weights_gate(model, 'weights/full_multiple_gates.hdf5')
model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=.0005),metrics=['accuracy'])
elif(which_model_to_run=='full'):
#new highest accuracy: 0.8363999933004379
weights_file='weights/moe_full'
merged=gating_multiplier(gate_network.layers[-1].output,[m.layers[-1].output for m in models])
b=Activation('softmax',name='gatex')(merged)
model = Model(inputs=inputs, outputs=b)
load_weights(model,base_weights_file)
model.compile(loss='categorical_crossentropy',optimizer=Adam(),metrics=['accuracy'])
train=True
if(train):
train_lambda_model(model, weights_file)
intermediate_layer_model = Model(inputs=model.input,outputs=model.get_layer('lambda_1').output)
val = intermediate_layer_model.predict(x_train)
else:
load_weights(model,base_weights_file)
if(which_model_to_run=='vae'):
vae.load_weights('weights/vae.h5')
elif(which_model_to_run=='full'):
model.load_weights('weights/moe_full.hdf5')