Pages

n-grams

import java.lang.Iterable;
import java.util.Iterator;
import java.util.Collection;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.Arrays;
import java.util.List;
import java.util.LinkedList;
import java.util.ArrayList;
import java.util.Queue;
import java.util.Map;
import java.util.HashMap;

public class Ngrams
{
    public static void print(Collection<Ngram> ngrams, int lineLength)
    {
        String buffer = "";        
        for (Ngram ngram : ngrams)
        {
            buffer += ngram.toString() + " ";
            if (buffer.length() > lineLength)
            {
                System.out.println(buffer);
                buffer = "";
            }
        }
        if (buffer.length() > 0) System.out.println(buffer);
    }
    public static void main(String[] args)
    {
        final List<String> tokens = Arrays.asList(args);
        final NgramSet ngrams = new NgramSet(tokens);
        final NgramDistribution dist = ngrams.distribution(2).top(10);
        Ngrams.print(dist.random(100), 80);
        System.out.println();
        System.out.println(dist);
    }
}

class NgramUpperBound
{
    private final Ngram ngram;
    private final double upperBound;
    public NgramUpperBound(Ngram ngram, double upperBound)
    {
        this.ngram = ngram;
        this.upperBound = upperBound;
    }
    public Ngram ngram() { return this.ngram; }
    public double upperBound() { return this.upperBound; }
    public String toString()
    {
        return this.ngram.toString() + ": " + this.upperBound;
    }
}

class NgramDistribution
{
    private final SortedSet<Ngram> sortedSet;
    private final ArrayList<NgramUpperBound> upperBounds;
    
    public NgramDistribution(SortedSet<Ngram> sortedSet)
    {
        this.sortedSet = sortedSet;
        this.upperBounds = new ArrayList<NgramUpperBound>(sortedSet.size());
        final Double sum = NgramUtil.sumFrequencies(sortedSet).doubleValue();
        double nextUpperBound = 1.0;
        for (Ngram ngram : sortedSet)
        {
            upperBounds.add(new NgramUpperBound(ngram, nextUpperBound));
            nextUpperBound -= (ngram.frequency() / sum);
        }
    }
    public NgramDistribution top(int n)
    {
        SortedSet<Ngram> topset = new TreeSet<Ngram>();
        for (Ngram ngram : this.sortedSet)
        {
            if (--n < 0) break;
            else topset.add(ngram);
        }
        return new NgramDistribution(topset);
    }
    public Ngram get(Double p)
    {
        if (p >= 1) return null;
        final int n = upperBounds.size();
        int i = 0;
        NgramUpperBound curr, prev = null;
        while (i < n)
        {
            curr = upperBounds.get(i);
            if (curr.upperBound() > p)
            {
                prev = curr;
                i++;
            }
            else break;
        }
        return prev != null ? prev.ngram() : null;
    }
    public Ngram random()
    {
        return this.get(Math.random());
    }
    public Collection<Ngram> random(int count)
    {
        if (count < 0) count = 0;
        final Collection<Ngram> ngrams = new ArrayList<Ngram>(count);
        if (upperBounds.size() > 0)
        {
            while (count-- > 0)
                ngrams.add(this.random());
        }
        return ngrams;
    }
    public String toString()
    {
        String s = "";
        for (NgramUpperBound upperBound : upperBounds)
            s += upperBound.toString() + "\n";
        return s;
    }
}

class Ngram implements Comparable<Ngram>
{
    private final String tokenString;
    private final String terminalToken;
    private final List<Ngram> ancestors;
    private final NgramSet siblings;
    private NgramSet descendents = null;
    private final Integer freq;

    public Ngram(String terminalToken, List<Ngram> ancestors, NgramSet siblings)
    {
        this.terminalToken = terminalToken;
        this.ancestors = ancestors;
        this.siblings = siblings;        
        this.freq = siblings.frequencyOf(terminalToken);
        if (ancestors.isEmpty())
        {
            this.tokenString = terminalToken;
        }
        else
        {
            Ngram parent = ancestors.get(ancestors.size() - 1); 
            this.tokenString = parent.tokenString + " " + terminalToken;
        }
    }
    public int compareTo(Ngram other)
    {
        final int result = this.freq.compareTo(other.freq);
        if (result != 0) return -result;
        else return this.tokenString.compareTo(other.tokenString);
    }
    public String toString() { return this.tokenString; }
    public String terminalToken() { return this.terminalToken; }
    public int frequency() { return this.freq; }
    public Iterable<Ngram> ancestors()  { return this.ancestors; }
    public int numPrevNgrams() { return this.ancestors.size(); }
    public NgramSet descendents()
    {
        if (this.descendents == null) createDescendents();
        return this.descendents;
    }
    private void createDescendents()
    {
        final int n = this.ancestors.size() + 1;
        List<Ngram> newPrevNgrams = new ArrayList<Ngram>(n);
        for (Ngram ancestor : this.ancestors) newPrevNgrams.add(ancestor);
        newPrevNgrams.add(this);
        this.descendents = new NgramSet(this.siblings.tokenList(), newPrevNgrams);
    }
}

class NgramUtil
{
    public static Integer sumFrequencies(Iterable<Ngram> ngrams)
    {
        int sum = 0;
        for (Ngram ngram : ngrams) sum += ngram.frequency();
        return sum;
    }
    public static List<String> filter(List<String> sequence, List<Ngram> prefix)
    {
        if (prefix == null) return sequence;
        final int n = prefix.size();
        if (n == 0) return sequence;
        final Queue<String> subsequence = new LinkedList<String>();
        final List<String> matchingTokens = new LinkedList<String>();
        boolean tokenFollowsPrefix = false;
        for (String token : sequence)
        {
            if (tokenFollowsPrefix) matchingTokens.add(token);
            subsequence.add(token);
            if (subsequence.size() > n) subsequence.remove();
            tokenFollowsPrefix = areEqual(subsequence, prefix);
        }
        return matchingTokens;
    }
    public static boolean areEqual(Iterable<String> tokenList, Iterable<Ngram> ngramList)
    {
        final Iterator<String> tokenIter = tokenList.iterator();
        final Iterator<Ngram> ngramIter = ngramList.iterator();
        while (tokenIter.hasNext() && ngramIter.hasNext())
        {
            String token1 = tokenIter.next();
            String token2 = ngramIter.next().terminalToken();
            if (!token1.equals(token2)) return false;
        }
        return !tokenIter.hasNext() && !ngramIter.hasNext();
    }
    public static void gather(int n, Iterable<Ngram> ngrams, SortedSet<Ngram> bag)
    {
        if (n < 1) return;
        else if (n == 1) for (Ngram ngram : ngrams) bag.add(ngram);
        else for (Ngram ngram : ngrams) gather(n - 1, ngram.descendents(), bag);
    }
}

class NgramSet implements Iterable<Ngram>
{
    private final List<String> tokenList;
    private final SortedSet<Ngram> sortedNgrams;
    private final int sumOfFrequencies;
    private final Map<String, Integer> ngramFreqMap;

    public NgramSet(List<String> sequence, List<Ngram> prefix)
    {
        if (prefix == null) prefix = new ArrayList<Ngram>(0);
        this.tokenList = sequence;
        this.sortedNgrams = new TreeSet<Ngram>();
        this.ngramFreqMap = new HashMap<String, Integer>();
        Iterable<String> subset = NgramUtil.filter(sequence, prefix);
        this.sumOfFrequencies = mapFrequencies(subset);
        for (String ngram : subset)
            this.sortedNgrams.add(new Ngram(ngram, prefix, this));
    }
    public NgramSet(List<String> sequence)
    {
        this(sequence, null);
    }
    public int size() { return this.sortedNgrams.size(); }
    public List<String> tokenList() { return this.tokenList; }
    public int frequencyOf(String ngram) { return this.ngramFreqMap.get(ngram); }
    public Integer sumOfFrequencies() { return this.sumOfFrequencies; }
    public Iterator<Ngram> iterator() { return this.sortedNgrams.iterator(); }
    public NgramDistribution distribution(int ngramLength)
    {
        SortedSet<Ngram> superset = new TreeSet<Ngram>();
        NgramUtil.gather(ngramLength, this.sortedNgrams, superset);
        return new NgramDistribution(superset);
    }
    private int mapFrequencies(Iterable<String> subset)
    {
        int sum = 0;
        for (String ngram : subset)
        {
            sum++;
            final Integer prevCount = this.ngramFreqMap.get(ngram);
            this.ngramFreqMap.put(ngram, prevCount == null ? 1 : prevCount + 1);
        }
        return sum;
    }
}