import java.lang.Iterable; import java.util.Iterator; import java.util.Collection; import java.util.SortedSet; import java.util.TreeSet; import java.util.Arrays; import java.util.List; import java.util.LinkedList; import java.util.ArrayList; import java.util.Queue; import java.util.Map; import java.util.HashMap; public class Ngrams { public static void print(Collection<Ngram> ngrams, int lineLength) { String buffer = ""; for (Ngram ngram : ngrams) { buffer += ngram.toString() + " "; if (buffer.length() > lineLength) { System.out.println(buffer); buffer = ""; } } if (buffer.length() > 0) System.out.println(buffer); } public static void main(String[] args) { final List<String> tokens = Arrays.asList(args); final NgramSet ngrams = new NgramSet(tokens); final NgramDistribution dist = ngrams.distribution(2).top(10); Ngrams.print(dist.random(100), 80); System.out.println(); System.out.println(dist); } } class NgramUpperBound { private final Ngram ngram; private final double upperBound; public NgramUpperBound(Ngram ngram, double upperBound) { this.ngram = ngram; this.upperBound = upperBound; } public Ngram ngram() { return this.ngram; } public double upperBound() { return this.upperBound; } public String toString() { return this.ngram.toString() + ": " + this.upperBound; } } class NgramDistribution { private final SortedSet<Ngram> sortedSet; private final ArrayList<NgramUpperBound> upperBounds; public NgramDistribution(SortedSet<Ngram> sortedSet) { this.sortedSet = sortedSet; this.upperBounds = new ArrayList<NgramUpperBound>(sortedSet.size()); final Double sum = NgramUtil.sumFrequencies(sortedSet).doubleValue(); double nextUpperBound = 1.0; for (Ngram ngram : sortedSet) { upperBounds.add(new NgramUpperBound(ngram, nextUpperBound)); nextUpperBound -= (ngram.frequency() / sum); } } public NgramDistribution top(int n) { SortedSet<Ngram> topset = new TreeSet<Ngram>(); for (Ngram ngram : this.sortedSet) { if (--n < 0) break; else topset.add(ngram); } return new NgramDistribution(topset); } public Ngram get(Double p) { if (p >= 1) return null; final int n = upperBounds.size(); int i = 0; NgramUpperBound curr, prev = null; while (i < n) { curr = upperBounds.get(i); if (curr.upperBound() > p) { prev = curr; i++; } else break; } return prev != null ? prev.ngram() : null; } public Ngram random() { return this.get(Math.random()); } public Collection<Ngram> random(int count) { if (count < 0) count = 0; final Collection<Ngram> ngrams = new ArrayList<Ngram>(count); if (upperBounds.size() > 0) { while (count-- > 0) ngrams.add(this.random()); } return ngrams; } public String toString() { String s = ""; for (NgramUpperBound upperBound : upperBounds) s += upperBound.toString() + "\n"; return s; } } class Ngram implements Comparable<Ngram> { private final String tokenString; private final String terminalToken; private final List<Ngram> ancestors; private final NgramSet siblings; private NgramSet descendents = null; private final Integer freq; public Ngram(String terminalToken, List<Ngram> ancestors, NgramSet siblings) { this.terminalToken = terminalToken; this.ancestors = ancestors; this.siblings = siblings; this.freq = siblings.frequencyOf(terminalToken); if (ancestors.isEmpty()) { this.tokenString = terminalToken; } else { Ngram parent = ancestors.get(ancestors.size() - 1); this.tokenString = parent.tokenString + " " + terminalToken; } } public int compareTo(Ngram other) { final int result = this.freq.compareTo(other.freq); if (result != 0) return -result; else return this.tokenString.compareTo(other.tokenString); } public String toString() { return this.tokenString; } public String terminalToken() { return this.terminalToken; } public int frequency() { return this.freq; } public Iterable<Ngram> ancestors() { return this.ancestors; } public int numPrevNgrams() { return this.ancestors.size(); } public NgramSet descendents() { if (this.descendents == null) createDescendents(); return this.descendents; } private void createDescendents() { final int n = this.ancestors.size() + 1; List<Ngram> newPrevNgrams = new ArrayList<Ngram>(n); for (Ngram ancestor : this.ancestors) newPrevNgrams.add(ancestor); newPrevNgrams.add(this); this.descendents = new NgramSet(this.siblings.tokenList(), newPrevNgrams); } } class NgramUtil { public static Integer sumFrequencies(Iterable<Ngram> ngrams) { int sum = 0; for (Ngram ngram : ngrams) sum += ngram.frequency(); return sum; } public static List<String> filter(List<String> sequence, List<Ngram> prefix) { if (prefix == null) return sequence; final int n = prefix.size(); if (n == 0) return sequence; final Queue<String> subsequence = new LinkedList<String>(); final List<String> matchingTokens = new LinkedList<String>(); boolean tokenFollowsPrefix = false; for (String token : sequence) { if (tokenFollowsPrefix) matchingTokens.add(token); subsequence.add(token); if (subsequence.size() > n) subsequence.remove(); tokenFollowsPrefix = areEqual(subsequence, prefix); } return matchingTokens; } public static boolean areEqual(Iterable<String> tokenList, Iterable<Ngram> ngramList) { final Iterator<String> tokenIter = tokenList.iterator(); final Iterator<Ngram> ngramIter = ngramList.iterator(); while (tokenIter.hasNext() && ngramIter.hasNext()) { String token1 = tokenIter.next(); String token2 = ngramIter.next().terminalToken(); if (!token1.equals(token2)) return false; } return !tokenIter.hasNext() && !ngramIter.hasNext(); } public static void gather(int n, Iterable<Ngram> ngrams, SortedSet<Ngram> bag) { if (n < 1) return; else if (n == 1) for (Ngram ngram : ngrams) bag.add(ngram); else for (Ngram ngram : ngrams) gather(n - 1, ngram.descendents(), bag); } } class NgramSet implements Iterable<Ngram> { private final List<String> tokenList; private final SortedSet<Ngram> sortedNgrams; private final int sumOfFrequencies; private final Map<String, Integer> ngramFreqMap; public NgramSet(List<String> sequence, List<Ngram> prefix) { if (prefix == null) prefix = new ArrayList<Ngram>(0); this.tokenList = sequence; this.sortedNgrams = new TreeSet<Ngram>(); this.ngramFreqMap = new HashMap<String, Integer>(); Iterable<String> subset = NgramUtil.filter(sequence, prefix); this.sumOfFrequencies = mapFrequencies(subset); for (String ngram : subset) this.sortedNgrams.add(new Ngram(ngram, prefix, this)); } public NgramSet(List<String> sequence) { this(sequence, null); } public int size() { return this.sortedNgrams.size(); } public List<String> tokenList() { return this.tokenList; } public int frequencyOf(String ngram) { return this.ngramFreqMap.get(ngram); } public Integer sumOfFrequencies() { return this.sumOfFrequencies; } public Iterator<Ngram> iterator() { return this.sortedNgrams.iterator(); } public NgramDistribution distribution(int ngramLength) { SortedSet<Ngram> superset = new TreeSet<Ngram>(); NgramUtil.gather(ngramLength, this.sortedNgrams, superset); return new NgramDistribution(superset); } private int mapFrequencies(Iterable<String> subset) { int sum = 0; for (String ngram : subset) { sum++; final Integer prevCount = this.ngramFreqMap.get(ngram); this.ngramFreqMap.put(ngram, prevCount == null ? 1 : prevCount + 1); } return sum; } }