贝叶斯模型
package bayes; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; public class Model { public Set<String> categorySet = new HashSet<String>(); public Set<String> keyWordsSet = new HashSet<String>(); public Map<String, Long> probabilityMap = new HashMap<String, Long>(); }
贝叶斯主类
package bayes; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileReader; import java.io.FileWriter; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; public class Bayes { /** * P(B|a1,a2,a3)= cBa1 * cBa2 * cBa3 / (cB * cB) / ( cBa1 * cBa2 * cba3 / (cB * cB) + cAa1 * cAa2 * cAa3 / (cA * cA) ) * * @param source * @param model * @return */ public static Map<String, Double> getValue(String source, Model model) { if (model != null && model.keyWordsSet != null && model.categorySet != null && source != null) { Set<String> keyWordSet = new HashSet<String>(); for (String key : model.keyWordsSet) { if (source.contains(key)) { keyWordSet.add(key); } } if (keyWordSet.size() > 0) { Map<String, Double> probabilityMap = new HashMap<String, Double>(); double sumProbalitity = 0; for (String category : model.categorySet) { double numerator = 1; double denominator = 1; int index = 0; for (String keyword : keyWordSet) { if (index > 0) { denominator = denominator * getProbalityValue(model, category); } numerator = numerator * getProbalityValue(model, category + "-" + keyword); index = index + 1; } double probalisty = Double.valueOf(numerator / denominator); sumProbalitity = sumProbalitity + probalisty; probabilityMap.put(category, probalisty); } Map<String, Double> rtnMap = new HashMap<String, Double>(); if (sumProbalitity > 0) { for (String category : model.categorySet) { rtnMap.put(category, Double.valueOf(probabilityMap.get(category) / sumProbalitity)); } } else { for (String category : model.categorySet) { rtnMap.put(category, Double.valueOf(1.0 / model.categorySet.size())); } } return rtnMap; } } return null; } public static long getProbalityValue(Model model, String key) { long rtn = 0; if (model.probabilityMap.containsKey(key)) { rtn = model.probabilityMap.get(key); } return rtn; } public static Model train(String[] categorys, String[][] data, String[] keyWords) { if (categorys != null && data != null && data.length == categorys.length && categorys.length > 1 && keyWords != null && keyWords.length > 1) { Model model = new Model(); model.categorySet.addAll(Arrays.asList(categorys)); model.keyWordsSet.addAll(Arrays.asList(keyWords)); for (int i = 0; i < categorys.length; i++) { calculateProbability(categorys[i], data[i], model); } return model; } else { System.out.println("data error!"); } return null; } private static void calculateProbability(String category, String[] categoryData, Model model) { for (String source : categoryData) { addCategoryKeywordCount(category, model); for (String keywork : model.keyWordsSet) { if (source.contains(keywork)) { addCategoryKeywordCount(keywork, model); addCategoryKeywordCount(category + "-" + keywork, model); } } } } private static void addCategoryKeywordCount(String key, Model model) { Long count = null; count = model.probabilityMap.get(key); if (count != null) { count = count + 1; } else { count = 1L; } model.probabilityMap.put(key, count); } public static void saveModel(String fileName, Model model) { try (BufferedWriter writer = new BufferedWriter(new FileWriter(fileName));) { for (String category : model.categorySet) { writer.write(category); writer.write(","); } writer.write("\n"); for (String keyword : model.keyWordsSet) { writer.write(keyword); writer.write(","); } writer.write("\n"); for (String key : model.probabilityMap.keySet()) { writer.write(key); writer.write(":"); writer.write(model.probabilityMap.get(key).toString()); writer.write("\n"); } writer.write("\n"); } catch (Exception e) { System.out.println("save Model error"); } } public static Model loadModel(String fileName) { Model model = new Model(); try (BufferedReader reader = new BufferedReader(new FileReader(fileName));) { String categoryLine = reader.readLine(); model.categorySet.addAll(getStringSet(categoryLine, ",")); String keyWorksLine = reader.readLine(); model.keyWordsSet.addAll(getStringSet(keyWorksLine, ",")); String probalilityLine = reader.readLine(); while (probalilityLine != null) { if (probalilityLine.trim().length() > 0) { String[] itemStr = (probalilityLine + ":").split(":"); if (itemStr.length == 2) { String key = itemStr[0]; Long probalility = Long.valueOf(itemStr[0]); model.probabilityMap.put(key, probalility); } else { System.out.println("Error model line:" + probalilityLine); } } probalilityLine = reader.readLine(); } } catch (Exception e) { System.out.println("load model error"); } return model; } public static Set<String> getStringSet(String sourceStr, String splitor) { Set<String> rtn = new HashSet<String>(); if (sourceStr != null && splitor != null) { String[] strs = sourceStr.split(splitor); if (strs != null && strs.length > 0) { for (String str : strs) { if (str != null) { rtn.add(str.trim()); } } } } return rtn; } }
相关推荐
java实现朴素贝叶斯分类算法
朴素贝叶斯算法的java实现,具有很好的分类效果
P(X|Y) = P(X,Y)/P(Y) (条件概率)->P(X,Y) = P(X|Y)P(Y)->P(X,Y) = P(Y|X)P(X)->P(X|Y) = P(Y|X)P(X)/P(Y), p(yi|X) = P(yi)p(X|yi)/P(X) 其中P(X)为常数 p(yi|X) = P(yi)p(X|yi) p(yi|X)->某特征下是某类别的...
朴素贝叶斯java代码参考朴素贝叶斯java代码参考
朴素贝叶斯算法文本分类JAVA实现
对指定数据集进行分类问题的分析,选择适当的分类算法,编写程序实现,提交程序和结果报告 数据集: balance-scale.data(见附件一) ,已有数据集构建贝叶斯分类器。 数据包括四个属性:五个属性值 第一个属性值...
将一些随机爬下来的帖子进行分类,利用了朴素贝叶斯算法
代码是我实验课完成的,Java实现分类问题,朴素贝叶斯算法,对网上需手动输入数据的代码稍微改进,数据是文件夹里的txt文件,读者可以自己更改数据文件,非常方便,注释详细。
主要介绍了Java实现的朴素贝叶斯算法,结合实例形式分析了基于java的朴素贝叶斯算法定义及样本数据训练操作相关使用技巧,需要的朋友可以参考下
树型朴素贝叶斯算法java数据挖掘算法源码
NaiveBays朴素贝叶斯算法在JAVA中的实现
朴素贝叶斯算法java数据挖掘算法源码 数据挖掘算法是根据数据创建数据挖掘模型的一组试探法和计算。 为了创建模型,算法将首先分析您提供的数据,并查找特定类型的模式和趋势。概念描述算法使用此分析的结果来定义...
这个主要是利用spark的api,朴素贝叶斯算法,来预测股票,其中包含的股票的原始数据和处理后适合spark api处理的训练模型
Spring-Boot集成Neo4j并利用Spark的朴素贝叶斯分类器实现基于电影知识图谱的智能问答系统
对于朴素贝叶斯算法过程与步骤的简单描述,并用java语言实现
java实现的NB算法,在UCI的三个个数据集上进行了测试,包含测试结果和实验报告,还有UCI测试数据 人工智能大作业实验
基于贝叶斯算法的分类实现,java源代码。下载后即可导入测试。
Java编写的朴素贝叶斯分类器,用于学习机器学习算法,使用Java原生sdk实现,内有数据集,可以运行。