## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#fromabcimportABCMeta,abstractmethodfromtypingimportAny,Generic,Optional,List,Type,TypeVar,TYPE_CHECKINGfrompysparkimportsincefrompyspark.sqlimportDataFramefrompyspark.mlimportEstimator,Predictor,PredictionModel,Transformer,Modelfrompyspark.ml.baseimport_PredictorParamsfrompyspark.ml.paramimportParam,Paramsfrompyspark.ml.utilimport_jvmfrompyspark.ml.commonimportinherit_doc,_java2py,_py2javaifTYPE_CHECKING:frompyspark.ml._typingimportParamMapfrompy4j.java_gatewayimportJavaObject,JavaClassT=TypeVar("T")JW=TypeVar("JW",bound="JavaWrapper")JM=TypeVar("JM",bound="JavaTransformer")JP=TypeVar("JP",bound="JavaParams")classJavaWrapper:""" Wrapper class for a Java companion object """def__init__(self,java_obj:Optional["JavaObject"]=None):super(JavaWrapper,self).__init__()self._java_obj=java_objdef__del__(self)->None:frompyspark.core.contextimportSparkContextifSparkContext._active_spark_contextandself._java_objisnotNone:SparkContext._active_spark_context._gateway.detach(# type: ignore[union-attr]self._java_obj)@classmethoddef_create_from_java_class(cls:Type[JW],java_class:str,*args:Any)->JW:""" Construct this object from given Java classname and arguments """java_obj=JavaWrapper._new_java_obj(java_class,*args)returncls(java_obj)def_call_java(self,name:str,*args:Any)->Any:frompyspark.core.contextimportSparkContextm=getattr(self._java_obj,name)sc=SparkContext._active_spark_contextassertscisnotNonejava_args=[_py2java(sc,arg)forarginargs]return_java2py(sc,m(*java_args))@staticmethoddef_new_java_obj(java_class:str,*args:Any)->"JavaObject":""" Returns a new Java object. """frompyspark.core.contextimportSparkContextsc=SparkContext._active_spark_contextassertscisnotNonejava_obj=_jvm()fornameinjava_class.split("."):java_obj=getattr(java_obj,name)java_args=[_py2java(sc,arg)forarginargs]returnjava_obj(*java_args)@staticmethoddef_new_java_array(pylist:List[Any],java_class:"JavaClass")->"JavaObject":""" Create a Java array of given java_class type. Useful for calling a method with a Scala Array from Python with Py4J. If the param pylist is a 2D array, then a 2D java array will be returned. The returned 2D java array is a square, non-jagged 2D array that is big enough for all elements. The empty slots in the inner Java arrays will be filled with null to make the non-jagged 2D array. Parameters ---------- pylist : list Python list to convert to a Java Array. java_class : :py:class:`py4j.java_gateway.JavaClass` Java class to specify the type of Array. Should be in the form of sc._gateway.jvm.* (sc is a valid Spark Context). Example primitive Java classes: - basestring -> sc._gateway.jvm.java.lang.String - int -> sc._gateway.jvm.java.lang.Integer - float -> sc._gateway.jvm.java.lang.Double - bool -> sc._gateway.jvm.java.lang.Boolean Returns ------- :py:class:`py4j.java_collections.JavaArray` Java Array of converted pylist. """frompyspark.core.contextimportSparkContextsc=SparkContext._active_spark_contextassertscisnotNoneassertsc._gatewayisnotNonejava_array=Noneiflen(pylist)>0andisinstance(pylist[0],list):# If pylist is a 2D array, then a 2D java array will be created.# The 2D array is a square, non-jagged 2D array that is big enough for all elements.inner_array_length=0foriinrange(len(pylist)):inner_array_length=max(inner_array_length,len(pylist[i]))java_array=sc._gateway.new_array(java_class,len(pylist),inner_array_length)foriinrange(len(pylist)):forjinrange(len(pylist[i])):java_array[i][j]=pylist[i][j]else:java_array=sc._gateway.new_array(java_class,len(pylist))foriinrange(len(pylist)):java_array[i]=pylist[i]returnjava_array@inherit_docclassJavaParams(JavaWrapper,Params,metaclass=ABCMeta):""" Utility class to help create wrapper classes from Java/Scala implementations of pipeline components. """#: The param values in the Java object should be#: synced with the Python wrapper in fit/transform/evaluate/copy.def_make_java_param_pair(self,param:Param[T],value:T)->"JavaObject":""" Makes a Java param pair. """frompyspark.core.contextimportSparkContextsc=SparkContext._active_spark_contextassertscisnotNoneandself._java_objisnotNoneparam=self._resolveParam(param)java_param=self._java_obj.getParam(param.name)java_value=_py2java(sc,value)returnjava_param.w(java_value)def_transfer_params_to_java(self)->None:""" Transforms the embedded params to the companion Java object. """frompyspark.core.contextimportSparkContextassertself._java_objisnotNonepair_defaults=[]forparaminself.params:ifself.isSet(param):pair=self._make_java_param_pair(param,self._paramMap[param])self._java_obj.set(pair)ifself.hasDefault(param):pair=self._make_java_param_pair(param,self._defaultParamMap[param])pair_defaults.append(pair)iflen(pair_defaults)>0:sc=SparkContext._active_spark_contextassertscisnotNoneandsc._jvmisnotNonepair_defaults_seq=sc._jvm.PythonUtils.toSeq(pair_defaults)self._java_obj.setDefault(pair_defaults_seq)def_transfer_param_map_to_java(self,pyParamMap:"ParamMap")->"JavaObject":""" Transforms a Python ParamMap into a Java ParamMap. """paramMap=JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")forparaminself.params:ifparaminpyParamMap:pair=self._make_java_param_pair(param,pyParamMap[param])paramMap.put([pair])returnparamMapdef_create_params_from_java(self)->None:""" SPARK-10931: Temporary fix to create params that are defined in the Java obj but not here """assertself._java_objisnotNonejava_params=list(self._java_obj.params())frompyspark.ml.paramimportParamforjava_paraminjava_params:java_param_name=java_param.name()ifnothasattr(self,java_param_name):param:Param[Any]=Param(self,java_param_name,java_param.doc())setattr(param,"created_from_java_param",True)setattr(self,java_param_name,param)self._params=None# need to reset so self.params will discover new paramsdef_transfer_params_from_java(self)->None:""" Transforms the embedded params from the companion Java object. """frompyspark.core.contextimportSparkContextsc=SparkContext._active_spark_contextassertscisnotNoneandself._java_objisnotNoneforparaminself.params:ifself._java_obj.hasParam(param.name):java_param=self._java_obj.getParam(param.name)# SPARK-14931: Only check set params back to avoid default params mismatch.ifself._java_obj.isSet(java_param):java_value=self._java_obj.getOrDefault(java_param)ifparam.typeConverter.__name__.startswith("toList"):value=[_java2py(sc,x)forxinlist(java_value)]else:value=_java2py(sc,java_value)self._set(**{param.name:value})# SPARK-10931: Temporary fix for params that have a default in Javaifself._java_obj.hasDefault(java_param)andnotself.isDefined(param):value=_java2py(sc,self._java_obj.getDefault(java_param)).get()self._setDefault(**{param.name:value})def_transfer_param_map_from_java(self,javaParamMap:"JavaObject")->"ParamMap":""" Transforms a Java ParamMap into a Python ParamMap. """frompyspark.core.contextimportSparkContextsc=SparkContext._active_spark_contextassertscisnotNoneparamMap=dict()forpairinjavaParamMap.toList():param=pair.param()ifself.hasParam(str(param.name())):paramMap[self.getParam(param.name())]=_java2py(sc,pair.value())returnparamMap@staticmethoddef_empty_java_param_map()->"JavaObject":""" Returns an empty Java ParamMap reference. """return_jvm().org.apache.spark.ml.param.ParamMap()def_to_java(self)->"JavaObject":""" Transfer this instance's Params to the wrapped Java object, and return the Java object. Used for ML persistence. Meta-algorithms such as Pipeline should override this method. Returns ------- py4j.java_gateway.JavaObject Java object equivalent to this instance. """self._transfer_params_to_java()returnself._java_obj@staticmethoddef_from_java(java_stage:"JavaObject")->"JP":# type: ignore""" Given a Java object, create and return a Python wrapper of it. Used for ML persistence. Meta-algorithms such as Pipeline should override this method as a classmethod. """def__get_class(clazz:str)->Type[JP]:""" Loads Python class from its name. """parts=clazz.split(".")module=".".join(parts[:-1])m=__import__(module,fromlist=[parts[-1]])returngetattr(m,parts[-1])stage_name=java_stage.getClass().getName().replace("org.apache.spark","pyspark")# Generate a default new instance from the stage_name class.py_type=__get_class(stage_name)ifissubclass(py_type,JavaParams):# Load information from java_stage to the instance.py_stage=py_type()py_stage._java_obj=java_stage# SPARK-10931: Temporary fix so that persisted models would own params from Estimatorifissubclass(py_type,JavaModel):py_stage._create_params_from_java()py_stage._resetUid(java_stage.uid())py_stage._transfer_params_from_java()elifhasattr(py_type,"_from_java"):py_stage=py_type._from_java(java_stage)else:raiseNotImplementedError("This Java stage cannot be loaded into Python currently: %r"%stage_name)returnpy_stagedefcopy(self:"JP",extra:Optional["ParamMap"]=None)->"JP":""" Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and then make a copy of the companion Java pipeline component with extra params. So both the Python wrapper and the Java pipeline component get copied. Parameters ---------- extra : dict, optional Extra parameters to copy to the new instance Returns ------- :py:class:`JavaParams` Copy of this instance """ifextraisNone:extra=dict()that=super(JavaParams,self).copy(extra)ifself._java_objisnotNone:that._java_obj=self._java_obj.copy(self._empty_java_param_map())that._transfer_params_to_java()returnthatdefclear(self,param:Param)->None:""" Clears a param from the param map if it has been explicitly set. """assertself._java_objisnotNonesuper(JavaParams,self).clear(param)java_param=self._java_obj.getParam(param.name)self._java_obj.clear(java_param)@inherit_docclassJavaEstimator(JavaParams,Estimator[JM],metaclass=ABCMeta):""" Base class for :py:class:`Estimator`s that wrap Java/Scala implementations. """@abstractmethoddef_create_model(self,java_model:"JavaObject")->JM:""" Creates a model from the input Java model reference. """raiseNotImplementedError()def_fit_java(self,dataset:DataFrame)->"JavaObject":""" Fits a Java model to the input dataset. Examples -------- dataset : :py:class:`pyspark.sql.DataFrame` input dataset Returns ------- py4j.java_gateway.JavaObject fitted Java model """assertself._java_objisnotNoneself._transfer_params_to_java()returnself._java_obj.fit(dataset._jdf)def_fit(self,dataset:DataFrame)->JM:java_model=self._fit_java(dataset)model=self._create_model(java_model)returnself._copyValues(model)@inherit_docclassJavaTransformer(JavaParams,Transformer,metaclass=ABCMeta):""" Base class for :py:class:`Transformer`s that wrap Java/Scala implementations. Subclasses should ensure they have the transformer Java object available as _java_obj. """def_transform(self,dataset:DataFrame)->DataFrame:assertself._java_objisnotNoneself._transfer_params_to_java()returnDataFrame(self._java_obj.transform(dataset._jdf),dataset.sparkSession)@inherit_docclassJavaModel(JavaTransformer,Model,metaclass=ABCMeta):""" Base class for :py:class:`Model`s that wrap Java/Scala implementations. Subclasses should inherit this class before param mix-ins, because this sets the UID from the Java model. """def__init__(self,java_model:Optional["JavaObject"]=None):""" Initialize this instance with a Java model object. Subclasses should call this constructor, initialize params, and then call _transfer_params_from_java. This instance can be instantiated without specifying java_model, it will be assigned after that, but this scenario only used by :py:class:`JavaMLReader` to load models. This is a bit of a hack, but it is easiest since a proper fix would require MLReader (in pyspark.ml.util) to depend on these wrappers, but these wrappers depend on pyspark.ml.util (both directly and via other ML classes). """super(JavaModel,self).__init__(java_model)ifjava_modelisnotNone:# SPARK-10931: This is a temporary fix to allow models to own params# from estimators. Eventually, these params should be in models through# using common base classes between estimators and models.self._create_params_from_java()self._resetUid(java_model.uid())def__repr__(self)->str:returnself._call_java("toString")@inherit_docclassJavaPredictor(Predictor,JavaEstimator[JM],_PredictorParams,Generic[JM],metaclass=ABCMeta):""" (Private) Java Estimator for prediction tasks (regression and classification). """pass@inherit_docclassJavaPredictionModel(PredictionModel[T],JavaModel,_PredictorParams):""" (Private) Java Model for prediction tasks (regression and classification). """@property@since("2.1.0")defnumFeatures(self)->int:""" Returns the number of features the model was trained on. If unknown, returns -1 """returnself._call_java("numFeatures")@since("3.0.0")defpredict(self,value:T)->float:""" Predict label for the given features. """returnself._call_java("predict",value)