From 383b39f38f62a916d4be00ce86fddb0f70847bce Mon Sep 17 00:00:00 2001 From: JianYan-g Date: Sun, 23 May 2021 01:18:52 +0900 Subject: [PATCH] Update the weight matrix and the matrix product There is no need to calculate the distance in a quadratic form by generating the diagonal matrix intermediately. Just arranging the weights as a vector and calculating the linear combination solves the calculation. This also reduces the overhead of computation. --- .../04_mixed_distance_functions_knn.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/05_Nearest_Neighbor_Methods/04_Computing_with_Mixed_Distance_Functions/04_mixed_distance_functions_knn.py b/05_Nearest_Neighbor_Methods/04_Computing_with_Mixed_Distance_Functions/04_mixed_distance_functions_knn.py index fb87178c2..500203585 100644 --- a/05_Nearest_Neighbor_Methods/04_Computing_with_Mixed_Distance_Functions/04_mixed_distance_functions_knn.py +++ b/05_Nearest_Neighbor_Methods/04_Computing_with_Mixed_Distance_Functions/04_mixed_distance_functions_knn.py @@ -49,7 +49,7 @@ ## Create distance metric weight matrix weighted by standard deviation weight_diagonal = x_vals.std(0) -weight_matrix = tf.cast(tf.diag(weight_diagonal), dtype=tf.float32) +weight_matrix = tf.cast(tf.expand_dims(weight_diagonal,1), dtype=tf.float32) # Split the data into train and test sets np.random.seed(13) # reproducible results @@ -73,9 +73,8 @@ # Declare weighted distance metric # Weighted L2 = sqrt((x-y)^T * A * (x-y)) subtraction_term = tf.subtract(x_data_train, tf.expand_dims(x_data_test,1)) -first_product = tf.matmul(subtraction_term, tf.tile(tf.expand_dims(weight_matrix,0), [batch_size,1,1])) -second_product = tf.matmul(first_product, tf.transpose(subtraction_term, perm=[0,2,1])) -distance = tf.sqrt(tf.matrix_diag_part(second_product)) +product = tf.matmul(tf.square(subtraction_term), tf.tile(tf.expand_dims(weight_matrix,0), [batch_size,1,1])) +distance = tf.sqrt(tf.squeeze(product,axis=2)) # Predict: Get min distance index (Nearest neighbor) top_k_xvals, top_k_indices = tf.nn.top_k(tf.negative(distance), k=k) @@ -113,4 +112,4 @@ plt.xlabel('Med Home Value in $1,000s') plt.ylabel('Frequency') plt.legend(loc='upper right') -plt.show() \ No newline at end of file +plt.show()