## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#fromtypingimportList,Sequence,TypeVar,TYPE_CHECKINGfrompysparkimportsincefrompyspark.ml.linalgimportVectorfrompyspark.ml.paramimportParamsfrompyspark.ml.param.sharedimport(HasCheckpointInterval,HasSeed,HasWeightCol,Param,TypeConverters,HasMaxIter,HasStepSize,HasValidationIndicatorCol,)frompyspark.ml.wrapperimportJavaPredictionModelfrompyspark.ml.commonimportinherit_docifTYPE_CHECKING:frompyspark.ml._typingimportPT=TypeVar("T")@inherit_docclass_DecisionTreeModel(JavaPredictionModel[T]):""" Abstraction for Decision Tree models. .. versionadded:: 1.5.0 """@property@since("1.5.0")defnumNodes(self)->int:"""Return number of nodes of the decision tree."""returnself._call_java("numNodes")@property@since("1.5.0")defdepth(self)->int:"""Return depth of the decision tree."""returnself._call_java("depth")@property@since("2.0.0")deftoDebugString(self)->str:"""Full description of model."""returnself._call_java("toDebugString")@since("3.0.0")defpredictLeaf(self,value:Vector)->float:""" Predict the indices of the leaves corresponding to the feature vector. """returnself._call_java("predictLeaf",value)class_DecisionTreeParams(HasCheckpointInterval,HasSeed,HasWeightCol):""" Mixin for Decision Tree parameters. """leafCol:Param[str]=Param(Params._dummy(),"leafCol","Leaf indices column name. Predicted leaf "+"index of each instance in each tree by preorder.",typeConverter=TypeConverters.toString,)maxDepth:Param[int]=Param(Params._dummy(),"maxDepth","Maximum depth of the tree. (>= 0) E.g., "+"depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. "+"Must be in range [0, 30].",typeConverter=TypeConverters.toInt,)maxBins:Param[int]=Param(Params._dummy(),"maxBins","Max number of bins for discretizing continuous "+"features. Must be >=2 and >= number of categories for any categorical "+"feature.",typeConverter=TypeConverters.toInt,)minInstancesPerNode:Param[int]=Param(Params._dummy(),"minInstancesPerNode","Minimum number of "+"instances each child must have after split. If a split causes "+"the left or right child to have fewer than "+"minInstancesPerNode, the split will be discarded as invalid. "+"Should be >= 1.",typeConverter=TypeConverters.toInt,)minWeightFractionPerNode:Param[float]=Param(Params._dummy(),"minWeightFractionPerNode","Minimum ""fraction of the weighted sample count that each child ""must have after split. If a split causes the fraction ""of the total weight in the left or right child to be ""less than minWeightFractionPerNode, the split will be ""discarded as invalid. Should be in interval [0.0, 0.5).",typeConverter=TypeConverters.toFloat,)minInfoGain:Param[float]=Param(Params._dummy(),"minInfoGain","Minimum information gain for a split "+"to be considered at a tree node.",typeConverter=TypeConverters.toFloat,)maxMemoryInMB:Param[int]=Param(Params._dummy(),"maxMemoryInMB","Maximum memory in MB allocated to "+"histogram aggregation. If too small, then 1 node will be split per "+"iteration, and its aggregates may exceed this size.",typeConverter=TypeConverters.toInt,)cacheNodeIds:Param[bool]=Param(Params._dummy(),"cacheNodeIds","If false, the algorithm will pass "+"trees to executors to match instances with nodes. If true, the "+"algorithm will cache node IDs for each instance. Caching can speed "+"up training of deeper trees. Users can set how often should the cache "+"be checkpointed or disable it by setting checkpointInterval.",typeConverter=TypeConverters.toBoolean,)def__init__(self)->None:super(_DecisionTreeParams,self).__init__()defsetLeafCol(self:"P",value:str)->"P":""" Sets the value of :py:attr:`leafCol`. """returnself._set(leafCol=value)defgetLeafCol(self)->str:""" Gets the value of leafCol or its default value. """returnself.getOrDefault(self.leafCol)defgetMaxDepth(self)->int:""" Gets the value of maxDepth or its default value. """returnself.getOrDefault(self.maxDepth)defgetMaxBins(self)->int:""" Gets the value of maxBins or its default value. """returnself.getOrDefault(self.maxBins)defgetMinInstancesPerNode(self)->int:""" Gets the value of minInstancesPerNode or its default value. """returnself.getOrDefault(self.minInstancesPerNode)defgetMinWeightFractionPerNode(self)->float:""" Gets the value of minWeightFractionPerNode or its default value. """returnself.getOrDefault(self.minWeightFractionPerNode)defgetMinInfoGain(self)->float:""" Gets the value of minInfoGain or its default value. """returnself.getOrDefault(self.minInfoGain)defgetMaxMemoryInMB(self)->int:""" Gets the value of maxMemoryInMB or its default value. """returnself.getOrDefault(self.maxMemoryInMB)defgetCacheNodeIds(self)->bool:""" Gets the value of cacheNodeIds or its default value. """returnself.getOrDefault(self.cacheNodeIds)@inherit_docclass_TreeEnsembleModel(JavaPredictionModel[T]):""" (private abstraction) Represents a tree ensemble model. """@property@since("2.0.0")deftrees(self)->Sequence["_DecisionTreeModel"]:"""Trees in this ensemble. Warning: These have null parent Estimators."""return[_DecisionTreeModel(m)forminlist(self._call_java("trees"))]@property@since("2.0.0")defgetNumTrees(self)->int:"""Number of trees in ensemble."""returnself._call_java("getNumTrees")@property@since("1.5.0")deftreeWeights(self)->List[float]:"""Return the weights for each tree"""returnlist(self._call_java("javaTreeWeights"))@property@since("2.0.0")deftotalNumNodes(self)->int:"""Total number of nodes, summed over all trees in the ensemble."""returnself._call_java("totalNumNodes")@property@since("2.0.0")deftoDebugString(self)->str:"""Full description of model."""returnself._call_java("toDebugString")@since("3.0.0")defpredictLeaf(self,value:Vector)->float:""" Predict the indices of the leaves corresponding to the feature vector. """returnself._call_java("predictLeaf",value)class_TreeEnsembleParams(_DecisionTreeParams):""" Mixin for Decision Tree-based ensemble algorithms parameters. """subsamplingRate:Param[float]=Param(Params._dummy(),"subsamplingRate","Fraction of the training data "+"used for learning each decision tree, in range (0, 1].",typeConverter=TypeConverters.toFloat,)supportedFeatureSubsetStrategies:List[str]=["auto","all","onethird","sqrt","log2"]featureSubsetStrategy:Param[str]=Param(Params._dummy(),"featureSubsetStrategy","The number of features to consider for splits at each tree node. Supported "+"options: 'auto' (choose automatically for task: If numTrees == 1, set to "+"'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to "+"'onethird' for regression), 'all' (use all features), 'onethird' (use "+"1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use "+"log2(number of features)), 'n' (when n is in the range (0, 1.0], use "+"n * number of features. When n is in the range (1, number of features), use"+" n features). default = 'auto'",typeConverter=TypeConverters.toString,)def__init__(self)->None:super(_TreeEnsembleParams,self).__init__()@since("1.4.0")defgetSubsamplingRate(self)->float:""" Gets the value of subsamplingRate or its default value. """returnself.getOrDefault(self.subsamplingRate)@since("1.4.0")defgetFeatureSubsetStrategy(self)->str:""" Gets the value of featureSubsetStrategy or its default value. """returnself.getOrDefault(self.featureSubsetStrategy)class_RandomForestParams(_TreeEnsembleParams):""" Private class to track supported random forest parameters. """numTrees:Param[int]=Param(Params._dummy(),"numTrees","Number of trees to train (>= 1).",typeConverter=TypeConverters.toInt,)bootstrap:Param[bool]=Param(Params._dummy(),"bootstrap","Whether bootstrap samples are used ""when building trees.",typeConverter=TypeConverters.toBoolean,)def__init__(self)->None:super(_RandomForestParams,self).__init__()@since("1.4.0")defgetNumTrees(self)->int:""" Gets the value of numTrees or its default value. """returnself.getOrDefault(self.numTrees)@since("3.0.0")defgetBootstrap(self)->bool:""" Gets the value of bootstrap or its default value. """returnself.getOrDefault(self.bootstrap)class_GBTParams(_TreeEnsembleParams,HasMaxIter,HasStepSize,HasValidationIndicatorCol):""" Private class to track supported GBT params. """stepSize:Param[float]=Param(Params._dummy(),"stepSize","Step size (a.k.a. learning rate) in interval (0, 1] for shrinking "+"the contribution of each estimator.",typeConverter=TypeConverters.toFloat,)validationTol:Param[float]=Param(Params._dummy(),"validationTol","Threshold for stopping early when fit with validation is used. "+"If the error rate on the validation input changes by less than the "+"validationTol, then learning will stop early (before `maxIter`). "+"This parameter is ignored when fit without validation is used.",typeConverter=TypeConverters.toFloat,)@since("3.0.0")defgetValidationTol(self)->float:""" Gets the value of validationTol or its default value. """returnself.getOrDefault(self.validationTol)class_HasVarianceImpurity(Params):""" Private class to track supported impurity measures. """supportedImpurities:List[str]=["variance"]impurity:Param[str]=Param(Params._dummy(),"impurity","Criterion used for information gain calculation (case-insensitive). "+"Supported options: "+", ".join(supportedImpurities),typeConverter=TypeConverters.toString,)def__init__(self)->None:super(_HasVarianceImpurity,self).__init__()@since("1.4.0")defgetImpurity(self)->str:""" Gets the value of impurity or its default value. """returnself.getOrDefault(self.impurity)class_TreeClassifierParams(Params):""" Private class to track supported impurity measures. .. versionadded:: 1.4.0 """supportedImpurities:List[str]=["entropy","gini"]impurity:Param[str]=Param(Params._dummy(),"impurity","Criterion used for information gain calculation (case-insensitive). "+"Supported options: "+", ".join(supportedImpurities),typeConverter=TypeConverters.toString,)def__init__(self)->None:super(_TreeClassifierParams,self).__init__()@since("1.6.0")defgetImpurity(self)->str:""" Gets the value of impurity or its default value. """returnself.getOrDefault(self.impurity)class_TreeRegressorParams(_HasVarianceImpurity):""" Private class to track supported impurity measures. """pass