时间: 2020-09-3|tag: 25次围观|0 条评论

K-means的步骤

输入: 含n 个样本的数据集,簇的数据K

输出: K 个簇

算法步骤:

1.初始化K个簇类中心C1,C2,-……Ck (通常随机选择)

2.repeat 步骤3,4

3,将数据集中的每个样本分配到与之最近的中心Ci所在的簇Cj ;

4. 更新聚类中心Ci,即计算各个簇的样本均值;

5.直到样本分配不在改变

上代码:

import java.lang.annotation.ElementType;import java.lang.annotation.Retention;import java.lang.annotation.RetentionPolicy;import java.lang.annotation.Target;/** * 在对象的属性上标注此注释, * 表示纳入kmeans算法,仅支持数值类属性 * @author 阿飞哥 */@Retention(RetentionPolicy.RUNTIME)@Target(ElementType.FIELD)public @interface KmeanField {}

 

import java.lang.annotation.Annotation;import java.lang.reflect.Field;import java.lang.reflect.Method;import java.util.ArrayList;import java.util.List;/** *  * @author 阿飞哥 *  */public class Kmeans<T> {    /**     * 所有数据列表     */    private List<T> players = new ArrayList<T>();    /**     * 数据类别     */    private Class<T> classT;    /**     * 初始化列表     */    private List<T> initPlayers;    /**     * 需要纳入kmeans算法的属性名称     */    private List<String> fieldNames = new ArrayList<String>();    /**     * 分类数     */    private int k = 1;    public Kmeans() {    }    /**     * 初始化列表     *      * @param list     * @param k     */    public Kmeans(List<T> list, int k) {        this.players = list;        this.k = k;        T t = list.get(0);        this.classT = (Class<T>) t.getClass();        Field[] fields = this.classT.getDeclaredFields();        System.out.println("fields---------------------------------------------="+fields.length);        for (int i = 0; i < fields.length; i++) {            Annotation kmeansAnnotation = fields[i]                    .getAnnotation(KmeanField.class);            if (kmeansAnnotation != null) {                fieldNames.add(fields[i].getName());                System.out.println("fieldNames.add"+ fields[i].getName());                            }        }        initPlayers = new ArrayList<T>();        for (int i = 0; i < k; i++) {            initPlayers.add(players.get(i));        }    }    public List<T>[] comput() {        List<T>[] results = new ArrayList[k];        boolean centerchange = true;        while (centerchange) {            centerchange = false;            for (int i = 0; i < k; i++) {                results[i] = new ArrayList<T>();            }            for (int i = 0; i < players.size(); i++) {                T p = players.get(i);                double[] dists = new double[k];                for (int j = 0; j < initPlayers.size(); j++) {                    T initP = initPlayers.get(j);                    /* 计算距离 */                    double dist = distance(initP, p);//                    double dist = 1.0;//                    double dist = LevenshteinDistance.levenshteinDistance(initP, p);//                    System.out.println("dist="+dist);                                    dists[j] = dist;                }                int dist_index = computOrder(dists);//                System.out.println("dist_index="+dist_index);                results[dist_index].add(p);            }            //            System.out.println("results[0].size()="+results[0].size());            for (int i = 0; i < k; i++) { // 在每一个簇中寻找中心点                T player_new = findNewCenter(results[i]);//                System.out.println( "results[i]"+i+"----"+k+"---===="+results[i].size() +"===="+player_new.toString());                T player_old = initPlayers.get(i);                if (!IsPlayerEqual(player_new, player_old)) {                    centerchange = true;                    initPlayers.set(i, player_new);                }            }        }//        System.out.println( "results+"+results.length);        return results;    }    /**     * 比较是否两个对象是否属性一致     *      * @param p1     * @param p2     * @return     */    public boolean IsPlayerEqual(T p1, T p2) {        if (p1 == p2) {            return true;        }        if (p1 == null || p2 == null) {            return false;        }                boolean flag = true;        try {            for (int i = 0; i < fieldNames.size(); i++) {                                String fieldName=fieldNames.get(i);                String getName = "get"                        + fieldName.substring(0, 1).toUpperCase()                        + fieldName.substring(1);        //                System.out.println(fieldNames);                Object value1 = invokeMethod(p1,getName,null);                Object value2 = invokeMethod(p2,getName,null);                if (!value1.equals(value2)) {                    flag = false;                    break;                }            }        } catch (Exception e) {            e.printStackTrace();            flag = false;        }        return flag;    }    /**     * 得到新聚类中心对象     *      * @param ps     * @return     */    public T findNewCenter(List<T> ps) {        try {            T t = classT.newInstance();            if (ps == null || ps.size() == 0) {                return t;            }            double[] ds = new double[fieldNames.size()];            for (T vo : ps) {                for (int i = 0; i < fieldNames.size(); i++) {                    String fieldName=fieldNames.get(i);                    String getName = "get"                            + fieldName.substring(0, 1).toUpperCase()                            + fieldName.substring(1);                    Object obj=invokeMethod(vo,getName,null);                    Double fv=(obj==null?0:Double.parseDouble(obj+""));                    ds[i] += fv;                }            }            //            System.out.println("-----------------");            for (int i = 0; i < fieldNames.size(); i++) {                ds[i] = ds[i] / ps.size();    // 平均距离                String fieldName = fieldNames.get(i);                                /* 给对象设值 */                String setName = "set"                        + fieldName.substring(0, 1).toUpperCase()                        + fieldName.substring(1);//                invokeMethod(t,setName,new Class[]{double.class},ds[i]);                System.out.println("ds[i] ++="+ds[i]+"----ps.size()"+ps.size());                invokeMethod(t,setName,new Class[]{double.class},ds[i]);            }                                                return t;        } catch (Exception ex) {            ex.printStackTrace();        }        return null;    }    /**     * 得到最短距离,并返回最短距离索引     *      * @param dists     * @return     */    public int computOrder(double[] dists) {        double min = 0;        int index = 0;        for (int i = 0; i < dists.length - 1; i++) {            double dist0 = dists[i];            if (i == 0) {                min = dist0;                index = 0;            }            double dist1 = dists[i + 1];            if (min > dist1) {                min = dist1;                index = i + 1;            }        }        return index;    }    /**     * 计算距离(相似性) 采用欧几里得算法     *      * @param p0     * @param p1     * @return     */    public double distance(T p0, T p1) {        double dis = 0;        try {            for (int i = 0; i < fieldNames.size(); i++) {                String fieldName = fieldNames.get(i);                String getName = "get"                        + fieldName.substring(0, 1).toUpperCase()                        + fieldName.substring(1);                //                System.out.println("fieldNames-----="+fieldNames.size());                Double field0Value=Double.parseDouble(invokeMethod(p0,getName,null)+"");                Double field1Value=Double.parseDouble(invokeMethod(p1,getName,null)+"");//                System.out.println("field0Value="+field0Value);                dis += Math.pow(field0Value - field1Value, 2);                                                             }                } catch (Exception ex) {            ex.printStackTrace();        }        return Math.sqrt(dis);    }        /*------公共方法-----*/    public Object invokeMethod(Object owner, String methodName,Class[] argsClass,            Object... args) {        Class ownerClass = owner.getClass();                try {            Method method=ownerClass.getDeclaredMethod(methodName,argsClass);                        return method.invoke(owner, args);        } catch (SecurityException e) {            e.printStackTrace();        } catch (NoSuchMethodException e) {            e.printStackTrace();        } catch (Exception ex) {            ex.printStackTrace();        }        return null;    }}

public class Player {private int id;//@KmeanFieldprivate String name;private int age;/* 得分 */@KmeanFieldprivate double goal;/* 助攻 *///@KmeanFieldprivate double assists;/* 篮板 *///@KmeanFieldprivate double backboard;/* 抢断 *///@KmeanFieldprivate double steals;public int getId() {    return id;}public void setId(int id) {    this.id = id;}public String getName() {    return name;}public void setName(String name) {    this.name = name;}public int getAge() {    return age;}public void setAge(int age) {    this.age = age;}public double getGoal() {    return goal;}public void setGoal(double goal) {    this.goal = goal;}public double getAssists() {    return assists;}public void setAssists(double assists) {    this.assists = assists;}public double getBackboard() {    return backboard;}public void setBackboard(double backboard) {    this.backboard = backboard;}public double getSteals() {    return steals;}public void setSteals(double steals) {    this.steals = steals;}@Override    public String toString() {        // TODO Auto-generated method stub        return name;    }}

 

 
import java.util.ArrayList;import java.util.List;import java.util.Random;public class TestMain {    public static void main(String[] args) {       List<Player> listPlayers=new ArrayList<Player>();                for(int i=0;i<15;i++){                        Player p1=new Player();            p1.setName("afei-"+i);            p1.setAssists(i);            p1.setBackboard(i);                        //p1.setGoal(new Random(100*i).nextDouble());            p1.setGoal(i*10);            p1.setSteals(i);            //listPlayers.add(p1);            }                Player p1=new Player();        p1.setName("afei1");        p1.setGoal(1);        p1.setAssists(8);        listPlayers.add(p1);               Player p2=new Player();        p2.setName("afei2");        p2.setGoal(2);        listPlayers.add(p2);                 Player p3=new Player();        p3.setName("afei3");        p3.setGoal(3);        listPlayers.add(p3);                 Player p4=new Player();        p4.setName("afei4");        p4.setGoal(7);        listPlayers.add(p4);                 Player p5=new Player();        p5.setName("afei5");        p5.setGoal(8);        listPlayers.add(p5);                 Player p6=new Player();        p6.setName("afei6");        p6.setGoal(25);        listPlayers.add(p6);                 Player p7=new Player();        p7.setName("afei7");        p7.setGoal(26);        listPlayers.add(p7);                 Player p8=new Player();        p8.setName("afei8");        p8.setGoal(27);        listPlayers.add(p8);                 Player p9=new Player();        p9.setName("afei9");        p9.setGoal(28);        listPlayers.add(p9);                        Kmeans<Player> kmeans = new Kmeans<Player>(listPlayers,2);        List<Player>[] results = kmeans.comput();        for (int i = 0; i < results.length; i++) {            System.out.println("===========类别" + (i + 1) + "================");            List<Player> list = results[i];            for (Player p : list) {                System.out.println(p.getName() + "--->"                        + p.getGoal() + "," + p.getAssists() + ","                        + p.getSteals() + "," + p.getBackboard());            }        }                                  }}

 

源码:https://github.com/chaoren399/dkdemo/tree/master/kmeans/src

文章转载于:https://www.cnblogs.com/chaoren399/p/5006563.html

原著是一个有趣的人,若有侵权,请通知删除

本博客所有文章如无特别注明均为原创。
复制或转载请以超链接形式注明转自起风了,原文地址《3.聚类–K-means的Java实现
   

还没有人抢沙发呢~