## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#"""User-defined table function related classes and functions"""importpicklefromdataclassesimportdataclass,fieldimportinspectimportsysimportwarningsfromtypingimportAny,Type,TYPE_CHECKING,Optional,Sequence,Unionfrompyspark.errorsimportPySparkAttributeError,PySparkPicklingError,PySparkTypeErrorfrompyspark.utilimportPythonEvalTypefrompyspark.sql.pandas.utilsimportrequire_minimum_pandas_version,require_minimum_pyarrow_versionfrompyspark.sql.typesimportDataType,StructType,_parse_datatype_stringfrompyspark.sql.udfimport_wrap_functionifTYPE_CHECKING:frompy4j.java_gatewayimportJavaObjectfrompyspark.sql._typingimportColumnOrNamefrompyspark.sql.dataframeimportDataFramefrompyspark.sql.sessionimportSparkSession__all__=["AnalyzeArgument","AnalyzeResult","PartitioningColumn","OrderingColumn","SelectedColumn","SkipRestOfInputTableException","UDTFRegistration",]@dataclass(frozen=True)classAnalyzeArgument:""" The argument for Python UDTF's analyze static method. Parameters ---------- dataType : :class:`DataType` The argument's data type value : any, optional The calculated value if the argument is foldable; otherwise None isTable : bool If True, the argument is a table argument. isConstantExpression : bool If True, the argument is a constant-foldable scalar expression. Then the 'value' field contains None if the argument is a NULL literal, or a non-None value if the argument is a non-NULL literal. In this way, we can distinguish between a literal NULL argument and other types of arguments such as complex expression trees or table arguments where the 'value' field is always None. """dataType:DataTypevalue:Optional[Any]isTable:boolisConstantExpression:bool@dataclass(frozen=True)classPartitioningColumn:""" Represents an expression that the UDTF is specifying for Catalyst to partition the input table by. This can be either the name of a single column from the input table (such as "columnA"), or a SQL expression based on the column names of the input table (such as "columnA + columnB"). Parameters ---------- name : str The contents of the partitioning column name or expression represented as a SQL string. """name:str@dataclass(frozen=True)classOrderingColumn:""" Represents an expression that the UDTF is specifying for Catalyst to order the input partition by. This can be either the name of a single column from the input table (such as "columnA"), or a SQL expression based on the column names of the input table (such as "columnA + columnB"). Parameters ---------- name : str The contents of the ordering column name or expression represented as a SQL string. ascending : bool, default True This is if this expression specifies an ascending sorting order. overrideNullsFirst : str, optional If this is None, use the default behavior to sort NULL values first when sorting in ascending order, or last when sorting in descending order. Otherwise, if this is True or False, we override the default behavior accordingly. """name:strascending:bool=TrueoverrideNullsFirst:Optional[bool]=None@dataclass(frozen=True)classSelectedColumn:""" Represents an expression that the UDTF is specifying for Catalyst to evaluate against the columns in the input TABLE argument. The UDTF then receives one input column for each expression in the list, in the order they are listed. Parameters ---------- name : str The contents of the selected column name or expression represented as a SQL string. alias : str, default '' If non-empty, this is the alias for the column or expression as visible from the UDTF's 'eval' method. This is required if the expression is not a simple column reference. """name:stralias:str=""# Note: this class is a "dataclass" for purposes of convenience, but it is not marked "frozen"# because the intention is that users may create subclasses of it for purposes of returning custom# information from the "analyze" method.@dataclassclassAnalyzeResult:""" The return of Python UDTF's analyze static method. Parameters ---------- schema: :class:`StructType` The schema that the Python UDTF will return. withSinglePartition: bool If true, the UDTF is specifying for Catalyst to repartition all rows of the input TABLE argument to one collection for consumption by exactly one instance of the correpsonding UDTF class. partitionBy: sequence of :class:`PartitioningColumn` If non-empty, this is a sequence of expressions that the UDTF is specifying for Catalyst to partition the input TABLE argument by. In this case, calls to the UDTF may not include any explicit PARTITION BY clause, in which case Catalyst will return an error. This option is mutually exclusive with 'withSinglePartition'. orderBy: sequence of :class:`OrderingColumn` If non-empty, this is a sequence of expressions that the UDTF is specifying for Catalyst to sort the input TABLE argument by. Note that the 'partitionBy' list must also be non-empty in this case. select: sequence of :class:`SelectedColumn` If non-empty, this is a sequence of expressions that the UDTF is specifying for Catalyst to evaluate against the columns in the input TABLE argument. The UDTF then receives one input attribute for each name in the list, in the order they are listed. """schema:StructTypewithSinglePartition:bool=FalsepartitionBy:Sequence[PartitioningColumn]=field(default_factory=tuple)orderBy:Sequence[OrderingColumn]=field(default_factory=tuple)select:Sequence[SelectedColumn]=field(default_factory=tuple)classSkipRestOfInputTableException(Exception):""" This represents an exception that the 'eval' method may raise to indicate that it is done consuming rows from the current partition of the input table. Then the UDTF's 'terminate' method runs (if any). """passdef_create_udtf(cls:Type,returnType:Optional[Union[StructType,str]],name:Optional[str]=None,evalType:int=PythonEvalType.SQL_TABLE_UDF,deterministic:bool=False,)->"UserDefinedTableFunction":"""Create a Python UDTF with the given eval type."""udtf_obj=UserDefinedTableFunction(cls,returnType=returnType,name=name,evalType=evalType,deterministic=deterministic)returnudtf_objdef_create_py_udtf(cls:Type,returnType:Optional[Union[StructType,str]],name:Optional[str]=None,deterministic:bool=False,useArrow:Optional[bool]=None,)->"UserDefinedTableFunction":"""Create a regular or an Arrow-optimized Python UDTF."""# Determine whether to create Arrow-optimized UDTFs.ifuseArrowisnotNone:arrow_enabled=useArrowelse:frompyspark.sqlimportSparkSessionsession=SparkSession._instantiatedSessionarrow_enabled=FalseifsessionisnotNone:value=session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled")ifisinstance(value,str)andvalue.lower()=="true":arrow_enabled=Trueeval_type:int=PythonEvalType.SQL_TABLE_UDFifarrow_enabled:# Return the regular UDTF if the required dependencies are not satisfied.try:require_minimum_pandas_version()require_minimum_pyarrow_version()eval_type=PythonEvalType.SQL_ARROW_TABLE_UDFexceptImportErrorase:warnings.warn(f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. "f"Falling back to using regular Python UDTFs.",UserWarning,)return_create_udtf(cls=cls,returnType=returnType,name=name,evalType=eval_type,deterministic=deterministic,)def_validate_udtf_handler(cls:Any,returnType:Optional[Union[StructType,str]])->None:"""Validate the handler class of a UDTF."""ifnotisinstance(cls,type):raisePySparkTypeError(error_class="INVALID_UDTF_HANDLER_TYPE",message_parameters={"type":type(cls).__name__})ifnothasattr(cls,"eval"):raisePySparkAttributeError(error_class="INVALID_UDTF_NO_EVAL",message_parameters={"name":cls.__name__})has_analyze=hasattr(cls,"analyze")has_analyze_staticmethod=has_analyzeandisinstance(inspect.getattr_static(cls,"analyze"),staticmethod)ifreturnTypeisNoneandnothas_analyze_staticmethod:raisePySparkAttributeError(error_class="INVALID_UDTF_RETURN_TYPE",message_parameters={"name":cls.__name__})ifreturnTypeisnotNoneandhas_analyze:raisePySparkAttributeError(error_class="INVALID_UDTF_BOTH_RETURN_TYPE_AND_ANALYZE",message_parameters={"name":cls.__name__},)
[docs]classUserDefinedTableFunction:""" User-defined table function in Python .. versionadded:: 3.5.0 Notes ----- The constructor of this class is not supposed to be directly called. Use :meth:`pyspark.sql.functions.udtf` to create this instance. This API is evolving. """def__init__(self,func:Type,returnType:Optional[Union[StructType,str]],name:Optional[str]=None,evalType:int=PythonEvalType.SQL_TABLE_UDF,deterministic:bool=False,):_validate_udtf_handler(func,returnType)self.func=funcself._returnType=returnTypeself._returnType_placeholder:Optional[StructType]=Noneself._inputTypes_placeholder=Noneself._judtf_placeholder=Noneself._name=nameorfunc.__name__self.evalType=evalTypeself.deterministic=deterministic@propertydefreturnType(self)->Optional[StructType]:ifself._returnTypeisNone:returnNone# `_parse_datatype_string` accesses to JVM for parsing a DDL formatted string.# This makes sure this is called after SparkContext is initialized.ifself._returnType_placeholderisNone:ifisinstance(self._returnType,str):parsed=_parse_datatype_string(self._returnType)else:parsed=self._returnTypeifnotisinstance(parsed,StructType):raisePySparkTypeError(error_class="UDTF_RETURN_TYPE_MISMATCH",message_parameters={"name":self._name,"return_type":f"{parsed}",},)self._returnType_placeholder=parsedreturnself._returnType_placeholder@propertydef_judtf(self)->"JavaObject":ifself._judtf_placeholderisNone:self._judtf_placeholder=self._create_judtf(self.func)returnself._judtf_placeholderdef_create_judtf(self,func:Type)->"JavaObject":frompyspark.sqlimportSparkSessionspark=SparkSession._getActiveSessionOrCreate()sc=spark.sparkContexttry:wrapped_func=_wrap_function(sc,func)exceptpickle.PicklingErrorase:if"CONTEXT_ONLY_VALID_ON_DRIVER"instr(e):raisePySparkPicklingError(error_class="UDTF_SERIALIZATION_ERROR",message_parameters={"name":self._name,"message":"it appears that you are attempting to reference SparkSession ""inside a UDTF. SparkSession can only be used on the driver, ""not in code that runs on workers. Please remove the reference ""and try again.",},)fromNoneraisePySparkPicklingError(error_class="UDTF_SERIALIZATION_ERROR",message_parameters={"name":self._name,"message":"Please check the stack trace and make sure the ""function is serializable.",},)assertsc._jvmisnotNoneifself.returnTypeisNone:judtf=sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction(self._name,wrapped_func,self.evalType,self.deterministic)else:jdt=spark._jsparkSession.parseDataType(self.returnType.json())judtf=sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction(self._name,wrapped_func,jdt,self.evalType,self.deterministic)returnjudtfdef__call__(self,*args:"ColumnOrName",**kwargs:"ColumnOrName")->"DataFrame":frompyspark.sql.classic.columnimport_to_java_column,_to_java_expr,_to_seqfrompyspark.sqlimportDataFrame,SparkSessionspark=SparkSession._getActiveSessionOrCreate()sc=spark.sparkContextassertsc._jvmisnotNonejcols=[_to_java_column(arg)forarginargs]+[sc._jvm.Column(sc._jvm.org.apache.spark.sql.catalyst.expressions.NamedArgumentExpression(key,_to_java_expr(value)))forkey,valueinkwargs.items()]judtf=self._judtfjPythonUDTF=judtf.apply(spark._jsparkSession,_to_seq(sc,jcols))returnDataFrame(jPythonUDTF,spark)
[docs]defasDeterministic(self)->"UserDefinedTableFunction":""" Updates UserDefinedTableFunction to deterministic. """# Explicitly clean the cache to create a JVM UDTF instance.self._judtf_placeholder=Noneself.deterministic=Truereturnself
[docs]classUDTFRegistration:""" Wrapper for user-defined table function registration. This instance can be accessed by :attr:`spark.udtf` or :attr:`sqlContext.udtf`. .. versionadded:: 3.5.0 """def__init__(self,sparkSession:"SparkSession"):self.sparkSession=sparkSession
[docs]defregister(self,name:str,f:"UserDefinedTableFunction",)->"UserDefinedTableFunction":"""Register a Python user-defined table function as a SQL table function. .. versionadded:: 3.5.0 Parameters ---------- name : str The name of the user-defined table function in SQL statements. f : function or :meth:`pyspark.sql.functions.udtf` The user-defined table function. Returns ------- function The registered user-defined table function. Notes ----- Spark uses the return type of the given user-defined table function as the return type of the registered user-defined function. To register a nondeterministic Python table function, users need to first build a nondeterministic user-defined table function and then register it as a SQL function. Examples -------- >>> from pyspark.sql.functions import udtf >>> @udtf(returnType="c1: int, c2: int") ... class PlusOne: ... def eval(self, x: int): ... yield x, x + 1 ... >>> _ = spark.udtf.register(name="plus_one", f=PlusOne) >>> spark.sql("SELECT * FROM plus_one(1)").collect() [Row(c1=1, c2=2)] Use it with lateral join >>> spark.sql("SELECT * FROM VALUES (0, 1), (1, 2) t(x, y), LATERAL plus_one(x)").collect() [Row(x=0, y=1, c1=0, c2=1), Row(x=1, y=2, c1=1, c2=2)] """iff.evalTypenotin[PythonEvalType.SQL_TABLE_UDF,PythonEvalType.SQL_ARROW_TABLE_UDF]:raisePySparkTypeError(error_class="INVALID_UDTF_EVAL_TYPE",message_parameters={"name":name,"eval_type":"SQL_TABLE_UDF, SQL_ARROW_TABLE_UDF",},)register_udtf=_create_udtf(cls=f.func,returnType=f.returnType,name=name,evalType=f.evalType,deterministic=f.deterministic,)self.sparkSession._jsparkSession.udtf().registerPython(name,register_udtf._judtf)returnregister_udtf