diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java index 8be3d4e938893..78a84028db0fd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java @@ -13,7 +13,6 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.license.License; import org.elasticsearch.license.LicensedFeature; -import org.elasticsearch.license.XPackLicenseState; import java.math.BigInteger; import java.nio.charset.StandardCharsets; @@ -41,12 +40,6 @@ public final class MachineLearningField { License.OperationMode.PLATINUM ); - public static final LicensedFeature.Momentary ML_MODEL_INFERENCE_PLATINUM_FEATURE = LicensedFeature.momentary( - MachineLearningField.ML_FEATURE_FAMILY, - "model-inference-platinum-check", - License.OperationMode.PLATINUM - ); - private MachineLearningField() {} public static String valuesToId(String... values) { @@ -59,10 +52,4 @@ public static String valuesToId(String... values) { return new BigInteger(hashedBytes) + "_" + combined.length(); } - public static boolean featureCheckForMode(License.OperationMode mode, XPackLicenseState licenseState) { - if (mode.equals(License.OperationMode.PLATINUM)) { - return ML_MODEL_INFERENCE_PLATINUM_FEATURE.check(licenseState); - } - return true; - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 76c1babbf53f9..a9dc7e7bb465d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -351,7 +351,7 @@ public long getEstimatedOperations() { } // TODO if we ever support anything other than "basic" and platinum, we need to adjust our feature tracking logic - // Additionally, see `MachineLearningField. featureCheckForMode` for handling modes + // and we need to adjust our license checks to validate more than "is basic" or not public License.OperationMode getLicenseLevel() { return licenseLevel; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index 1a3782225a037..8a0ef52d74288 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.license.License; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; @@ -42,7 +43,6 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -import static org.elasticsearch.xpack.core.ml.MachineLearningField.featureCheckForMode; public class TransportInternalInferModelAction extends HandledTransportAction { @@ -83,7 +83,9 @@ protected void doExecute(Task task, Request request, ActionListener li request.getModelId(), GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(trainedModelConfig -> { - final boolean allowed = featureCheckForMode(trainedModelConfig.getLicenseLevel(), licenseState); + // Since we just checked MachineLearningField.ML_API_FEATURE.check(licenseState) and that check failed + // That means we don't have a plat+ license. The only licenses for trained models are basic (free) and plat. + boolean allowed = trainedModelConfig.getLicenseLevel() == License.OperationMode.BASIC; responseBuilder.setLicensed(allowed); if (allowed || request.isPreviouslyLicensed()) { doInfer(request, responseBuilder, listener); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java index 37e7a5cefbcd0..e1237ef7684f5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java @@ -24,6 +24,7 @@ import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.logging.HeaderWarning; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.license.License; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.Task; @@ -47,7 +48,6 @@ import java.util.Set; import java.util.function.Predicate; -import static org.elasticsearch.xpack.core.ml.MachineLearningField.featureCheckForMode; import static org.elasticsearch.xpack.core.ml.job.messages.Messages.TRAINED_MODEL_INPUTS_DIFFER_SIGNIFICANTLY; public class TransportPutTrainedModelAliasAction extends AcknowledgedTransportMasterNodeAction { @@ -93,7 +93,8 @@ protected void masterOperation( ) throws Exception { final boolean mlSupported = MachineLearningField.ML_API_FEATURE.check(licenseState); final Predicate isLicensed = (model) -> mlSupported - || featureCheckForMode(model.getLicenseLevel(), licenseState); + // Either we support plat+ or the model is basic licensed + || model.getLicenseLevel() == License.OperationMode.BASIC; final String oldModelId = ModelAliasMetadata.fromState(state).getModelId(request.getModelAlias()); if (oldModelId != null && (request.isReassign() == false)) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregationBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregationBuilder.java index e2e13810cd28c..5b40b1308b021 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregationBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregationBuilder.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.license.License; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.plugins.SearchPlugin; @@ -51,7 +52,6 @@ import java.util.function.Supplier; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; -import static org.elasticsearch.xpack.core.ml.MachineLearningField.featureCheckForMode; import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable; public class InferencePipelineAggregationBuilder extends AbstractPipelineAggregationBuilder { @@ -267,8 +267,8 @@ public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) .getModelForSearch(modelId, listener.delegateFailure((delegate, model) -> { loadedModel.set(model); - boolean isLicensed = MachineLearningField.ML_API_FEATURE.check(licenseState) - || featureCheckForMode(model.getLicenseLevel(), licenseState); + boolean isLicensed = model.getLicenseLevel() == License.OperationMode.BASIC + || MachineLearningField.ML_API_FEATURE.check(licenseState); if (isLicensed) { delegate.onResponse(null); } else {