Mixtures of Experts

A few experiments using the mixtures of experts technique.

Data trained on Cifar10 dataset. Ten experts were pretrained using binary classification on similar classes (cat vs dog, deer vs deer, etc.).

Gating network was trained using the following methods:

  • Standard Gating
  • Sparse Gating (no. experts=2)
  • Two Gate layers
  • Two Gate layers with a reduce_max function to take gate with highest confidence in its prediction.

All gate layers used relu activation. Best accuracy was with the standard gating, at roughly 84%.

Data Setup and Model Definitions

In [ ]:
from keras.losses import mse, binary_crossentropy, mae
import numpy as np
import tensorflow as tf
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Input, Dense, Dropout, Activation, Flatten, MaxPooling2D, Conv2D, Reshape,Conv2DTranspose
from keras.models import Model
from keras import backend as K
import os
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoard
from keras.layers import concatenate,BatchNormalization, Input
from keras.layers import Concatenate
import numpy as np
from keras.layers.core import Dense,Dropout,Activation,Flatten,Lambda
from keras import models
from keras.layers import multiply,add
from keras import regularizers
from keras.callbacks import History 
import os
import math
from keras.optimizers import Adam

batch_size = 50
num_classes = 10
weight_decay = 1e-4

datagen = ImageDataGenerator(
        featurewise_center=False,  # set input mean to 0 over the dataset
        samplewise_center=False,  # set each sample mean to 0
        featurewise_std_normalization=False,  # divide inputs by std of the dataset
        samplewise_std_normalization=False,  # divide each input by its std
        zca_whitening=False,  # apply ZCA whitening
        zca_epsilon=1e-06,  # epsilon for ZCA whitening
        rotation_range=0,  # randomly rotate images in the range (degrees, 0 to 180)
        # randomly shift images horizontally (fraction of total width)
        width_shift_range=0.1,
        # randomly shift images vertically (fraction of total height)
        height_shift_range=0.1,
        shear_range=0.,  # set range for random shear
        zoom_range=0.,  # set range for random zoom
        channel_shift_range=0.,  # set range for random channel shifts
        # set mode for filling points outside the input boundaries
        fill_mode='nearest',
        cval=0.,  # value used for fill_mode = "constant"
        horizontal_flip=True,  # randomly flip images
        vertical_flip=False,  # randomly flip images
        # set rescaling factor (applied before any other transformation)
        rescale=None,
)

(x_train, y_train), (x_test, y_test) = cifar10.load_data()


x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

labels=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)


inputs = Input(shape=x_train.shape[1:])


def base_model(filters, name):
    c1=Conv2D(filters, (3,3), padding='same',name='base1_'+name, kernel_regularizer=regularizers.l2(weight_decay), input_shape=x_train.shape[1:])(inputs)
    c2=Activation('elu',name='base2_'+name)(c1)
    c3=BatchNormalization(name='base3_'+name)(c2)
    c4=Conv2D(filters, (3,3),name='base4_'+name, padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c3)
    c5=Activation('elu',name='base5_'+name)(c4)
    c6=BatchNormalization(name='base6_'+name)(c5)
    c7=MaxPooling2D(pool_size=(2,2),name='base7_'+name)(c6)
    c8=Dropout(0.2,name='base8_'+name)(c7)
     
    c9=Conv2D(filters*2, (3,3),name='base9_'+name, padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c8)
    c10=Activation('elu',name='base10_'+name)(c9)
    c11=BatchNormalization(name='base11_'+name)(c10)
    c12=Conv2D(filters*2, (3,3),name='base12_'+name, padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c11)
    c13=Activation('elu',name='base13_'+name)(c12)
    c14=BatchNormalization(name='base14_'+name)(c13)
    c15=MaxPooling2D(pool_size=(2,2),name='base15_'+name)(c14)
    c16=Dropout(0.3,name='base16_'+name)(c15)
     
    c17=Conv2D(filters*4, (3,3),name='base17_'+name, padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c16)
    c18=Activation('elu',name='base18_'+name)(c17)
    c19=BatchNormalization(name='base19_'+name)(c18)
    c20=Conv2D(filters*4, (3,3),name='base20_'+name, padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c19)
    c21=Activation('elu',name='base21_'+name)(c20)
    c22=BatchNormalization(name='base22_'+name)(c21)
    c23=MaxPooling2D(pool_size=(2,2),name='base23_'+name)(c22)
    c24=Dropout(0.4,name='base24_'+name)(c23)
     
    c25=Flatten(name='base25_'+name)(c24)
    c26=Dense(num_classes,name='base26_'+name)(c25)
    c27=Activation('softmax',name='base27_'+name)(c26)
    return Model(inputs=inputs, outputs=c27)


def gatingNetwork():
    c1=Conv2D(32, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay), input_shape=x_train.shape[1:],name='gate1')(inputs)
    c2=Activation('elu',name='gate2')(c1)
    c3=BatchNormalization(name='gate3')(c2)
    c4=Conv2D(32, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay),name='gate4')(c3)
    c5=Activation('elu',name='gate5')(c4)
    c6=BatchNormalization(name='gate6')(c5)
    c7=MaxPooling2D(pool_size=(2,2),name='gate7')(c6)
    c8=Dropout(0.2,name='gate26')(c7)
     
    c9=Conv2D(32*2, (3,3),name='gate8', padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c8)
    c10=Activation('elu',name='gate9')(c9)
    c11=BatchNormalization(name='gate25')(c10)
    c12=Conv2D(32*2, (3,3),name='gate10', padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c11)
    c13=Activation('elu',name='gate11')(c12)
    c14=BatchNormalization(name='gate12')(c13)
    c15=MaxPooling2D(pool_size=(2,2),name='gate13')(c14)
    c16=Dropout(0.3,name='gate14')(c15)
     
    c17=Conv2D(32*4, (3,3), padding='same',name='gate15', kernel_regularizer=regularizers.l2(weight_decay))(c16)
    c18=Activation('elu',name='gate16')(c17)
    c19=BatchNormalization(name='gate17')(c18)
    c20=Conv2D(32*4, (3,3),name='gate18', padding='same', kernel_regularizer=regularizers.l2(weight_decay))(c19)
    c21=Activation('elu',name='gate19')(c20)
    c22=BatchNormalization(name='gate20')(c21)
    c23=MaxPooling2D(pool_size=(2,2),name='gate21')(c22)
    c24=Dropout(0.4,name='gate22')(c23)
     
    c25=Flatten(name='gate23')(c24)
    c26=Dense(10,name='gate24',activation='elu')(c25)

    model=Model(inputs=inputs, outputs=c26)
    return model

Training expert models

In [ ]:
#labels=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
def train_base_models(weights_file_in):
    for i in range(5):
        if(i==0):
            second=8
        if(i==1):
            second=9
        if(i==2):
            second=6
        if(i==3):
            second=5
        if(i==4):
            second=7
        model=models[i]
        weights_file=weights_file_in +str(i)
        checkpointer=ModelCheckpoint(weights_file+'.h5', monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=True, mode='auto')
        reduce_lr = ReduceLROnPlateau(monitor='val_acc', factor=0.5, patience=2, min_lr=0.000001, verbose=1)                                  
        earlystopper=EarlyStopping(monitor='val_acc', min_delta=0.00001, patience=10, verbose=1, mode='auto')        
        model.compile(loss='categorical_crossentropy',
                  optimizer=Adam(),
                  metrics=['accuracy'])
        callbacks_list = [reduce_lr,earlystopper,checkpointer]

        y=[j for j in range(len(y_train)) if y_train[j][i]==1 or y_train[j][second]==1]
        x=x_train[y]
        y=y_train[y]
        
        y_val=[j for j in range(len(y_test)) if y_test[j][i]==1 or y_test[j][second]==1]
        x_val=x_test[y_val]
        y_val=y_test[y_val]
        
        model.fit_generator(datagen.flow(x, y, batch_size=batch_size),
            epochs=100,
            steps_per_epoch=len(x) / batch_size,
            validation_data=(x_val, y_val),callbacks=callbacks_list,
            workers=4, verbose=2)

#labels=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
def train_base_models2(weights_file_in):
    for i in range(5):
        if(i==0):
            second=9
        if(i==1):
            second=8
        if(i==2):
            second=7
        if(i==3):
            second=6
        if(i==4):
            second=5
    for i in range(5):
        model=models[i+5]
        weights_file=weights_file_in+str(i+5)
        checkpointer=ModelCheckpoint(weights_file+'.h5', monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=True, mode='auto')
        reduce_lr = ReduceLROnPlateau(monitor='val_acc', factor=0.5, patience=2, min_lr=0.000001, verbose=1)                                  
        earlystopper=EarlyStopping(monitor='val_acc', min_delta=0.00001, patience=10, verbose=1, mode='auto')        
        model.compile(loss='categorical_crossentropy',
                  optimizer=Adam(),
                  metrics=['accuracy'])
        callbacks_list = [reduce_lr,earlystopper,checkpointer]

        y=[j for j in range(len(y_train)) if y_train[j][i]==1 or y_train[j][second]==1]
        x=x_train[y]
        y=y_train[y]
        
        y_val=[j for j in range(len(y_test)) if y_test[j][i]==1 or y_test[j][second]==1]
        x_val=x_test[y_val]
        y_val=y_test[y_val]
        
        model.fit_generator(datagen.flow(x, y, batch_size=batch_size),
            epochs=100,
            steps_per_epoch=len(x) / batch_size,
            validation_data=(x_val, y_val),callbacks=callbacks_list,
            workers=4, verbose=2)
        
        
models=[base_model(32,"1"),base_model(32,"2"),base_model(32,"3"),base_model(32,"4"),base_model(32,"5"),
        base_model(32,"6"),base_model(32,"7"),base_model(32,"8"),base_model(32,"9"),base_model(32,"10")]

base_weights_file='weights/base_model_'
train_base=False
if(train_base):
    train_base_models(base_weights_file)
    train_base_models2(base_weights_file)
    
    
    
# Loading base model weights
def load_weights(model, weights_file):
    for a in range(5):
        m=models[a]
        file=weights_file+str(a)+'.h5'
        m.load_weights(file,by_name=True)
        for b in m.layers:
            for l in model.layers:
                if(l.name==b.name):
                    l.set_weights(b.get_weights())
                    #3print("loaded")

    for l in model.layers:
        if('gate' in l.name or 'lambda' in l.name or 'encoder'in l.name or 'decoder' in l.name):
            l.trainable=True
            #print("training gate ")
        else:
            l.trainable=False

Functions for merging gating network with expert networks

In [ ]:
def gating_multiplier(gate,branches):
    forLambda=[gate]
    forLambda.extend(branches)
    add= Lambda(lambda x:K.tf.transpose(
        sum(K.tf.transpose(forLambda[i]) * 
            forLambda[0][:, i-1] for i in range(1,len(forLambda))
           )
    ))(forLambda)
    return add


def merge_two_gates(gates,branches):
    o1, o2 ,o3,o4,o5,o6,o7,o8,o9,o10= branches
    add= Lambda(lambda x:K.tf.transpose
                (K.tf.transpose(o1) * gates[0][:, 0]+K.tf.transpose(o1)* gates[1][:, 0]
                + K.tf.transpose(o2) * gates[0][:, 1]+K.tf.transpose(o2)* gates[1][:, 1]
                + K.tf.transpose(o3) * gates[0][:, 2]+K.tf.transpose(o3)* gates[1][:, 2]
                + K.tf.transpose(o4) * gates[0][:, 3]+K.tf.transpose(o4) * gates[1][:, 3]
                + K.tf.transpose(o5) * gates[0][:, 4]+K.tf.transpose(o5)* gates[1][:, 4]
                + K.tf.transpose(o6) * gates[0][:, 5]+K.tf.transpose(o6)* gates[1][:, 5]
                + K.tf.transpose(o7) * gates[0][:, 6]+K.tf.transpose(o7)* gates[1][:, 6]
                + K.tf.transpose(o8) * gates[0][:, 7]+K.tf.transpose(o8)* gates[1][:, 7]
                + K.tf.transpose(o9) * gates[0][:, 8]+K.tf.transpose(o9) * gates[1][:, 8]
                + K.tf.transpose(o10) * gates[0][:, 9]+K.tf.transpose(o10)* gates[1][:, 9]
                ))([gates[0],gates[1],o1,o2,o3,o4,o5,o6,o7,o8,o9,o10])
    return add




def slices_to_dims(slice_indices):
  """
  Args:
    slice_indices: An [N, k] Tensor mapping to column indices.
  Returns:
    An index Tensor with shape [N * k, 2], corresponding to indices suitable for
    passing to SparseTensor.
  """
  slice_indices = tf.cast(slice_indices, tf.int64)
  num_rows = tf.shape(slice_indices, out_type=tf.int64)[0]
  row_range = tf.range(num_rows)
  item_numbers = slice_indices * num_rows + tf.expand_dims(row_range, axis=1)
  item_numbers_flat = tf.reshape(item_numbers, [-1])
  return tf.stack([item_numbers_flat % num_rows, 
                   item_numbers_flat // num_rows], axis=1)


def sparseGating(inputs_,gates=2):
    indi=tf.cast(tf.math.top_k(inputs_,gates, sorted=False).indices,dtype=tf.int64)
    v=tf.math.top_k(gate_network.layers[-1].output,gates, sorted=False).values
    sparse_indices = slices_to_dims(indi)
    sparse = tf.sparse_reorder(tf.SparseTensor(
        indices=sparse_indices, values=tf.reshape(v, [-1]),
        dense_shape=tf.cast(tf.shape(gate_network.layers[-1].output),dtype=tf.int64)))
    c=tf.zeros_like(gate_network.layers[-1].output)
    d=tf.sparse_add(c, sparse)

    z =tf.ones_like(gate_network.layers[-1].output)*-np.inf
    
    mask = tf.less_equal(d,  tf.zeros_like(d))
    new_tensor = tf.multiply(z, tf.cast(mask, dtype=tf.float32))

    g=tf.where(tf.is_nan(new_tensor), tf.zeros_like(new_tensor), new_tensor)
    g=tf.sparse_add(g,sparse)

    b=Lambda(lambda a:g)(gate_network.layers[-1].output)
    return b



def stack(a,b,gates=1):
    b=Lambda(lambda z:tf.stack([a,b]))([a,b])
    return b

def sort(stacked,gates=1):
    x=tf.math.reduce_max(stacked,axis=0)
    b=Lambda(lambda a:x)(stacked)
    return b

Training main method

In [ ]:
def train_lambda_model(model, weights_file):
        history=History()
        highest_acc=0.8
        iterationsWithoutImprovement=0
        lr=.001
        for i in range(100):
            if(os.path.isfile(weights_file+'.hdf5')):
               model.load_weights(weights_file+'.hdf5')
            #load_weights()
            hist = model.fit_generator(datagen.flow(x_train, y_train, batch_size=50),
                                epochs=1,
                                steps_per_epoch=len(x_train) / 50,
                                validation_data=(x_test, y_test),callbacks=[history],
                                workers=4, verbose=1)
            val_acc=history.history['val_acc'][-1]
            if(val_acc>highest_acc):
                model.save_weights(weights_file+'.hdf5')
                print("Saving weights, new highest accuracy: "+str(val_acc))
                highest_acc=val_acc
                iterationsWithoutImprovement=0
            else:
                iterationsWithoutImprovement+=1
                if(iterationsWithoutImprovement>3):
                    lr*=.5
                    K.set_value(model.optimizer.lr,lr )
                    print("Learning rate reduced to: "+str(lr))
                    iterationsWithoutImprovement=0

Main

In [ ]:
which_model_to_run='full_multiple_gates'
if(which_model_to_run=='sparse'):
    weights_file='weights/sparse'
    layer=sparseGating(gate_network.layers[-1].output,gates=2)
    b=Activation('relu',name='sparse_relu')(layer)
    merged=gating_multiplier(b,[m.layers[-1].output for m in models])
    b=Activation('softmax',name='gatex')(merged)
    model = Model(inputs=inputs, outputs=b)
    load_weights(model,base_weights_file)
    model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=0.0001),metrics=['accuracy'])

elif(which_model_to_run=='twoGates_takeMax'):
    gate_network=gatingNetwork_2outputs()
    weights_file='weights/twoGates_takeMax'
    merged1=gating_multiplier(gate_network.output[0],[m.layers[-1].output for m in models])
    merged2=gating_multiplier(gate_network.output[1],[m.layers[-1].output for m in models])
    b=stack(merged1,merged2)
    b=sort(b)
    b=Activation('softmax',name='gatex')(b)
    model = Model(inputs=inputs, outputs=b)
    load_weights(model,base_weights_file)
    model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=.0005),metrics=['accuracy'])


elif(which_model_to_run=='full_multiple_gates'):
    gate_network=gatingNetwork_2outputs()
    weights_file='weights/full_multiple_gates'
    merged1=merge_two_gates(gate_network.output,[m.layers[-1].output for m in models])

    b=Activation('softmax',name='gatex')(merged1)
    model = Model(inputs=inputs, outputs=b)
    load_weights(model,base_weights_file)
    load_weights_gate(model, 'weights/full_multiple_gates.hdf5')
    model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=.0005),metrics=['accuracy'])

elif(which_model_to_run=='full'):
    #new highest accuracy: 0.8363999933004379
    weights_file='weights/moe_full'
    merged=gating_multiplier(gate_network.layers[-1].output,[m.layers[-1].output for m in models])
    b=Activation('softmax',name='gatex')(merged)
    model = Model(inputs=inputs, outputs=b)
    load_weights(model,base_weights_file)
    model.compile(loss='categorical_crossentropy',optimizer=Adam(),metrics=['accuracy'])


train=True
if(train):
        train_lambda_model(model, weights_file)
        intermediate_layer_model = Model(inputs=model.input,outputs=model.get_layer('lambda_1').output) 
        val = intermediate_layer_model.predict(x_train)


else:
    load_weights(model,base_weights_file)
    if(which_model_to_run=='vae'):
        vae.load_weights('weights/vae.h5')
    elif(which_model_to_run=='full'):
        model.load_weights('weights/moe_full.hdf5')