package com.xforceplus.ultraman.bocp.ai.dsl.impl;

import com.fasterxml.jackson.databind.JsonNode;
import com.xforceplus.ultraman.bocp.ai.AIService;
import com.xforceplus.ultraman.bocp.ai.dsl.DSLGenerator;
import com.xforceplus.ultraman.bocp.ai.entity.ChatCompletionRequest;
import com.xforceplus.ultraman.bocp.ai.entity.ChatCompletionResult;
import com.xforceplus.ultraman.bocp.metadata.util.JsonUtils;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import okhttp3.Response;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.http.ResponseEntity;
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;

/* loaded from: input_file:com/xforceplus/ultraman/bocp/ai/dsl/impl/DSLGeneratorImpl.class */
public class DSLGeneratorImpl implements DSLGenerator {
    private AIService aiService;
    private static final String DATA_PREFIX = "data: ";
    private static final String DONE_FLAG = "[DONE]";
    private static final Logger log = LogManager.getLogger(DSLGeneratorImpl.class);
    private static final Long PROMPT_ID = 56L;
    private static final Integer MAX_TOKEN = 2999;
    private static final Double TOP_P = Double.valueOf(0.9d);
    private static final Double FREQUENCY_PENALTY = Double.valueOf(0.6d);
    private static final Double PRESENCE_PENALTY = Double.valueOf(0.6d);
    private static final Double TEMPERATURE = Double.valueOf(0.0d);
    private static final Integer ZERO = 0;

    /* loaded from: input_file:com/xforceplus/ultraman/bocp/ai/dsl/impl/DSLGeneratorImpl$DataKeys.class */
    private static class DataKeys {
        public static final String CHOICES = "choices";
        public static final String DELTA = "delta";
        public static final String CONTENT = "content";
        public static final String FINISH_REASON = "finish_reason";

        private DataKeys() {
        }
    }

    public DSLGeneratorImpl(AIService aIService) {
        this.aiService = aIService;
    }

    @Override // com.xforceplus.ultraman.bocp.ai.dsl.DSLGenerator
    public String generate(ChatCompletionRequest chatCompletionRequest) {
        chatCompletionRequest.setPromptId(PROMPT_ID);
        chatCompletionRequest.setMax_tokens(MAX_TOKEN);
        chatCompletionRequest.setTop_p(TOP_P);
        chatCompletionRequest.setFrequency_penalty(FREQUENCY_PENALTY);
        chatCompletionRequest.setPresence_penalty(PRESENCE_PENALTY);
        chatCompletionRequest.setTemperature(TEMPERATURE);
        ChatCompletionResult generateCompletion = this.aiService.generateCompletion(chatCompletionRequest);
        log.info("Token used {}", Long.valueOf(generateCompletion.getUsage().getTotalTokens()));
        return checkAndFix((String) generateCompletion.getChoices().stream().findFirst().map(chatCompletionChoice -> {
            return chatCompletionChoice.getMessage().getContent();
        }).orElse("[]"));
    }

    @Override // com.xforceplus.ultraman.bocp.ai.dsl.DSLGenerator
    public ResponseEntity<StreamingResponseBody> generateStream(ChatCompletionRequest chatCompletionRequest) {
        chatCompletionRequest.setPromptId(PROMPT_ID);
        chatCompletionRequest.setMax_tokens(MAX_TOKEN);
        chatCompletionRequest.setTop_p(TOP_P);
        chatCompletionRequest.setFrequency_penalty(FREQUENCY_PENALTY);
        chatCompletionRequest.setPresence_penalty(PRESENCE_PENALTY);
        chatCompletionRequest.setTemperature(TEMPERATURE);
        Response generateCompletionStream = this.aiService.generateCompletionStream(chatCompletionRequest);
        return ResponseEntity.ok().body(outputStream -> {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(generateCompletionStream.body().byteStream(), StandardCharsets.UTF_8));
            StringBuilder sb = new StringBuilder();
            StringBuilder sb2 = new StringBuilder();
            int intValue = ZERO.intValue();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    log.info("total dsl : {}", checkAndFix(sb2.toString()));
                    return;
                }
                log.info("read line : {}", readLine);
                if (!StringUtils.isBlank(readLine.trim())) {
                    String substring = readLine.substring(readLine.indexOf(DATA_PREFIX) + DATA_PREFIX.length());
                    if (substring.trim().contains(DONE_FLAG)) {
                        outputStream.write(substring.getBytes(StandardCharsets.UTF_8));
                    } else {
                        JsonNode readTree = JsonUtils.readTree(substring);
                        if (readTree != null && !isFinish(readTree)) {
                            String dataContent = getDataContent(readTree);
                            sb2.append(dataContent);
                            sb.append(dataContent);
                            if (dataContent.equals(" {\n") || dataContent.equals("{\n") || dataContent.equals(" {") || dataContent.equals("{")) {
                                intValue++;
                                log.info("braceCount:{}", Integer.valueOf(intValue));
                            } else if (dataContent.equals(" },\n") || dataContent.equals("]") || dataContent.equals(" },") || dataContent.equals("},\n") || dataContent.equals("},")) {
                                intValue--;
                                log.info("braceCount:{}", Integer.valueOf(intValue));
                                if (intValue == ZERO.intValue()) {
                                    String checkAndFix = checkAndFix(sb.toString());
                                    outputStream.write(checkAndFix.getBytes(StandardCharsets.UTF_8));
                                    outputStream.flush();
                                    log.info(checkAndFix);
                                    sb.setLength(ZERO.intValue());
                                }
                            }
                        }
                    }
                }
            }
        });
    }

    private static String getDataContent(JsonNode jsonNode) {
        return jsonNode.get("choices").get(0).get("delta").get("content").asText();
    }

    private static boolean isFinish(JsonNode jsonNode) {
        return StringUtils.isNotBlank(jsonNode.get("choices").get(0).get("finish_reason").asText(""));
    }

    private static String checkAndFix(String str) {
        if (!isValidFormat(str)) {
            str = str.replaceAll("'([^']+)'\\s*:", "\"$1\":").replaceAll("([a-zA-Z0-9]+):", "\"$1\":").replaceAll("'([^']+)'", "\"$1\"").replaceAll("“([^”]+)”", "\"$1\"").replace("“", "\"").replace("”", "\"").replace("：", ":").replace("，", ",");
        }
        return str;
    }

    private static boolean isValidFormat(String str) {
        try {
            JsonUtils.readTree(str);
            return true;
        } catch (Exception e) {
            return false;
        }
    }
}
