在Python里访问Java对象 – Py4J

Py4J是一套java+python软件包,利用它开发者可以在python里动态的访问jvm里的java对象,其基本原理是在java端创建一个Socket服务,在python端使用Socket客户端调用,Py4J在python语法上的进行了精心的包装,使得在python里的这些远程java对象使用起来与本地的python对象十分接近。

环境信息

Java Python Py4J
11.0.10 3.8.5 0.10.9.7

Java服务端

创建一个新的java工程,并在pom.xml里引入py4j依赖:

<dependency>
    <groupId>net.sf.py4j</groupId>
    <artifactId>py4j</artifactId>
    <version>0.10.9.7</version>
</dependency>

以下借用py4j官方提供的例子。首先定义一个Java类提供一些功能,这里是用一个堆栈类作为例子,提供了push, pop等功能:

package py4j.examples;

import java.util.LinkedList;
import java.util.List;

public class Stack {
    private List<String> internalList = new LinkedList<String>();

    public void push(String element) {
        internalList.add(0, element);
    }

    public String pop() {
        return internalList.remove(0);
    }

    public List<String> getInternalList() {
        return internalList;
    }

    public void pushAll(List<String> elements) {
        for (String element : elements) {
            this.push(element);
        }
    }
}

然后要提供一个对应的EntryPoint类作为服务提供者,这个类实现main函数,执行后会启动一个Socket服务,用于接收python客户端的请求。

package py4j.examples;

import py4j.GatewayServer;

public class StackEntryPoint {

    private Stack stack;

    public StackEntryPoint() {
        stack = new Stack();
        stack.push("Initial Item");
    }

    public Stack getStack() {
        return stack;
    }

    public static void main(String[] args) {
        // 默认端口号25335,可以在构造方法里配置其他端口号
        GatewayServer gatewayServer = new GatewayServer(new StackEntryPoint());
        gatewayServer.start();
        System.out.println("Gateway Server Started");
    }

}

直接执行StackEntryPoint就启动了Java服务端。

Python客户端

首先安装py4j的最新版本:

> pip install py4j

在python命令行里执行下面的代码就可以获得一个Stack对象,其中.entry_point得到的是GatewayServer里的StackEntryPoint实例。

>>> from py4j.java_gateway import JavaGateway
>>> gateway = JavaGateway()
>>> stack = gateway.entry_point.getStack()

测试一下这个Stack对象的功能:

>>> stack
JavaObject id=o0

>>> stack.getInternalList()
['Initial Item']

>>> stack.push('item2')
>>> stack.getInternalList()
['item2', 'Initial Item']

Stack里可以直接push字符串,也可以push一个java的list:

>>> list = gateway.jvm.java.util.ArrayList()
>>> list.add('item3')
>>> list.add('item4')
>>> stack.pushAll(list)
>>> stack.getInternalList()
['item4', 'item3', 'item2', 'Initial Item']

但不可以直接push一个python的list,将会抛出异常。需要借助py4j提供的ListConverter先转换为java list:

>>> stack.pushAll(['item3','item4'])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\zhanghao\anaconda3\lib\site-packages\py4j\java_gateway.py", line 1314, in __call__
    args_command, temp_args = self._build_args(*args)
  File "C:\Users\zhanghao\anaconda3\lib\site-packages\py4j\java_gateway.py", line 1283, in _build_args
    [get_command_part(arg, self.pool) for arg in new_args])
  File "C:\Users\zhanghao\anaconda3\lib\site-packages\py4j\java_gateway.py", line 1283, in <listcomp>
    [get_command_part(arg, self.pool) for arg in new_args])
  File "C:\Users\zhanghao\anaconda3\lib\site-packages\py4j\protocol.py", line 298, in get_command_part
    command_part = REFERENCE_TYPE + parameter._get_object_id()
AttributeError: 'list' object has no attribute '_get_object_id'

如果要在python里使用任意的Java对象,可以用gateway.jvm引用完整的java类名来创建对象实例:

>>> random = gateway.jvm.java.util.Random()
>>> random.nextInt()
433815240

所以在本文的例子里,gateway.entry_point.getStack()gateway.jvm.py4j.examples.Stack()基本是等价的,除了前者是单例的,初始化的内容有所不同。

参考资料

https://www.py4j.org/getting_started.html