'vowpal wabbit java: get raw predictions
I am using Java API of vowpal wabbit to get predictions. I need raw prediction (same as -r output.txt
) but I couldn't find any such method in VWMulticlassLearner
class. I am using below arg
to train my model in python via cmd -
vw -f model_filepath -c --cache_file cache_filepath -k --csoaa 40 -b 24 -q cd -q .... -q n: --ignore a --ignore x
and we are using below code in Java to get predictions -
VWLearners.create("-i ./data/train.model -t --quiet"); // VWMulticlassLearner
VWLearners.create("-i ./data/train.model -t --quiet --csoaa_ldf=mc --loss_function=logistic --probabilities"); //VWProbLearner
None of the classes has any method which returns raw prediction.
I want the same prediction as below -
$ echo ' .. sample string .. ' | vw -i data/train.model -t -r test -p /dev/stdout
creating quadratic features for pairs: cd ce cu cw de du dw eu ew uw n:
ignoring namespaces beginning with: a x
only testing
predictions = /dev/stdout
raw predictions = test
Num weight bits = 24
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile =
num sources = 1
average since example example current current current
loss last counter weight label predict features
39
0.000000 0.000000 1 1.0 known 39 171
finished run
number of examples per pass = 1
passes used = 1
weighted example sum = 1.000000
weighted label sum = 0.000000
average loss = 0.000000
total feature number = 171
$ cat test
0:1.05645 1:0.83437 2:-0.210798 3:-2.81048 4:-4.47558 5:-4.45883 6:-3.65177 7:-3.71191 8:-2.96008 9:-2.82846 10:-2.31816 11:0.925984 12:3.28547 13:5.20375 14:6.34244 15:6.13525 16:1.65726 17:1.22801 18:1.35034 19:3.27091 20:2.94066 21:-0.0276409 22:0.391437 23:1.267 24:-0.689573 25:0.0171876 26:3.12935 27:3.95045 28:3.86978 29:1.18468 30:0.0921049 31:0.436564 32:0.98946 33:1.00963 34:-0.265355 35:-3.02128 36:-2.52846 37:-2.8066 38:-3.50639 39:-4.6184
How can I get values that are in file test
in Java as a method response? I don't want to read the file to get a response in Java which will be slow.
Solution 1:[1]
I ended up using one of the abandoned PR. Here is my working git patch file -
diff --git a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc
index 6b51c4d30..f3ccb6621 100644
--- a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc
+++ b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc
@@ -11,3 +11,17 @@ JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predict(JNI
JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predictMultiline(JNIEnv *env, jobject obj, jobjectArray example_strings, jboolean learn, jlong vwPtr)
{ return base_predict<jint>(env, example_strings, learn, vwPtr, multiclass_predictor);
}
+
+jfloatArray multiclass_raw_predictor(example* vec, JNIEnv *env){
+ size_t num_values = vec->l.cs.costs.size();
+ jfloatArray j_labels = env->NewFloatArray(num_values);
+ for (int i=0 ; i<num_values; i++) {
+ jfloat f[] = { vec->l.cs.costs[i].partial_prediction };
+ env->SetFloatArrayRegion(j_labels, i, 1, (float*)f);
+ }
+ return j_labels;
+ }
+
+JNIEXPORT jfloatArray JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_rawPredict(JNIEnv *env, jobject obj, jstring example_string, jboolean learn, jlong vwPtr){
+return base_predict<jfloatArray>(env, example_string, learn, vwPtr, multiclass_raw_predictor);
+}
diff --git a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h
index 05204d53e..5610704fa 100644
--- a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h
+++ b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h
@@ -24,6 +24,15 @@ JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predict
JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predictMultiline
(JNIEnv *, jobject, jobjectArray, jboolean, jlong);
+/*
+ * Class: vowpalWabbit_learner_VWMulticlassLearner
+ * Method: rawPredict
+ * Signature: ([Ljava/lang/String;ZJ)I
+ */
+JNIEXPORT jfloatArray JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_rawPredict
+ (JNIEnv *, jobject, jstring, jboolean, jlong);
+
+
#ifdef __cplusplus
}
#endif
diff --git a/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java b/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java
index b506cfb25..bb3156351 100644
--- a/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java
+++ b/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java
@@ -13,4 +13,25 @@ final public class VWMulticlassLearner extends VWIntLearner {
@Override
protected native int predictMultiline(String[] example, boolean learn, long nativePointer);
+
+ protected native float[] rawPredict(String example, boolean learn, long nativePointer);
+
+ /**
+ * Get raw prediction output.
+ *
+ * @param example a single vw example string
+ * @return Raw prediction
+ */
+
+ public float[] rawPredict(final String example) {
+ lock.lock();
+ try {
+ if (isOpen()) {
+ return rawPredict(example, false, nativePointer);
+ }
+ throw new IllegalStateException("Already closed.");
+ } finally {
+ lock.unlock();
+ }
+ }
}
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Nishant Kumar |