package com.yanzuoguang.util.helper;

import com.yanzuoguang.util.YzgError;
import com.yanzuoguang.util.contants.SystemContants;
import com.yanzuoguang.util.exception.CodeException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.Base64Utils;

import javax.crypto.Cipher;
import java.io.ByteArrayOutputStream;
import java.security.*;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;

/**
 * RSA
 * # 知识点
 * 1. 从一个固定长度的字节数组转换为另外一个固定长度的字节数组。如117的加密对应128的解密
 * 2. RSA分为公钥+私钥
 * 3. 可以用公钥加密+私钥解密 或者 私钥加密+公钥解密
 * 4. 因为加密只有字节,所以需要转换为base64
 * <p>
 * # 流程
 * 1. 生成公钥私钥
 * 2. 将公钥转换为字节,然后转换为base64字符串
 * 3. 将私钥转换为字节,然后转换为base64字符串
 * 4. 将来源字符串转换为字节,并按照固定加密长度截断,依次用公钥字符串生成的公钥进行加密。将加密后的字节转换为base64字符串。
 * 5. 将需要解密的字符串转换为字节,按照固定解密长度阶段,依次用私钥字符串生成的私钥进行解密。
 *
 * @author 颜佐光
 */
public final class RsaHelper {

    private static final Logger logger = LoggerFactory.getLogger(RsaHelper.class);

    private static final String ALGORITHM_RSA = "RSA";

    private static final String ALGORITHM_SIGN = "MD5withRSA";

    private static final int KEYPAIR_LEN = 1024;
    /**
     * RSA最大加密明文大小
     */
    private static final int MAX_ENCRYPT_BLOCK = 117;

    /**
     * RSA最大解密密文大小
     */
    private static final int MAX_DECRYPT_BLOCK = 128;

    private RsaHelper() {
        super();
    }

    private interface HandleBytes {
        byte[] handle(byte[] from, int offset, int len) throws Exception;
    }

    private static byte[] handle(byte[] froms, int size, HandleBytes handle) throws Exception {
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        int inputLen = froms.length;
        int offSet = 0;
        byte[] cache;
        int i = 0;
        // 对数据分段加密
        while (inputLen - offSet > 0) {
            if (inputLen - offSet > size) {
                cache = handle.handle(froms, offSet, size);
            } else {
                cache = handle.handle(froms, offSet, inputLen - offSet);
            }
            out.write(cache, 0, cache.length);
            i++;
            offSet = i * size;
        }
        byte[] to = out.toByteArray();
        out.close();
        return to;
    }

    /**
     * 生成密钥对
     *
     * @throws Exception
     */
    public static void generatorKeyPair() {
        try {
            KeyPairGenerator keyPairGen = KeyPairGenerator.getInstance(ALGORITHM_RSA);
            keyPairGen.initialize(KEYPAIR_LEN);
            KeyPair keyPair = keyPairGen.generateKeyPair();
            RSAPublicKey rsaPublicKey = (RSAPublicKey) keyPair.getPublic();
            RSAPrivateKey rsaPrivateKey = (RSAPrivateKey) keyPair.getPrivate();

            byte[] keyBs = rsaPublicKey.getEncoded();
            String publicKey = encodeBase64(keyBs);
            logger.info("生成的公钥:\t{}", publicKey);
            keyBs = rsaPrivateKey.getEncoded();
            String privateKey = encodeBase64(keyBs);
            logger.info("生成的私钥:\t{}", privateKey);
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    /**
     * 获取公钥
     *
     * @return
     * @throws Exception
     */
    private static PublicKey getPublicKey(String publicKey) throws Exception {
        try {
            X509EncodedKeySpec publicKeySpec = new X509EncodedKeySpec(decodeBase64(publicKey));
            KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM_RSA);
            return keyFactory.generatePublic(publicKeySpec);
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    /**
     * 获取私钥
     *
     * @return
     * @throws Exception
     */
    private static PrivateKey getPrivateKey(String privateKey) {
        try {
            PKCS8EncodedKeySpec privateKeySpec = new PKCS8EncodedKeySpec(decodeBase64(privateKey));
            KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM_RSA);
            return keyFactory.generatePrivate(privateKeySpec);
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    /**
     * 公钥加密
     *
     * @param source
     * @param publicKeyStr
     * @return
     * @throws Exception
     */
    public static String encryptionByPublicKey(String source, String publicKeyStr) {
        try {
            PublicKey publicKey = getPublicKey(publicKeyStr);
            Cipher cipher = Cipher.getInstance(publicKey.getAlgorithm());
            cipher.init(Cipher.ENCRYPT_MODE, publicKey);

            byte[] bytes = source.getBytes(SystemContants.UTF8);
            byte[] to = handle(bytes, MAX_ENCRYPT_BLOCK, new HandleBytes() {
                @Override
                public byte[] handle(byte[] from, int offset, int len) throws Exception {
                    return cipher.doFinal(from, offset, len);
                }
            });
            return encodeBase64(to);
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    /**
     * 公钥解密
     *
     * @param target
     * @throws Exception
     */
    public static String decryptionByPublicKey(String target, String publicKeyStr) {
        try {
            byte[] bytes = decodeBase64(target);

            PublicKey publicKey = getPublicKey(publicKeyStr);
            Cipher cipher = Cipher.getInstance(publicKey.getAlgorithm());
            cipher.init(Cipher.DECRYPT_MODE, publicKey);

            byte[] to = handle(bytes, MAX_DECRYPT_BLOCK, new HandleBytes() {
                @Override
                public byte[] handle(byte[] from, int offset, int len) throws Exception {
                    return cipher.doFinal(from, offset, len);
                }
            });

            return new String(to, SystemContants.UTF8);
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    /**
     * 公钥验证签名
     *
     * @return
     * @throws Exception
     */
    public static void verifyByPublicKey(String target, String sign, String publicKeyStr) {
        try {
            PublicKey publicKey = getPublicKey(publicKeyStr);
            Signature signature = Signature.getInstance(ALGORITHM_SIGN);
            signature.initVerify(publicKey);
            signature.update(target.getBytes(SystemContants.UTF8));
            if (signature.verify(decodeBase64(sign))) {
                logger.info("sign true");
            } else {
                logger.info("sign false");
            }
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    /**
     * 私钥加密
     *
     * @param source
     * @return
     * @throws Exception
     */
    public static String encryptionByPrivateKey(String source, String privateKeyStr) {
        try {
            byte[] bytes = source.getBytes(SystemContants.UTF8);

            PrivateKey privateKey = getPrivateKey(privateKeyStr);
            Cipher cipher = Cipher.getInstance(privateKey.getAlgorithm());
            cipher.init(Cipher.ENCRYPT_MODE, privateKey);

            byte[] to = handle(bytes, MAX_ENCRYPT_BLOCK, new HandleBytes() {
                @Override
                public byte[] handle(byte[] from, int offset, int len) throws Exception {
                    return cipher.doFinal(from, offset, len);
                }
            });

            String target = encodeBase64(to);
            return target;
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    /**
     * 私钥解密
     *
     * @param target
     * @throws Exception
     */
    public static String decryptionByPrivateKey(String target, String privateKeyStr) {
        try {
            byte[] bytes = decodeBase64(target);

            PrivateKey privateKey = getPrivateKey(privateKeyStr);
            Cipher cipher = Cipher.getInstance(privateKey.getAlgorithm());
            cipher.init(Cipher.DECRYPT_MODE, privateKey);

            byte[] to = handle(bytes, MAX_DECRYPT_BLOCK, new HandleBytes() {
                @Override
                public byte[] handle(byte[] from, int offset, int len) throws Exception {
                    return cipher.doFinal(from, offset, len);
                }
            });

            return new String(to, SystemContants.UTF8);
        } catch (Exception ex) {
            throw YzgError.getRuntimeException(ex,"056",ex.getMessage());
        }
    }

    /**
     * 私钥签名
     *
     * @param target
     * @return
     * @throws Exception
     */
    public static String signByPrivateKey(String target, String privateKeyStr) {
        try {
            PrivateKey privateKey = getPrivateKey(privateKeyStr);
            Signature signature = Signature.getInstance(ALGORITHM_SIGN);
            signature.initSign(privateKey);
            signature.update(target.getBytes(SystemContants.UTF8));
            String sign = encodeBase64(signature.sign());
            return sign;
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    /**
     * base64编码
     *
     * @param source
     * @return
     * @throws Exception
     */
    public static String encodeBase64(byte[] source) {
        try {
            byte[] to = Base64Utils.encode(source);
            return new String(to, SystemContants.UTF8);
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    /**
     * Base64解码
     *
     * @param target
     * @return
     * @throws Exception
     */
    public static byte[] decodeBase64(String target) {
        try {
            byte[] from = target.getBytes(SystemContants.UTF8);
            return Base64Utils.decode(from);
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }
}