import tensorflow as tf
import os
import numpy as np
import pandas as pd
from scipy.misc import imread
from scipy.misc import imresize
from scipy.stats import spearmanr

from sklearn.metrics import accuracy_score
from sklearn import datasets, linear_model
import sklearn as sk
import time

s = 64
seed = 128
rng = np.random.RandomState(seed)

data_dir = os.path.abspath("/home/yihong/temp/image-regression-dataset3");
data = pd.read_csv(os.path.join(data_dir, "data.csv"), header=None)
print(len(data))

read_start_time = time.time()
temp = []
for img_name in data[0]:
    filepath = os.path.join(data_dir, img_name)
    img = imread(filepath+"_resized.jpg", flatten=True)
    img = img.astype('float32')
    temp.append(img)
data_x = np.stack(temp)
print "finished reading images. Time used: ", (time.time() - read_start_time)
data_x2 = data.iloc[:,1:3].values

split = int(data_x.shape[0]*0.5)
train_x, test_x = data_x[:split], data_x[split:]
train_x2, test_x2 = data_x2[:split], data_x2[split:]


# number of neurons in each layer
input_num_units = s*s
input_num_units2 = 2
full1_num_units = 100
full2_num_units = 60

# define placeholders
x = tf.placeholder(tf.float32, [None, input_num_units])
x2 = tf.placeholder(tf.float32, [None, input_num_units2])

y = tf.placeholder(tf.float32, [None])

input_layer = tf.reshape(x, [-1, s, s, 1])
conv1 = tf.layers.conv2d(inputs=input_layer,filters=32,kernel_size=[5, 5],
                         padding="same",activation=tf.nn.relu)
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

conv2 = tf.layers.conv2d(inputs=pool1,filters=64,kernel_size=[5, 5],
                         padding="same",activation=tf.nn.relu)

pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

pool2_out_num = 16 * 16 * 64
pool2flat = tf.reshape(pool2, [-1, pool2_out_num])

weights = {
    'full1': tf.Variable(tf.random_normal([pool2_out_num+input_num_units2, full1_num_units], seed=seed)),
    'full2': tf.Variable(tf.random_normal([full1_num_units, full2_num_units], seed=seed)),
    'output': tf.Variable(tf.random_normal([full2_num_units, 1], seed=seed))
}

biases = {
    'full1': tf.Variable(tf.random_normal([full1_num_units], seed=seed)),
    'full2': tf.Variable(tf.random_normal([full2_num_units], seed=seed)),
    'output': tf.Variable(tf.random_normal([1], seed=seed))
}

full1_input = tf.concat([pool2flat, x2], 1)

full_layer1 = tf.add(tf.matmul(full1_input, weights['full1']), biases['full1'])
full_layer1 = tf.nn.relu(full_layer1)

full_layer2 = tf.add(tf.matmul(full_layer1, weights['full2']), biases['full2'])
full_layer2 = tf.nn.relu(full_layer2)

output_layer = tf.matmul(full_layer2, weights['output']) + biases['output']

cost = tf.reduce_mean(tf.square(tf.transpose(output_layer)-y))
learning_rate = 0.001
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
init = tf.global_variables_initializer()
epochs = 2000
batch_size = 128

saver = tf.train.Saver()

data_y = data[4]
train_y, test_y = data_y[:split], data_y[split:]

print "Starting training wrt nlike..."
for run_i in range(10):
    print "=== run ", run_i, " ==="
    
    with tf.Session() as sess:
        # create initialized variables
        sess.run(init)
        for epoch in range(epochs):
            start_time = time.time()
            avg_cost = 0
            total_batch = int(train_x.shape[0]/batch_size)
            for i in range(total_batch-1):
                batch_x = train_x[i*batch_size:(i+1)*batch_size].reshape(-1, input_num_units)
                batch_x2 = train_x2[i*batch_size:(i+1)*batch_size].reshape(-1, input_num_units2)
                batch_y = train_y[i*batch_size:(i+1)*batch_size]
                _, c, p = sess.run([optimizer, cost, output_layer], feed_dict = {x: batch_x, x2: batch_x2, y: batch_y})
                #print "cost", c
                avg_cost += c / total_batch
            if (epoch+1) % 20 == 0:
                print ("Epoch:", '%04d' % (epoch+1), "cost=", \
                "{:.9f}".format(avg_cost), " time=", (time.time() - start_time))
    
                total_batch_test = int(test_x.shape[0]/batch_size)
                y_pred = []
                y_label = []
                for i in range(total_batch_test - 1):
                    batch_x = test_x[i*batch_size:(i+1)*batch_size].reshape(-1, input_num_units)
                    batch_x2 = test_x2[i*batch_size:(i+1)*batch_size].reshape(-1, input_num_units2)
                    batch_y = test_y[i*batch_size:(i+1)*batch_size]
                    c, p = sess.run([cost, output_layer], feed_dict={x: batch_x.reshape(-1, input_num_units), x2: batch_x2.reshape(-1,input_num_units2), y:batch_y})
                    y_pred.extend(np.hstack(p.flat))
                    y_label.extend(batch_y)
                mse = sum([(y_pred[i] - y_label[i]) ** 2 for i in range(len(y_pred))])/len(y_pred)
                print "Test MSE:", mse
                cor, pval = spearmanr(y_label, y_pred, axis=None)
                print "Spearman correlation:", cor

