Skip to content

Commit ebbb28e

Browse files
committed
Basic gradient steppig and consine similarity with gradients
1 parent a2da876 commit ebbb28e

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

src/eff_conv/spacial.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import numpy as np
2+
3+
class SimilarityFunction:
4+
@staticmethod
5+
def calculate_gradients(referent_locations: np.ndarray, word_locations: np.ndarray) -> np.ndarray:
6+
"""
7+
Calculates gradients with respect to referent locations for given referent and word locations for
8+
the specified similarity function.
9+
10+
Inputs:
11+
- referent_locations: ||U|| x L matrix, where U is the set of referents and L is the dimensionality of the space.
12+
- word_locations: ||W|| x L matrix, where W is the set of words and L is the dimensionality of the space.
13+
14+
Output:
15+
- ||U|| x ||W|| x L matrix, the gradients for each of the referents and words.
16+
"""
17+
raise NotImplementedError("SimilarityFunction is an abstract class, do not call calculate_gradients")
18+
19+
20+
class CosineSimilarity(SimilarityFunction):
21+
@staticmethod
22+
def calculate_gradients(referent_locations: np.ndarray, word_locations: np.ndarray) -> np.ndarray:
23+
"""
24+
Calculates the gradients for referents and words based on the derivative of the cosine similarity function
25+
for vectors.
26+
27+
For inputs and outputs refer to SimilarityFunction.calculate_gradients
28+
29+
The derivative this is based on is the following:
30+
d/dX cos(X, Y) = (Y - proj(X, Y))/(||X||*||Y||)
31+
"""
32+
word_norms = np.linalg.norm(word_locations, axis=1)
33+
referent_norms = np.linalg.norm(referent_locations, axis=1)
34+
# Dot products between X and Y
35+
dot_products = (word_locations @ referent_locations.T).T
36+
# Projections from Y onto X
37+
# This is kept as a ||U|| x ||W|| x L matrix
38+
# Evil broadcasting gets this to work
39+
projections = referent_locations[:, None, :]*(dot_products/(referent_norms**2)[:, None])[:, :, None]
40+
subtractions = word_locations[None, :, :] - projections
41+
combined_norms = (word_norms*referent_norms[:, None])[:, :, None]
42+
43+
return subtractions/combined_norms
44+
45+
def run_grad_step(
46+
referent_locations: np.ndarray,
47+
word_locations: np.ndarray,
48+
priors: np.ndarray,
49+
sim_func: SimilarityFunction,
50+
) -> np.ndarray:
51+
"""
52+
Recalculates referent locations based on gradient ascent of the following function: score(u) = sum_w q(u|w)*sim_func(u, w)
53+
54+
Inputs:
55+
- referent_locations: ||U|| x L matrix, where U is the set of referents and L is the dimensionality of the space.
56+
- word_locations: ||W|| x L matrix, where W is the set of words and L is the dimensionality of the space.
57+
- priors: The q(u|w) matrix, ||U|| x ||W||
58+
- sim_func: The similarity function used for scoring
59+
60+
Output:
61+
- ||U|| x L matrix, the new positions for each referent.
62+
"""
63+
64+
# Check dimensions because it's very easy to get mixed up
65+
if len(referent_locations.shape) != 2 or len(word_locations.shape) != 2 or len(priors.shape) != 2:
66+
raise ValueError("Invalid shape for input array")
67+
ref_num, dimensionality = referent_locations.shape
68+
word_nums = word_locations.shape[0]
69+
if dimensionality != word_locations.shape[1]:
70+
raise ValueError(f"Conflicting dimensionality arguments: {dimensionality} and {word_locations.shape[1]}")
71+
if ref_num != priors.shape[0]:
72+
raise ValueError(f"Confliction referent numbers: {ref_num} and {priors.shape[0]}")
73+
if word_nums != priors.shape[1]:
74+
raise ValueError(f"Confliction word_nums numbers: {word_nums} and {priors.shape[1]}")
75+
76+
# Gradients for every referent and word combination, weighted by priors
77+
# Is ||U|| x ||W|| x L
78+
weighted_gradients = sim_func.calculate_gradients(referent_locations, word_locations)*priors[:, :, None]
79+
80+
return referent_locations + np.sum(weighted_gradients, axis=1)

0 commit comments

Comments
 (0)