forked from Chat_things/NeuroDock
374 lines
13 KiB
Java
374 lines
13 KiB
Java
package me.zacharias.chat.display;
|
|
|
|
import com.github.dockerjava.api.DockerClient;
|
|
import com.github.dockerjava.api.async.ResultCallback;
|
|
import com.github.dockerjava.api.command.LogContainerCmd;
|
|
import com.github.dockerjava.api.model.Frame;
|
|
import com.github.dockerjava.core.DefaultDockerClientConfig;
|
|
import com.github.dockerjava.core.DockerClientBuilder;
|
|
import me.zacharias.chat.core.Core;
|
|
import me.zacharias.chat.core.Pair;
|
|
import me.zacharias.chat.ollama.*;
|
|
import me.zacharias.chat.ollama.exceptions.OllamaToolErrorException;
|
|
import org.json.JSONArray;
|
|
import org.json.JSONObject;
|
|
|
|
import java.io.*;
|
|
import java.net.ServerSocket;
|
|
import java.net.Socket;
|
|
import java.nio.charset.StandardCharsets;
|
|
import java.security.MessageDigest;
|
|
import java.util.ArrayList;
|
|
import java.util.Arrays;
|
|
import java.util.List;
|
|
import java.util.logging.Logger;
|
|
|
|
import static me.zacharias.chat.core.Core.writeLog;
|
|
|
|
/**
|
|
* A tool that runs python code.
|
|
* This is a wrapper around a docker container.
|
|
* This is partly meant as a proof of concept, but also as a way to run python code while keeping the executed code in a secure environment.
|
|
*/
|
|
public class PythonRunner extends OllamaFunctionTool {
|
|
/**
|
|
* The DockerClient instance.
|
|
*/
|
|
private DockerClient dockerClient;
|
|
/**
|
|
* The Core instance.
|
|
*/
|
|
private Core core;
|
|
|
|
/**
|
|
* The ServerSocket instance.
|
|
*/
|
|
private ServerSocket serverSocket;
|
|
|
|
/**
|
|
* Creates a new instance of PythonRunner.
|
|
* @param core The Core instance
|
|
*/
|
|
public PythonRunner(Core core) {
|
|
this.core = core;
|
|
|
|
try {
|
|
serverSocket = new ServerSocket(6050);
|
|
Thread thread = new Thread(() -> {
|
|
while (true) {
|
|
try {
|
|
Socket socket = serverSocket.accept();
|
|
|
|
BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()));
|
|
|
|
String inputLine = in.readLine();
|
|
|
|
BufferedWriter out = new BufferedWriter(new OutputStreamWriter(socket.getOutputStream()));
|
|
|
|
try {
|
|
JSONObject data = new JSONObject(inputLine);
|
|
List<Pair<OllamaFunctionTool, String>> list = core.getFuntionTools().stream().filter(funtionTool -> funtionTool.getKey().name().equalsIgnoreCase(data.optString("function", ""))).toList();
|
|
|
|
if (list.isEmpty()) {
|
|
out.write(new JSONObject().put("error", "Function dose't exist").toString());
|
|
out.newLine();
|
|
out.flush();
|
|
out.close();
|
|
in.close();
|
|
socket.close();
|
|
continue;
|
|
}
|
|
|
|
ArrayList<OllamaFunctionArgument> args = new ArrayList<>();
|
|
|
|
for (Object o : data.optJSONArray("arguments", new JSONArray())) {
|
|
if (o instanceof JSONObject obj) {
|
|
OllamaFunctionArgument arg = new OllamaFunctionArgument(obj.getString("name"), obj.getString("value"));
|
|
args.add(arg);
|
|
}
|
|
}
|
|
|
|
out.write(list.getFirst().getKey().function(args.toArray(new OllamaFunctionArgument[0])).getResponse());
|
|
out.newLine();
|
|
out.flush();
|
|
out.close();
|
|
in.close();
|
|
socket.close();
|
|
} catch (Exception e) {
|
|
}
|
|
} catch (Exception e) {
|
|
|
|
}
|
|
}
|
|
});
|
|
|
|
thread.start();
|
|
}catch (Exception e) {
|
|
e.printStackTrace();
|
|
}
|
|
|
|
DefaultDockerClientConfig.Builder config
|
|
= DefaultDockerClientConfig.createDefaultConfigBuilder()
|
|
.withDockerHost("tcp://localhost:2375")
|
|
.withDockerTlsVerify(false);
|
|
dockerClient = DockerClientBuilder
|
|
.getInstance(config)
|
|
.build();
|
|
}
|
|
|
|
@Override
|
|
public String name() {
|
|
return "python_runner";
|
|
}
|
|
|
|
@Override
|
|
public String description() {
|
|
return "Runs python code";
|
|
}
|
|
|
|
@Override
|
|
public OllamaPerameter parameters() {
|
|
return OllamaPerameter.builder()
|
|
.addProperty("code", OllamaPerameter.OllamaPerameterBuilder.Type.STRING, "The code to be executed", true)
|
|
.addProperty("name", OllamaPerameter.OllamaPerameterBuilder.Type.STRING, "The name of the python code")
|
|
.build();
|
|
}
|
|
|
|
@Override
|
|
public OllamaToolRespnce function(OllamaFunctionArgument... args) {
|
|
if(args.length == 0)
|
|
{
|
|
throw new OllamaToolErrorException(name(), "Missing code argument");
|
|
}
|
|
|
|
String name = null;
|
|
String code = null;
|
|
|
|
for(OllamaFunctionArgument arg : args)
|
|
{
|
|
if(arg.argument().equals("name"))
|
|
{
|
|
name = (String) arg.value();
|
|
if(!name.endsWith(".py"))
|
|
{
|
|
name += ".py";
|
|
}
|
|
} else if (arg.argument().equals("code")) {
|
|
code = (String) arg.value();
|
|
}
|
|
}
|
|
|
|
if(name == null)
|
|
{
|
|
try {
|
|
MessageDigest digest = MessageDigest.getInstance("SHA-256");
|
|
byte[] encodedhash = digest.digest(String.valueOf(args[0].value()).getBytes(StandardCharsets.UTF_8));
|
|
StringBuffer hexString = new StringBuffer();
|
|
for(byte b : encodedhash)
|
|
{
|
|
hexString.append(String.format("%02x", b));
|
|
}
|
|
name = hexString.toString()+".py";
|
|
}catch (Exception e) {}
|
|
}
|
|
|
|
name = name.replace(" ", "_");
|
|
|
|
writeLog("Running python code `" + name + "`");
|
|
|
|
File pythonFile = new File("./pythonFiles", name);
|
|
|
|
code = "from external_tools import *\n\n"+code;
|
|
|
|
if(!pythonFile.exists())
|
|
{
|
|
try {
|
|
BufferedWriter writer = new BufferedWriter(new FileWriter(pythonFile));
|
|
writer.write(code);
|
|
writer.close();
|
|
}catch(IOException e) {}
|
|
}
|
|
|
|
File f = new File("./pythonFiles", "external_tools.py");
|
|
|
|
try {
|
|
String external_tools = generateExternalTools();
|
|
BufferedWriter bw = new BufferedWriter(new FileWriter(f));
|
|
bw.write(external_tools);
|
|
bw.flush();
|
|
bw.close();
|
|
}catch(IOException e) {}
|
|
|
|
try {
|
|
String containerId = dockerClient.createContainerCmd("python").withCmd("python", name).exec().getId();
|
|
dockerClient.copyArchiveToContainerCmd(containerId)
|
|
.withHostResource(pythonFile.getPath())
|
|
.exec();
|
|
|
|
dockerClient.copyArchiveToContainerCmd(containerId)
|
|
.withHostResource(f.getPath())
|
|
.exec();
|
|
|
|
dockerClient.startContainerCmd(containerId).exec();
|
|
|
|
GetContainerLog log = new GetContainerLog(dockerClient, containerId);
|
|
|
|
List<String> logs = new ArrayList<>();
|
|
|
|
do {
|
|
try {
|
|
Thread.sleep(2000);
|
|
} catch (InterruptedException e) {
|
|
throw new RuntimeException(e);
|
|
}
|
|
logs.addAll(log.getDockerLogs());
|
|
}
|
|
while (logs.isEmpty());
|
|
|
|
StringBuilder output = new StringBuilder();
|
|
|
|
for (String s : logs) {
|
|
output.append(s).append("\n");
|
|
}
|
|
|
|
//writeLog("Result from python: " + output.toString());
|
|
|
|
return new OllamaToolRespnce(name(), output.toString());
|
|
}
|
|
catch (Exception e) {
|
|
throw new OllamaToolErrorException(name(), "Docker unavalible");
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Generates the external_tools.py file.<br>
|
|
* This is meant to provide the python code with all ExternalTools defined in the OllamaObject.
|
|
* @return The generated external_tools.py file
|
|
*/
|
|
private String generateExternalTools() {
|
|
StringBuilder code = new StringBuilder();
|
|
|
|
code.append("""
|
|
import socket
|
|
import json
|
|
|
|
HOST = "host.docker.internal"
|
|
PORT = 6050
|
|
|
|
def connect(data):
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.connect((HOST, PORT))
|
|
data = data + "\\n"
|
|
s.sendall(data.encode("utf-8"))
|
|
responce = s.recv(4096)
|
|
return responce.decode("utf-8")
|
|
|
|
|
|
""");
|
|
|
|
for(Pair<OllamaFunctionTool, String> funtionTool : core.getFuntionTools())
|
|
{
|
|
OllamaFunctionTool tool = funtionTool.getKey();
|
|
|
|
String name = tool.name();
|
|
|
|
code.append("def ").append(name).append("(");
|
|
|
|
ArrayList<String> args = new ArrayList<>();
|
|
|
|
boolean first = true;
|
|
try {
|
|
for (String argName : tool.parameters().getProperties().keySet()) {
|
|
args.add(argName);
|
|
if (!first) {
|
|
code.append(", ");
|
|
}
|
|
code.append(argName);
|
|
if (Arrays.stream(tool.parameters().getRequired()).noneMatch(required -> required.equals(argName))) {
|
|
code.append("=None");
|
|
}
|
|
first = false;
|
|
}
|
|
}catch (Exception e) {}
|
|
code.append("):\n");
|
|
code.append(" data = {\"function\":\"").append(tool.name()).append("\"");
|
|
if(args.size() > 0)
|
|
{
|
|
code.append(",\"arguments\":[");
|
|
}
|
|
first = true;
|
|
for(String str : args)
|
|
{
|
|
if(!first)
|
|
{
|
|
code.append(", ");
|
|
}
|
|
code.append("{\"name\": \"").append(str).append("\", \"value\": ").append(str).append("}");
|
|
first = false;
|
|
}
|
|
if(args.size() > 0)
|
|
{
|
|
code.append("]");
|
|
}
|
|
code.append("}\n");
|
|
code.append(" return connect(json.dumps(data))\n\n");
|
|
}
|
|
|
|
return code.toString();
|
|
}
|
|
|
|
/**
|
|
* A Helper class to get the logs from a docker container.
|
|
*/
|
|
public class GetContainerLog {
|
|
private DockerClient dockerClient;
|
|
private String containerId;
|
|
private int lastLogTime;
|
|
|
|
private static String nameOfLogger = "dockertest.PrintContainerLog";
|
|
private static Logger myLogger = Logger.getLogger(nameOfLogger);
|
|
|
|
/**
|
|
* Creates a new instance of {@link GetContainerLog}
|
|
* @param dockerClient The DockerClient instance
|
|
* @param containerId The container id
|
|
*/
|
|
public GetContainerLog(DockerClient dockerClient, String containerId) {
|
|
this.dockerClient = dockerClient;
|
|
this.containerId = containerId;
|
|
this.lastLogTime = (int) (System.currentTimeMillis() / 1000);
|
|
}
|
|
|
|
/**
|
|
* Gets the logs of the container.
|
|
* @return The logs of the container
|
|
*/
|
|
public List<String> getDockerLogs() {
|
|
|
|
final List<String> logs = new ArrayList<>();
|
|
|
|
LogContainerCmd logContainerCmd = dockerClient.logContainerCmd(containerId);
|
|
logContainerCmd.withStdOut(true).withStdErr(true);
|
|
logContainerCmd.withSince(lastLogTime); // UNIX timestamp (integer) to filter logs. Specifying a timestamp will only output log-entries since that timestamp.
|
|
// logContainerCmd.withTail(4); // get only the last 4 log entries
|
|
|
|
logContainerCmd.withTimestamps(true);
|
|
|
|
try {
|
|
logContainerCmd.exec(new ResultCallback.Adapter<Frame>() {
|
|
@Override
|
|
public void onNext(Frame item) {
|
|
logs.add(new String(item.getPayload()).trim());
|
|
}
|
|
}).awaitCompletion();
|
|
} catch (InterruptedException e) {
|
|
myLogger.severe("Interrupted Exception!" + e.getMessage());
|
|
}
|
|
|
|
lastLogTime = (int) (System.currentTimeMillis() / 1000) + 5; // assumes at least a 5 second wait between calls to getDockerLogs
|
|
|
|
return logs;
|
|
}
|
|
}
|
|
}
|