package de.uni_freiburg.informatik.ultimate.modelcheckerutils.smt;

import de.uni_freiburg.informatik.ultimate.core.model.services.ILogger;
import de.uni_freiburg.informatik.ultimate.core.model.services.IUltimateServiceProvider;
import de.uni_freiburg.informatik.ultimate.lib.smtlibutils.SmtSortUtils;
import de.uni_freiburg.informatik.ultimate.lib.smtlibutils.SmtUtils;
import de.uni_freiburg.informatik.ultimate.lib.smtlibutils.solverbuilder.SMTFeatureExtractionTermClassifier;
import de.uni_freiburg.informatik.ultimate.logic.Logics;
import de.uni_freiburg.informatik.ultimate.logic.Script;
import de.uni_freiburg.informatik.ultimate.logic.Sort;
import de.uni_freiburg.informatik.ultimate.logic.Term;
import de.uni_freiburg.informatik.ultimate.smtsolver.external.TermParseUtils;
import de.uni_freiburg.informatik.ultimate.test.mocks.UltimateMocks;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.hamcrest.MatcherAssert;
import org.hamcrest.core.IsEqual;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:de/uni_freiburg/informatik/ultimate/modelcheckerutils/smt/SmtFeatureExtractionTest.class */
public class SmtFeatureExtractionTest {
    private static final String ABCDE = "ABCDE";
    private Script mScript;
    private ILogger mLogger;

    @Before
    public void setUp() {
        IUltimateServiceProvider createUltimateServiceProviderMock = UltimateMocks.createUltimateServiceProviderMock(ILogger.LogLevel.DEBUG);
        this.mScript = UltimateMocks.createZ3Script(ILogger.LogLevel.INFO);
        this.mLogger = createUltimateServiceProviderMock.getLoggingService().getLogger("lol");
        this.mScript.setLogic(Logics.ALL);
    }

    @Test
    public void checkSingleTerm() {
        Sort intSort = SmtSortUtils.getIntSort(this.mScript);
        for (int i = 0; i < ABCDE.length(); i++) {
            declareVar(String.valueOf(ABCDE.charAt(i)), intSort);
        }
        check(TermParseUtils.parseTerm(this.mScript, "(= A 0)"), Set.of(Set.of("A")), 1, Map.of("=", 1), Map.of("Int", 1), 0, Collections.emptyMap());
    }

    @Test
    public void checkAdditionTerm() {
        Sort intSort = SmtSortUtils.getIntSort(this.mScript);
        for (int i = 0; i < ABCDE.length(); i++) {
            declareVar(String.valueOf(ABCDE.charAt(i)), intSort);
        }
        check(TermParseUtils.parseTerm(this.mScript, "(= A (+ B 1))"), Set.of(Set.of("A", "B")), 2, Map.of("=", 1, "+", 1), Map.of("Int", 2), 0, Collections.emptyMap());
    }

    @Test
    public void checkSubtractionTerm() {
        Sort intSort = SmtSortUtils.getIntSort(this.mScript);
        for (int i = 0; i < ABCDE.length(); i++) {
            declareVar(String.valueOf(ABCDE.charAt(i)), intSort);
        }
        check(TermParseUtils.parseTerm(this.mScript, "(= A (- B 1))"), Set.of(Set.of("A", "B")), 2, Map.of("=", 1, "-", 1), Map.of("Int", 2), 0, Collections.emptyMap());
    }

    @Test
    public void checkMultiplicationTerm() {
        Sort intSort = SmtSortUtils.getIntSort(this.mScript);
        for (int i = 0; i < ABCDE.length(); i++) {
            declareVar(String.valueOf(ABCDE.charAt(i)), intSort);
        }
        check(TermParseUtils.parseTerm(this.mScript, "(= A (* B 1))"), Set.of(Set.of("A", "B")), 2, Map.of("=", 1, "*", 1), Map.of("Int", 2), 0, Collections.emptyMap());
    }

    @Test
    public void checkAndTerm() {
        Sort intSort = SmtSortUtils.getIntSort(this.mScript);
        for (int i = 0; i < ABCDE.length(); i++) {
            declareVar(String.valueOf(ABCDE.charAt(i)), intSort);
        }
        check(TermParseUtils.parseTerm(this.mScript, "(and (= A (+ B 1)) (= C 0))"), Set.of(Set.of("A", "B", "C")), 3, Map.of("and", 1, "+", 1, "=", 2), Map.of("Int", 3), 0, Collections.emptyMap());
    }

    @Test
    public void checkOrTerm() {
        Sort intSort = SmtSortUtils.getIntSort(this.mScript);
        for (int i = 0; i < ABCDE.length(); i++) {
            declareVar(String.valueOf(ABCDE.charAt(i)), intSort);
        }
        check(TermParseUtils.parseTerm(this.mScript, "(or (= A (+ B 1)) (= C 0))"), Set.of(Set.of("A", "B"), Set.of("C")), 3, Map.of("=", 2, "+", 1, "or", 1), Map.of("Int", 3), 0, Collections.emptyMap());
    }

    @Test
    public void checkOrAndTerm() {
        Sort intSort = SmtSortUtils.getIntSort(this.mScript);
        for (int i = 0; i < ABCDE.length(); i++) {
            declareVar(String.valueOf(ABCDE.charAt(i)), intSort);
        }
        check(TermParseUtils.parseTerm(this.mScript, "(or (and (= A (+ B 1)) (= D 1)) (= C 0))"), Set.of(Set.of("A", "B", "D"), Set.of("C")), 4, Map.of("=", 3, "+", 1, "or", 1, "and", 1), Map.of("Int", 4), 0, Collections.emptyMap());
    }

    @Test
    public void checkOrOrTerm() {
        Sort intSort = SmtSortUtils.getIntSort(this.mScript);
        for (int i = 0; i < ABCDE.length(); i++) {
            declareVar(String.valueOf(ABCDE.charAt(i)), intSort);
        }
        check(TermParseUtils.parseTerm(this.mScript, "(or (or (= A (+ B 1)) (= D 1)) (= C 0))"), Set.of(Set.of("A", "B"), Set.of("D"), Set.of("C")), 4, Map.of("=", 3, "+", 1, "or", 2), Map.of("Int", 4), 0, Collections.emptyMap());
    }

    @Test
    public void checkOrOrAndTerm() {
        Sort intSort = SmtSortUtils.getIntSort(this.mScript);
        for (int i = 0; i < ABCDE.length(); i++) {
            declareVar(String.valueOf(ABCDE.charAt(i)), intSort);
        }
        check(TermParseUtils.parseTerm(this.mScript, "(or (or (and (= A 0) (= D 1)) (= E 0)) (= C 0))"), Set.of(Set.of("A", "D"), Set.of("E"), Set.of("C")), 4, Map.of("=", 4, "and", 1, "or", 2), Map.of("Int", 4), 0, Collections.emptyMap());
    }

    @Test
    public void checkCountSorts() {
        Sort intSort = SmtSortUtils.getIntSort(this.mScript);
        Sort realSort = SmtSortUtils.getRealSort(this.mScript);
        Sort boolSort = SmtSortUtils.getBoolSort(this.mScript);
        for (int i = 0; i < "AB".length(); i++) {
            declareVar(String.valueOf("AB".charAt(i)), intSort);
        }
        for (int i2 = 0; i2 < "CD".length(); i2++) {
            declareVar(String.valueOf("CD".charAt(i2)), realSort);
        }
        for (int i3 = 0; i3 < "EF".length(); i3++) {
            declareVar(String.valueOf("EF".charAt(i3)), boolSort);
        }
        check(TermParseUtils.parseTerm(this.mScript, "(and (= A 1) (= B A) (= C 1.0) (= D 1.5) E F)"), Set.of(Set.of("A", "B", "C", "D", "E", "F")), 6, Map.of("=", 4, "and", 1), Map.of("Int", 2, "Bool", 2, "Real", 2), 0, Collections.emptyMap());
    }

    @Test
    public void checkCountFunctions() {
        Sort intSort = SmtSortUtils.getIntSort(this.mScript);
        Sort realSort = SmtSortUtils.getRealSort(this.mScript);
        Sort boolSort = SmtSortUtils.getBoolSort(this.mScript);
        for (int i = 0; i < "AB".length(); i++) {
            declareVar(String.valueOf("AB".charAt(i)), intSort);
        }
        for (int i2 = 0; i2 < "CD".length(); i2++) {
            declareVar(String.valueOf("CD".charAt(i2)), realSort);
        }
        for (int i3 = 0; i3 < "EF".length(); i3++) {
            declareVar(String.valueOf("EF".charAt(i3)), boolSort);
        }
        check(TermParseUtils.parseTerm(this.mScript, "(and (= A (+ A 1)) (= B (- A 1)) (= C (* 1.0 D)) (= D (* 1.5 4.0)) E F)"), Set.of(Set.of("A", "B", "C", "D", "E", "F")), 6, Map.of("=", 4, "and", 1, "-", 1, "*", 2, "+", 1), Map.of("Int", 2, "Bool", 2, "Real", 2), 0, Collections.emptyMap());
    }

    @Test
    public void checkNot() {
        Sort intSort = SmtSortUtils.getIntSort(this.mScript);
        Sort realSort = SmtSortUtils.getRealSort(this.mScript);
        Sort boolSort = SmtSortUtils.getBoolSort(this.mScript);
        for (int i = 0; i < "AB".length(); i++) {
            declareVar(String.valueOf("AB".charAt(i)), intSort);
        }
        for (int i2 = 0; i2 < "CD".length(); i2++) {
            declareVar(String.valueOf("CD".charAt(i2)), realSort);
        }
        for (int i3 = 0; i3 < "EF".length(); i3++) {
            declareVar(String.valueOf("EF".charAt(i3)), boolSort);
        }
        check(TermParseUtils.parseTerm(this.mScript, "(and (not E) F)"), Set.of(Set.of("E", "F")), 2, Map.of("and", 1, "not", 1), Map.of("Bool", 2), 0, Collections.emptyMap());
    }

    @Test
    public void checkCountQuantifiers() {
        Sort intSort = SmtSortUtils.getIntSort(this.mScript);
        Sort realSort = SmtSortUtils.getRealSort(this.mScript);
        Sort boolSort = SmtSortUtils.getBoolSort(this.mScript);
        for (int i = 0; i < "AB".length(); i++) {
            declareVar(String.valueOf("AB".charAt(i)), intSort);
        }
        for (int i2 = 0; i2 < "CD".length(); i2++) {
            declareVar(String.valueOf("CD".charAt(i2)), realSort);
        }
        for (int i3 = 0; i3 < "EF".length(); i3++) {
            declareVar(String.valueOf("EF".charAt(i3)), boolSort);
        }
        check(TermParseUtils.parseTerm(this.mScript, "(exists ((A Int))(and (= A (+ A 1)) (= B (- A 1)) (= C (* 1.0 D)) (= D (* 1.5 4.0)) E F))"), Set.of(Set.of("A", "B", "C", "D", "E", "F")), 6, Map.of("=", 4, "and", 1, "-", 1, "*", 2, "+", 1), Map.of("Int", 4, "Bool", 2, "Real", 2), 1, Map.of(0, 1));
    }

    private void check(Term term, Set<Set<String>> set, int i, Map<String, Integer> map, Map<String, Integer> map2, int i2, Map<Integer, Integer> map3) {
        Script.LBool checkSatTerm = SmtUtils.checkSatTerm(this.mScript, term);
        SMTFeatureExtractionTermClassifier sMTFeatureExtractionTermClassifier = new SMTFeatureExtractionTermClassifier();
        sMTFeatureExtractionTermClassifier.checkTerm(term);
        this.mLogger.info("Original:               " + term.toStringDirect());
        this.mLogger.info("Original isSat:         " + checkSatTerm);
        this.mLogger.info("Original equiv classes: " + sMTFeatureExtractionTermClassifier.getEquivalenceClasses());
        this.mLogger.info("Original #Vars:         " + sMTFeatureExtractionTermClassifier.getNumberOfVariables());
        this.mLogger.info("Original Functions:     " + sMTFeatureExtractionTermClassifier.getOccuringFunctionNames());
        this.mLogger.info("Original Sorts:         " + sMTFeatureExtractionTermClassifier.getOccuringSortNames());
        this.mLogger.info("Original Quantifiers:   " + sMTFeatureExtractionTermClassifier.getNumberOfQuantifiers());
        MatcherAssert.assertThat("equiv classes", (Set) sMTFeatureExtractionTermClassifier.getEquivalenceClasses().getAllEquivalenceClasses().stream().map(set2 -> {
            return (Set) set2.stream().map((v0) -> {
                return v0.toString();
            }).collect(Collectors.toSet());
        }).collect(Collectors.toSet()), IsEqual.equalTo(set));
        MatcherAssert.assertThat("#Vars", Integer.valueOf(sMTFeatureExtractionTermClassifier.getNumberOfVariables()), IsEqual.equalTo(Integer.valueOf(i)));
        MatcherAssert.assertThat("Functions", sMTFeatureExtractionTermClassifier.getOccuringFunctionNames(), IsEqual.equalTo(map));
        MatcherAssert.assertThat("Sorts", sMTFeatureExtractionTermClassifier.getOccuringSortNames(), IsEqual.equalTo(map2));
        MatcherAssert.assertThat("#Quantifier", Integer.valueOf(sMTFeatureExtractionTermClassifier.getNumberOfQuantifiers()), IsEqual.equalTo(Integer.valueOf(i2)));
        MatcherAssert.assertThat("#Quantifier", sMTFeatureExtractionTermClassifier.getOccuringQuantifiers(), IsEqual.equalTo(map3));
    }

    private Term declareVar(String str, Sort sort) {
        this.mScript.declareFun(str, new Sort[0], sort);
        return this.mScript.term(str, new Term[0]);
    }
}
