import java.lang.Iterable;
import java.util.Iterator;
import java.util.Collection;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.Arrays;
import java.util.List;
import java.util.LinkedList;
import java.util.ArrayList;
import java.util.Queue;
import java.util.Map;
import java.util.HashMap;
public class Ngrams
{
public static void print(Collection<Ngram> ngrams, int lineLength)
{
String buffer = "";
for (Ngram ngram : ngrams)
{
buffer += ngram.toString() + " ";
if (buffer.length() > lineLength)
{
System.out.println(buffer);
buffer = "";
}
}
if (buffer.length() > 0) System.out.println(buffer);
}
public static void main(String[] args)
{
final List<String> tokens = Arrays.asList(args);
final NgramSet ngrams = new NgramSet(tokens);
final NgramDistribution dist = ngrams.distribution(2).top(10);
Ngrams.print(dist.random(100), 80);
System.out.println();
System.out.println(dist);
}
}
class NgramUpperBound
{
private final Ngram ngram;
private final double upperBound;
public NgramUpperBound(Ngram ngram, double upperBound)
{
this.ngram = ngram;
this.upperBound = upperBound;
}
public Ngram ngram() { return this.ngram; }
public double upperBound() { return this.upperBound; }
public String toString()
{
return this.ngram.toString() + ": " + this.upperBound;
}
}
class NgramDistribution
{
private final SortedSet<Ngram> sortedSet;
private final ArrayList<NgramUpperBound> upperBounds;
public NgramDistribution(SortedSet<Ngram> sortedSet)
{
this.sortedSet = sortedSet;
this.upperBounds = new ArrayList<NgramUpperBound>(sortedSet.size());
final Double sum = NgramUtil.sumFrequencies(sortedSet).doubleValue();
double nextUpperBound = 1.0;
for (Ngram ngram : sortedSet)
{
upperBounds.add(new NgramUpperBound(ngram, nextUpperBound));
nextUpperBound -= (ngram.frequency() / sum);
}
}
public NgramDistribution top(int n)
{
SortedSet<Ngram> topset = new TreeSet<Ngram>();
for (Ngram ngram : this.sortedSet)
{
if (--n < 0) break;
else topset.add(ngram);
}
return new NgramDistribution(topset);
}
public Ngram get(Double p)
{
if (p >= 1) return null;
final int n = upperBounds.size();
int i = 0;
NgramUpperBound curr, prev = null;
while (i < n)
{
curr = upperBounds.get(i);
if (curr.upperBound() > p)
{
prev = curr;
i++;
}
else break;
}
return prev != null ? prev.ngram() : null;
}
public Ngram random()
{
return this.get(Math.random());
}
public Collection<Ngram> random(int count)
{
if (count < 0) count = 0;
final Collection<Ngram> ngrams = new ArrayList<Ngram>(count);
if (upperBounds.size() > 0)
{
while (count-- > 0)
ngrams.add(this.random());
}
return ngrams;
}
public String toString()
{
String s = "";
for (NgramUpperBound upperBound : upperBounds)
s += upperBound.toString() + "\n";
return s;
}
}
class Ngram implements Comparable<Ngram>
{
private final String tokenString;
private final String terminalToken;
private final List<Ngram> ancestors;
private final NgramSet siblings;
private NgramSet descendents = null;
private final Integer freq;
public Ngram(String terminalToken, List<Ngram> ancestors, NgramSet siblings)
{
this.terminalToken = terminalToken;
this.ancestors = ancestors;
this.siblings = siblings;
this.freq = siblings.frequencyOf(terminalToken);
if (ancestors.isEmpty())
{
this.tokenString = terminalToken;
}
else
{
Ngram parent = ancestors.get(ancestors.size() - 1);
this.tokenString = parent.tokenString + " " + terminalToken;
}
}
public int compareTo(Ngram other)
{
final int result = this.freq.compareTo(other.freq);
if (result != 0) return -result;
else return this.tokenString.compareTo(other.tokenString);
}
public String toString() { return this.tokenString; }
public String terminalToken() { return this.terminalToken; }
public int frequency() { return this.freq; }
public Iterable<Ngram> ancestors() { return this.ancestors; }
public int numPrevNgrams() { return this.ancestors.size(); }
public NgramSet descendents()
{
if (this.descendents == null) createDescendents();
return this.descendents;
}
private void createDescendents()
{
final int n = this.ancestors.size() + 1;
List<Ngram> newPrevNgrams = new ArrayList<Ngram>(n);
for (Ngram ancestor : this.ancestors) newPrevNgrams.add(ancestor);
newPrevNgrams.add(this);
this.descendents = new NgramSet(this.siblings.tokenList(), newPrevNgrams);
}
}
class NgramUtil
{
public static Integer sumFrequencies(Iterable<Ngram> ngrams)
{
int sum = 0;
for (Ngram ngram : ngrams) sum += ngram.frequency();
return sum;
}
public static List<String> filter(List<String> sequence, List<Ngram> prefix)
{
if (prefix == null) return sequence;
final int n = prefix.size();
if (n == 0) return sequence;
final Queue<String> subsequence = new LinkedList<String>();
final List<String> matchingTokens = new LinkedList<String>();
boolean tokenFollowsPrefix = false;
for (String token : sequence)
{
if (tokenFollowsPrefix) matchingTokens.add(token);
subsequence.add(token);
if (subsequence.size() > n) subsequence.remove();
tokenFollowsPrefix = areEqual(subsequence, prefix);
}
return matchingTokens;
}
public static boolean areEqual(Iterable<String> tokenList, Iterable<Ngram> ngramList)
{
final Iterator<String> tokenIter = tokenList.iterator();
final Iterator<Ngram> ngramIter = ngramList.iterator();
while (tokenIter.hasNext() && ngramIter.hasNext())
{
String token1 = tokenIter.next();
String token2 = ngramIter.next().terminalToken();
if (!token1.equals(token2)) return false;
}
return !tokenIter.hasNext() && !ngramIter.hasNext();
}
public static void gather(int n, Iterable<Ngram> ngrams, SortedSet<Ngram> bag)
{
if (n < 1) return;
else if (n == 1) for (Ngram ngram : ngrams) bag.add(ngram);
else for (Ngram ngram : ngrams) gather(n - 1, ngram.descendents(), bag);
}
}
class NgramSet implements Iterable<Ngram>
{
private final List<String> tokenList;
private final SortedSet<Ngram> sortedNgrams;
private final int sumOfFrequencies;
private final Map<String, Integer> ngramFreqMap;
public NgramSet(List<String> sequence, List<Ngram> prefix)
{
if (prefix == null) prefix = new ArrayList<Ngram>(0);
this.tokenList = sequence;
this.sortedNgrams = new TreeSet<Ngram>();
this.ngramFreqMap = new HashMap<String, Integer>();
Iterable<String> subset = NgramUtil.filter(sequence, prefix);
this.sumOfFrequencies = mapFrequencies(subset);
for (String ngram : subset)
this.sortedNgrams.add(new Ngram(ngram, prefix, this));
}
public NgramSet(List<String> sequence)
{
this(sequence, null);
}
public int size() { return this.sortedNgrams.size(); }
public List<String> tokenList() { return this.tokenList; }
public int frequencyOf(String ngram) { return this.ngramFreqMap.get(ngram); }
public Integer sumOfFrequencies() { return this.sumOfFrequencies; }
public Iterator<Ngram> iterator() { return this.sortedNgrams.iterator(); }
public NgramDistribution distribution(int ngramLength)
{
SortedSet<Ngram> superset = new TreeSet<Ngram>();
NgramUtil.gather(ngramLength, this.sortedNgrams, superset);
return new NgramDistribution(superset);
}
private int mapFrequencies(Iterable<String> subset)
{
int sum = 0;
for (String ngram : subset)
{
sum++;
final Integer prevCount = this.ngramFreqMap.get(ngram);
this.ngramFreqMap.put(ngram, prevCount == null ? 1 : prevCount + 1);
}
return sum;
}
}