001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.text.similarity;
018
019import java.util.HashSet;
020import java.util.Map;
021import java.util.Set;
022
023/**
024 * Measures the Cosine similarity of two vectors of an inner product space and
025 * compares the angle between them.
026 *
027 * <p>
028 * For further explanation about the Cosine Similarity, refer to
029 * http://en.wikipedia.org/wiki/Cosine_similarity.
030 * </p>
031 *
032 * @since 1.0
033 */
034public class CosineSimilarity {
035
036    /**
037     * Calculates the cosine similarity for two given vectors.
038     *
039     * @param leftVector left vector
040     * @param rightVector right vector
041     * @return cosine similarity between the two vectors
042     */
043    public Double cosineSimilarity(final Map<CharSequence, Integer> leftVector,
044                                   final Map<CharSequence, Integer> rightVector) {
045        if (leftVector == null || rightVector == null) {
046            throw new IllegalArgumentException("Vectors must not be null");
047        }
048
049        final Set<CharSequence> intersection = getIntersection(leftVector, rightVector);
050
051        final double dotProduct = dot(leftVector, rightVector, intersection);
052        double d1 = 0.0d;
053        for (final Integer value : leftVector.values()) {
054            d1 += Math.pow(value, 2);
055        }
056        double d2 = 0.0d;
057        for (final Integer value : rightVector.values()) {
058            d2 += Math.pow(value, 2);
059        }
060        final double cosineSimilarity;
061        if (d1 <= 0.0 || d2 <= 0.0) {
062            cosineSimilarity = 0.0;
063        } else {
064            cosineSimilarity = dotProduct / (Math.sqrt(d1) * Math.sqrt(d2));
065        }
066        return cosineSimilarity;
067    }
068
069    /**
070     * Computes the dot product of two vectors. It ignores remaining elements. It means
071     * that if a vector is longer than other, then a smaller part of it will be used to compute
072     * the dot product.
073     *
074     * @param leftVector left vector
075     * @param rightVector right vector
076     * @param intersection common elements
077     * @return The dot product
078     */
079    private double dot(final Map<CharSequence, Integer> leftVector, final Map<CharSequence, Integer> rightVector,
080            final Set<CharSequence> intersection) {
081        long dotProduct = 0;
082        for (final CharSequence key : intersection) {
083            dotProduct += leftVector.get(key) * (long) rightVector.get(key);
084        }
085        return dotProduct;
086    }
087
088    /**
089     * Returns a set with strings common to the two given maps.
090     *
091     * @param leftVector left vector map
092     * @param rightVector right vector map
093     * @return common strings
094     */
095    private Set<CharSequence> getIntersection(final Map<CharSequence, Integer> leftVector,
096            final Map<CharSequence, Integer> rightVector) {
097        final Set<CharSequence> intersection = new HashSet<>(leftVector.keySet());
098        intersection.retainAll(rightVector.keySet());
099        return intersection;
100    }
101
102}